Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,24 +1,36 @@
altair==5.5.0
annotated-types==0.7.0
attrs==25.3.0
blinker==1.9.0
click==8.2.1
contourpy==1.3.3
coverage==7.10.1
cycler==0.12.1
Cython==3.1.2
exceptiongroup==1.3.0
Flask==3.1.1
fonttools==4.59.2
importlib_metadata==7.1.0
iniconfig==2.1.0
itsdangerous==2.2.0
Jinja2==3.1.6
joblib==1.3.2
jsonschema==4.25.1
jsonschema-specifications==2025.9.1
kiwisolver==1.4.9
line_profiler==5.0.0
MarkupSafe==2.1.5
matplotlib==3.10.6
narwhals==2.5.0
numpy==1.26.4
orjson==3.11.1
packaging==25.0
pandas==2.3.1
parasail==1.3.4
patsy==0.5.6
pillow==11.3.0
pluggy==1.6.0
polars==1.33.1
psutil==6.1.0
pydantic==2.11.7
pydantic_core==2.33.2
Expand All @@ -29,8 +41,11 @@ pytest==8.4.1
pytest-cov==6.2.1
python-dateutil==2.9.0.post0
pytz==2025.2
referencing==0.36.2
rpds-py==0.27.1
scikit-learn==1.7.2
scipy==1.16.1
seaborn==0.13.2
six==1.17.0
statsmodels==0.14.5
strkit_rust_ext==0.24.2
Expand All @@ -39,5 +54,6 @@ tomli==2.2.1
typing-inspection==0.4.1
typing_extensions==4.14.1
tzdata==2025.2
vl-convert-python==1.8.0
Werkzeug==3.1.3
zipp==3.23.0
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
"statsmodels>=0.14.0,<0.15",
"strkit_rust_ext==0.24.2",
],
extras_require={
"plot": ["pandas>=2.3.2,<2.4", "seaborn>=0.13.2,<0.14"],
},

description="A toolkit for analyzing variation in short(ish) tandem repeats.",
long_description=long_description,
Expand Down
41 changes: 41 additions & 0 deletions strkit/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,18 @@ def add_cv_parser_args(al_parser):
al_parser.add_argument("--out-format", type=str, choices=CONVERTER_OUTPUT_FORMATS, help="Format to convert to.")


def add_pl_parser_args(pl_parser):
pl_parser.add_argument("in_file", type=str, help="Input STRkit JSON report to plot a locus from.")
pl_parser.add_argument("idx", type=int) # TODO: HELP
pl_parser.add_argument("out_file", type=str) # TODO: HELP
pl_parser.add_argument("--allele", type=str, default="all") # TODO: HELP
pl_parser.add_argument("--allele-label", type=str, default="Allele") # TODO: HELP
pl_parser.add_argument("--n-bins", type=int, default=None) # TODO: HELP
pl_parser.add_argument("--x-label", type=str, default="Copy number") # TODO: HELP
pl_parser.add_argument("--y-label", type=str, default="# reads") # TODO: HELP
pl_parser.add_argument("--annotation", "--annot", type=str, action="append") # TODO: HELP


def add_vs_parser_args(vs_parser):
vs_parser.add_argument("align_file", type=str, help="Alignment file to visualize.")
vs_parser.add_argument(
Expand Down Expand Up @@ -607,6 +619,26 @@ def _exec_convert(p_args):
return convert(p_args.in_file, p_args.in_format, p_args.out_format, _main_logger(p_args))


def _exec_plot(p_args):
from strkit.json import json
from strkit.plot.locus import plot_locus

with open(p_args.in_file, "rb") as fh:
data = json.loads(fh.read())

plot_locus(
data["results"][p_args.idx],
p_args.allele,
p_args.allele_label,
p_args.n_bins,
p_args.x_label,
p_args.y_label,
p_args.annotation,
p_args.out_file,
_main_logger(p_args),
)


def _exec_viz_server(p_args):
from strkit.json import json
from strkit.viz.server import run_server as viz_run_server
Expand Down Expand Up @@ -706,6 +738,15 @@ def _make_subparser(*names: str, help_text: str, exec_func: Callable, arg_func:
exec_func=_exec_convert,
arg_func=add_cv_parser_args)

_make_subparser(
"plot",
help_text=(
"Generate a plot of a specific locus (and, optionally, a specific allele) with a copy number histogram "
"and/or a k-mer histogram."
),
exec_func=_exec_plot,
arg_func=add_pl_parser_args)

_make_subparser(
"visualize", "vis", "viz",
help_text="Start a web server to visualize results from an STR genotyping report.",
Expand Down
Empty file added strkit/plot/__init__.py
Empty file.
211 changes: 211 additions & 0 deletions strkit/plot/locus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import re
from collections import deque, defaultdict
from logging import Logger
from pathlib import Path
from strkit.call.types import LocusResult
from strkit.utils import cat_strs

__all__ = ["plot_locus"]


def normalize_motif(normalized_motifs_cache: set, motif: str):
motif_deque = deque(motif)

if len(motif) == 3 and "C" in motif:
if motif.count("C") == 2:
# special case: if len(motif) == 3 and we have two C, rotate until we start with two Cs
while not (motif_deque[0] == "C" and motif_deque[1] == "C"):
motif_deque.rotate(1)
return cat_strs(motif_deque)
else:
# special case: if len(motif) == 3 and we have a C, rotate until we start with C
while motif_deque[0] != "C":
motif_deque.rotate(1)
return cat_strs(motif_deque)
else:
for _ in range(len(motif)):
if (c := cat_strs(motif_deque)) in normalized_motifs_cache:
return c
motif_deque.rotate(1)
normalized_motifs_cache.add(motif)
return motif


def plot_locus(
locus_data: LocusResult,
allele: str,
allele_label: str,
n_bins: int | None,
x_label: str,
y_label: str,
annotations: list[str] | None,
out_file: str | Path,
logger: Logger,
) -> None:
"""
TODO
:param locus_data: TODO
:param allele: TODO
:param allele_label: TODO
:param n_bins: TODO
:param x_label: TODO
:param y_label: TODO
:param annotations: TODO
:param out_file: TODO
:param logger: TODO
:return: TODO
"""

try:
import seaborn.objects as so
except ImportError:
so = None

if so is None:
logger.error("Could not import seaborn. Make sure to install 'strkit[plot]'!")
exit(1)

import matplotlib as mpl
import pandas as pd
import seaborn as sns

from matplotlib import figure as mpl_figure

font_size = 10
font_rc = {
"font.family": "Arial",
"font.size": font_size,
"axes.titlesize": font_size,
"axes.labelsize": font_size,
"xtick.labelsize": font_size,
"ytick.labelsize": font_size,
"legend.title_fontsize": font_size,
"legend.fontsize": font_size,
}
mpl.rcParams.update(font_rc)

def add_text(plot_, x: int, y: int, text: str):
return plot_.add(
so.Text(halign="right"),
data=pd.DataFrame.from_records([{"x": x, "y": y, "text": text}]),
x="x", y="y", text="text"
)

def themed(plot_):
return plot_.theme({
**sns.axes_style("white"),
"axes.spines.top": False,
"axes.spines.right": False,
**font_rc,
"patch.linewidth": 0,
})

# ------------------------------------------------------------------------------------------------------------------

if "reads" not in locus_data or "peaks" not in locus_data:
logger.error("No reads and/or peaks in specified locus. This locus may not have been called.")
exit(1)

records = []

for read in locus_data["reads"].values():
records.append({
allele_label: str(read.get("p", -1)),
"Copy Number": read["cn"],
# "Sequence Length": # TODO, in future STRKit, include read sequence length!
})

df = pd.DataFrame.from_records(records)

if df[df[allele_label] == -1].shape[0] == df.shape[0]:
logger.warning("No reads with peaks in specified locus; ignoring specified allele and plotting everything.")

# TODO: add k-mer plot

if re.match(r"\d+", allele):
allele_idx = int(allele)
df = df[df[allele_label] == allele_idx]

ppi = 300

n_bins = n_bins or 90 # TODO: normal default

f = mpl_figure.Figure(figsize=(6.5, 8)) # tight_layout=True
subfigs = f.subfigures(2, 1)

plot = (
themed(so.Plot(df, x="Copy Number"))
.scale(color=sns.color_palette("muted", 2))
.add(so.Bars(edgewidth=0), so.Hist(bins=n_bins), color=allele_label)
.on(subfigs[0])
)

def bad_annot(a: str):
logger.error("Bad annotation (format: x y message): %s", a)

for annot in (annotations or []):
ad = annot.split(" ", maxsplit=2)
if len(ad) != 3:
bad_annot(annot)
exit(1)

try:
ad_x = int(ad[0])
ad_y = int(ad[1])
except ValueError:
bad_annot(annot)
exit(1)

ad_msg = ad[2]

plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, ad_y]}).assign(x=ad_x), x="x", y="y")
plot = add_text(plot, ad_x, ad_y, ad_msg)

plot = plot.label(x=x_label, y=y_label or "# reads")
ptr = plot.plot()
ptr._figure.legends[0].set_bbox_to_anchor((0.8, 0.95))

normalize: bool = True

if "kmers" in locus_data["peaks"]:
normalized_motifs_cache = set()
kmers_dict: dict[tuple[str, str], float] = defaultdict(lambda: 0.0)
for pi, p in enumerate(locus_data["peaks"]["kmers"]):
for k, v in p.items():
vnorm = v / locus_data["peaks"]["n_reads"][pi] / (len(k) if normalize else 1)
kmers_dict[str(pi), normalize_motif(normalized_motifs_cache, k) if normalize else k] += vnorm
kmers_df = pd.DataFrame.from_records([
{allele_label: pi, "k-mer": kmer, "Count": count} for (pi, kmer), count in kmers_dict.items()
])
kmers_df_top5 = kmers_df.nlargest(5, "Count", keep="all")
kmers_plot = (
themed(so.Plot(kmers_df_top5, x="k-mer", y="Count"))
.facet(col=allele_label)
.add(so.Bar(), so.Dodge())
.on(subfigs[1])
)
kmers_ptr = kmers_plot.plot()
kmers_ptr._figure.axes[1].xaxis.set_tick_params(rotation=90)
kmers_ptr._figure.axes[2].xaxis.set_tick_params(rotation=90)

# plot = add_text(plot, 107, 80, "Main expansion peak (De Luca et al.)")
#
# plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, 95]}).assign(x=134), x="x", y="y")
# plot = add_text(plot, 134, 95, "First mosaic peak (De Luca et al.)")
#
# plot = plot.add(so.Line(color="#666666"), data=pd.DataFrame({"y": [0, 110]}).assign(x=175), x="x", y="y")
# plot = add_text(plot, 175, 110, "Second mosaic peak (De Luca et al.)")

# ptr = plot.plot()
# # noinspection PyProtectedMember
# ptr._figure.legends[0].set_bbox_to_anchor((0.8, 0.95))
#
# if kmers_plot:
# kmers_plot.plot()

f.savefig(out_file, dpi=ppi)

# ptr.save(out_file, dpi=ppi)

# TODO: panel two: overall k-mer distribution
# TODO: panel three?: comparison with other tools?
Loading