"""Supporting functions for the text/tabular-reporting commands.
Namely: breaks, genemetrics.
"""
from __future__ import annotations
import collections
import math
from typing import TYPE_CHECKING, Optional
import numpy as np
import pandas as pd
from . import params
from .segmetrics import segment_mean
if TYPE_CHECKING:
from collections.abc import Iterator
from cnvlib.cnary import CopyNumArray
from numpy import float64
# _____________________________________________________________________________
# breaks
[docs]
def do_breaks(
probes: CopyNumArray, segments: CopyNumArray, min_probes: int = 1
) -> pd.DataFrame:
"""List the targeted genes in which a copy number breakpoint occurs.
Parameters
----------
probes : CopyNumArray
Bin-level copy number data.
segments : CopyNumArray
Segmented copy number data.
min_probes : int, optional
Minimum number of probes required on each side of the breakpoint.
Default is 1.
Returns
-------
pd.DataFrame
Table with columns: gene, chromosome, location, change,
probes_left, probes_right.
"""
intervals = get_gene_intervals(probes)
bpoints = get_breakpoints(intervals, segments, min_probes)
return pd.DataFrame.from_records(
bpoints,
columns=[
"gene",
"chromosome",
"location",
"change",
"probes_left",
"probes_right",
],
)
[docs]
def get_gene_intervals(
all_probes: CopyNumArray, ignore: tuple[str, str, str] = params.IGNORE_GENE_NAMES
) -> collections.defaultdict[str, list[tuple[str, list[int], int]]]:
"""Tally genomic locations of each targeted gene.
Return a dict of chromosomes to a list of tuples: (gene name, starts, end),
where gene name is a string, starts is a sorted list of probe start
positions, and end is the last probe's end position as an integer. (The
endpoints are redundant since probes are adjacent.)
"""
ignore += params.ANTITARGET_ALIASES
# Tally the start & end points for each targeted gene; group by chromosome
gene_probes = collections.defaultdict(lambda: collections.defaultdict(list))
for row in all_probes:
gname = str(row.gene)
if gname not in ignore:
gene_probes[row.chromosome][gname].append(row)
# Condense into a single interval for each gene
intervals = collections.defaultdict(list)
for chrom, gp in gene_probes.items():
for gene, probes in gp.items():
starts = sorted(row.start for row in probes)
end = max(row.end for row in probes)
intervals[chrom].append((gene, starts, end))
intervals[chrom].sort(key=lambda gse: gse[1])
return intervals
[docs]
def get_breakpoints(
intervals: collections.defaultdict[str, list[tuple[str, list[int], int]]],
segments: CopyNumArray,
min_probes: int,
) -> list[tuple[str, str, int, float64, int, int]]:
"""Identify segment breaks within the targeted intervals."""
# TODO use segments.by_ranges(intervals)
breakpoints = []
for i, curr_row in enumerate(segments[:-1]):
curr_chrom = curr_row.chromosome
curr_end = curr_row.end
next_row = segments[i + 1]
# Skip if this segment is the last (or only) one on this chromosome
if next_row.chromosome != curr_chrom:
continue
for gname, gstarts, gend in intervals[curr_chrom]:
if gstarts[0] < curr_end < gend:
probes_left = sum(s < curr_end for s in gstarts)
probes_right = sum(s >= curr_end for s in gstarts)
if probes_left >= min_probes and probes_right >= min_probes:
breakpoints.append(
(
gname,
curr_chrom,
math.ceil(curr_end),
next_row.log2 - curr_row.log2,
probes_left,
probes_right,
)
)
breakpoints.sort(key=lambda row: (min(row[4], row[5]), abs(row[3])), reverse=True)
return breakpoints
# _____________________________________________________________________________
# genemetrics
[docs]
def do_genemetrics(
cnarr: CopyNumArray,
segments: Optional[CopyNumArray] = None,
threshold: float = 0.2,
min_probes: int = 3,
skip_low: bool = False,
is_haploid_x_reference: bool = False,
is_sample_female: None = None,
diploid_parx_genome: Optional[str] = None,
) -> pd.DataFrame:
"""Identify targeted genes with copy number gain or loss.
Parameters
----------
cnarr : CopyNumArray
Bin-level copy number data.
segments : CopyNumArray, optional
Segmented copy number data. If provided, metrics are calculated
per segment.
threshold : float, optional
Minimum absolute log2 ratio to consider a gene altered. Default is 0.2.
min_probes : int, optional
Minimum number of probes required to report a gene. Default is 3.
skip_low : bool, optional
Skip bins with low coverage. Default is False.
is_haploid_x_reference : bool, optional
Whether reference is male (haploid X). Default is False.
is_sample_female : bool, optional
Whether sample is female. If None, inferred from data.
diploid_parx_genome : str, optional
Reference genome name for pseudo-autosomal region handling
(e.g., 'hg19', 'hg38', 'mm10').
Returns
-------
pd.DataFrame
Table of genes with copy number alterations, including gene name,
chromosome, log2 ratio, and probe counts.
"""
if is_sample_female is None:
is_sample_female = cnarr.guess_xx(
is_haploid_x_reference=is_haploid_x_reference,
diploid_parx_genome=diploid_parx_genome,
)
cnarr = cnarr.shift_xx(
is_haploid_x_reference, is_sample_female, diploid_parx_genome
)
if segments:
segments = segments.shift_xx(
is_haploid_x_reference, is_sample_female, diploid_parx_genome
)
rows = gene_metrics_by_segment(cnarr, segments, threshold, skip_low)
else:
rows = gene_metrics_by_gene(cnarr, threshold, skip_low)
rows = list(rows)
columns = rows[0].index if len(rows) else cnarr._required_columns
columns = ["gene"] + [col for col in columns if col != "gene"]
table = pd.DataFrame.from_records(rows).reindex(columns=columns)
if min_probes and len(table):
n_probes = (
table.segment_probes if "segment_probes" in table.columns else table.probes
)
table = table[n_probes >= min_probes]
return table
[docs]
def gene_metrics_by_gene(
cnarr: CopyNumArray, threshold: float, skip_low: bool = False
) -> Iterator[pd.Series]:
"""Identify genes where average bin copy ratio value exceeds `threshold`.
NB: Adjust the sample's sex-chromosome log2 values beforehand with shift_xx,
otherwise all chrX/chrY genes may be reported gained/lost.
"""
for row in group_by_genes(cnarr, skip_low):
if abs(row.log2) >= threshold and row.gene:
yield row
[docs]
def gene_metrics_by_segment(
cnarr: CopyNumArray,
segments: CopyNumArray,
threshold: float,
skip_low: bool = False,
) -> Iterator[pd.Series]:
"""Identify genes where segmented copy ratio exceeds `threshold`.
In the output table, show each segment's weight and probes as segment_weight
and segment_probes, alongside the gene-level weight and probes.
NB: Adjust the sample's sex-chromosome log2 values beforehand with shift_xx,
otherwise all chrX/chrY genes may be reported gained/lost.
"""
extra_cols = [
col
for col in segments.data.columns
if col not in cnarr.data.columns and col not in ("depth", "probes", "weight")
]
for colname in extra_cols:
cnarr[colname] = np.nan
for segment, subprobes in cnarr.by_ranges(segments):
if abs(segment.log2) >= threshold:
for row in group_by_genes(subprobes, skip_low):
row["log2"] = segment.log2
if hasattr(segment, "weight"):
row["segment_weight"] = segment.weight
if hasattr(segment, "probes"):
row["segment_probes"] = segment.probes
for colname in extra_cols:
row[colname] = getattr(segment, colname)
yield row
# ENH consolidate with CNA.squash_genes
[docs]
def group_by_genes(cnarr: CopyNumArray, skip_low: bool) -> Iterator[pd.Series]:
"""Group probe and coverage data by gene.
Return an iterable of genes, in chromosomal order, associated with their
location and coverages:
[(gene, chrom, start, end, [coverages]), ...]
"""
ignore = ("", np.nan, *params.ANTITARGET_ALIASES)
for gene, rows in cnarr.by_gene():
if not rows or gene in ignore:
continue
segmean = segment_mean(rows, skip_low)
if segmean is None:
continue
outrow = rows[0].copy()
outrow["end"] = rows.end.iat[-1]
outrow["gene"] = gene
outrow["log2"] = segmean
outrow["probes"] = len(rows)
if "weight" in rows:
outrow["weight"] = rows["weight"].sum()
if "depth" in rows:
outrow["depth"] = np.average(rows["depth"], weights=rows["weight"])
elif "depth" in rows:
outrow["depth"] = rows["depth"].mean()
yield outrow