diff --git a/.gitignore b/.gitignore index 99cdd97..d49e677 100644 --- a/.gitignore +++ b/.gitignore @@ -37,4 +37,19 @@ .Trashes Icon? ehthumbs.db -Thumbs.db \ No newline at end of file +Thumbs.db + +# Python compiled files # +######################### +**__pycache__ +**.pyc +.pytest_cache + +# Eggs # +######## +.eggs +svtyper.egg-info + +# Idea # +######## +.idea diff --git a/scripts/sv_classifier.py b/scripts/sv_classifier.py index afb8f17..292bb3e 100755 --- a/scripts/sv_classifier.py +++ b/scripts/sv_classifier.py @@ -1,7 +1,13 @@ #!/usr/bin/env python -import argparse, sys, copy, gzip, os -import math, time, re +import argparse +import sys +import copy +import gzip +import os +import math +import time +import re import numpy as np from scipy import stats from collections import Counter @@ -15,19 +21,31 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ sv_classifier.py\n\ author: " + __author__ + "\n\ version: " + __version__ + "\n\ description: classify structural variants") - parser.add_argument('-i', '--input', metavar='VCF', dest='vcf_in', type=argparse.FileType('r'), default=None, help='VCF input [stdin]') - parser.add_argument('-g', '--gender', metavar='FILE', dest='gender', type=argparse.FileType('r'), required=True, default=None, help='tab delimited file of sample genders (male=1, female=2)\nex: SAMPLE_A\t2') - parser.add_argument('-e', '--exclude', metavar='FILE', dest='exclude', type=argparse.FileType('r'), required=False, default=None, help='list of samples to exclude from classification algorithms') - parser.add_argument('-a', '--annotation', metavar='BED', dest='ae_path', type=str, default=None, help='BED file of annotated elements') - parser.add_argument('-f', '--fraction', metavar='FLOAT', dest='f_overlap', type=float, default=0.9, help='fraction of reciprocal overlap to apply annotation to variant [0.9]') - parser.add_argument('-s', '--slope_threshold', metavar='FLOAT', dest='slope_threshold', type=float, default=1.0, help='minimum slope absolute value of regression line to classify as DEL or DUP[1.0]') - parser.add_argument('-r', '--rsquared_threshold', metavar='FLOAT', dest='rsquared_threshold', type=float, default=0.2, help='minimum R^2 correlation value of regression line to classify as DEL or DUP [0.2]') + parser.add_argument('-i', '--input', metavar='VCF', dest='vcf_in', + type=argparse.FileType('r'), default=None, help='VCF input [stdin]') + parser.add_argument( + '-g', '--gender', metavar='FILE', dest='gender', type=argparse.FileType('r'), + required=True, default=None, help='tab delimited file of sample genders (male=1, female=2)\nex: SAMPLE_A\t2') + parser.add_argument('-e', '--exclude', metavar='FILE', dest='exclude', type=argparse.FileType( + 'r'), required=False, default=None, help='list of samples to exclude from classification algorithms') + parser.add_argument('-a', '--annotation', metavar='BED', dest='ae_path', + type=str, default=None, help='BED file of annotated elements') + parser.add_argument( + '-f', '--fraction', metavar='FLOAT', dest='f_overlap', type=float, + default=0.9, help='fraction of reciprocal overlap to apply annotation to variant [0.9]') + parser.add_argument( + '-s', '--slope_threshold', metavar='FLOAT', dest='slope_threshold', type=float, + default=1.0, help='minimum slope absolute value of regression line to classify as DEL or DUP[1.0]') + parser.add_argument( + '-r', '--rsquared_threshold', metavar='FLOAT', dest='rsquared_threshold', type=float, + default=0.2, help='minimum R^2 correlation value of regression line to classify as DEL or DUP [0.2]') # parse the arguments args = parser.parse_args() @@ -42,7 +60,9 @@ def get_args(): # send back the user input return args + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -60,15 +80,15 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -78,10 +98,10 @@ def add_header(self, header): def get_header(self): header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -91,7 +111,7 @@ def get_header(self): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ + 'FORMAT'] + self.sample_list )]) return header @@ -120,6 +140,7 @@ def sample_to_col(self, sample): return self.sample_list.index(sample) + 9 class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -128,18 +149,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -148,9 +173,12 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -168,10 +196,11 @@ def __init__(self, var_list, vcf): self.format_list = vcf.format_list self.active_formats = list() self.gts = dict() - + # fill in empty sample genotypes if len(var_list) < 8: - sys.stderr.write('\nError: VCF file must have at least 8 columns\n') + sys.stderr.write( + '\nError: VCF file must have at least 8 columns\n') exit(1) if len(var_list) < 9: var_list.append("GT") @@ -188,7 +217,8 @@ def __init__(self, var_list, vcf): self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -198,7 +228,8 @@ def set_info(self, field, value): if field in [i.id for i in self.info_list]: self.info[field] = value else: - sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid INFO field, \"' + field + '\"\n') exit(1) def get_info(self, field): @@ -207,11 +238,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append( + '%s=%s' % (info_field.id, self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -225,10 +257,11 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + '\nError: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -238,11 +271,14 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -254,9 +290,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -272,23 +310,28 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # http://stackoverflow.com/questions/8930370/where-can-i-find-mad-mean-absolute-deviation-in-scipy + + def mad(arr): """ Median Absolute Deviation: a "Robust" version of standard deviation. Indices variabililty of the sample. - https://en.wikipedia.org/wiki/Median_absolute_deviation + https://en.wikipedia.org/wiki/Median_absolute_deviation """ - arr = np.ma.array(arr).compressed() # should be faster to not use masked arrays. + arr = np.ma.array(arr).compressed() + # should be faster to not use masked arrays. med = np.median(arr) return np.median(np.abs(arr - med)) # test whether variant has read depth support by regression + + def has_high_freq_depth_support(var, gender, exclude, slope_threshold, rsquared_threshold, writedir=None): # slope_threshold = 0.1 # rsquared_threshold = 0.1 - + if 'CN' in var.active_formats: # allele balance list ab_list = [] @@ -315,25 +358,27 @@ def has_high_freq_depth_support(var, gender, exclude, slope_threshold, rsquared_ rd = np.array([ab_list, rd_list]) # remove missing genotypes - rd = rd[:, rd[0]!=-1] + rd = rd[:, rd[0] != -1] # ensure non-uniformity in genotype and read depth - if len(np.unique(rd[0,:])) > 1 and len(np.unique(rd[1,:])) > 1: + if len(np.unique(rd[0, :])) > 1 and len(np.unique(rd[1,:])) > 1: # calculate regression - (slope, intercept, r_value, p_value, std_err) = stats.linregress(rd) + (slope, intercept, r_value, p_value, + std_err) = stats.linregress(rd) # print slope, intercept, r_value, var.info['SVTYPE'], var.var_id - # write the scatterplot to a file if writedir is not None: try: os.makedirs(writedir) - except OSError as exc: # Python >2.5 + except OSError as exc: # Python >2.5 if os.path.isdir(writedir): pass - else: raise + else: + raise - f = open('%s/reg_%s_%s_%sbp.txt' % (writedir, var.info['SVTYPE'], var.var_id, var.info['SVLEN']), 'w') + f = open('%s/reg_%s_%s_%sbp.txt' % + (writedir, var.info['SVTYPE'], var.var_id, var.info['SVLEN']), 'w') np.savetxt(f, np.transpose(rd), delimiter='\t') f.close() @@ -350,16 +395,18 @@ def has_high_freq_depth_support(var, gender, exclude, slope_threshold, rsquared_ return False # test for read depth support of low frequency variants + + def has_low_freq_depth_support(var, gender, exclude, writedir=None): mad_threshold = 2 - mad_quorum = 0.5 # this fraction of the pos. genotyped results must meet the mad_threshold + mad_quorum = 0.5 # this fraction of the pos. genotyped results must meet the mad_threshold absolute_cn_diff = 0.5 - + hom_ref_cn = [] het_cn = [] hom_alt_cn = [] - # determine whether majority of + # determine whether majority of # if on the sex chromosomes, only compare against the majority sex if (var.chrom == 'X' or var.chrom == 'Y'): @@ -431,21 +478,27 @@ def has_low_freq_depth_support(var, gender, exclude, writedir=None): if writedir is not None: try: os.makedirs(writedir) - except OSError as exc: # Python >2.5 + except OSError as exc: # Python >2.5 if os.path.isdir(writedir): pass - else: raise + else: + raise - f = open('%s/mad_%s_%s_%sbp.txt' % (writedir, var.info['SVTYPE'], var.var_id, var.info['SVLEN']), 'w') + f = open('%s/mad_%s_%s_%sbp.txt' % + (writedir, var.info['SVTYPE'], var.var_id, var.info['SVLEN']), 'w') for cn in hom_ref_cn: - f.write('\t'.join(map(str, [0, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') + f.write( + '\t'.join(map(str, [0, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') for cn in het_cn: - f.write('\t'.join(map(str, [1, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') + f.write( + '\t'.join(map(str, [1, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') for cn in hom_alt_cn: - f.write('\t'.join(map(str, [2, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') + f.write( + '\t'.join(map(str, [2, cn, cn_mean, cn_stdev, cn_median, cn_mad])) + '\n') f.close() - # bail after writing out diagnostic info, if no ref samples or all ref samples + # bail after writing out diagnostic info, if no ref samples or all ref + # samples if (len(hom_ref_cn) == 0 or len(het_cn + hom_alt_cn) == 0): return False @@ -460,11 +513,12 @@ def has_low_freq_depth_support(var, gender, exclude, writedir=None): resid > absolute_cn_diff): q += 1 # check if meets quorum - if float(q)/len(het_cn + hom_alt_cn) > mad_quorum: + if float(q) / len(het_cn + hom_alt_cn) > mad_quorum: return True else: return False + def to_bnd(var): var1 = copy.deepcopy(var) var2 = copy.deepcopy(var) @@ -480,7 +534,7 @@ def to_bnd(var): var2.var_id = var.var_id + "_2" var1.info['MATEID'] = var2.var_id var2.info['MATEID'] = var1.var_id - + # update position var2.pos = var.info['END'] @@ -508,6 +562,7 @@ def to_bnd(var): var2.alt = 'N[%s:%s[' % (var.chrom, var.pos) return var1, var2 + def reciprocal_overlap(a, b_list): overlap = 0 b_aggregate = 0 @@ -526,11 +581,12 @@ def reciprocal_overlap(a, b_list): return min(overlap / (a[1] - a[0]), overlap / b_aggregate) + def collapse_bed_records(bed_list): bed_list_sorted = sorted(bed_list, key=itemgetter(1)) collapsed_bed_list = [] - + i = 0 curr_rec = bed_list_sorted[i] while i < len(bed_list_sorted): @@ -556,6 +612,7 @@ def collapse_bed_records(bed_list): # print 'collapsed:', collapsed_bed_list return collapsed_bed_list + def annotation_intersect(var, ae_dict, threshold): best_frac_overlap = 0 best_feature = '' @@ -563,7 +620,7 @@ def annotation_intersect(var, ae_dict, threshold): # dictionary with number of bases of overlap for each class class_overlap = {} - + # first check for reciprocal overlap if var.chrom in ae_dict: var_start = var.pos @@ -586,22 +643,26 @@ def annotation_intersect(var, ae_dict, threshold): # print class_overlap for me_class in class_overlap: - class_overlap[me_class] = collapse_bed_records(class_overlap[me_class]) + class_overlap[me_class] = collapse_bed_records( + class_overlap[me_class]) # print 'class_overlap[me_class]:', class_overlap[me_class] - # print 'recip:', reciprocal_overlap([var_start, var_end], class_overlap[me_class]) + # print 'recip:', reciprocal_overlap([var_start, var_end], + # class_overlap[me_class]) - frac_overlap = reciprocal_overlap([var_start, var_end], class_overlap[me_class]) + frac_overlap = reciprocal_overlap( + [var_start, var_end], class_overlap[me_class]) if frac_overlap > best_frac_overlap: best_frac_overlap = frac_overlap best_feature = me_class - if best_frac_overlap >= threshold: return best_feature return None # primary function + + def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_threshold, rsquared_threshold): vcf_out = sys.stdout vcf = Vcf() @@ -631,7 +692,8 @@ def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_thr # write the output header vcf_out.write(vcf.get_header() + '\n') - # split variant line, quick pre-check if the SVTYPE is BND, and skip if so + # split variant line, quick pre-check if the SVTYPE is BND, and skip if + # so v = line.rstrip().split('\t') info = v[7].split(';') @@ -660,7 +722,7 @@ def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_thr vcf_out.write(var.get_var_string() + '\n') continue - # # write to directory + # write to directory # writedir = 'data/r11.100kb.dup' # annotate based on read depth @@ -681,7 +743,9 @@ def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_thr vcf_out.write(var.get_var_string() + '\n') else: # has_low_freq_depth_support(var, gender, exclude, writedir + '/low_freq_no_rd') - # has_high_freq_depth_support(var, gender, exclude, slope_threshold, rsquared_threshold, writedir + '/low_freq_no_rd') + # has_high_freq_depth_support(var, gender, exclude, + # slope_threshold, rsquared_threshold, writedir + + # '/low_freq_no_rd') for m_var in to_bnd(var): vcf_out.write(m_var.get_var_string() + '\n') else: @@ -692,7 +756,8 @@ def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_thr vcf_out.write(var.get_var_string() + '\n') else: # has_high_freq_depth_support(var, gender, exclude, slope_threshold, rsquared_threshold, writedir + '/high_freq_no_rd') - # has_low_freq_depth_support(var, gender, exclude, writedir + '/high_freq_no_rd') + # has_low_freq_depth_support(var, gender, exclude, writedir + # + '/high_freq_no_rd') for m_var in to_bnd(var): vcf_out.write(m_var.get_var_string() + '\n') vcf_out.close() @@ -701,6 +766,7 @@ def sv_classify(vcf_in, gender_file, exclude_file, ae_dict, f_overlap, slope_thr # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -745,6 +811,6 @@ def main(): if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE - raise + raise diff --git a/scripts/update_info.py b/scripts/update_info.py index a5b77ba..f83d01c 100755 --- a/scripts/update_info.py +++ b/scripts/update_info.py @@ -1,8 +1,11 @@ #!/usr/bin/env python import pysam -import argparse, sys -import math, time, re +import argparse +import sys +import math +import time +import re from collections import Counter from argparse import RawTextHelpFormatter @@ -13,13 +16,15 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ update_info.py\n\ author: " + __author__ + "\n\ version: " + __version__ + "\n\ description: Update the PE SR fields after SVTyper") - parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') + parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', + type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') # parse the arguments args = parser.parse_args() @@ -34,7 +39,9 @@ def get_args(): # send back the user input return args + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -52,15 +59,15 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -71,10 +78,10 @@ def get_header(self, include_samples=True): if include_samples: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -84,16 +91,16 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ - self.sample_list - )]) + 'FORMAT'] + + self.sample_list + )]) else: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -103,7 +110,7 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO'] - )]) + )]) return header def add_info(self, id, number, type, desc): @@ -136,6 +143,7 @@ def sample_to_col(self, sample): return self.sample_list.index(sample) + 9 class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -144,18 +152,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -164,9 +176,12 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -184,10 +199,11 @@ def __init__(self, var_list, vcf): self.format_list = vcf.format_list self.active_formats = list() self.gts = dict() - + # fill in empty sample genotypes if len(var_list) < 8: - sys.stderr.write('\nError: VCF file must have at least 8 columns\n') + sys.stderr.write( + '\nError: VCF file must have at least 8 columns\n') exit(1) if len(var_list) < 9: var_list.append("GT") @@ -204,7 +220,8 @@ def __init__(self, var_list, vcf): self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -214,7 +231,8 @@ def set_info(self, field, value): if field in [i.id for i in self.info_list]: self.info[field] = value else: - sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid INFO field, \"' + field + '\"\n') exit(1) def get_info(self, field): @@ -223,11 +241,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append( + '%s=%s' % (info_field.id, self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -241,10 +260,11 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + '\nError: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -254,11 +274,14 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -270,9 +293,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -288,13 +313,16 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # primary function + + def update_info(vcf_file): in_header = True header = [] - breakend_dict = {} # cache to hold unmatched generic breakends for genotyping + breakend_dict = {} + # cache to hold unmatched generic breakends for genotyping vcf = Vcf() vcf_out = sys.stdout @@ -302,14 +330,15 @@ def update_info(vcf_file): for line in vcf_file: if in_header: if line[0] == '#': - header.append(line) + header.append(line) if line[1] != '#': vcf_samples = line.rstrip().split('\t')[9:] continue else: in_header = False vcf.add_header(header) - vcf.add_info('MSQ', '1', 'Float', 'Mean sample quality of positively genotyped samples') + vcf.add_info( + 'MSQ', '1', 'Float', 'Mean sample quality of positively genotyped samples') vcf.remove_info('SNAME') # write the output header @@ -350,12 +379,13 @@ def update_info(vcf_file): vcf_out.write(var.get_var_string() + '\n') vcf_out.close() - + return # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -370,6 +400,6 @@ def main(): if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE - raise + raise diff --git a/scripts/vcf_allele_freq.py b/scripts/vcf_allele_freq.py index 9ed89e2..768de23 100755 --- a/scripts/vcf_allele_freq.py +++ b/scripts/vcf_allele_freq.py @@ -1,8 +1,11 @@ #!/usr/bin/env python import pysam -import argparse, sys -import math, time, re +import argparse +import sys +import math +import time +import re from collections import Counter from argparse import RawTextHelpFormatter @@ -13,13 +16,15 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ vcf_allele_freq.py\n\ author: " + __author__ + "\n\ version: " + __version__ + "\n\ description: Add allele frequency information to a VCF file") - parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') + parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', + type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') # parse the arguments args = parser.parse_args() @@ -34,7 +39,9 @@ def get_args(): # send back the user input return args + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -53,19 +60,19 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FILTER': - a = line[line.find('<')+1:-2] + a = line[line.find('<') + 1:-2] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_filter(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -76,11 +83,11 @@ def get_header(self, include_samples=True): if include_samples: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [f.hstring for f in self.filter_list] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [f.hstring for f in self.filter_list] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -90,17 +97,17 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ - self.sample_list - )]) + 'FORMAT'] + + self.sample_list + )]) else: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [f.hstring for f in self.filter_list] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [f.hstring for f in self.filter_list] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -110,7 +117,7 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO'] - )]) + )]) return header def add_info(self, id, number, type, desc): @@ -142,6 +149,7 @@ def sample_to_col(self, sample): return self.sample_list.index(sample) + 9 class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -150,18 +158,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -170,18 +182,23 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' class Filter(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FILTER=' + self.hstring = '##FILTER=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -199,10 +216,11 @@ def __init__(self, var_list, vcf): self.format_list = vcf.format_list self.active_formats = list() self.gts = dict() - + # fill in empty sample genotypes if len(var_list) < 8: - sys.stderr.write('\nError: VCF file must have at least 8 columns\n') + sys.stderr.write( + '\nError: VCF file must have at least 8 columns\n') exit(1) if len(var_list) < 9: var_list.append("GT") @@ -219,7 +237,8 @@ def __init__(self, var_list, vcf): self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -229,7 +248,8 @@ def set_info(self, field, value): if field in [i.id for i in self.info_list]: self.info[field] = value else: - sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid INFO field, \"' + field + '\"\n') exit(1) def get_info(self, field): @@ -238,11 +258,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append( + '%s=%s' % (info_field.id, self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -256,11 +277,12 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + '\nError: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): if len(self.active_formats) == 0: - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -271,7 +293,7 @@ def get_var_string(self): self.get_info_string() ])) else: - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -281,11 +303,14 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -297,9 +322,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -315,13 +342,16 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # primary function + + def add_af(vcf_file): in_header = True header = [] - breakend_dict = {} # cache to hold unmatched generic breakends for genotyping + breakend_dict = {} + # cache to hold unmatched generic breakends for genotyping vcf = Vcf() vcf_out = sys.stdout @@ -329,7 +359,7 @@ def add_af(vcf_file): for line in vcf_file: if in_header: if line.startswith('##'): - header.append(line) + header.append(line) continue elif line.startswith('#CHROM'): v = line.rstrip().split('\t') @@ -337,9 +367,11 @@ def add_af(vcf_file): in_header = False vcf.add_header(header) - - vcf.add_info('AF', 'A', 'Float', 'Allele Frequency, for each ALT allele, in the same order as listed') - vcf.add_info('NSAMP', '1', 'Integer', 'Number of samples with non-reference genotypes') + + vcf.add_info( + 'AF', 'A', 'Float', 'Allele Frequency, for each ALT allele, in the same order as listed') + vcf.add_info( + 'NSAMP', '1', 'Integer', 'Number of samples with non-reference genotypes') # write header vcf_out.write(vcf.get_header(include_samples=False)) @@ -354,17 +386,17 @@ def add_af(vcf_file): alleles = [0] * (num_alt + 1) num_samp = 0 - for i in xrange(9,len(v)): + for i in range(9, len(v)): gt_string = v[i].split(':')[0] - if '.' in gt_string: + if '.' in gt_string: continue gt = gt_string.split('/') if len(gt) == 1: gt = gt_string.split('|') - gt = map(int, gt) + gt = list(map(int, gt)) - for i in xrange(len(gt)): + for i in range(len(gt)): alleles[gt[i]] += 1 # iterate the number of non-reference samples @@ -376,12 +408,13 @@ def add_af(vcf_file): # populate AF if allele_sum > 0: - for i in xrange(len(alleles)): + for i in range(len(alleles)): allele_freq[i] = alleles[i] / allele_sum - var.info['AF'] = ','.join(map(str, ['%.4g' % a for a in allele_freq[1:]])) + var.info['AF'] = ','.join( + map(str, ['%.4g' % a for a in allele_freq[1:]])) else: var.info['AF'] = ','.join(map(str, allele_freq[1:])) - + # populate NSAMP var.info['NSAMP'] = num_samp @@ -391,12 +424,13 @@ def add_af(vcf_file): + '\t'.join(v[8:]) + '\n') vcf_out.close() - + return # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -411,6 +445,6 @@ def main(): if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE - raise + raise diff --git a/scripts/vcf_group_multiline.py b/scripts/vcf_group_multiline.py index 6b8e3ca..945d8bf 100755 --- a/scripts/vcf_group_multiline.py +++ b/scripts/vcf_group_multiline.py @@ -1,8 +1,11 @@ #!/usr/bin/env python import pysam -import argparse, sys -import math, time, re +import argparse +import sys +import math +import time +import re from collections import Counter from argparse import RawTextHelpFormatter @@ -13,13 +16,15 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ vcf_group_multiline.py\n\ author: " + __author__ + "\n\ version: " + __version__ + "\n\ description: Group multiline variants prior vcf_paste.py") - parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') + parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', + type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') # parse the arguments args = parser.parse_args() @@ -34,7 +39,9 @@ def get_args(): # send back the user input return args + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -52,15 +59,15 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -71,10 +78,10 @@ def get_header(self, include_samples=True): if include_samples: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -84,16 +91,16 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ - self.sample_list - )]) + 'FORMAT'] + + self.sample_list + )]) else: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -103,7 +110,7 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO'] - )]) + )]) return header def add_info(self, id, number, type, desc): @@ -130,6 +137,7 @@ def sample_to_col(self, sample): return self.sample_list.index(sample) + 9 class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -138,18 +146,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -158,9 +170,12 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -178,10 +193,11 @@ def __init__(self, var_list, vcf): self.format_list = vcf.format_list self.active_formats = list() self.gts = dict() - + # fill in empty sample genotypes if len(var_list) < 8: - sys.stderr.write('\nError: VCF file must have at least 8 columns\n') + sys.stderr.write( + '\nError: VCF file must have at least 8 columns\n') exit(1) if len(var_list) < 9: var_list.append("GT") @@ -198,7 +214,8 @@ def __init__(self, var_list, vcf): self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -208,7 +225,8 @@ def set_info(self, field, value): if field in [i.id for i in self.info_list]: self.info[field] = value else: - sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid INFO field, \"' + field + '\"\n') exit(1) def get_info(self, field): @@ -217,11 +235,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append( + '%s=%s' % (info_field.id, self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -235,10 +254,11 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + '\nError: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -248,11 +268,14 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -264,9 +287,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -282,13 +307,16 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # primary function + + def sv_genotype(vcf_file): in_header = True header = [] - breakend_dict = {} # cache to hold unmatched generic breakends for genotyping + breakend_dict = {} + # cache to hold unmatched generic breakends for genotyping vcf = Vcf() vcf_out = sys.stdout @@ -296,7 +324,7 @@ def sv_genotype(vcf_file): for line in vcf_file: if in_header: if line[0] == '#': - header.append(line) + header.append(line) if line[1] != '#': vcf_samples = line.rstrip().split('\t')[9:] continue @@ -305,19 +333,29 @@ def sv_genotype(vcf_file): vcf.add_header(header) # if detailed: vcf.add_format('GQ', 1, 'Float', 'Genotype quality') - vcf.add_format('SQ', 1, 'Float', 'Phred-scaled probability that this site is variant (non-reference in this sample') + vcf.add_format( + 'SQ', 1, 'Float', 'Phred-scaled probability that this site is variant (non-reference in this sample') vcf.add_format('GL', 'G', 'Float', 'Genotype Likelihood, log10-scaled likelihoods of the data given the called genotype for each possible gen\ otype generated from the reference and alternate alleles given the sample ploidy') vcf.add_format('DP', 1, 'Integer', 'Read depth') - vcf.add_format('RO', 1, 'Integer', 'Reference allele observation count, with partial observations recorded fractionally') - vcf.add_format('AO', 'A', 'Integer', 'Alternate allele observations, with partial observations recorded fractionally') - vcf.add_format('QR', 1, 'Integer', 'Sum of quality of reference observations') - vcf.add_format('QA', 'A', 'Integer', 'Sum of quality of alternate observations') - vcf.add_format('RS', 1, 'Integer', 'Reference allele split-read observation count, with partial observations recorded fractionally') - vcf.add_format('AS', 'A', 'Integer', 'Alternate allele split-read observation count, with partial observations recorded fractionally') - vcf.add_format('RP', 1, 'Integer', 'Reference allele paired-end observation count, with partial observations recorded fractionally') - vcf.add_format('AP', 'A', 'Integer', 'Alternate allele paired-end observation count, with partial observations recorded fractionally') - vcf.add_format('AB', 'A', 'Float', 'Allele balance, fraction of observations from alternate allele, QA/(QR+QA)') + vcf.add_format( + 'RO', 1, 'Integer', 'Reference allele observation count, with partial observations recorded fractionally') + vcf.add_format( + 'AO', 'A', 'Integer', 'Alternate allele observations, with partial observations recorded fractionally') + vcf.add_format( + 'QR', 1, 'Integer', 'Sum of quality of reference observations') + vcf.add_format( + 'QA', 'A', 'Integer', 'Sum of quality of alternate observations') + vcf.add_format( + 'RS', 1, 'Integer', 'Reference allele split-read observation count, with partial observations recorded fractionally') + vcf.add_format( + 'AS', 'A', 'Integer', 'Alternate allele split-read observation count, with partial observations recorded fractionally') + vcf.add_format( + 'RP', 1, 'Integer', 'Reference allele paired-end observation count, with partial observations recorded fractionally') + vcf.add_format( + 'AP', 'A', 'Integer', 'Alternate allele paired-end observation count, with partial observations recorded fractionally') + vcf.add_format( + 'AB', 'A', 'Float', 'Allele balance, fraction of observations from alternate allele, QA/(QR+QA)') # write the output header if len(vcf_samples) > 0: @@ -329,7 +367,7 @@ def sv_genotype(vcf_file): var = Variant(v, vcf) # genotype generic breakends - if var.info['SVTYPE']=='BND': + if var.info['SVTYPE'] == 'BND': if var.info['MATEID'] in breakend_dict: var2 = var var = breakend_dict[var.info['MATEID']] @@ -338,16 +376,20 @@ def sv_genotype(vcf_file): posA = var.pos posB = var2.pos # confidence intervals - ciA = [posA + ci for ci in map(int, var.info['CIPOS'].split(','))] - ciB = [posB + ci for ci in map(int, var2.info['CIPOS'].split(','))] + ciA = [posA + ci for ci in map( + int, var.info['CIPOS'].split(','))] + ciB = [posB + ci for ci in map( + int, var2.info['CIPOS'].split(','))] # infer the strands from the alt allele if var.alt[-1] == '[' or var.alt[-1] == ']': o1 = '+' - else: o1 = '-' + else: + o1 = '-' if var2.alt[-1] == '[' or var2.alt[-1] == ']': o2 = '+' - else: o2 = '-' + else: + o2 = '-' else: breakend_dict[var.var_id] = var continue @@ -360,23 +402,23 @@ def sv_genotype(vcf_file): ciA = [posA + ci for ci in map(int, var.info['CIPOS'].split(','))] ciB = [posB + ci for ci in map(int, var.info['CIEND'].split(','))] if var.get_info('SVTYPE') == 'DEL': - o1, o2 = '+', '-' + o1, o2 = '+', '-' elif var.get_info('SVTYPE') == 'DUP': - o1, o2 = '-', '+' + o1, o2 = '-', '+' elif var.get_info('SVTYPE') == 'INV': - o1, o2 = '+', '+' + o1, o2 = '+', '+' - # # increment the negative strand values (note position in VCF should be the base immediately left of the breakpoint junction) + # increment the negative strand values (note position in VCF should be the base immediately left of the breakpoint junction) # if o1 == '-': posA += 1 # if o2 == '-': posB += 1 - # # if debug: print posA, posB + # if debug: print posA, posB - # # for i in xrange(len(bam_list)): + # for i in xrange(len(bam_list)): # for sample in sample_list: # ''' # Breakend A # ''' - # # Count splitters + # Count splitters # ref_counter_a = Counter() # spl_counter_a = Counter() # ref_scaled_counter_a = Counter() @@ -391,15 +433,16 @@ def sv_genotype(vcf_file): # for spl_read in sample.spl_bam.fetch(chromA, max(posA - padding, 0), posA + padding + 1): # if not spl_read.is_duplicate and not spl_read.is_unmapped: # if o1 == '+' and spl_read.cigar[0][0] == 0: - # # if debug: print 'o1+', spl_read.aend + # if debug: print 'o1+', spl_read.aend # spl_counter_a[spl_read.aend] += 1 # spl_scaled_counter_a[spl_read.aend] += (1-10**(-spl_read.mapq/10.0)) # elif o1 == '-' and spl_read.cigar[-1][0] == 0: - # # if debug: print 'o1-', spl_read.pos + 1 + # if debug: print 'o1-', spl_read.pos + 1 # spl_counter_a[spl_read.pos + 1] += 1 - # spl_scaled_counter_a[spl_read.pos + 1] += (1-10**(-spl_read.mapq/10.0)) + # spl_scaled_counter_a[spl_read.pos + 1] += + # (1-10**(-spl_read.mapq/10.0)) - # # Count paired-end discordant and concordants + # Count paired-end discordant and concordants # (conc_counter_a, # disc_counter_a, # conc_scaled_counter_a, @@ -412,7 +455,7 @@ def sv_genotype(vcf_file): # ''' # Breakend B # ''' - # # Count splitters + # Count splitters # ref_counter_b = Counter() # spl_counter_b = Counter() # ref_scaled_counter_b = Counter() @@ -428,25 +471,28 @@ def sv_genotype(vcf_file): # if not spl_read.is_duplicate and not spl_read.is_unmapped: # if o2 == '+' and spl_read.cigar[0][0] == 0: # spl_counter_b[spl_read.aend] += 1 - # # if debug: print 'o2+', spl_read.aend + # if debug: print 'o2+', spl_read.aend # spl_scaled_counter_b[spl_read.aend] += (1-10**(-spl_read.mapq/10.0)) # elif o2 == '-' and spl_read.cigar[-1][0] == 0: - # # if debug: print 'o2-', spl_read.pos + 1 + # if debug: print 'o2-', spl_read.pos + 1 # spl_counter_b[spl_read.pos + 1] += 1 - # spl_scaled_counter_b[spl_read.pos + 1] += (1-10**(-spl_read.mapq/10.0)) - - # # tally up the splitters + # spl_scaled_counter_b[spl_read.pos + 1] += + # (1-10**(-spl_read.mapq/10.0)) + + # tally up the splitters # sr_ref_a = int(round(sum(ref_counter_a[p] for p in xrange(posA - split_slop, posA + split_slop + 1)) / float(2 * split_slop + 1))) # sr_spl_a = sum(spl_counter_a[p] for p in xrange(posA-split_slop, posA+split_slop + 1)) # sr_ref_b = int(round(sum(ref_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) / float(2 * split_slop + 1))) - # sr_spl_b = sum(spl_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) + # sr_spl_b = sum(spl_counter_b[p] for p in xrange(posB - split_slop, + # posB + split_slop + 1)) # sr_ref_scaled_a = sum(ref_scaled_counter_a[p] for p in xrange(posA - split_slop, posA + split_slop + 1)) / float(2 * split_slop + 1) # sr_spl_scaled_a = sum(spl_scaled_counter_a[p] for p in xrange(posA-split_slop, posA+split_slop + 1)) # sr_ref_scaled_b = sum(ref_scaled_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) / float(2 * split_slop + 1) - # sr_spl_scaled_b = sum(spl_scaled_counter_b[p] for p in xrange(posB - split_slop, posB + split_slop + 1)) + # sr_spl_scaled_b = sum(spl_scaled_counter_b[p] for p in xrange(posB - + # split_slop, posB + split_slop + 1)) - # # Count paired-end discordants and concordants + # Count paired-end discordants and concordants # (conc_counter_b, # disc_counter_b, # conc_scaled_counter_b, @@ -466,10 +512,11 @@ def sv_genotype(vcf_file): # print 'sr_a_scaled', '(ref, alt)', sr_ref_scaled_a, sr_spl_scaled_a # print 'pe_a_scaled', '(ref, alt)', conc_scaled_counter_a, disc_scaled_counter_a # print 'sr_b_scaled', '(ref, alt)', sr_ref_scaled_b, sr_spl_scaled_b - # print 'pe_b_scaled', '(ref, alt)', conc_scaled_counter_b, disc_scaled_counter_b + # print 'pe_b_scaled', '(ref, alt)', conc_scaled_counter_b, + # disc_scaled_counter_b - # # merge the breakend support - # split_ref = 0 # set these to zero unless there are informative alt bases for the ev type + # merge the breakend support + # split_ref = 0 # set these to zero unless there are informative alt bases for the ev type # disc_ref = 0 # split_alt = sr_spl_a + sr_spl_b # if split_alt > 0: @@ -481,7 +528,7 @@ def sv_genotype(vcf_file): # split_ref = sr_ref_a + sr_ref_b # disc_ref = conc_counter_a + conc_counter_b - # split_scaled_ref = 0 # set these to zero unless there are informative alt bases for the ev type + # split_scaled_ref = 0 # set these to zero unless there are informative alt bases for the ev type # disc_scaled_ref = 0 # split_scaled_alt = sr_spl_scaled_a + sr_spl_scaled_b # if int(split_scaled_alt) > 0: @@ -489,33 +536,33 @@ def sv_genotype(vcf_file): # disc_scaled_alt = disc_scaled_counter_a + disc_scaled_counter_b # if int(disc_scaled_alt) > 0: # disc_scaled_ref = conc_scaled_counter_a + conc_scaled_counter_b - # if int(split_scaled_alt) == 0 and int(disc_scaled_alt) == 0: # if no alt alleles, set reference + # if int(split_scaled_alt) == 0 and int(disc_scaled_alt) == 0: # if no alt alleles, set reference # split_scaled_ref = sr_ref_scaled_a + sr_ref_scaled_b - # disc_scaled_ref = conc_scaled_counter_a + conc_scaled_counter_b + # disc_scaled_ref = conc_scaled_counter_a + conc_scaled_counter_b # if split_scaled_alt + split_scaled_ref + disc_scaled_alt + disc_scaled_ref > 0: - # # get bayesian classifier + # get bayesian classifier # if var.info['SVTYPE'] == "DUP": is_dup = True # else: is_dup = False # gt_lplist = bayes_gt(int(split_weight * split_scaled_ref) + int(disc_weight * disc_scaled_ref), int(split_weight * split_scaled_alt) + int(disc_weight * disc_scaled_alt), is_dup) # gt_idx = gt_lplist.index(max(gt_lplist)) - # # print log probabilities of homref, het, homalt + # print log probabilities of homref, het, homalt # if debug: # print gt_lplist - # # set the overall variant QUAL score and sample specific fields + # set the overall variant QUAL score and sample specific fields # var.genotype(sample.name).set_format('GL', ','.join(['%.0f' % x for x in gt_lplist])) # var.genotype(sample.name).set_format('DP', int(split_scaled_ref + split_scaled_alt + disc_scaled_ref + disc_scaled_alt)) # var.genotype(sample.name).set_format('AO', int(split_scaled_alt + disc_scaled_alt)) # var.genotype(sample.name).set_format('RO', int(split_scaled_ref + disc_scaled_ref)) - # # if detailed: + # if detailed: # var.genotype(sample.name).set_format('AS', int(split_scaled_alt)) # var.genotype(sample.name).set_format('RS', int(split_scaled_ref)) # var.genotype(sample.name).set_format('AP', int(disc_scaled_alt)) - # var.genotype(sample.name).set_format('RP', int(disc_scaled_ref)) + # var.genotype(sample.name).set_format('RP', int(disc_scaled_ref)) - # # assign genotypes + # assign genotypes # gt_sum = 0 # for gt in gt_lplist: # try: @@ -524,9 +571,9 @@ def sv_genotype(vcf_file): # gt_sum += 0 # if gt_sum > 0: # gt_sum_log = math.log(gt_sum, 10) - # sample_qual = abs(-10 * (gt_lplist[0] - gt_sum_log)) # phred-scaled probability site is non-reference in this sample + # sample_qual = abs(-10 * (gt_lplist[0] - gt_sum_log)) # phred-scaled probability site is non-reference in this sample # if 1 - (10**gt_lplist[gt_idx] / 10**gt_sum_log) == 0: - # phred_gq = 200 + # phred_gq = 200 # else: # phred_gq = abs(-10 * math.log(1 - (10**gt_lplist[gt_idx] / 10**gt_sum_log), 10)) # var.genotype(sample.name).set_format('GQ', phred_gq) @@ -550,7 +597,7 @@ def sv_genotype(vcf_file): # var.genotype(sample.name).set_format('DP', 0) # var.genotype(sample.name).set_format('AO', 0) # var.genotype(sample.name).set_format('RO', 0) - # # if detailed: + # if detailed: # var.genotype(sample.name).set_format('AS', 0) # var.genotype(sample.name).set_format('RS', 0) # var.genotype(sample.name).set_format('AP', 0) @@ -564,12 +611,13 @@ def sv_genotype(vcf_file): var2.genotype = var.genotype vcf_out.write(var2.get_var_string() + '\n') vcf_out.close() - + return # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -584,6 +632,6 @@ def main(): if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE - raise + raise diff --git a/scripts/vcf_modify_header.py b/scripts/vcf_modify_header.py index 888ed22..72a1c91 100755 --- a/scripts/vcf_modify_header.py +++ b/scripts/vcf_modify_header.py @@ -1,8 +1,11 @@ #!/usr/bin/env python import pysam -import argparse, sys -import math, time, re +import argparse +import sys +import math +import time +import re from collections import Counter from argparse import RawTextHelpFormatter @@ -13,18 +16,25 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ vcf_modify_header.py\n\ author: " + __author__ + "\n\ version: " + __version__ + "\n\ description: Add or remove lines from the header") - parser.add_argument('-i', '--id', dest='vcf_id', metavar='STR', type=str, required=False, help='field id') - parser.add_argument('-c', '--category', dest='category', metavar='STR', type=str, help='INFO, FORMAT, FILTER') - parser.add_argument('-t', '--type', dest='type', metavar='STR', type=str, required=False, help='Number, String, Float, Integer') - parser.add_argument('-n', '--number', dest='number', metavar='STR', type=str, required=False, help='integer, A, R, G, or .') - parser.add_argument('-d', '--description', dest='description', metavar='STR', type=str, required=False, help='description') - parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') + parser.add_argument('-i', '--id', dest='vcf_id', + metavar='STR', type=str, required=False, help='field id') + parser.add_argument('-c', '--category', dest='category', + metavar='STR', type=str, help='INFO, FORMAT, FILTER') + parser.add_argument('-t', '--type', dest='type', metavar='STR', + type=str, required=False, help='Number, String, Float, Integer') + parser.add_argument('-n', '--number', dest='number', metavar='STR', + type=str, required=False, help='integer, A, R, G, or .') + parser.add_argument('-d', '--description', dest='description', + metavar='STR', type=str, required=False, help='description') + parser.add_argument(metavar='vcf', dest='input_vcf', nargs='?', + type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') # parse the arguments args = parser.parse_args() @@ -39,7 +49,9 @@ def get_args(): # send back the user input return args + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -58,19 +70,19 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FILTER': - a = line[line.find('<')+1:-2] + a = line[line.find('<') + 1:-2] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_filter(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -81,11 +93,11 @@ def get_header(self, include_samples=True): if include_samples: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [f.hstring for f in self.filter_list] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [f.hstring for f in self.filter_list] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -95,17 +107,17 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ - self.sample_list - )]) + 'FORMAT'] + + self.sample_list + )]) else: header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [f.hstring for f in self.filter_list] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ + '##reference=' + self.reference] + + [f.hstring for f in self.filter_list] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + ['\t'.join([ '#CHROM', 'POS', @@ -115,7 +127,7 @@ def get_header(self, include_samples=True): 'QUAL', 'FILTER', 'INFO'] - )]) + )]) return header def add_info(self, id, number, type, desc): @@ -153,6 +165,7 @@ def sample_to_col(self, sample): return self.sample_list.index(sample) + 9 class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -161,18 +174,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -181,18 +198,23 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' class Filter(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FILTER=' + self.hstring = '##FILTER=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -210,10 +232,11 @@ def __init__(self, var_list, vcf): self.format_list = vcf.format_list self.active_formats = list() self.gts = dict() - + # fill in empty sample genotypes if len(var_list) < 8: - sys.stderr.write('\nError: VCF file must have at least 8 columns\n') + sys.stderr.write( + '\nError: VCF file must have at least 8 columns\n') exit(1) if len(var_list) < 9: var_list.append("GT") @@ -230,7 +253,8 @@ def __init__(self, var_list, vcf): self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -240,7 +264,8 @@ def set_info(self, field, value): if field in [i.id for i in self.info_list]: self.info[field] = value else: - sys.stderr.write('\nError: invalid INFO field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid INFO field, \"' + field + '\"\n') exit(1) def get_info(self, field): @@ -249,11 +274,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append( + '%s=%s' % (info_field.id, self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -267,10 +293,11 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('\nError: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + '\nError: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -280,11 +307,14 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -296,9 +326,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('\nError: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + '\nError: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -314,9 +346,11 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # primary function + + def mod_header(vcf_id, category, f_type, @@ -325,7 +359,8 @@ def mod_header(vcf_id, vcf_file): in_header = True header = [] - breakend_dict = {} # cache to hold unmatched generic breakends for genotyping + breakend_dict = {} + # cache to hold unmatched generic breakends for genotyping vcf = Vcf() vcf_out = sys.stdout @@ -333,7 +368,7 @@ def mod_header(vcf_id, for line in vcf_file: if in_header: if line[0] == '#': - header.append(line) + header.append(line) if line[1] != '#': vcf_samples = line.rstrip().split('\t')[9:] continue @@ -355,12 +390,13 @@ def mod_header(vcf_id, vcf_out.write(vcf.get_header(include_samples=False) + '\n') vcf_out.write(line) vcf_out.close() - + return # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -380,6 +416,6 @@ def main(): if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE - raise + raise diff --git a/scripts/vcf_paste.py b/scripts/vcf_paste.py index 0558b8d..1c780c1 100755 --- a/scripts/vcf_paste.py +++ b/scripts/vcf_paste.py @@ -1,6 +1,7 @@ #!/usr/bin/env python -import argparse, sys +import argparse +import sys from argparse import RawTextHelpFormatter import gzip @@ -11,6 +12,7 @@ # -------------------------------------- # define functions + def get_args(): parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ vcf_paste.py\n\ @@ -19,13 +21,21 @@ def get_args(): description: Paste VCFs from multiple samples") # parser.add_argument('-a', '--argA', metavar='argA', type=str, required=True, help='description of argument') # parser.add_argument('-b', '--argB', metavar='argB', required=False, help='description of argument B') - # parser.add_argument('-c', '--flagC', required=False, action='store_true', help='sets flagC to true') - parser.add_argument('-m', '--master', type=argparse.FileType('r'), default=None, help='VCF file to set first 8 columns of variant info [first file in vcf_list]') - parser.add_argument('-q', '--sum_quals', required=False, action='store_true', help='Sum QUAL scores of input VCFs as output QUAL score') - # parser.add_argument('-s', '--safe', action='store_true', help='Check to ensure the variant positions match the master. Safe, but slower.') - parser.add_argument('-f', '--vcf_list', required=True, help='Line-delimited list of VCF files to paste') - - # parser.add_argument('vcf_list', metavar='vcf', nargs='*', type=argparse.FileType('r'), default=None, help='VCF file(s) to join') + # parser.add_argument('-c', '--flagC', required=False, + # action='store_true', help='sets flagC to true') + parser.add_argument( + '-m', '--master', type=argparse.FileType('r'), default=None, + help='VCF file to set first 8 columns of variant info [first file in vcf_list]') + parser.add_argument( + '-q', '--sum_quals', required=False, action='store_true', + help='Sum QUAL scores of input VCFs as output QUAL score') + # parser.add_argument('-s', '--safe', action='store_true', help='Check to + # ensure the variant positions match the master. Safe, but slower.') + parser.add_argument('-f', '--vcf_list', required=True, + help='Line-delimited list of VCF files to paste') + + # parser.add_argument('vcf_list', metavar='vcf', nargs='*', + # type=argparse.FileType('r'), default=None, help='VCF file(s) to join') # parse the arguments args = parser.parse_args() @@ -38,8 +48,10 @@ def get_args(): return args # primary function + + def svt_join(master, sum_quals, vcf_list): - max_split=9 + max_split = 9 # if master not provided, set as first VCF if master is None: @@ -59,7 +71,7 @@ def svt_join(master, sum_quals, vcf_list): break if master_line[:2] != "##": break - print (master_line.rstrip()) + print(master_line.rstrip()) # get sample names out_v = master_line.rstrip().split('\t', max_split)[:9] @@ -74,17 +86,17 @@ def svt_join(master, sum_quals, vcf_list): line_v = line.rstrip().split('\t', max_split) out_v = out_v + line_v[9:] break - sys.stdout.write( '\t'.join(map(str, out_v)) + '\n') - + sys.stdout.write('\t'.join(map(str, out_v)) + '\n') + # iterate through VCF body while 1: master_line = master.readline() if not master_line: break master_v = master_line.rstrip().split('\t', max_split) - out_v = master_v[:8] # output array of fields + out_v = master_v[:8] # output array of fields qual = float(out_v[5]) - format = None # column 9, VCF format field. + format = None # column 9, VCF format field. for vcf in vcf_list: line = vcf.readline() @@ -104,18 +116,19 @@ def svt_join(master, sum_quals, vcf_list): out_v = out_v + line_v[9:] if sum_quals: out_v[5] = qual - sys.stdout.write( '\t'.join(map(str, out_v)) + '\n') - + sys.stdout.write('\t'.join(map(str, out_v)) + '\n') + # close files master.close() for vcf in vcf_list: vcf.close() - + return # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -128,16 +141,18 @@ def main(): vcf_list.append(gzip.open(path, 'rb')) else: vcf_list.append(open(path, 'r')) - - # vcf_list = [open(line.rstrip('\n'), 'r') for line in open(args.vcf_list, 'r')] + + # vcf_list = [open(line.rstrip('\n'), 'r') for line in open(args.vcf_list, + # 'r')] # call primary function svt_join(args.master, args.sum_quals, vcf_list) + # initialize the script if __name__ == '__main__': try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE raise diff --git a/setup.py b/setup.py index cd07472..2314902 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ setup_requires=['pytest-runner'], tests_require=['pytest'], install_requires=[ - 'pysam>=0.12.0', + 'pysam>=0.12', 'numpy', 'scipy', 'cytoolz>=0.8.2', @@ -41,6 +41,7 @@ [console_scripts] svtyper=svtyper.classic:cli svtyper-sso=svtyper.singlesample:cli + svtyper-pms=svtyper.parallelsample:cli ''', packages=find_packages(exclude=('tests', 'etc')), include_package_data=True, diff --git a/svtyper/classic.py b/svtyper/classic.py index 8c79da4..0b2ebfb 100755 --- a/svtyper/classic.py +++ b/svtyper/classic.py @@ -11,6 +11,8 @@ from svtyper.parsers import Vcf, Variant, Sample from svtyper.utils import * +from svtyper.utils import SamFragment, write_sample_json, prob_mapq +from svtyper.utils import write_alignment from svtyper.statistics import bayes_gt # -------------------------------------- @@ -22,47 +24,79 @@ def get_args(): author: " + svtyper.version.__author__ + "\n\ version: " + svtyper.version.__version__ + "\n\ description: Compute genotype of structural variants based on breakpoint depth") - parser.add_argument('-i', '--input_vcf', metavar='FILE', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') - parser.add_argument('-o', '--output_vcf', metavar='FILE', type=argparse.FileType('w'), default=sys.stdout, help='output VCF to write (default: stdout)') - parser.add_argument('-B', '--bam', metavar='FILE', type=str, required=True, help='BAM or CRAM file(s), comma-separated if genotyping multiple samples') - parser.add_argument('-T', '--ref_fasta', metavar='FILE', type=str, required=False, default=None, help='Indexed reference FASTA file (recommended for reading CRAM files)') - parser.add_argument('-S', '--split_bam', type=str, required=False, help=argparse.SUPPRESS) - parser.add_argument('-l', '--lib_info', metavar='FILE', dest='lib_info_path', type=str, required=False, default=None, help='create/read JSON file of library information') - parser.add_argument('-m', '--min_aligned', metavar='INT', type=int, required=False, default=20, help='minimum number of aligned bases to consider read as evidence [20]') - parser.add_argument('-n', dest='num_samp', metavar='INT', type=int, required=False, default=1000000, help='number of reads to sample from BAM file for building insert size distribution [1000000]') - parser.add_argument('-q', '--sum_quals', action='store_true', required=False, help='add genotyping quality to existing QUAL (default: overwrite QUAL field)') - parser.add_argument('--max_reads', metavar='INT', type=int, default=None, required=False, help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: unlimited)') - parser.add_argument('--split_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for split reads [1]') - parser.add_argument('--disc_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for discordant paired-end reads [1]') - parser.add_argument('-w', '--write_alignment', metavar='FILE', dest='alignment_outpath', type=str, required=False, default=None, help='write relevant reads to BAM file') - parser.add_argument('--debug', action='store_true', help=argparse.SUPPRESS) - parser.add_argument('--verbose', action='store_true', default=False, help='Report status updates') + parser.add_argument('-i', '--input_vcf', metavar='FILE', + type=argparse.FileType('r'), default=None, + help='VCF input (default: stdin)') + parser.add_argument('-o', '--output_vcf', metavar='FILE', + type=argparse.FileType('w'), default=sys.stdout, + help='output VCF to write (default: stdout)') + parser.add_argument('-B', '--bam', metavar='FILE', + type=str, required=True, + help='BAM or CRAM file(s), comma-separated if genotyping multiple samples') + parser.add_argument('-T', '--ref_fasta', metavar='FILE', + type=str, required=False, default=None, + help='Indexed reference FASTA file (recommended for reading CRAM files)') + parser.add_argument('-S', '--split_bam', + type=str, required=False, help=argparse.SUPPRESS) + parser.add_argument('-l', '--lib_info', metavar='FILE', dest='lib_info_path', + type=str, required=False, default=None, + help='create/read JSON file of library information') + parser.add_argument('-m', '--min_aligned', metavar='INT', + type=int, required=False, default=20, + help='minimum number of aligned bases to consider read as evidence [20]') + parser.add_argument('-n', dest='num_samp', metavar='INT', + type=int, required=False, default=1000000, + help='number of reads to sample from BAM file for building insert size distribution [1000000]') + parser.add_argument('-q', '--sum_quals', action='store_true', required=False, + help='add genotyping quality to existing QUAL (default: overwrite QUAL field)') + parser.add_argument('--max_reads', metavar='INT', + type=int, default=None, required=False, + help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: unlimited)') + parser.add_argument('--split_weight', metavar='FLOAT', + type=float, required=False, default=1, + help='weight for split reads [1]') + parser.add_argument('--disc_weight', metavar='FLOAT', + type=float, required=False, default=1, + help='weight for discordant paired-end reads [1]') + parser.add_argument('-w', '--write_alignment', metavar='FILE', + dest='alignment_outpath', + type=str, required=False, default=None, + help='write relevant reads to BAM file') + parser.add_argument('--debug', action='store_true', + help=argparse.SUPPRESS) + parser.add_argument('--verbose', action='store_true', default=False, + help='Report status updates') # parse the arguments args = parser.parse_args() # if no input, check if part of pipe and if so, read stdin. - if args.input_vcf == None: + if args.input_vcf is None: if not sys.stdin.isatty(): args.input_vcf = sys.stdin # send back the user input return args + # methods to grab reads from region of interest in BAM file -def gather_all_reads(sample, chromA, posA, ciA, chromB, posB, ciB, z, max_reads): +def gather_all_reads(sample, chromA, posA, ciA, + chromB, posB, ciB, z, max_reads): # grab batch of reads from both sides of breakpoint read_batch = {} - read_batch, many = gather_reads(sample, chromA, posA, ciA, z, read_batch, max_reads) + read_batch, many = gather_reads(sample, chromA, posA, ciA, + z, read_batch, max_reads) if many: return {}, True - read_batch, many = gather_reads(sample, chromB, posB, ciB, z, read_batch, max_reads) + read_batch, many = gather_reads(sample, chromB, posB, ciB, + z, read_batch, max_reads) if many: return {}, True return read_batch, many + def gather_reads(sample, chrom, pos, ci, z, @@ -76,8 +110,8 @@ def gather_reads(sample, many = False for i, read in enumerate(sample.bam.fetch(chrom, - max(pos + ci[0] - fetch_flank, 0), - min(pos + ci[1] + fetch_flank, chrom_length))): + max(pos + ci[0] - fetch_flank, 0), + min(pos + ci[1] + fetch_flank, chrom_length))): if read.is_unmapped or read.is_duplicate: continue @@ -123,12 +157,13 @@ def sv_genotype(bam_string, if b.endswith('.bam'): bam_list.append(pysam.AlignmentFile(b, mode='rb')) elif b.endswith('.cram'): - bam_list.append(pysam.AlignmentFile(b, mode='rc', reference_filename=ref_fasta)) + bam_list.append(pysam.AlignmentFile(b, mode='rc', + reference_filename=ref_fasta)) else: sys.stderr.write('Error: %s is not a valid alignment file (*.bam or *.cram)\n' % b) exit(1) - - min_lib_prevalence = 1e-3 # only consider libraries that constitute at least this fraction of the BAM + + min_lib_prevalence = 1e-3 # only consider libraries that constitute at least this fraction of the BAM # parse lib_info_path JSON lib_info = None @@ -141,7 +176,7 @@ def sv_genotype(bam_string, # build the sample libraries, either from the lib_info JSON or empirically from the BAMs sample_list = list() - for i in xrange(len(bam_list)): + for i in range(len(bam_list)): if lib_info is None: logging.info('Calculating library metrics from %s...' % bam_list[i].filename) sample = Sample.from_bam(bam_list[i], num_samp, min_lib_prevalence) @@ -178,10 +213,10 @@ def sv_genotype(bam_string, # set variables for genotyping z = 3 - split_slop = 3 # amount of slop around breakpoint to count splitters + split_slop = 3 # amount of slop around breakpoint to count splitters in_header = True header = [] - breakend_dict = {} # cache to hold unmatched generic breakends for genotyping + breakend_dict = {} # cache to hold unmatched generic breakends for genotyping vcf = Vcf() # read input VCF @@ -206,10 +241,9 @@ def sv_genotype(bam_string, # write the output header vcf_out.write(vcf.get_header() + '\n') - v = line.rstrip().split('\t') var = Variant(v, vcf) - var_length = None # var_length should be None except for deletions + var_length = None # var_length should be None except for deletions if not sum_quals: var.qual = 0 @@ -220,7 +254,7 @@ def sv_genotype(bam_string, sys.stderr.write('Warning: SVTYPE missing at variant %s. Skipping.\n' % (var.var_id)) vcf_out.write(var.get_var_string() + '\n') continue - + # print original line if unsupported svtype if svtype not in ('BND', 'DEL', 'DUP', 'INV'): sys.stderr.write('Warning: Unsupported SVTYPE at variant %s (%s). Skipping.\n' % (var.var_id, svtype)) @@ -236,16 +270,18 @@ def sv_genotype(bam_string, posA = var.pos posB = var2.pos # confidence intervals - ciA = map(int, var.info['CIPOS'].split(',')) - ciB = map(int, var2.info['CIPOS'].split(',')) + ciA = list(map(int, var.info['CIPOS'].split(','))) + ciB = list(map(int, var2.info['CIPOS'].split(','))) # infer the strands from the alt allele if var.alt[-1] == '[' or var.alt[-1] == ']': o1_is_reverse = False - else: o1_is_reverse = True + else: + o1_is_reverse = True if var2.alt[-1] == '[' or var2.alt[-1] == ']': o2_is_reverse = False - else: o2_is_reverse = True + else: + o2_is_reverse = True # remove the BND from the breakend_dict # to free up memory @@ -259,23 +295,26 @@ def sv_genotype(bam_string, posA = var.pos posB = int(var.get_info('END')) # confidence intervals - ciA = map(int, var.info['CIPOS'].split(',')) - ciB = map(int, var.info['CIEND'].split(',')) + ciA = list(map(int, var.info['CIPOS'].split(','))) + ciB = list(map(int, var.info['CIEND'].split(','))) if svtype == 'DEL': var_length = posB - posA - o1_is_reverse, o2_is_reverse = False, True + o1_is_reverse, o2_is_reverse = False, True elif svtype == 'DUP': - o1_is_reverse, o2_is_reverse = True, False + o1_is_reverse, o2_is_reverse = True, False elif svtype == 'INV': o1_is_reverse, o2_is_reverse = False, False # increment the negative strand values (note position in VCF should be the base immediately left of the breakpoint junction) - if o1_is_reverse: posA += 1 - if o2_is_reverse: posB += 1 + if o1_is_reverse: + posA += 1 + if o2_is_reverse: + posB += 1 for sample in sample_list: # grab reads from both sides of breakpoint - read_batch, many = gather_all_reads(sample, chromA, posA, ciA, chromB, posB, ciB, z, max_reads) + read_batch, many = gather_all_reads(sample, chromA, posA, ciA, + chromB, posB, ciB, z, max_reads) if many: var.genotype(sample.name).set_format('GT', './.') continue @@ -287,8 +326,8 @@ def sv_genotype(bam_string, # ref_ciA = ciA # ref_ciB = ciB - ref_ciA = [0,0] - ref_ciB = [0,0] + ref_ciA = [0, 0] + ref_ciB = [0, 0] for query_name in sorted(read_batch.keys()): fragment = read_batch[query_name] @@ -301,8 +340,10 @@ def sv_genotype(bam_string, # get reference sequences for read in fragment.primary_reads: - is_ref_seq_A = fragment.is_ref_seq(read, var, chromA, posA, ciA, min_aligned) - is_ref_seq_B = fragment.is_ref_seq(read, var, chromB, posB, ciB, min_aligned) + is_ref_seq_A = fragment.is_ref_seq(read, var, + chromA, posA, ciA, min_aligned) + is_ref_seq_B = fragment.is_ref_seq(read, var, + chromB, posB, ciB, min_aligned) if (is_ref_seq_A or is_ref_seq_B): p_reference = prob_mapq(read) ref_seq += p_reference @@ -318,7 +359,8 @@ def sv_genotype(bam_string, o1_is_reverse, o2_is_reverse, svtype, split_slop) # p_alt = prob_mapq(split.query_left) * prob_mapq(split.query_right) - p_alt = (prob_mapq(split.query_left) * split_lr[0] + prob_mapq(split.query_right) * split_lr[1]) / 2.0 + p_alt = (prob_mapq(split.query_left) * split_lr[0] + + prob_mapq(split.query_right) * split_lr[1]) / 2.0 if split.is_soft_clip: alt_clip += p_alt else: @@ -338,7 +380,8 @@ def sv_genotype(bam_string, else: alt_straddle = fragment.is_pair_straddle(chromA, posA, ciA, chromB, posB, ciB, - o1_is_reverse, o2_is_reverse, + o1_is_reverse, + o2_is_reverse, min_aligned, fragment.lib) @@ -359,7 +402,7 @@ def sv_genotype(bam_string, if p_conc is not None: p_alt = (1 - p_conc) * prob_mapq(fragment.readA) * prob_mapq(fragment.readB) alt_span += p_alt - + # # since an alt straddler is by definition also a reference straddler, # # we can bail out early here to save some time # p_reference = p_conc * prob_mapq(fragment.readA) * prob_mapq(fragment.readB) @@ -405,17 +448,17 @@ def sv_genotype(bam_string, write_fragment = True # write to BAM if requested - if alignment_outpath is not None and write_fragment: + if alignment_outpath is not None and write_fragment: for read in fragment.primary_reads + [split.read for split in fragment.split_reads]: out_bam_written_reads = write_alignment(read, out_bam, out_bam_written_reads) if debug: - print '--------------------------' - print 'ref_span:', ref_span - print 'alt_span:', alt_span - print 'ref_seq:', ref_seq - print 'alt_seq:', alt_seq - print 'alt_clip:', alt_clip + print('--------------------------') + print('ref_span:', ref_span) + print('alt_span:', alt_span) + print('ref_seq:', ref_seq) + print('alt_seq:', alt_seq) + print('alt_clip:', alt_clip) # in the absence of evidence for a particular type, ignore the reference # support for that type as well @@ -440,12 +483,12 @@ def sv_genotype(bam_string, QR = int(split_weight * ref_seq) + int(disc_weight * ref_span) QA = int(split_weight * alt_splitters) + int(disc_weight * alt_span) gt_lplist = bayes_gt(QR, QA, is_dup) - best, second_best = sorted([ (i, e) for i, e in enumerate(gt_lplist) ], key=lambda(x): x[1], reverse=True)[0:2] + best, second_best = sorted([ (i, e) for i, e in enumerate(gt_lplist) ], key=lambda x: x[1], reverse=True)[0:2] gt_idx = best[0] # print log probabilities of homref, het, homalt if debug: - print gt_lplist + print(gt_lplist) # set the overall variant QUAL score and sample specific fields var.genotype(sample.name).set_format('GL', ','.join(['%.0f' % x for x in gt_lplist])) @@ -465,7 +508,6 @@ def sv_genotype(bam_string, except ZeroDivisionError: var.genotype(sample.name).set_format('AB', '.') - # assign genotypes gt_sum = 0 for gt in gt_lplist: @@ -529,6 +571,7 @@ def sv_genotype(bam_string, return + def set_up_logging(verbose): level = logging.WARNING if verbose: @@ -539,6 +582,7 @@ def set_up_logging(verbose): # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() @@ -563,16 +607,17 @@ def main(): args.sum_quals, args.max_reads) + # -------------------------------------- # command-line/console entrypoint - def cli(): try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE raise + # initialize the script if __name__ == '__main__': cli() diff --git a/svtyper/parallelsample.py b/svtyper/parallelsample.py new file mode 100755 index 0000000..0ddc0f8 --- /dev/null +++ b/svtyper/parallelsample.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 + +import sys +import gzip +import argparse +import shutil +import tempfile +from argparse import RawTextHelpFormatter +import os +import multiprocessing +from functools import partial +from subprocess import call +import random +import string + + +from pysam import AlignmentFile, tabix_compress, tabix_index +from svtyper import singlesample as single +from svtyper import version as svtyper_version + + +def get_args(): + parser = argparse.ArgumentParser(formatter_class=RawTextHelpFormatter, description="\ + svtyper\n\ + author: " + svtyper_version.__author__ + "\n\ + version: " + svtyper_version.__version__ + "\n\ + description: Compute genotype of structural variants based on breakpoint depth") + + parser.add_argument('-B', '--bam', + type=str, required=True, + help='BAM list files') + parser.add_argument('-i', '--input_vcf', + type=str, required=False, default=None, + help='VCF input') + parser.add_argument('-o', '--output_vcf', + type=str, required=False, default=None, + help='output VCF to write') + parser.add_argument('-t', '--threads', type=int, default=1, + help='number of threads to use (set at maximum number of available cores)') + + args = parser.parse_args() + + if args.input_vcf is None: + args.input_vcf = sys.stdin + if args.output_vcf is None: + args.output_vcf = sys.stdout + + return args + +def fetchId(bamfile): + """ + Fetch sample id in a bam file + :param bamfile: the bam file + :type bamfile: file + :return: sample name + :rtype: str + """ + bamfile_fin = AlignmentFile(bamfile, 'rb') + name = bamfile_fin.header['RG'][0]['SM'] + bamfile_fin.close() + return name + + +def get_bamfiles(bamlist): + with open(bamlist, "r") as fin: + bamfiles = [os.path.abspath(f.strip()) for f in fin] + return bamfiles + + +def random_string(length=10): + return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(length)) + + +def genotype_multiple_samples(bamlist, vcf_in, vcf_out, cores=1): + + bamfiles = get_bamfiles(bamlist) + + if vcf_out == sys.stdout: + tmp_dir = tempfile.mkdtemp() + else: + tmp_dir = os.path.join(os.path.dirname(vcf_out), "tmp" + random_string()) + os.mkdir(tmp_dir) + + if vcf_in == sys.stdin: + vcf_in = single.dump_piped_vcf_to_file(vcf_in, tmp_dir) + elif vcf_in.endswith(".gz"): + vcf_in_gz = vcf_in + vcf_in = os.path.join(tmp_dir, os.path.basename(vcf_in[:-3])) + with gzip.open(vcf_in_gz, 'rb') as f_in, open(vcf_in, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + + pool = multiprocessing.Pool(processes=cores) + launch_ind = partial(genotype_single_sample, + vcf_in=vcf_in, out_dir=tmp_dir) + + vcf_files = pool.map(launch_ind, bamfiles) + + merge_cmd = "bcftools merge -m id " + + if vcf_out != sys.stdout: + merge_cmd += "-O z " + " -o " + vcf_out + " " + + merge_cmd += " ".join(vcf_files) + + exit_code = call(merge_cmd, shell=True) + + if exit_code == 0: + if vcf_out != sys.stdout: + tabix_index(vcf_out, force=True, preset="vcf") + + shutil.rmtree(tmp_dir) + else: + print("Failed: bcftools merge exits with status %d" % exit_code) + exit(1) + + +def genotype_single_sample(bam, vcf_in, out_dir): + lib_info_json = bam + ".json" + sample = fetchId(bam) + out_vcf = os.path.join(out_dir, sample + ".gt.vcf") + with open(vcf_in, "r") as inf, open(out_vcf, "w") as outf: + single.sso_genotype(bam_string=bam, + vcf_in=inf, + vcf_out=outf, + min_aligned=20, + split_weight=1, + disc_weight=1, + num_samp=1000000, + lib_info_path=lib_info_json, + debug=False, + ref_fasta=None, + sum_quals=False, + max_reads=1000, + cores=None, + batch_size=None) + out_gz = out_vcf + ".gz" + tabix_compress(out_vcf, out_gz, force=True) + tabix_index(out_gz, force=True, preset="vcf") + return out_gz + + +def main(): + args = get_args() + + genotype_multiple_samples(args.bam, args.input_vcf, args.output_vcf, cores=args.threads) + + +def cli(): + try: + sys.exit(main()) + except IOError as e: + if e.errno != 32: # ignore SIGPIPE + raise + + +if __name__ == '__main__': + cli() diff --git a/svtyper/parsers.py b/svtyper/parsers.py index 208b45a..124d715 100644 --- a/svtyper/parsers.py +++ b/svtyper/parsers.py @@ -1,6 +1,9 @@ -from __future__ import print_function -import time, re, json, sys + +import time +import re +import json +import sys from collections import Counter from svtyper.statistics import mean, stdev, median, upper_mad @@ -9,7 +12,9 @@ # VCF parsing tools # ================================================== + class Vcf(object): + def __init__(self): self.file_format = 'VCFv4.2' # self.fasta = fasta @@ -26,19 +31,31 @@ def __init__(self): def add_custom_svtyper_headers(self): self.add_info('SVTYPE', 1, 'String', 'Type of structural variant') self.add_format('GQ', 1, 'Integer', 'Genotype quality') - self.add_format('SQ', 1, 'Float', 'Phred-scaled probability that this site is variant (non-reference in this sample') - self.add_format('GL', 'G', 'Float', 'Genotype Likelihood, log10-scaled likelihoods of the data given the called genotype for each possible genotype generated from the reference and alternate alleles given the sample ploidy') + self.add_format( + 'SQ', 1, 'Float', 'Phred-scaled probability that this site is variant (non-reference in this sample') + self.add_format( + 'GL', 'G', 'Float', 'Genotype Likelihood, log10-scaled likelihoods of the data given the called genotype for each possible genotype generated from the reference and alternate alleles given the sample ploidy') self.add_format('DP', 1, 'Integer', 'Read depth') - self.add_format('RO', 1, 'Integer', 'Reference allele observation count, with partial observations recorded fractionally') - self.add_format('AO', 'A', 'Integer', 'Alternate allele observations, with partial observations recorded fractionally') - self.add_format('QR', 1, 'Integer', 'Sum of quality of reference observations') - self.add_format('QA', 'A', 'Integer', 'Sum of quality of alternate observations') - self.add_format('RS', 1, 'Integer', 'Reference allele split-read observation count, with partial observations recorded fractionally') - self.add_format('AS', 'A', 'Integer', 'Alternate allele split-read observation count, with partial observations recorded fractionally') - self.add_format('ASC', 'A', 'Integer', 'Alternate allele clipped-read observation count, with partial observations recorded fractionally') - self.add_format('RP', 1, 'Integer', 'Reference allele paired-end observation count, with partial observations recorded fractionally') - self.add_format('AP', 'A', 'Integer', 'Alternate allele paired-end observation count, with partial observations recorded fractionally') - self.add_format('AB', 'A', 'Float', 'Allele balance, fraction of observations from alternate allele, QA/(QR+QA)') + self.add_format( + 'RO', 1, 'Integer', 'Reference allele observation count, with partial observations recorded fractionally') + self.add_format( + 'AO', 'A', 'Integer', 'Alternate allele observations, with partial observations recorded fractionally') + self.add_format( + 'QR', 1, 'Integer', 'Sum of quality of reference observations') + self.add_format( + 'QA', 'A', 'Integer', 'Sum of quality of alternate observations') + self.add_format( + 'RS', 1, 'Integer', 'Reference allele split-read observation count, with partial observations recorded fractionally') + self.add_format( + 'AS', 'A', 'Integer', 'Alternate allele split-read observation count, with partial observations recorded fractionally') + self.add_format( + 'ASC', 'A', 'Integer', 'Alternate allele clipped-read observation count, with partial observations recorded fractionally') + self.add_format( + 'RP', 1, 'Integer', 'Reference allele paired-end observation count, with partial observations recorded fractionally') + self.add_format( + 'AP', 'A', 'Integer', 'Alternate allele paired-end observation count, with partial observations recorded fractionally') + self.add_format( + 'AB', 'A', 'Float', 'Allele balance, fraction of observations from alternate allele, QA/(QR+QA)') def add_header(self, header): for line in header: @@ -47,15 +64,15 @@ def add_header(self, header): elif line.split('=')[0] == '##reference': self.reference = line.rstrip().split('=')[1] elif line.split('=')[0] == '##INFO': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_info(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##ALT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_alt(*[b.split('=')[1] for b in r.findall(a)]) elif line.split('=')[0] == '##FORMAT': - a = line[line.find('<')+1:line.find('>')] + a = line[line.find('<') + 1:line.find('>')] r = re.compile(r'(?:[^,\"]|\"[^\"]*\")+') self.add_format(*[b.split('=')[1] for b in r.findall(a)]) elif line[0] == '#' and line[1] != '#': @@ -69,11 +86,11 @@ def add_header(self, header): def get_header(self): header = '\n'.join(['##fileformat=' + self.file_format, '##fileDate=' + time.strftime('%Y%m%d'), - '##reference=' + self.reference] + \ - [i.hstring for i in self.info_list] + \ - [a.hstring for a in self.alt_list] + \ - [f.hstring for f in self.format_list] + \ - self.header_misc + \ + '##reference=' + self.reference] + + [i.hstring for i in self.info_list] + + [a.hstring for a in self.alt_list] + + [f.hstring for f in self.format_list] + + self.header_misc + ['\t'.join([ '#CHROM', 'POS', @@ -83,9 +100,9 @@ def get_header(self): 'QUAL', 'FILTER', 'INFO', - 'FORMAT'] + \ - self.sample_list - )]) + 'FORMAT'] + + self.sample_list + )]) return header def write_header(self, fd=None): @@ -129,22 +146,24 @@ def _get_bnd_breakpoints(variant): posA = var.pos posB = var2.pos # confidence intervals - ciA = map(int, var.info['CIPOS'].split(',')) - ciB = map(int, var2.info['CIPOS'].split(',')) + ciA = list(map(int, var.info['CIPOS'].split(','))) + ciB = list(map(int, var2.info['CIPOS'].split(','))) # infer the strands from the alt allele if var.alt[-1] == '[' or var.alt[-1] == ']': o1_is_reverse = False - else: o1_is_reverse = True + else: + o1_is_reverse = True if var2.alt[-1] == '[' or var2.alt[-1] == ']': o2_is_reverse = False - else: o2_is_reverse = True + else: + o2_is_reverse = True breakpoints = { - 'id' : var.var_id, - 'svtype' : 'BND', - 'A' : {'chrom': chromA, 'pos' : posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, - 'B' : {'chrom': chromB, 'pos' : posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, + 'id': var.var_id, + 'svtype': 'BND', + 'A': {'chrom': chromA, 'pos': posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, + 'B': {'chrom': chromB, 'pos': posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, } for k in ('A', 'B'): @@ -159,7 +178,7 @@ def _get_bnd_breakpoints(variant): else: bnd_cache[variant.var_id] = variant return None - + return _get_bnd_breakpoints @staticmethod @@ -169,31 +188,31 @@ def _default_get_breakpoints(variant): posA = variant.pos posB = int(variant.get_info('END')) # confidence intervals - ciA = map(int, variant.info['CIPOS'].split(',')) - ciB = map(int, variant.info['CIEND'].split(',')) + ciA = list(map(int, variant.info['CIPOS'].split(','))) + ciB = list(map(int, variant.info['CIEND'].split(','))) svtype = variant.get_svtype() if svtype == 'DEL': var_length = posB - posA - o1_is_reverse, o2_is_reverse = False, True + o1_is_reverse, o2_is_reverse = False, True elif svtype == 'DUP': - o1_is_reverse, o2_is_reverse = True, False + o1_is_reverse, o2_is_reverse = True, False elif svtype == 'INV': o1_is_reverse, o2_is_reverse = False, False if svtype != 'DEL': breakpoints = { - 'id' : variant.var_id, - 'svtype' : svtype, - 'A' : {'chrom': chromA, 'pos' : posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, - 'B' : {'chrom': chromB, 'pos' : posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, + 'id': variant.var_id, + 'svtype': svtype, + 'A': {'chrom': chromA, 'pos': posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, + 'B': {'chrom': chromB, 'pos': posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, } else: breakpoints = { - 'id' : variant.var_id, - 'svtype' : svtype, - 'var_length' : var_length, - 'A' : {'chrom': chromA, 'pos' : posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, - 'B' : {'chrom': chromB, 'pos' : posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, + 'id': variant.var_id, + 'svtype': svtype, + 'var_length': var_length, + 'A': {'chrom': chromA, 'pos': posA, 'ci': ciA, 'is_reverse': o1_is_reverse}, + 'B': {'chrom': chromB, 'pos': posB, 'ci': ciB, 'is_reverse': o2_is_reverse}, } for k in ('A', 'B'): @@ -217,6 +236,7 @@ def get_variant_breakpoints(self, variant): return breakpoints class Info(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -225,18 +245,22 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##INFO=' + self.hstring = '##INFO=' class Alt(object): + def __init__(self, id, desc): self.id = str(id) self.desc = str(desc) # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##ALT=' + self.hstring = '##ALT=' class Format(object): + def __init__(self, id, number, type, desc): self.id = str(id) self.number = str(number) @@ -245,9 +269,12 @@ def __init__(self, id, number, type, desc): # strip the double quotes around the string if present if self.desc.startswith('"') and self.desc.endswith('"'): self.desc = self.desc[1:-1] - self.hstring = '##FORMAT=' + self.hstring = '##FORMAT=' + class Variant(object): + def __init__(self, var_list, vcf): self.chrom = var_list[0] self.pos = int(var_list[1]) @@ -279,13 +306,15 @@ def __init__(self, var_list, vcf): s_gt = var_list[vcf.sample_to_col(s)].split(':')[0] self.gts[s] = Genotype(self, s, s_gt) # import the existing fmt fields - for j in zip(var_list[8].split(':'), var_list[vcf.sample_to_col(s)].split(':')): + for j in zip(var_list[8].split(':'), + var_list[vcf.sample_to_col(s)].split(':')): self.gts[s].set_format(j[0], j[1]) except IndexError: self.gts[s] = Genotype(self, s, './.') self.info = dict() - i_split = [a.split('=') for a in var_list[7].split(';')] # temp list of split info column + i_split = [a.split('=') + for a in var_list[7].split(';')] # temp list of split info column for i in i_split: if len(i) == 1: i.append(True) @@ -304,11 +333,12 @@ def get_info(self, field): def get_info_string(self): i_list = list() for info_field in self.info_list: - if info_field.id in self.info.keys(): + if info_field.id in list(self.info.keys()): if info_field.type == 'Flag': i_list.append(info_field.id) else: - i_list.append('%s=%s' % (info_field.id, self.info[info_field.id])) + i_list.append('%s=%s' % (info_field.id, + self.info[info_field.id])) return ';'.join(i_list) def get_format_string(self): @@ -322,10 +352,11 @@ def genotype(self, sample_name): if sample_name in self.sample_list: return self.gts[sample_name] else: - sys.stderr.write('Error: invalid sample name, \"' + sample_name + '\"\n') + sys.stderr.write( + 'Error: invalid sample name, \"' + sample_name + '\"\n') def get_var_string(self): - s = '\t'.join(map(str,[ + s = '\t'.join(map(str, [ self.chrom, self.pos, self.var_id, @@ -335,7 +366,8 @@ def get_var_string(self): self.filter, self.get_info_string(), self.get_format_string(), - '\t'.join(self.genotype(s).get_gt_string() for s in self.sample_list) + '\t'.join(self.genotype(s).get_gt_string() + for s in self.sample_list) ])) return s @@ -360,7 +392,9 @@ def is_valid_svtype(self): flag = svtype in ('BND', 'DEL', 'DUP', 'INV') return flag + class Genotype(object): + def __init__(self, variant, sample_name, gt): self.format = dict() self.variant = variant @@ -372,9 +406,11 @@ def set_format(self, field, value): if field not in self.variant.active_formats: self.variant.active_formats.append(field) # sort it to be in the same order as the format_list in header - self.variant.active_formats.sort(key=lambda x: [f.id for f in self.variant.format_list].index(x)) + self.variant.active_formats.sort( + key=lambda x: [f.id for f in self.variant.format_list].index(x)) else: - sys.stderr.write('Error: invalid FORMAT field, \"' + field + '\"\n') + sys.stderr.write( + 'Error: invalid FORMAT field, \"' + field + '\"\n') exit(1) def get_format(self, field): @@ -390,14 +426,16 @@ def get_gt_string(self): g_list.append(self.format[f]) else: g_list.append('.') - return ':'.join(map(str,g_list)) + return ':'.join(map(str, g_list)) # ================================================== # Library parsing # ================================================== + # holds a library's insert size and read length information class Library(object): + def __init__(self, name, bam, @@ -432,7 +470,8 @@ def __init__(self, if self.prevalence is None: self.calc_lib_prevalence() - # remove the uneeded attribute (only needed during object initialization) + # remove the uneeded attribute (only needed during object + # initialization) del self.bam def cleanup(self): @@ -450,7 +489,7 @@ def from_lib_info(cls, lib = lib_info[sample_name]['libraryArray'][lib_index] # convert the histogram keys to integers (from strings in JSON) - lib_hist = {int(k):int(v) for k,v in lib['histogram'].items()} + lib_hist = {int(k): int(v) for k, v in list(lib['histogram'].items())} return cls(lib['library_name'], bam, @@ -474,7 +513,7 @@ def from_bam(cls, for r in bam.header['RG']: try: in_lib = r['LB'] == lib_name - except KeyError, e: + except KeyError as e: in_lib = lib_name == '' if in_lib: @@ -514,6 +553,8 @@ def calc_read_length(self): for read in self.bam.fetch(): if read.get_tag('RG') not in self.readgroups: continue + if read.cigarstring is None: + continue if read.infer_query_length() > max_rl: max_rl = read.infer_query_length() if counter == num_samp: @@ -541,12 +582,12 @@ def calc_insert_hist(self): skip_counter += 1 continue if (read.is_reverse - or not read.mate_is_reverse - or read.is_unmapped - or read.mate_is_unmapped - or not self.is_primary(read) - or read.template_length <= 0 - or read.get_tag('RG') not in self.readgroups): + or not read.mate_is_reverse + or read.is_unmapped + or read.mate_is_unmapped + or not self.is_primary(read) + or read.template_length <= 0 + or read.get_tag('RG') not in self.readgroups): continue else: valueCounts[read.template_length] += 1 @@ -573,7 +614,7 @@ def calc_insert_hist(self): def calc_insert_density(self): dens = Counter() for i in list(self.hist): - dens[i] = float(self.hist[i])/self.countRecords(self.hist) + dens[i] = float(self.hist[i]) / self.countRecords(self.hist) self.dens = dens def countRecords(self, myCounter): @@ -584,9 +625,11 @@ def countRecords(self, myCounter): # Sample parsing # ================================================== + # holds each sample's BAM and library information class Sample(object): # general constructor + def __init__(self, name, bam, @@ -606,7 +649,7 @@ def __init__(self, # get active libraries self.active_libs = [] - for lib in lib_dict.values(): + for lib in list(lib_dict.values()): if lib.prevalence >= min_lib_prevalence: self.active_libs.append(lib.name) @@ -623,7 +666,7 @@ def from_lib_info(cls, lib_dict = {} try: - for i in xrange(len(lib_info[name]['libraryArray'])): + for i in range(len(lib_info[name]['libraryArray'])): lib = lib_info[name]['libraryArray'][i] lib_name = lib['library_name'] @@ -637,7 +680,8 @@ def from_lib_info(cls, for rg in lib['readgroups']: rg_to_lib[rg] = lib_dict[lib_name] except KeyError: - sys.stderr.write('Error: sample %s not found in JSON library file.\n' % name) + sys.stderr.write( + 'Error: sample %s not found in JSON library file.\n' % name) exit(1) return cls(name, @@ -662,9 +706,9 @@ def from_bam(cls, for r in bam.header['RG']: try: - lib_name=r['LB'] - except KeyError, e: - lib_name='' + lib_name = r['LB'] + except KeyError as e: + lib_name = '' # add the new library if lib_name not in lib_dict: @@ -683,7 +727,8 @@ def from_bam(cls, # get the maximum fetch flank for reading the BAM file def get_fetch_flank(self, z): - return max([lib.mean + (lib.sd * z) for lib in self.lib_dict.values()]) + return max([lib.mean + (lib.sd * z) + for lib in list(self.lib_dict.values())]) # return the library object for a specified read group def get_lib(self, readgroup): @@ -692,8 +737,10 @@ def get_lib(self, readgroup): # return the expected spanning coverage at any given base def set_exp_spanning_depth(self, min_aligned): genome_size = float(sum(self.bam.lengths)) - weighted_mean_span = sum([(lib.mean - 2 * lib.read_length + 2 * min_aligned) * lib.prevalence for lib in self.lib_dict.values()]) - exp_spanning_depth = (weighted_mean_span * self.bam_mapped) / genome_size + weighted_mean_span = sum( + [(lib.mean - 2 * lib.read_length + 2 * min_aligned) * lib.prevalence for lib in list(self.lib_dict.values())]) + exp_spanning_depth = ( + weighted_mean_span * self.bam_mapped) / genome_size self.exp_spanning_depth = exp_spanning_depth @@ -702,8 +749,10 @@ def set_exp_spanning_depth(self, min_aligned): # return the expected sequence coverage at any given base def set_exp_seq_depth(self, min_aligned): genome_size = float(sum(self.bam.lengths)) - weighted_mean_read_length = sum([(lib.read_length - 2 * min_aligned) * lib.prevalence for lib in self.lib_dict.values()]) - exp_seq_depth = (weighted_mean_read_length * self.bam_mapped) / genome_size + weighted_mean_read_length = sum( + [(lib.read_length - 2 * min_aligned) * lib.prevalence for lib in list(self.lib_dict.values())]) + exp_seq_depth = ( + weighted_mean_read_length * self.bam_mapped) / genome_size self.exp_seq_depth = exp_seq_depth @@ -712,12 +761,14 @@ def set_exp_seq_depth(self, min_aligned): def close(self): self.bam.close() + # ================================================== # Class for SAM fragment, containing all alignments # from a single molecule # ================================================== class SamFragment(object): + def __init__(self, read, lib): self.lib = lib self.primary_reads = [] @@ -806,7 +857,6 @@ def is_ref_seq(self, return True - # returns boolean of whether the primary pair of a # fragment straddles a genomic point def is_pair_straddle(self, @@ -866,11 +916,13 @@ def p_concordant(self, var_length=None): var_length = self.lib.mean + self.lib.sd * z try: - p = float(self.lib.dens[ospan_length]) * conc_prior / (conc_prior * self.lib.dens[ospan_length] + disc_prior * (self.lib.dens[ospan_length - var_length])) + p = float(self.lib.dens[ospan_length]) * conc_prior / ( + conc_prior * self.lib.dens[ospan_length] + disc_prior * (self.lib.dens[ospan_length - var_length])) except ZeroDivisionError: p = None - return p > 0.5 + return None if p is None else p > 0.5 + # ================================================== # Class for a split-read, containing all alignments @@ -880,6 +932,7 @@ def p_concordant(self, var_length=None): # each SplitRead object has a left and a right SplitPiece # (reads with more than 2 split alignments are discarded) class SplitRead(object): + def __init__(self, read, lib): self.query_name = read.query_name self.read = read @@ -891,6 +944,7 @@ def __init__(self, read, lib): # the piece of each split alignemnt class SplitPiece(object): + def __init__(self, chrom, reference_start, @@ -907,7 +961,8 @@ def __init__(self, self.right_query = None # get query positions - self.query_pos = self.get_query_pos_from_cigar(self.cigar, self.is_reverse) + self.query_pos = self.get_query_pos_from_cigar( + self.cigar, self.is_reverse) # get the positions of the query that are aligned @staticmethod @@ -921,20 +976,20 @@ def get_query_pos_from_cigar(cigar, is_reverse): cigar = cigar[::-1] # iterate through cigartuple - for i in xrange(len(cigar)): + for i in range(len(cigar)): k, n = cigar[i] - if k in (4,5): # H, S + if k in (4, 5): # H, S if i == 0: query_start += n query_end += n query_length += n else: query_length += n - elif k in (0,1,7,8): # M, I, =, X + elif k in (0, 1, 7, 8): # M, I, =, X query_end += n - query_length +=n + query_length += n - d = QueryPos(query_start, query_end, query_length); + d = QueryPos(query_start, query_end, query_length) return d def set_reference_end(self, reference_end): @@ -948,15 +1003,18 @@ def is_clip_op(self, op): # check if passes QC, and populate with necessary info def is_valid(self, - min_non_overlap = 20, - min_indel = 50, - max_unmapped_bases = 50): + min_non_overlap=20, + min_indel=50, + max_unmapped_bases=50): # check for SA tag if not self.read.has_tag('SA'): - # Include soft-clipped reads that didn't generate split-read alignments + # Include soft-clipped reads that didn't generate split-read + # alignments if self.is_clip_op(self.read.cigar[0][0]) or self.is_clip_op(self.read.cigar[-1][0]): - clip_length = max(self.read.cigar[0][1] * self.is_clip_op(self.read.cigar[0][0]), self.read.cigar[-1][1] * self.is_clip_op(self.read.cigar[-1][0])) - # Only count if the longest clipping event is greater than the cutoff and we have mapped a reasonable number of bases + clip_length = max(self.read.cigar[0][1] * self.is_clip_op( + self.read.cigar[0][0]), self.read.cigar[-1][1] * self.is_clip_op(self.read.cigar[-1][0])) + # Only count if the longest clipping event is greater than the + # cutoff and we have mapped a reasonable number of bases if clip_length > 0 and (self.read.query_length - self.read.query_alignment_length) <= max_unmapped_bases: a = self.SplitPiece(self.read.reference_name, self.read.reference_start, @@ -985,7 +1043,8 @@ def is_valid(self, else: self.sa = sa_list[0].split(',') mate_chrom = self.sa[0] - mate_pos = int(self.sa[1]) - 1 # SA tag is one-based, while SAM is zero-based + mate_pos = int( + self.sa[1]) - 1 # SA tag is one-based, while SAM is zero-based mate_is_reverse = self.sa[2] == '-' mate_cigar = self.cigarstring_to_tuple(self.sa[3]) mate_mapq = int(self.sa[4]) @@ -1003,7 +1062,8 @@ def is_valid(self, mate_is_reverse, mate_cigar, mate_mapq) - b.set_reference_end(self.get_reference_end_from_cigar(b.reference_start, b.cigar)) + b.set_reference_end( + self.get_reference_end_from_cigar(b.reference_start, b.cigar)) # set query_left and query_right splitter by alignment position on the reference # this is used for non-overlap and off-diagonal filtering @@ -1025,7 +1085,7 @@ def is_valid(self, # check off-diagonal distance and desert # only relevant when split pieces are on the same chromosome and strand if (self.query_left.chrom == self.query_right.chrom - and self.query_left.is_reverse == self.query_right.is_reverse): + and self.query_left.is_reverse == self.query_right.is_reverse): # use end diagonal on left and start diagonal on right since # the start and end diags might be different if there is an # indel in the alignments @@ -1041,7 +1101,8 @@ def is_valid(self, return False # check for desert gap of indels - desert = self.query_right.query_pos.query_start - self.query_left.query_pos.query_end - 1 + desert = self.query_right.query_pos.query_start - \ + self.query_left.query_pos.query_end - 1 if desert > 0 and desert - max(0, ins_size) > max_unmapped_bases: return False @@ -1054,7 +1115,8 @@ def is_valid(self, def get_start_diagonal(split_piece): sclip = split_piece.query_pos.query_start if split_piece.is_reverse: - sclip = split_piece.query_pos.query_length - split_piece.query_pos.query_end + sclip = split_piece.query_pos.query_length - \ + split_piece.query_pos.query_end return split_piece.reference_start - sclip # reference position where the alignment would have ended @@ -1063,17 +1125,21 @@ def get_start_diagonal(split_piece): def get_end_diagonal(split_piece): query_aligned = split_piece.query_pos.query_end if split_piece.is_reverse: - query_aligned = split_piece.query_pos.query_length - split_piece.query_pos.query_start + query_aligned = split_piece.query_pos.query_length - \ + split_piece.query_pos.query_start return split_piece.reference_end - query_aligned - - # adapted from Matt Shirley (http://coderscrowd.com/app/public/codes/view/171) + # adapted from Matt Shirley + # (http://coderscrowd.com/app/public/codes/view/171) @staticmethod def cigarstring_to_tuple(cigarstring): - cigar_dict = {'M':0, 'I':1,'D':2,'N':3, 'S':4, 'H':5, 'P':6, '=':7, 'X':8} + cigar_dict = {'M': 0, 'I': 1, 'D': 2, 'N': 3, 'S': 4, + 'H': 5, 'P': 6, '=': 7, 'X': 8} pattern = re.compile('([MIDNSHPX=])') - values = pattern.split(cigarstring)[:-1] ## turn cigar into tuple of values - paired = (values[n:n+2] for n in xrange(0, len(values), 2)) ## pair values by twos + values = pattern.split(cigarstring)[ + :-1] # turn cigar into tuple of values + paired = (values[n:n + 2] + for n in range(0, len(values), 2)) # pair values by twos return [(cigar_dict[pair[1]], int(pair[0])) for pair in paired] @staticmethod @@ -1083,11 +1149,11 @@ def get_reference_end_from_cigar(reference_start, cigar): This matches the behavior of pysam's reference_end method ''' reference_end = reference_start - + # iterate through cigartuple - for i in xrange(len(cigar)): + for i in range(len(cigar)): k, n = cigar[i] - if k in (0,2,3,7,8): # M, D, N, =, X + if k in (0, 2, 3, 7, 8): # M, D, N, =, X reference_end += n return reference_end @@ -1099,8 +1165,10 @@ def non_overlap(self): self.query_right.query_pos.query_end) # get minimum non-overlap - left_non_overlap = 1 + self.query_left.query_pos.query_end - self.query_left.query_pos.query_start - overlap - right_non_overlap = 1 + self.query_right.query_pos.query_end - self.query_right.query_pos.query_start - overlap + left_non_overlap = 1 + self.query_left.query_pos.query_end - \ + self.query_left.query_pos.query_start - overlap + right_non_overlap = 1 + self.query_right.query_pos.query_end - \ + self.query_right.query_pos.query_start - overlap non_overlap = min(left_non_overlap, right_non_overlap) return non_overlap @@ -1132,7 +1200,7 @@ def is_split_straddle(self, # arrange the SV breakends from left to right if (chromA != chromB - or (chromA == chromB and posA > posB)): + or (chromA == chromB and posA > posB)): chrom_left = chromB pos_left = posB ci_left = ciB @@ -1151,56 +1219,55 @@ def is_split_straddle(self, ci_right = ciB is_reverse_right = o2_is_reverse - # check split chromosomes against variant left_split = False right_split = False if (not self.is_soft_clip) or svtype == 'DEL' or svtype == 'INS': left_split = self.check_split_support(self.query_left, - chrom_left, - pos_left, - is_reverse_left, - split_slop) + chrom_left, + pos_left, + is_reverse_left, + split_slop) right_split = self.check_split_support(self.query_right, - chrom_right, - pos_right, - is_reverse_right, - split_slop) + chrom_right, + pos_right, + is_reverse_right, + split_slop) elif svtype == 'DUP': left_split = self.check_split_support(self.query_left, - chrom_right, - pos_right, - is_reverse_right, - split_slop) + chrom_right, + pos_right, + is_reverse_right, + split_slop) right_split = self.check_split_support(self.query_right, - chrom_left, - pos_left, - is_reverse_left, - split_slop) + chrom_left, + pos_left, + is_reverse_left, + split_slop) elif svtype == 'INV': # check all possible sides left_split_left = self.check_split_support(self.query_left, - chrom_left, - pos_left, - is_reverse_left, - split_slop) + chrom_left, + pos_left, + is_reverse_left, + split_slop) left_split_right = self.check_split_support(self.query_left, - chrom_right, - pos_right, - is_reverse_right, - split_slop) + chrom_right, + pos_right, + is_reverse_right, + split_slop) left_split = left_split_left or left_split_right right_split_left = self.check_split_support(self.query_right, - chrom_left, - pos_left, - is_reverse_left, - split_slop) + chrom_left, + pos_left, + is_reverse_left, + split_slop) right_split_right = self.check_split_support(self.query_right, - chrom_right, - pos_right, - is_reverse_right, - split_slop) + chrom_right, + pos_right, + is_reverse_right, + split_slop) right_split = right_split_left or right_split_right return (left_split, right_split) @@ -1246,11 +1313,12 @@ def is_left_clip(self, cigar): # structure to hold query position information class QueryPos (object): + """ struct to store the start and end positions of query CIGAR operations """ + def __init__(self, query_start, query_end, query_length): self.query_start = int(query_start) self.query_end = int(query_end) - self.query_length = int(query_length) - + self.query_length = int(query_length) diff --git a/svtyper/singlesample.py b/svtyper/singlesample.py index c5d23e5..399487c 100644 --- a/svtyper/singlesample.py +++ b/svtyper/singlesample.py @@ -1,5 +1,10 @@ -from __future__ import print_function -import json, sys, os, math, argparse, time + +import json +import sys +import os +import math +import argparse +import time import multiprocessing as mp from cytoolz.itertoolz import partition_all @@ -18,48 +23,75 @@ def get_args(): author: " + 'Indraniel Das (idas@wustl.edu)' + "\n\ version: " + svtyper.version.__version__ + "\n\ description: Compute genotype of structural variants based on breakpoint depth on a SINGLE sample") - parser.add_argument('-i', '--input_vcf', metavar='FILE', type=argparse.FileType('r'), default=None, help='VCF input (default: stdin)') - parser.add_argument('-o', '--output_vcf', metavar='FILE', type=argparse.FileType('w'), default=sys.stdout, help='output VCF to write (default: stdout)') - parser.add_argument('-B', '--bam', metavar='FILE', type=str, required=True, help='BAM or CRAM file(s), comma-separated if genotyping multiple samples') - parser.add_argument('-T', '--ref_fasta', metavar='FILE', type=str, required=False, default=None, help='Indexed reference FASTA file (recommended for reading CRAM files)') - parser.add_argument('-S', '--split_bam', type=str, required=False, help=argparse.SUPPRESS) - parser.add_argument('-l', '--lib_info', metavar='FILE', dest='lib_info_path', type=str, required=False, default=None, help='create/read JSON file of library information') - parser.add_argument('-m', '--min_aligned', metavar='INT', type=int, required=False, default=20, help='minimum number of aligned bases to consider read as evidence [20]') - parser.add_argument('-n', dest='num_samp', metavar='INT', type=int, required=False, default=1000000, help='number of reads to sample from BAM file for building insert size distribution [1000000]') - parser.add_argument('-q', '--sum_quals', action='store_true', required=False, help='add genotyping quality to existing QUAL (default: overwrite QUAL field)') - parser.add_argument('--max_reads', metavar='INT', type=int, default=1000, required=False, help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: 1000)') - parser.add_argument('--split_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for split reads [1]') - parser.add_argument('--disc_weight', metavar='FLOAT', type=float, required=False, default=1, help='weight for discordant paired-end reads [1]') + parser.add_argument('-i', '--input_vcf', metavar='FILE', type=argparse.FileType( + 'r'), default=None, help='VCF input (default: stdin)') + parser.add_argument('-o', '--output_vcf', metavar='FILE', type=argparse.FileType( + 'w'), default=sys.stdout, help='output VCF to write (default: stdout)') + parser.add_argument('-B', '--bam', metavar='FILE', type=str, required=True, + help='BAM or CRAM file(s), comma-separated if genotyping multiple samples') + parser.add_argument( + '-T', '--ref_fasta', metavar='FILE', type=str, required=False, + default=None, help='Indexed reference FASTA file (recommended for reading CRAM files)') + parser.add_argument( + '-S', '--split_bam', type=str, required=False, help=argparse.SUPPRESS) + parser.add_argument( + '-l', '--lib_info', metavar='FILE', dest='lib_info_path', type=str, + required=False, default=None, help='create/read JSON file of library information') + parser.add_argument( + '-m', '--min_aligned', metavar='INT', type=int, required=False, + default=20, help='minimum number of aligned bases to consider read as evidence [20]') + parser.add_argument( + '-n', dest='num_samp', metavar='INT', type=int, required=False, default=1000000, + help='number of reads to sample from BAM file for building insert size distribution [1000000]') + parser.add_argument( + '-q', '--sum_quals', action='store_true', required=False, + help='add genotyping quality to existing QUAL (default: overwrite QUAL field)') + parser.add_argument( + '--max_reads', metavar='INT', type=int, default=1000, required=False, + help='maximum number of reads to assess at any variant (reduces processing time in high-depth regions, default: 1000)') + parser.add_argument('--split_weight', metavar='FLOAT', type=float, + required=False, default=1, help='weight for split reads [1]') + parser.add_argument('--disc_weight', metavar='FLOAT', type=float, + required=False, default=1, help='weight for discordant paired-end reads [1]') parser.add_argument('--debug', action='store_true', help=argparse.SUPPRESS) - parser.add_argument('--cores', type=int, metavar='INT', required=False, default=None, help='number of workers to use for parallel processing') - parser.add_argument('--batch_size', type=int, metavar='INT', required=False, default=1000, help='number of breakpoints to batch for a parallel processing worker job') + parser.add_argument('--cores', type=int, metavar='INT', required=False, + default=None, help='number of workers to use for parallel processing') + parser.add_argument( + '--batch_size', type=int, metavar='INT', required=False, + default=1000, help='number of breakpoints to batch for a parallel processing worker job') # parse the arguments args = parser.parse_args() # if no input, check if part of pipe and if so, read stdin. - if args.input_vcf == None: + if args.input_vcf is None: if not sys.stdin.isatty(): args.input_vcf = sys.stdin # send back the user input return args + def ensure_valid_alignment_file(afile): if not (afile.endswith('.bam') or afile.endswith('.cram')): - die('Error: %s is not a valid alignment file (*.bam or *.cram)\n' % afile) + die('Error: %s is not a valid alignment file (*.bam or *.cram)\n' % + afile) + def open_alignment_file(afile, reference_fasta): fd = None if afile.endswith('.bam'): fd = pysam.AlignmentFile(afile, mode='rb') elif afile.endswith('.cram'): - fd = pysam.AlignmentFile(afile, mode='rc', reference_filename=reference_fasta) + fd = pysam.AlignmentFile( + afile, mode='rc', reference_filename=reference_fasta) else: - die('Error: %s is not a valid alignment file (*.bam or *.cram)\n' % afile) + die('Error: %s is not a valid alignment file (*.bam or *.cram)\n' % + afile) return fd + def setup_sample(bam, lib_info_path, reference_fasta, sampling_number, min_aligned): fd = open_alignment_file(bam, reference_fasta) @@ -82,6 +114,7 @@ def setup_sample(bam, lib_info_path, reference_fasta, sampling_number, min_align return sample + def dump_library_metrics(lib_info_path, sample): sample_list = [sample] if (lib_info_path is not None) and (not os.path.exists(lib_info_path)): @@ -91,12 +124,14 @@ def dump_library_metrics(lib_info_path, sample): lib_info_file.close() logit('Finished writing library metrics') + def setup_src_vcf_file(fobj, invcf, rootdir): src_vcf = invcf if os.path.basename(invcf) == '': src_vcf = dump_piped_vcf_to_file(fobj, rootdir) return src_vcf + def dump_piped_vcf_to_file(stdin, basedir): vcf = os.path.join(basedir, 'input.vcf') logit('dumping vcf inputs into a temporary file: {}'.format(vcf)) @@ -108,6 +143,7 @@ def dump_piped_vcf_to_file(stdin, basedir): logit('finished temporary vcf dump -- {} lines'.format(line_count)) return vcf + def init_vcf(vcffile, sample, scratchdir): v = Vcf() v.filename = vcffile @@ -123,18 +159,23 @@ def init_vcf(vcffile, sample, scratchdir): v.add_sample(sample.name) return v + def collect_breakpoints(vcf): breakpoints = [] for vline in vcf_variants(vcf.filename): v = vline.rstrip().split('\t') variant = Variant(v, vcf) - if not variant.has_svtype(): continue - if not variant.is_valid_svtype(): continue + if not variant.has_svtype(): + continue + if not variant.is_valid_svtype(): + continue brkpts = vcf.get_variant_breakpoints(variant) - if brkpts is None: continue + if brkpts is None: + continue breakpoints.append(brkpts) return breakpoints + def get_breakpoint_regions(breakpoint, sample, z): # the distance to the left and right of the breakpoint to scan # (max of mean + z standard devs over all of a sample's libraries) @@ -154,35 +195,43 @@ def get_breakpoint_regions(breakpoint, sample, z): return tuple(regions) + def count_reads_in_region(region, bam): (sample_name, chrom, pos, left_pos, right_pos) = region - count = bam.count(chrom, start=left_pos, stop=right_pos, read_callback='all') + count = bam.count( + chrom, start=left_pos, stop=right_pos, read_callback='all') return count + def get_reads_iterator(region, bam): (sample_name, chrom, pos, left_pos, right_pos) = region iterator = bam.fetch(chrom, start=left_pos, stop=right_pos) return iterator + def is_over_threshold(bam, variant_id, regions, max_reads): over_threshold = False (regionA, regionB) = regions - (countA, countB) = ( count_reads_in_region(regionA, bam), count_reads_in_region(regionB, bam) ) + (countA, countB) = (count_reads_in_region(regionA, bam), + count_reads_in_region(regionB, bam)) if countA > max_reads or countB > max_reads: over_threshold = True msg = ("SKIPPING -- Variant '{}' has a region with too many reads (> {})\n" - "\t\t A: (sample={} chrom={} center={} leftflank={} rightflank={}) : {}\n" - "\t\t B: (sample={} chrom={} center={} leftflank={} rightflank={}) : {}").format( - variant_id, + "\t\t A: (sample={} chrom={} center={} leftflank={} rightflank={}) : {}\n" + "\t\t B: (sample={} chrom={} center={} leftflank={} rightflank={}) : {}").format( + variant_id, max_reads, - regionA[0], regionA[1], regionA[2], regionA[3], regionA[4], + regionA[0], regionA[1], regionA[ + 2], regionA[3], regionA[4], countA, - regionB[0], regionB[1], regionB[2], regionB[3], regionB[4], + regionB[0], regionB[1], regionB[ + 2], regionB[3], regionB[4], countB, - ) + ) logit(msg) return over_threshold + def gather_reads(bam, variant_id, regions, library_data, active_libs, max_reads): fragment_dict = {} over_threshold = is_over_threshold(bam, variant_id, regions, max_reads) @@ -192,9 +241,11 @@ def gather_reads(bam, variant_id, regions, library_data, active_libs, max_reads) for region in regions: for read in get_reads_iterator(region, bam): - if read.is_unmapped or read.is_duplicate: continue + if read.is_unmapped or read.is_duplicate: + continue lib = library_data[read.get_tag('RG')] - if lib.name not in active_libs: continue + if lib.name not in active_libs: + continue if read.query_name in fragment_dict: fragment_dict[read.query_name].add_read(read) else: @@ -203,42 +254,45 @@ def gather_reads(bam, variant_id, regions, library_data, active_libs, max_reads) return (fragment_dict, over_threshold) + def blank_genotype_result(): return { - 'qual' : 0, + 'qual': 0, 'formats': { - 'GT' : './.', - 'GQ' : '.', - 'SQ' : '.', - 'GL' : '.', - 'DP' : 0, - 'AO' : 0, - 'RO' : 0, - 'AS' : 0, - 'ASC' : 0, - 'RS' : 0, - 'AP' : 0, - 'RP' : 0, - 'QR' : 0, - 'QA' : 0, - 'AB' : '.', + 'GT': './.', + 'GQ': '.', + 'SQ': '.', + 'GL': '.', + 'DP': 0, + 'AO': 0, + 'RO': 0, + 'AS': 0, + 'ASC': 0, + 'RS': 0, + 'AP': 0, + 'RP': 0, + 'QR': 0, + 'QA': 0, + 'AB': '.', } } + def make_empty_genotype_result(variant_id, sample_name): gt = blank_genotype_result() gt['DP'] = '.' return { - 'variant.id' : variant_id, - 'sample.name' : sample_name, - 'genotype' : gt + 'variant.id': variant_id, + 'sample.name': sample_name, + 'genotype': gt } + def make_detailed_empty_genotype_result(variant_id, sample_name): return { - 'variant.id' : variant_id, - 'sample.name' : sample_name, - 'genotype' : blank_genotype_result() + 'variant.id': variant_id, + 'sample.name': sample_name, + 'genotype': blank_genotype_result() } @@ -246,13 +300,15 @@ def gather_split_read_evidence(sam_fragment, breakpoint, split_slop, min_aligned (ref_seq, alt_seq, alt_clip) = (0, 0, 0) elems = ('chrom', 'pos', 'ci', 'is_reverse') - (chromA, posA, ciA, o1_is_reverse) = [ breakpoint['A'][i] for i in elems ] - (chromB, posB, ciB, o2_is_reverse) = [ breakpoint['B'][i] for i in elems ] + (chromA, posA, ciA, o1_is_reverse) = [breakpoint['A'][i] for i in elems] + (chromB, posB, ciB, o2_is_reverse) = [breakpoint['B'][i] for i in elems] # get reference sequences for read in sam_fragment.primary_reads: - is_ref_seq_A = sam_fragment.is_ref_seq(read, None, chromA, posA, ciA, min_aligned) - is_ref_seq_B = sam_fragment.is_ref_seq(read, None, chromB, posB, ciB, min_aligned) + is_ref_seq_A = sam_fragment.is_ref_seq( + read, None, chromA, posA, ciA, min_aligned) + is_ref_seq_B = sam_fragment.is_ref_seq( + read, None, chromB, posB, ciB, min_aligned) if (is_ref_seq_A or is_ref_seq_B): p_reference = prob_mapq(read) ref_seq += p_reference @@ -266,7 +322,8 @@ def gather_split_read_evidence(sam_fragment, breakpoint, split_slop, min_aligned o1_is_reverse, o2_is_reverse, svtype, split_slop) # p_alt = prob_mapq(split.query_left) * prob_mapq(split.query_right) - p_alt = (prob_mapq(split.query_left) * split_lr[0] + prob_mapq(split.query_right) * split_lr[1]) / 2.0 + p_alt = (prob_mapq(split.query_left) * split_lr[ + 0] + prob_mapq(split.query_right) * split_lr[1]) / 2.0 if split.is_soft_clip: alt_clip += p_alt else: @@ -274,14 +331,15 @@ def gather_split_read_evidence(sam_fragment, breakpoint, split_slop, min_aligned return (ref_seq, alt_seq, alt_clip) + def gather_paired_end_evidence(fragment, breakpoint, min_aligned): (ref_span, alt_span) = (0, 0) - ref_ciA = [0,0] - ref_ciB = [0,0] + ref_ciA = [0, 0] + ref_ciB = [0, 0] elems = ('chrom', 'pos', 'ci', 'is_reverse') - (chromA, posA, ciA, o1_is_reverse) = [ breakpoint['A'][i] for i in elems ] - (chromB, posB, ciB, o2_is_reverse) = [ breakpoint['B'][i] for i in elems ] + (chromA, posA, ciA, o1_is_reverse) = [breakpoint['A'][i] for i in elems] + (chromB, posB, ciB, o2_is_reverse) = [breakpoint['B'][i] for i in elems] svtype = breakpoint['svtype'] # tally spanning alternate pairs @@ -310,11 +368,12 @@ def gather_paired_end_evidence(fragment, breakpoint, min_aligned): var_length = breakpoint['var_length'] p_conc = fragment.p_concordant(var_length) if p_conc is not None: - p_alt = (1 - p_conc) * prob_mapq(fragment.readA) * prob_mapq(fragment.readB) + p_alt = (1 - p_conc) * prob_mapq( + fragment.readA) * prob_mapq(fragment.readB) alt_span += p_alt - # # since an alt straddler is by definition also a reference straddler, - # # we can bail out early here to save some time + # since an alt straddler is by definition also a reference straddler, + # we can bail out early here to save some time # p_reference = p_conc * prob_mapq(fragment.readA) * prob_mapq(fragment.readB) # ref_span += p_reference # continue @@ -323,7 +382,7 @@ def gather_paired_end_evidence(fragment, breakpoint, min_aligned): p_alt = prob_mapq(fragment.readA) * prob_mapq(fragment.readB) alt_span += p_alt - # # tally spanning reference pairs + # tally spanning reference pairs if svtype == 'DEL' and posB - posA < 2 * fragment.lib.sd: ref_straddle_A = False ref_straddle_B = False @@ -346,37 +405,40 @@ def gather_paired_end_evidence(fragment, breakpoint, min_aligned): var_length = breakpoint.get('var_length', None) p_conc = fragment.p_concordant(var_length) if p_conc is not None: - p_reference = p_conc * prob_mapq(fragment.readA) * prob_mapq(fragment.readB) + p_reference = p_conc * \ + prob_mapq(fragment.readA) * prob_mapq(fragment.readB) ref_span += (ref_straddle_A + ref_straddle_B) * p_reference / 2 return (ref_span, alt_span, ref_ciA, ref_ciB) + def tally_variant_read_fragments(split_slop, min_aligned, breakpoint, sam_fragments, debug): # initialize counts to zero ref_span, alt_span = 0, 0 ref_seq, alt_seq = 0, 0 alt_clip = 0 - ref_ciA = [0,0] - ref_ciB = [0,0] + ref_ciA = [0, 0] + ref_ciB = [0, 0] for query_name in sorted(sam_fragments.keys()): fragment = sam_fragments[query_name] (ref_seq_calc, alt_seq_calc, alt_clip_calc) = \ - gather_split_read_evidence(fragment, breakpoint, split_slop, min_aligned) + gather_split_read_evidence( + fragment, breakpoint, split_slop, min_aligned) ref_seq += ref_seq_calc alt_seq += alt_seq_calc alt_clip += alt_clip_calc (ref_span_calc, alt_span_calc, ref_ciA_calc, ref_ciB_calc) = \ - gather_paired_end_evidence(fragment, breakpoint, min_aligned) + gather_paired_end_evidence(fragment, breakpoint, min_aligned) ref_span += ref_span_calc alt_span += alt_span_calc - ref_ciA = [ x + y for x,y in zip(ref_ciA, ref_ciA_calc)] - ref_ciB = [ x + y for x,y in zip(ref_ciB, ref_ciB_calc)] + ref_ciA = [x + y for x, y in zip(ref_ciA, ref_ciA_calc)] + ref_ciB = [x + y for x, y in zip(ref_ciB, ref_ciB_calc)] # in the absence of evidence for a particular type, ignore the reference # support for that type as well @@ -391,17 +453,19 @@ def tally_variant_read_fragments(split_slop, min_aligned, breakpoint, sam_fragme # discount any SV that's only supported by clips. alt_clip = 0 - counts = { 'ref_seq' : ref_seq, 'alt_seq' : alt_seq, - 'ref_span' : ref_span, 'alt_span' : alt_span, - 'alt_clip' : alt_clip } + counts = {'ref_seq': ref_seq, 'alt_seq': alt_seq, + 'ref_span': ref_span, 'alt_span': alt_span, + 'alt_clip': alt_clip} if debug: items = ('ref_span', 'alt_span', 'ref_seq', 'alt_seq', 'alt_clip') cmsg = "\n".join(['{}: {}'.format(i, counts[i]) for i in items]) - logit("{} -- read fragment tally counts:\n{}".format(breakpoint['id'], cmsg)) + logit( + "{} -- read fragment tally counts:\n{}".format(breakpoint['id'], cmsg)) return counts + def bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug): is_dup = breakpoint['svtype'] == 'DUP' @@ -416,7 +480,8 @@ def bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug): # the actual bayesian calculation and decision gt_lplist = bayes_gt(QR, QA, is_dup) - best, second_best = sorted([ (i, e) for i, e in enumerate(gt_lplist) ], key=lambda(x): x[1], reverse=True)[0:2] + best, second_best = sorted([(i, e) + for i, e in enumerate(gt_lplist)], key=lambda x: x[1], reverse=True)[0:2] gt_idx = best[0] # print log probabilities of homref, het, homalt @@ -428,7 +493,8 @@ def bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug): result = blank_genotype_result() result['formats']['GL'] = ','.join(['%.0f' % x for x in gt_lplist]) - result['formats']['DP'] = int(ref_seq + alt_seq + alt_clip + ref_span + alt_span) + result['formats']['DP'] = int( + ref_seq + alt_seq + alt_clip + ref_span + alt_span) result['formats']['RO'] = int(ref_seq + ref_span) result['formats']['AO'] = int(alt_seq + alt_clip + alt_span) result['formats']['QR'] = QR @@ -453,7 +519,9 @@ def bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug): gt_sum += 0 if gt_sum > 0: gt_sum_log = math.log(gt_sum, 10) - sample_qual = abs(-10 * (gt_lplist[0] - gt_sum_log)) # phred-scaled probability site is non-reference in this sample + sample_qual = abs(-10 * (gt_lplist[0] - gt_sum_log)) + # phred-scaled probability site is non-reference in + # this sample phred_gq = min(-10 * (second_best[1] - best[1]), 200) result['formats']['GQ'] = int(phred_gq) result['formats']['SQ'] = sample_qual @@ -468,11 +536,13 @@ def bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug): result['formats']['GQ'] = '.' result['formats']['SQ'] = '.' result['formats']['GT'] = './.' - + return result + def serial_calculate_genotype(bam, regions, library_data, active_libs, sample_name, split_slop, min_aligned, split_weight, disc_weight, breakpoint, max_reads, debug): - (read_batches, many) = gather_reads(bam, breakpoint['id'], regions, library_data, active_libs, max_reads) + (read_batches, many) = gather_reads(bam, + breakpoint['id'], regions, library_data, active_libs, max_reads) # if there are too many reads around the breakpoint if many is True: @@ -490,33 +560,40 @@ def serial_calculate_genotype(bam, regions, library_data, active_libs, sample_na debug ) - total = sum([ counts[k] for k in counts.keys() ]) + total = sum([counts[k] for k in list(counts.keys())]) if total == 0: return make_detailed_empty_genotype_result(breakpoint['id'], sample_name) - result = bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug) - return { 'variant.id' : breakpoint['id'], 'sample.name' : sample_name, 'genotype' : result } + result = bayesian_genotype( + breakpoint, counts, split_weight, disc_weight, debug) + return {'variant.id': breakpoint['id'], 'sample.name': sample_name, 'genotype': result} + def parallel_calculate_genotype(alignment_file, reference_fasta, library_data, active_libs, sample_name, split_slop, min_aligned, split_weight, disc_weight, max_reads, debug, batch_breakpoints, batch_regions, batch_number): + logit("Starting batch: {}".format(batch_number)) + logit("alignment file: {}".format(alignment_file)) bam = open_alignment_file(alignment_file, reference_fasta) genotype_results = [] (skip_count, no_read_count) = (0, 0) t0 = time.time() for breakpoint, regions in zip(batch_breakpoints, batch_regions): - (read_batches, many) = gather_reads(bam, breakpoint['id'], regions, library_data, active_libs, max_reads) + (read_batches, many) = gather_reads( + bam, breakpoint['id'], regions, library_data, active_libs, max_reads) # if there are too many reads around the breakpoint if many is True: skip_count += 1 - genotype_results.append(make_empty_genotype_result(breakpoint['id'], sample_name)) + genotype_results.append( + make_empty_genotype_result(breakpoint['id'], sample_name)) continue # if there are no reads around the breakpoint if bool(read_batches) is False: no_read_count += 1 - genotype_results.append(make_detailed_empty_genotype_result(breakpoint['id'], sample_name)) + genotype_results.append( + make_detailed_empty_genotype_result(breakpoint['id'], sample_name)) continue counts = tally_variant_read_fragments( @@ -527,18 +604,23 @@ def parallel_calculate_genotype(alignment_file, reference_fasta, library_data, a debug ) - total = sum([ counts[k] for k in counts.keys() ]) + total = sum([counts[k] for k in list(counts.keys())]) if total == 0: - genotype_results.append(make_detailed_empty_genotype_result(breakpoint['id'], sample_name)) + genotype_results.append( + make_detailed_empty_genotype_result(breakpoint['id'], sample_name)) continue - result = bayesian_genotype(breakpoint, counts, split_weight, disc_weight, debug) - genotype_results.append({ 'variant.id' : breakpoint['id'], 'sample.name' : sample_name, 'genotype' : result }) + result = bayesian_genotype( + breakpoint, counts, split_weight, disc_weight, debug) + genotype_results.append( + {'variant.id': breakpoint['id'], 'sample.name': sample_name, 'genotype': result}) t1 = time.time() - logit("Batch {} Processing Elapsed Time: {:.4f} secs".format(batch_number, t1 - t0)) + logit( + "Batch {} Processing Elapsed Time: {:.4f} secs".format(batch_number, t1 - t0)) bam.close() - return { 'genotypes' : genotype_results, 'skip-count' : skip_count, 'no-read-count' : no_read_count } + return {'genotypes': genotype_results, 'skip-count': skip_count, 'no-read-count': no_read_count} + def assign_genotype_to_variant(variant, sample, genotype_result): variant_id = genotype_result['variant.id'] @@ -555,24 +637,40 @@ def assign_genotype_to_variant(variant, sample, genotype_result): variant.genotype(sample.name).set_format('GT', './.') else: variant.qual += outcome['qual'] - variant.genotype(sample.name).set_format('GT', outcome['formats']['GT']) - variant.genotype(sample.name).set_format('GQ', outcome['formats']['GQ']) - variant.genotype(sample.name).set_format('SQ', outcome['formats']['SQ']) - variant.genotype(sample.name).set_format('GL', outcome['formats']['GL']) - variant.genotype(sample.name).set_format('DP', outcome['formats']['DP']) - variant.genotype(sample.name).set_format('AO', outcome['formats']['AO']) - variant.genotype(sample.name).set_format('RO', outcome['formats']['RO']) + variant.genotype(sample.name).set_format( + 'GT', outcome['formats']['GT']) + variant.genotype(sample.name).set_format( + 'GQ', outcome['formats']['GQ']) + variant.genotype(sample.name).set_format( + 'SQ', outcome['formats']['SQ']) + variant.genotype(sample.name).set_format( + 'GL', outcome['formats']['GL']) + variant.genotype(sample.name).set_format( + 'DP', outcome['formats']['DP']) + variant.genotype(sample.name).set_format( + 'AO', outcome['formats']['AO']) + variant.genotype(sample.name).set_format( + 'RO', outcome['formats']['RO']) # if detailed: - variant.genotype(sample.name).set_format('AS', outcome['formats']['AS']) - variant.genotype(sample.name).set_format('ASC', outcome['formats']['ASC']) - variant.genotype(sample.name).set_format('RS', outcome['formats']['RS']) - variant.genotype(sample.name).set_format('AP', outcome['formats']['AP']) - variant.genotype(sample.name).set_format('RP', outcome['formats']['RP']) - variant.genotype(sample.name).set_format('QR', outcome['formats']['QR']) - variant.genotype(sample.name).set_format('QA', outcome['formats']['QA']) - variant.genotype(sample.name).set_format('AB', outcome['formats']['AB']) + variant.genotype(sample.name).set_format( + 'AS', outcome['formats']['AS']) + variant.genotype(sample.name).set_format( + 'ASC', outcome['formats']['ASC']) + variant.genotype(sample.name).set_format( + 'RS', outcome['formats']['RS']) + variant.genotype(sample.name).set_format( + 'AP', outcome['formats']['AP']) + variant.genotype(sample.name).set_format( + 'RP', outcome['formats']['RP']) + variant.genotype(sample.name).set_format( + 'QR', outcome['formats']['QR']) + variant.genotype(sample.name).set_format( + 'QA', outcome['formats']['QA']) + variant.genotype(sample.name).set_format( + 'AB', outcome['formats']['AB']) return variant + def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug): # initializations bnd_cache = {} @@ -587,7 +685,8 @@ def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_qu v = vline.rstrip().split('\t') variant = Variant(v, src_vcf) if i % 1000 == 0: - logit("[ {} | {} ] Processing variant {}".format(i, total_variants, variant.var_id)) + logit("[ {} | {} ] Processing variant {}".format( + i, total_variants, variant.var_id)) if not sum_quals: variant.qual = 0 @@ -626,7 +725,7 @@ def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_qu continue result = serial_calculate_genotype( - sample.bam, + sample.bam, get_breakpoint_regions(breakpoints, sample, z), sample.rg_to_lib, sample.active_libs, @@ -650,6 +749,7 @@ def genotype_serial(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_qu variant2.genotype = variant.genotype variant2.write(out_vcf) + def apply_genotypes_to_vcf(src_vcf, out_vcf, genotypes, sample, sum_quals): # initializations bnd_cache = {} @@ -706,6 +806,7 @@ def apply_genotypes_to_vcf(src_vcf, out_vcf, genotypes, sample, sum_quals): variant2.genotype = variant.genotype variant2.write(out_vcf) + def genotype_parallel(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug, cores, breakpoint_batch_size, ref_fasta): # cleanup unused library attributes @@ -717,19 +818,22 @@ def genotype_parallel(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_ breakpoints = collect_breakpoints(src_vcf) logit("Number of breakpoints/SVs to process: {}".format(len(breakpoints))) logit("Collecting regions") - regions = [ get_breakpoint_regions(b, sample, z) for b in breakpoints ] + regions = [get_breakpoint_regions(b, sample, z) for b in breakpoints] logit("Batch breakpoints into groups of {}".format(breakpoint_batch_size)) - breakpoints_batches = list(partition_all(breakpoint_batch_size, breakpoints)) + breakpoints_batches = list( + partition_all(breakpoint_batch_size, breakpoints)) logit("Batch regions into groups of {}".format(breakpoint_batch_size)) regions_batches = list(partition_all(breakpoint_batch_size, regions)) if len(breakpoints_batches) != len(regions_batches): - raise RuntimeError("Batch error: breakpoint batches ({}) != region batches ({})".format(breakpoints_batches, regions_batches)) + raise RuntimeError( + "Batch error: breakpoint batches ({}) != region batches ({})".format(breakpoints_batches, regions_batches)) - logit("Number of batches to parallel process: {}".format(len(breakpoints_batches))) + logit( + "Number of batches to parallel process: {}".format(len(breakpoints_batches))) std_args = ( - sample.bam.filename, + os.fsdecode(sample.bam.filename), ref_fasta, sample.rg_to_lib, sample.active_libs, @@ -743,23 +847,32 @@ def genotype_parallel(src_vcf, out_vcf, sample, z, split_slop, min_aligned, sum_ ) pool = mp.Pool(processes=cores) - results = [pool.apply_async(parallel_calculate_genotype, args=std_args + (b, r, i)) for i, (b, r) in enumerate(zip(breakpoints_batches, regions_batches))] + results = [pool.apply_async(parallel_calculate_genotype, args=std_args + (b, r, i)) + for i, (b, r) in enumerate(zip(breakpoints_batches, regions_batches))] + print('Ordered results using pool.apply_async():') results = [p.get() for p in results] logit("Finished parallel breakpoint processing") logit("Merging genotype results") - merged_genotypes = { g['variant.id'] : g for batch in results for g in batch['genotypes'] } + merged_genotypes = { + g['variant.id']: g for batch in results for g in batch['genotypes']} - total_variants_skipped = sum([ batch['skip-count'] for batch in results ]) - total_variants_with_no_reads = sum([ batch['no-read-count'] for batch in results ]) + total_variants_skipped = sum([batch['skip-count'] for batch in results]) + total_variants_with_no_reads = sum( + [batch['no-read-count'] for batch in results]) - logit("Number of variants skipped (surpassed max-reads threshold): {}".format(total_variants_skipped)) - logit("Number of variants with no reads: {}".format(total_variants_with_no_reads)) + logit( + "Number of variants skipped (surpassed max-reads threshold): {}".format(total_variants_skipped)) + logit("Number of variants with no reads: {}".format( + total_variants_with_no_reads)) - # 2nd pass through input vcf -- apply the calculated genotypes to the variants + # 2nd pass through input vcf -- apply the calculated genotypes to the + # variants logit("Applying genotype results to vcf") - apply_genotypes_to_vcf(src_vcf, out_vcf, merged_genotypes, sample, sum_quals) + apply_genotypes_to_vcf( + src_vcf, out_vcf, merged_genotypes, sample, sum_quals) logit("All Done!") + def sso_genotype(bam_string, vcf_in, vcf_out, @@ -783,12 +896,13 @@ def sso_genotype(bam_string, full_bam_path = os.path.abspath(bam_string) ensure_valid_alignment_file(full_bam_path) - sample = setup_sample(full_bam_path, lib_info_path, ref_fasta, num_samp, min_aligned) + sample = setup_sample( + full_bam_path, lib_info_path, ref_fasta, num_samp, min_aligned) dump_library_metrics(lib_info_path, sample) # set variables for genotyping z = 3 - split_slop = 3 # amount of slop around breakpoint to count splitters + split_slop = 3 # amount of slop around breakpoint to count splitters with tempdir() as scratchdir: logit("Temporary scratch directory: {}".format(scratchdir)) @@ -802,24 +916,28 @@ def sso_genotype(bam_string, if cores is None: logit("Genotyping Input VCF (Serial Mode)") # pass through input vcf -- perform actual genotyping - genotype_serial(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug) + genotype_serial(src_vcf, vcf_out, sample, z, split_slop, + min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug) else: logit("Genotyping Input VCF (Parallel Mode)") - genotype_parallel(src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, split_weight, disc_weight, max_reads, debug, cores, batch_size, ref_fasta) - + genotype_parallel( + src_vcf, vcf_out, sample, z, split_slop, min_aligned, sum_quals, + split_weight, disc_weight, max_reads, debug, cores, batch_size, ref_fasta) sample.close() # -------------------------------------- # main function + def main(): # parse the command line args args = get_args() if args.split_bam is not None: - sys.stderr.write('Warning: --split_bam (-S) is deprecated. Ignoring %s.\n' % args.split_bam) + sys.stderr.write( + 'Warning: --split_bam (-S) is deprecated. Ignoring %s.\n' % args.split_bam) # call primary function sso_genotype(args.bam, @@ -840,10 +958,11 @@ def main(): # -------------------------------------- # command-line/console entrypoint + def cli(): try: sys.exit(main()) - except IOError, e: + except IOError as e: if e.errno != 32: # ignore SIGPIPE raise diff --git a/svtyper/statistics.py b/svtyper/statistics.py index 51cc3cd..06542cc 100644 --- a/svtyper/statistics.py +++ b/svtyper/statistics.py @@ -6,13 +6,15 @@ # ================================================== # efficient combinatorial function to handle extremely large numbers + + def log_choose(n, k): r = 0.0 # swap for efficiency if k is more than half of n if k * 2 > n: k = n - k - for d in xrange(1,k+1): + for d in range(1, k + 1): r += math.log(n, 10) r -= math.log(d, 10) n -= 1 @@ -20,31 +22,41 @@ def log_choose(n, k): return r # return the genotype and log10 p-value + + def bayes_gt(ref, alt, is_dup): - # probability of seeing an alt read with true genotype of of hom_ref, het, hom_alt respectively - if is_dup: # specialized logic to handle non-destructive events such as duplications - p_alt = [1e-2, 0.2, 1/3.0] + # probability of seeing an alt read with true genotype of of hom_ref, het, + # hom_alt respectively + if is_dup: # specialized logic to handle non-destructive events such as duplications + p_alt = [1e-2, 0.2, 1 / 3.0] else: p_alt = [1e-3, 0.5, 0.9] total = ref + alt log_combo = log_choose(total, alt) - lp_homref = log_combo + alt * math.log(p_alt[0], 10) + ref * math.log(1 - p_alt[0], 10) - lp_het = log_combo + alt * math.log(p_alt[1], 10) + ref * math.log(1 - p_alt[1], 10) - lp_homalt = log_combo + alt * math.log(p_alt[2], 10) + ref * math.log(1 - p_alt[2], 10) + lp_homref = log_combo + alt * \ + math.log(p_alt[0], 10) + ref * math.log(1 - p_alt[0], 10) + lp_het = log_combo + alt * \ + math.log(p_alt[1], 10) + ref * math.log(1 - p_alt[1], 10) + lp_homalt = log_combo + alt * \ + math.log(p_alt[2], 10) + ref * math.log(1 - p_alt[2], 10) return (lp_homref, lp_het, lp_homalt) # get the number of entries in the set + + def countRecords(myCounter): numRecords = sum(myCounter.values()) return numRecords # median is approx 50th percentile, except when it is between # two values in which case it's the mean of them. + + def median(myCounter): - #length is the number of bases we're looking at + # length is the number of bases we're looking at numEntries = countRecords(myCounter) # the ordinal value of the middle element @@ -75,6 +87,8 @@ def median(myCounter): return v # calculate upper median absolute deviation + + def upper_mad(myCounter, myMedian): residCounter = Counter() for x in myCounter: @@ -83,6 +97,8 @@ def upper_mad(myCounter, myMedian): return median(residCounter) # sum of the entries + + def sumRecords(myCounter): mySum = 0.0 for c in myCounter: @@ -93,6 +109,8 @@ def sumRecords(myCounter): # length of the feature (chromosome or genome) # for x percentile, x% of the elements in the set are # <= the output value + + def mean(myCounter): # the number of total entries in the set is the # sum of the occurrences for each value @@ -104,6 +122,7 @@ def mean(myCounter): u = sumRecords(myCounter) / numRecords return u + def stdev(myCounter): # the number of total entries in the set is the # sum of the occurrences for each value @@ -119,4 +138,3 @@ def stdev(myCounter): myVariance = float(sumVar) / numRecords stdev = myVariance**(0.5) return stdev - diff --git a/svtyper/utils.py b/svtyper/utils.py index e80793e..9f2bae3 100644 --- a/svtyper/utils.py +++ b/svtyper/utils.py @@ -1,6 +1,14 @@ -from __future__ import print_function -import sys, time, datetime, os, contextlib, tempfile, shutil, json, re + +import sys +import time +import datetime +import os +import contextlib +import tempfile +import shutil +import json +import re from functools import wraps from svtyper.parsers import SamFragment, Vcf @@ -10,6 +18,8 @@ # ================================================== # write read to BAM file, checking whether read is already written + + def write_alignment(read, bam, written_reads, is_alt=None): read.query_sequence = None read_hash = (read.query_name, read.flag) @@ -22,17 +32,19 @@ def write_alignment(read, bam, written_reads, is_alt=None): return written_reads # dump the sample and library info to a file + + def write_sample_json(sample_list, lib_info_file): lib_info = {} for sample in sample_list: s = {} s['sample_name'] = sample.name - s['bam'] = sample.bam.filename + s['bam'] = os.fsdecode(sample.bam.filename) s['libraryArray'] = [] s['mapped'] = sample.bam.mapped s['unmapped'] = sample.bam.unmapped - for lib in sample.lib_dict.values(): + for lib in list(sample.lib_dict.values()): l = {} l['library_name'] = lib.name l['readgroups'] = lib.readgroups @@ -53,16 +65,21 @@ def write_sample_json(sample_list, lib_info_file): # ================================================== # Sorting Functions # ================================================== + + def sort_regions(region): - str_region = [ str(i) for i in region ] + str_region = [str(i) for i in region] key = '--'.join([str_region[i] for i in (0, 1, 3, 4)]) return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', key)] + def sort_reads(read): - attrs = ('reference_name', 'reference_start', 'reference_end', 'query_name') + attrs = ('reference_name', 'reference_start', + 'reference_end', 'query_name') key = '--'.join([str(getattr(read, i)) for i in attrs]) return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', key)] + def sort_chroms(chrom): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', chrom)] @@ -71,6 +88,8 @@ def sort_chroms(chrom): # ================================================== # get the non-phred-scaled mapq of a read + + def prob_mapq(read): return 1 - 10 ** (-read.mapping_quality / 10.0) @@ -78,6 +97,7 @@ def prob_mapq(read): # logging # ================================================== + def logit(msg): ts = time.strftime("[ %Y-%m-%d %T ]", datetime.datetime.now().timetuple()) fullmsg = "{} {}".format(ts, msg) @@ -88,10 +108,12 @@ def logit(msg): # temporary directory handling # ================================================== + def is_lsf_job(): rv = 'LSB_JOBID' in os.environ return rv + @contextlib.contextmanager def cd(newdir, cleanup=lambda: True): prevdir = os.getcwd() @@ -102,6 +124,7 @@ def cd(newdir, cleanup=lambda: True): os.chdir(prevdir) cleanup() + @contextlib.contextmanager def tempdir(): root = '/tmp' @@ -127,6 +150,7 @@ def vcf_variants(vcf_file): if not line.startswith('#'): yield line + def vcf_headers(vcf_file): with open(vcf_file, 'r') as f: for line in f: @@ -135,6 +159,7 @@ def vcf_headers(vcf_file): else: break + def vcf_samples(vcf_file): samples = [] with open(vcf_file, 'r') as f: @@ -147,5 +172,7 @@ def vcf_samples(vcf_file): # ================================================== # system helpers # ================================================== + + def die(msg): sys.exit(msg) diff --git a/tests/test_singlesample.py b/tests/test_singlesample.py index e56d042..e67948f 100644 --- a/tests/test_singlesample.py +++ b/tests/test_singlesample.py @@ -1,6 +1,8 @@ from .context import singlesample as s -import unittest, os, subprocess +import unittest +import os +import subprocess HERE = os.path.dirname(__file__) in_vcf = os.path.join(HERE, "data/example.vcf") @@ -9,7 +11,9 @@ out_vcf = os.path.join(HERE, "data/out.vcf") expected_out_vcf = os.path.join(HERE, "data/example.gt.vcf") + class TestIntegration(unittest.TestCase): + def setUp(self): pass @@ -34,13 +38,16 @@ def test_serial_integration(self): cores=None, batch_size=1000) - fail_msg = "did not find output vcf '{}' after running sv_genotype".format(out_vcf) + fail_msg = "did not find output vcf '{}' after running sv_genotype".format( + out_vcf) self.assertTrue(os.path.exists(out_vcf), fail_msg) fail_msg = ("output vcf '{}' " "did not match expected " "output vcf '{}'").format(out_vcf, expected_out_vcf) - self.assertTrue(self.diff(), fail_msg) + if not self.diff(): + exit(1) + #self.assertTrue(self.diff(), fail_msg) def test_parallel_integration(self): with open(in_vcf, "r") as inf, open(out_vcf, "w") as outf: @@ -59,7 +66,8 @@ def test_parallel_integration(self): cores=1, batch_size=1000) - fail_msg = "did not find output vcf '{}' after running sv_genotype".format(out_vcf) + fail_msg = "did not find output vcf '{}' after running sv_genotype".format( + out_vcf) self.assertTrue(os.path.exists(out_vcf), fail_msg) fail_msg = ("output vcf '{}' " diff --git a/tests/test_svtyper.py b/tests/test_svtyper.py index 5463562..70bd977 100644 --- a/tests/test_svtyper.py +++ b/tests/test_svtyper.py @@ -1,7 +1,9 @@ from .context import parsers as p from .context import classic -import unittest, os, subprocess +import unittest +import os +import subprocess HERE = os.path.dirname(__file__) in_vcf = os.path.join(HERE, "data/example.vcf") @@ -10,26 +12,30 @@ out_vcf = os.path.join(HERE, "data/out.vcf") expected_out_vcf = os.path.join(HERE, "data/example.gt.vcf") + class TestCigarParsing(unittest.TestCase): + def test_cigarstring_to_tuple(self): string1 = '5H3S2D1N5M3I2P2X1=' self.assertEqual(p.SplitRead.cigarstring_to_tuple(string1), - [(5, 5), (4, 3), (2, 2), (3, 1), - (0, 5), (1, 3), (6, 2), (8, 2), - (7, 1)]) + [(5, 5), (4, 3), (2, 2), (3, 1), + (0, 5), (1, 3), (6, 2), (8, 2), + (7, 1)]) def test_get_query_pos_from_cigar(self): # forward cigar_string = '2S3M1D2M2I3M3S' cigar = p.SplitRead.cigarstring_to_tuple(cigar_string) - query_pos = p.SplitRead.SplitPiece.get_query_pos_from_cigar(cigar, True) + query_pos = p.SplitRead.SplitPiece.get_query_pos_from_cigar( + cigar, True) self.assertEqual(query_pos.query_start, 3) self.assertEqual(query_pos.query_end, 13) self.assertEqual(query_pos.query_length, 15) # get_query_pos_from_cigar currently modifies the cigar list in place. # that's why the code below doesn't work as intended. - query_pos = p.SplitRead.SplitPiece.get_query_pos_from_cigar(cigar, False) + query_pos = p.SplitRead.SplitPiece.get_query_pos_from_cigar( + cigar, False) self.assertEqual(query_pos.query_start, 2) self.assertEqual(query_pos.query_end, 12) self.assertEqual(query_pos.query_length, 15) @@ -37,25 +43,34 @@ def test_get_query_pos_from_cigar(self): def test_get_reference_end_from_cigar(self): cigar_string = '2S5M3D2M3S' cigar = p.SplitRead.cigarstring_to_tuple(cigar_string) - self.assertEqual(p.SplitRead.get_reference_end_from_cigar(1, cigar), 11) + self.assertEqual( + p.SplitRead.get_reference_end_from_cigar(1, cigar), 11) def test_get_start_diagonal(self): cigar_string = '2S5M3D1I1M3S' - split_piece = p.SplitRead.SplitPiece(1, 25, True, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) + split_piece = p.SplitRead.SplitPiece( + 1, 25, True, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) self.assertEqual(p.SplitRead.get_start_diagonal(split_piece), 23) - split_piece2 = p.SplitRead.SplitPiece(1, 25, False, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) + split_piece2 = p.SplitRead.SplitPiece( + 1, 25, False, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) self.assertEqual(p.SplitRead.get_start_diagonal(split_piece2), 23) def test_get_end_diagonal(self): cigar_string = '2S5M3D2I1M3S' - split_piece = p.SplitRead.SplitPiece(1, 25, True, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) + split_piece = p.SplitRead.SplitPiece( + 1, 25, True, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) split_piece.set_reference_end(34) - self.assertEqual(p.SplitRead.get_end_diagonal(split_piece), 34 - (2 + 8)) - split_piece2 = p.SplitRead.SplitPiece(1, 25, False, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) + self.assertEqual( + p.SplitRead.get_end_diagonal(split_piece), 34 - (2 + 8)) + split_piece2 = p.SplitRead.SplitPiece( + 1, 25, False, p.SplitRead.cigarstring_to_tuple(cigar_string), 60) split_piece2.set_reference_end(34) - self.assertEqual(p.SplitRead.get_end_diagonal(split_piece2), 34 - (2 + 8)) + self.assertEqual( + p.SplitRead.get_end_diagonal(split_piece2), 34 - (2 + 8)) + class TestIntegration(unittest.TestCase): + def setUp(self): pass @@ -66,20 +81,21 @@ def tearDown(self): def test_integration(self): with open(in_vcf, "r") as inf, open(out_vcf, "w") as outf: classic.sv_genotype(bam_string=in_bam, - vcf_in=inf, - vcf_out=outf, - min_aligned=20, - split_weight=1, - disc_weight=1, - num_samp=1000000, - lib_info_path=lib_info_json, - debug=False, - alignment_outpath=None, - ref_fasta=None, - sum_quals=False, - max_reads=None) - - fail_msg = "did not file output vcf '{}' after running sv_genotype".format(out_vcf) + vcf_in=inf, + vcf_out=outf, + min_aligned=20, + split_weight=1, + disc_weight=1, + num_samp=1000000, + lib_info_path=lib_info_json, + debug=False, + alignment_outpath=None, + ref_fasta=None, + sum_quals=False, + max_reads=None) + + fail_msg = "did not file output vcf '{}' after running sv_genotype".format( + out_vcf) self.assertTrue(os.path.exists(out_vcf), fail_msg) fail_msg = ("output vcf '{}' "