From 8fe1dd140e2b86fbb9455412fbff6816d5b7332a Mon Sep 17 00:00:00 2001 From: David Lougheed Date: Wed, 18 Feb 2026 16:14:26 -0500 Subject: [PATCH] feat: add plot subcommand for CLI plotting --- requirements.txt | 16 +++ setup.py | 3 + strkit/entry.py | 41 ++++++++ strkit/plot/__init__.py | 0 strkit/plot/locus.py | 211 ++++++++++++++++++++++++++++++++++++++++ 5 files changed, 271 insertions(+) create mode 100644 strkit/plot/__init__.py create mode 100644 strkit/plot/locus.py diff --git a/requirements.txt b/requirements.txt index f287f8c..194a06a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,24 +1,36 @@ +altair==5.5.0 annotated-types==0.7.0 +attrs==25.3.0 blinker==1.9.0 click==8.2.1 +contourpy==1.3.3 coverage==7.10.1 +cycler==0.12.1 Cython==3.1.2 exceptiongroup==1.3.0 Flask==3.1.1 +fonttools==4.59.2 importlib_metadata==7.1.0 iniconfig==2.1.0 itsdangerous==2.2.0 Jinja2==3.1.6 joblib==1.3.2 +jsonschema==4.25.1 +jsonschema-specifications==2025.9.1 +kiwisolver==1.4.9 line_profiler==5.0.0 MarkupSafe==2.1.5 +matplotlib==3.10.6 +narwhals==2.5.0 numpy==1.26.4 orjson==3.11.1 packaging==25.0 pandas==2.3.1 parasail==1.3.4 patsy==0.5.6 +pillow==11.3.0 pluggy==1.6.0 +polars==1.33.1 psutil==6.1.0 pydantic==2.11.7 pydantic_core==2.33.2 @@ -29,8 +41,11 @@ pytest==8.4.1 pytest-cov==6.2.1 python-dateutil==2.9.0.post0 pytz==2025.2 +referencing==0.36.2 +rpds-py==0.27.1 scikit-learn==1.7.2 scipy==1.16.1 +seaborn==0.13.2 six==1.17.0 statsmodels==0.14.5 strkit_rust_ext==0.24.2 @@ -39,5 +54,6 @@ tomli==2.2.1 typing-inspection==0.4.1 typing_extensions==4.14.1 tzdata==2025.2 +vl-convert-python==1.8.0 Werkzeug==3.1.3 zipp==3.23.0 diff --git a/setup.py b/setup.py index 549eaca..22df672 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,9 @@ "statsmodels>=0.14.0,<0.15", "strkit_rust_ext==0.24.2", ], + extras_require={ + "plot": ["pandas>=2.3.2,<2.4", "seaborn>=0.13.2,<0.14"], + }, description="A toolkit for analyzing variation in short(ish) tandem repeats.", long_description=long_description, diff --git a/strkit/entry.py b/strkit/entry.py index e018458..31a84d8 100644 --- a/strkit/entry.py +++ b/strkit/entry.py @@ -446,6 +446,18 @@ def add_cv_parser_args(al_parser): al_parser.add_argument("--out-format", type=str, choices=CONVERTER_OUTPUT_FORMATS, help="Format to convert to.") +def add_pl_parser_args(pl_parser): + pl_parser.add_argument("in_file", type=str, help="Input STRkit JSON report to plot a locus from.") + pl_parser.add_argument("idx", type=int) # TODO: HELP + pl_parser.add_argument("out_file", type=str) # TODO: HELP + pl_parser.add_argument("--allele", type=str, default="all") # TODO: HELP + pl_parser.add_argument("--allele-label", type=str, default="Allele") # TODO: HELP + pl_parser.add_argument("--n-bins", type=int, default=None) # TODO: HELP + pl_parser.add_argument("--x-label", type=str, default="Copy number") # TODO: HELP + pl_parser.add_argument("--y-label", type=str, default="# reads") # TODO: HELP + pl_parser.add_argument("--annotation", "--annot", type=str, action="append") # TODO: HELP + + def add_vs_parser_args(vs_parser): vs_parser.add_argument("align_file", type=str, help="Alignment file to visualize.") vs_parser.add_argument( @@ -607,6 +619,26 @@ def _exec_convert(p_args): return convert(p_args.in_file, p_args.in_format, p_args.out_format, _main_logger(p_args)) +def _exec_plot(p_args): + from strkit.json import json + from strkit.plot.locus import plot_locus + + with open(p_args.in_file, "rb") as fh: + data = json.loads(fh.read()) + + plot_locus( + data["results"][p_args.idx], + p_args.allele, + p_args.allele_label, + p_args.n_bins, + p_args.x_label, + p_args.y_label, + p_args.annotation, + p_args.out_file, + _main_logger(p_args), + ) + + def _exec_viz_server(p_args): from strkit.json import json from strkit.viz.server import run_server as viz_run_server @@ -706,6 +738,15 @@ def _make_subparser(*names: str, help_text: str, exec_func: Callable, arg_func: exec_func=_exec_convert, arg_func=add_cv_parser_args) + _make_subparser( + "plot", + help_text=( + "Generate a plot of a specific locus (and, optionally, a specific allele) with a copy number histogram " + "and/or a k-mer histogram." + ), + exec_func=_exec_plot, + arg_func=add_pl_parser_args) + _make_subparser( "visualize", "vis", "viz", help_text="Start a web server to visualize results from an STR genotyping report.", diff --git a/strkit/plot/__init__.py b/strkit/plot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strkit/plot/locus.py b/strkit/plot/locus.py new file mode 100644 index 0000000..6e95d6c --- /dev/null +++ b/strkit/plot/locus.py @@ -0,0 +1,211 @@ +import re +from collections import deque, defaultdict +from logging import Logger +from pathlib import Path +from strkit.call.types import LocusResult +from strkit.utils import cat_strs + +__all__ = ["plot_locus"] + + +def normalize_motif(normalized_motifs_cache: set, motif: str): + motif_deque = deque(motif) + + if len(motif) == 3 and "C" in motif: + if motif.count("C") == 2: + # special case: if len(motif) == 3 and we have two C, rotate until we start with two Cs + while not (motif_deque[0] == "C" and motif_deque[1] == "C"): + motif_deque.rotate(1) + return cat_strs(motif_deque) + else: + # special case: if len(motif) == 3 and we have a C, rotate until we start with C + while motif_deque[0] != "C": + motif_deque.rotate(1) + return cat_strs(motif_deque) + else: + for _ in range(len(motif)): + if (c := cat_strs(motif_deque)) in normalized_motifs_cache: + return c + motif_deque.rotate(1) + normalized_motifs_cache.add(motif) + return motif + + +def plot_locus( + locus_data: LocusResult, + allele: str, + allele_label: str, + n_bins: int | None, + x_label: str, + y_label: str, + annotations: list[str] | None, + out_file: str | Path, + logger: Logger, +) -> None: + """ + TODO + :param locus_data: TODO + :param allele: TODO + :param allele_label: TODO + :param n_bins: TODO + :param x_label: TODO + :param y_label: TODO + :param annotations: TODO + :param out_file: TODO + :param logger: TODO + :return: TODO + """ + + try: + import seaborn.objects as so + except ImportError: + so = None + + if so is None: + logger.error("Could not import seaborn. Make sure to install 'strkit[plot]'!") + exit(1) + + import matplotlib as mpl + import pandas as pd + import seaborn as sns + + from matplotlib import figure as mpl_figure + + font_size = 10 + font_rc = { + "font.family": "Arial", + "font.size": font_size, + "axes.titlesize": font_size, + "axes.labelsize": font_size, + "xtick.labelsize": font_size, + "ytick.labelsize": font_size, + "legend.title_fontsize": font_size, + "legend.fontsize": font_size, + } + mpl.rcParams.update(font_rc) + + def add_text(plot_, x: int, y: int, text: str): + return plot_.add( + so.Text(halign="right"), + data=pd.DataFrame.from_records([{"x": x, "y": y, "text": text}]), + x="x", y="y", text="text" + ) + + def themed(plot_): + return plot_.theme({ + **sns.axes_style("white"), + "axes.spines.top": False, + "axes.spines.right": False, + **font_rc, + "patch.linewidth": 0, + }) + + # ------------------------------------------------------------------------------------------------------------------ + + if "reads" not in locus_data or "peaks" not in locus_data: + logger.error("No reads and/or peaks in specified locus. This locus may not have been called.") + exit(1) + + records = [] + + for read in locus_data["reads"].values(): + records.append({ + allele_label: str(read.get("p", -1)), + "Copy Number": read["cn"], + # "Sequence Length": # TODO, in future STRKit, include read sequence length! + }) + + df = pd.DataFrame.from_records(records) + + if df[df[allele_label] == -1].shape[0] == df.shape[0]: + logger.warning("No reads with peaks in specified locus; ignoring specified allele and plotting everything.") + + # TODO: add k-mer plot + + if re.match(r"\d+", allele): + allele_idx = int(allele) + df = df[df[allele_label] == allele_idx] + + ppi = 300 + + n_bins = n_bins or 90 # TODO: normal default + + f = mpl_figure.Figure(figsize=(6.5, 8)) # tight_layout=True + subfigs = f.subfigures(2, 1) + + plot = ( + themed(so.Plot(df, x="Copy Number")) + .scale(color=sns.color_palette("muted", 2)) + .add(so.Bars(edgewidth=0), so.Hist(bins=n_bins), color=allele_label) + .on(subfigs[0]) + ) + + def bad_annot(a: str): + logger.error("Bad annotation (format: x y message): %s", a) + + for annot in (annotations or []): + ad = annot.split(" ", maxsplit=2) + if len(ad) != 3: + bad_annot(annot) + exit(1) + + try: + ad_x = int(ad[0]) + ad_y = int(ad[1]) + except ValueError: + bad_annot(annot) + exit(1) + + ad_msg = ad[2] + + plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, ad_y]}).assign(x=ad_x), x="x", y="y") + plot = add_text(plot, ad_x, ad_y, ad_msg) + + plot = plot.label(x=x_label, y=y_label or "# reads") + ptr = plot.plot() + ptr._figure.legends[0].set_bbox_to_anchor((0.8, 0.95)) + + normalize: bool = True + + if "kmers" in locus_data["peaks"]: + normalized_motifs_cache = set() + kmers_dict: dict[tuple[str, str], float] = defaultdict(lambda: 0.0) + for pi, p in enumerate(locus_data["peaks"]["kmers"]): + for k, v in p.items(): + vnorm = v / locus_data["peaks"]["n_reads"][pi] / (len(k) if normalize else 1) + kmers_dict[str(pi), normalize_motif(normalized_motifs_cache, k) if normalize else k] += vnorm + kmers_df = pd.DataFrame.from_records([ + {allele_label: pi, "k-mer": kmer, "Count": count} for (pi, kmer), count in kmers_dict.items() + ]) + kmers_df_top5 = kmers_df.nlargest(5, "Count", keep="all") + kmers_plot = ( + themed(so.Plot(kmers_df_top5, x="k-mer", y="Count")) + .facet(col=allele_label) + .add(so.Bar(), so.Dodge()) + .on(subfigs[1]) + ) + kmers_ptr = kmers_plot.plot() + kmers_ptr._figure.axes[1].xaxis.set_tick_params(rotation=90) + kmers_ptr._figure.axes[2].xaxis.set_tick_params(rotation=90) + + # plot = add_text(plot, 107, 80, "Main expansion peak (De Luca et al.)") + # + # plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, 95]}).assign(x=134), x="x", y="y") + # plot = add_text(plot, 134, 95, "First mosaic peak (De Luca et al.)") + # + # plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, 110]}).assign(x=175), x="x", y="y") + # plot = add_text(plot, 175, 110, "Second mosaic peak (De Luca et al.)") + + # ptr = plot.plot() + # # noinspection PyProtectedMember + # ptr._figure.legends[0].set_bbox_to_anchor((0.8, 0.95)) + # + # if kmers_plot: + # kmers_plot.plot() + + f.savefig(out_file, dpi=ppi) + + # ptr.save(out_file, dpi=ppi) + + # TODO: panel two: overall k-mer distribution + # TODO: panel three?: comparison with other tools?