diff --git a/Makefile b/Makefile index 24b10f3..312a90a 100644 --- a/Makefile +++ b/Makefile @@ -24,6 +24,7 @@ all: cd benchmarks/pileup; $(MAKE) CC=$(CC) arch=$(ARCH) VTUNE_HOME=$(VTUNE_HOME) cd benchmarks/kmer-cnt; $(MAKE) CXX=$(CXX) arch=$(ARCH) VTUNE_HOME=$(VTUNE_HOME) cd benchmarks/grm/2.0/build_dynamic; $(MAKE) CC=$(CC) CXX=$(CXX) arch=$(ARCH) VTUNE_HOME=$(VTUNE_HOME) MKLROOT=$(MKLROOT) MKL_IOMP5_DIR=$(MKL_IOMP5_DIR) #needs MKL + cd benchmarks/wfa; $(MAKE) CXX=$(CXX) arch=$(ARCH) VTUNE_HOME=$(VTUNE_HOME) gpu: cd benchmarks/abea; $(MAKE) CUDA_LIB=$(CUDA_LIB) @@ -41,3 +42,4 @@ clean: cd benchmarks/pileup; $(MAKE) clean cd benchmarks/kmer-cnt; $(MAKE) clean cd benchmarks/grm/2.0/build_dynamic; $(MAKE) clean + cd benchmarks/wfa; $(MAKE) clean diff --git a/benchmarks/bsw/Makefile b/benchmarks/bsw/Makefile index fbda33f..e6e763f 100644 --- a/benchmarks/bsw/Makefile +++ b/benchmarks/bsw/Makefile @@ -56,7 +56,8 @@ else ifneq ($(arch),) ARCH_FLAGS=$(arch) endif -CXXFLAGS= -DSORT_PAIRS -DENABLE_PREFETCH -DBWA_OTHER_ELE=0 -O3 -std=c++11 -fopenmp $(ARCH_FLAGS) #-mtune=native -march=native +# fno-strict-aliasing is required to prevent errors in vector code. +CXXFLAGS= -DSORT_PAIRS -DENABLE_PREFETCH -DBWA_OTHER_ELE=0 -O3 -std=c++11 -fno-strict-aliasing -fopenmp $(ARCH_FLAGS) #-mtune=native -march=native #VTUNE_HOME= /opt/intel/oneapi/vtune/2021.1.1 INCLUDES= LIBS= -fopenmp -lz -ldl @@ -83,5 +84,5 @@ clean: # DO NOT DELETE -main_banded.o: bandedSWA.h macro.h utils.h -bandedSWA.o: bandedSWA.h macro.h utils.h +main_banded.o: bandedSWA.h macro.h utils.h sse2sve.h +bandedSWA.o: bandedSWA.h macro.h utils.h sse2sve.h diff --git a/benchmarks/bsw/bandedSWA.cpp b/benchmarks/bsw/bandedSWA.cpp index eb7648b..a2c8f80 100644 --- a/benchmarks/bsw/bandedSWA.cpp +++ b/benchmarks/bsw/bandedSWA.cpp @@ -4864,3 +4864,1502 @@ void BandedPairWiseSW::smithWaterman128_8(uint8_t seq1SoA[], } #endif + +#if (__ARM_FEATURE_SVE) + +inline void sortPairsLen(SeqPair *pairArray, int32_t count, SeqPair *tempArray, + int16_t *hist) +{ + int32_t i; + for (i = 0; i <= MAX_SEQ_LEN16; ++i) { + hist[i] = 0; + } + + // __m128i zero128 = _mm_setzero_si128(); + // for(i = 0; i <= MAX_SEQ_LEN16; i+=8) + // { + // _mm_store_si128((__m128i *)(hist + i), zero128); + // } + + for(i = 0; i < count; i++) + { + SeqPair sp = pairArray[i]; + hist[sp.len1]++; + } + + int32_t cumulSum = 0; + for(i = 0; i <= MAX_SEQ_LEN16; i++) + { + int32_t cur = hist[i]; + hist[i] = cumulSum; + cumulSum += cur; + } + + for(i = 0; i < count; i++) + { + SeqPair sp = pairArray[i]; + int32_t pos = hist[sp.len1]; + tempArray[pos] = sp; + hist[sp.len1]++; + } + + for(i = 0; i < count; i++) { + pairArray[i] = tempArray[i]; + } +} + +inline void sortPairsId(SeqPair *pairArray, int32_t first, + int32_t count, SeqPair *tempArray) +{ + + int32_t i; + + for(i = 0; i < count; i++) + { + SeqPair sp = pairArray[i]; + int32_t pos = sp.id - first; + tempArray[pos] = sp; + } + + for(i = 0; i < count; i++) + pairArray[i] = tempArray[i]; +} + + + +// SSE2 +#define PFD 2 +void BandedPairWiseSW::getScores16(SeqPair *pairArray, + uint8_t *seqBufRef, + uint8_t *seqBufQer, + int32_t numPairs, + uint16_t numThreads, + int32_t w) +{ + smithWatermanBatchWrapper16(pairArray, seqBufRef, + seqBufQer, numPairs, + numThreads, w); + +#if MAXI + for (int l=0; l roundNumPairs) last = roundNumPairs; + sortPairsLen(pairArray + first, last - first, myTempArray, myHist); + } + } + _mm_free(hist); +#endif + +#if RDT + st3 = ___rdtsc(); +#endif + + int eb = end_bonus; +// #pragma omp parallel num_threads(numThreads) + { + int32_t i; + uint16_t tid = 0; + uint16_t *mySeq1SoA = seq1SoA + tid * MAX_SEQ_LEN16 * SIMD_WIDTH16; + uint16_t *mySeq2SoA = seq2SoA + tid * MAX_SEQ_LEN16 * SIMD_WIDTH16; + assert(mySeq1SoA != NULL && mySeq2SoA != NULL); + + uint8_t *seq1; + uint8_t *seq2; + uint16_t h0[SIMD_WIDTH16] __attribute__((aligned(256))); + uint16_t qlen[SIMD_WIDTH16] __attribute__((aligned(256))); + int32_t bsize = 0; + + int16_t *H1 = H16_ + tid * SIMD_WIDTH16 * MAX_SEQ_LEN16; + int16_t *H2 = H16__ + tid * SIMD_WIDTH16 * MAX_SEQ_LEN16; + + __m128i zero128 = _mm_setzero_si128(); + __m128i e_ins128 = _mm_set1_epi16(e_ins); + __m128i oe_ins128 = _mm_set1_epi16(o_ins + e_ins); + __m128i o_del128 = _mm_set1_epi16(o_del); + __m128i e_del128 = _mm_set1_epi16(e_del); + __m128i eb_ins128 = _mm_set1_epi16(eb - o_ins); + __m128i eb_del128 = _mm_set1_epi16(eb - o_del); + + int16_t max = 0; + if (max < w_match) max = w_match; + if (max < w_mismatch) max = w_mismatch; + if (max < w_ambig) max = w_ambig; + + int nstart = 0, nend = numPairs; + +// #pragma omp for schedule(dynamic, 128) + for(i = nstart; i < nend; i+=SIMD_WIDTH16) + { + int32_t j, k; + uint16_t maxLen1 = 0; + uint16_t maxLen2 = 0; + bsize = w; + + for(j = 0; j < SIMD_WIDTH16; j++) + { + { // prefetch block + SeqPair spf = pairArray[i + j + PFD]; + _mm_prefetch((const char*) seqBufRef + (int64_t)spf.idr, _MM_HINT_NTA); + _mm_prefetch((const char*) seqBufRef + (int64_t)spf.idr + 64, _MM_HINT_NTA); + } + SeqPair sp = pairArray[i + j]; + h0[j] = sp.h0; + seq1 = seqBufRef + (int64_t)sp.idr; + + for(k = 0; k < sp.len1; k++) + { + mySeq1SoA[k * SIMD_WIDTH16 + j] = (seq1[k] == AMBIG?0xFFFF:seq1[k]); + H2[k * SIMD_WIDTH16 + j] = 0; + } + qlen[j] = sp.len2 * max; + if(maxLen1 < sp.len1) maxLen1 = sp.len1; + } + + for(j = 0; j < SIMD_WIDTH16; j++) + { + SeqPair sp = pairArray[i + j]; + for(k = sp.len1; k <= maxLen1; k++) //removed "=" + { + mySeq1SoA[k * SIMD_WIDTH16 + j] = DUMMY1; + H2[k * SIMD_WIDTH16 + j] = DUMMY1; + } + } +//-------------------- + __m128i h0_128 = _mm_load_si128((__m128i*) h0); + _mm_store_si128((__m128i *) H2, h0_128); + __m128i tmp128 = _mm_sub_epi16(h0_128, o_del128); + + for(k = 1; k < maxLen1; k++) + { + tmp128 = _mm_sub_epi16(tmp128, e_del128); + __m128i tmp128_ = _mm_max_epi16(tmp128, zero128); + _mm_store_si128((__m128i *)(H2 + k* SIMD_WIDTH16), tmp128_); + } +//------------------- + for(j = 0; j < SIMD_WIDTH16; j++) + { + { // prefetch block + SeqPair spf = pairArray[i + j + PFD]; + _mm_prefetch((const char*) seqBufQer + (int64_t)spf.idq, _MM_HINT_NTA); + _mm_prefetch((const char*) seqBufQer + (int64_t)spf.idq + 64, _MM_HINT_NTA); + } + + SeqPair sp = pairArray[i + j]; + seq2 = seqBufQer + (int64_t)sp.idq; + for(k = 0; k < sp.len2; k++) + { + mySeq2SoA[k * SIMD_WIDTH16 + j] = (seq2[k]==AMBIG?0xFFFF:seq2[k]); + H1[k * SIMD_WIDTH16 + j] = 0; + } + if(maxLen2 < sp.len2) maxLen2 = sp.len2; + } + + for(j = 0; j < SIMD_WIDTH16; j++) + { + SeqPair sp = pairArray[i + j]; + for(k = sp.len2; k <= maxLen2; k++) + { + mySeq2SoA[k * SIMD_WIDTH16 + j] = DUMMY2; + H1[k * SIMD_WIDTH16 + j] = 0; + } + } +//------------------------ + _mm_store_si128((__m128i *) H1, h0_128); + svbool_t cmp128 = _mm_cmpgt_epi16(h0_128, oe_ins128); + tmp128 = _mm_sub_epi16(h0_128, oe_ins128); + + tmp128 = _mm_blend_epi16(zero128, tmp128, cmp128); + _mm_store_si128((__m128i *) (H1 + SIMD_WIDTH16), tmp128); + for(k = 2; k < maxLen2; k++) + { + __m128i h1_128 = tmp128; + tmp128 = _mm_sub_epi16(h1_128, e_ins128); + tmp128 = _mm_max_epi16(tmp128, zero128); + _mm_store_si128((__m128i *)(H1 + k*SIMD_WIDTH16), tmp128); + } +//------------------------ + uint16_t myband[SIMD_WIDTH16] __attribute__((aligned(256))); + uint16_t temp[SIMD_WIDTH16] __attribute__((aligned(256))); + { + __m128i qlen128 = _mm_load_si128((__m128i *) qlen); + __m128i sum128 = _mm_add_epi16(qlen128, eb_ins128); + _mm_store_si128((__m128i *) temp, sum128); + for (int l=0; l 1? max_ins : 1; + myband[l] = min_(bsize, max_ins); + } + sum128 = _mm_add_epi16(qlen128, eb_del128); + _mm_store_si128((__m128i *) temp, sum128); + for (int l=0; l 1? max_ins : 1; + myband[l] = min_(myband[l], max_ins); + bsize = bsize < myband[l] ? myband[l] : bsize; + } + } + + smithWaterman128_16(mySeq1SoA, + mySeq2SoA, + maxLen1, + maxLen2, + pairArray + i, + h0, + tid, + numPairs, + zdrop, + bsize, + qlen, + myband); + } + } + +#if RDT + st4 = ___rdtsc(); +#endif + +#if SORT_PAIRS // disbaled in bwa-mem2 (only used in separate benchmark bsw code) + { + // Sort the sequences according to increasing order of id +#pragma omp parallel num_threads(numThreads) + { + int32_t tid = omp_get_thread_num(); + SeqPair *myTempArray = tempArray + tid * SORT_BLOCK_SIZE; + +#pragma omp for + for(ii = 0; ii < roundNumPairs; ii+=SORT_BLOCK_SIZE) + { + int32_t first, last; + first = ii; + last = ii + SORT_BLOCK_SIZE; + if(last > roundNumPairs) last = roundNumPairs; + sortPairsId(pairArray + first, first, last - first, myTempArray); + } + } + _mm_free(tempArray); + } +#endif + +#if RDT + st5 = ___rdtsc(); + setupTicks += st2 - st1; + sort1Ticks += st3 - st2; + swTicks += st4 - st3; + sort2Ticks += st5 - st4; +#endif + // free mem + _mm_free(seq1SoA); + _mm_free(seq2SoA); + + return; +} + +void BandedPairWiseSW::smithWaterman128_16(uint16_t seq1SoA[], + uint16_t seq2SoA[], + uint16_t nrow, + uint16_t ncol, + SeqPair *p, + uint16_t h0[], + uint16_t tid, + int32_t numPairs, + int zdrop, + int32_t w, + uint16_t qlen[], + uint16_t myband[]) +{ + + __m128i match128 = _mm_set1_epi16(this->w_match); + __m128i mismatch128 = _mm_set1_epi16(this->w_mismatch); + __m128i w_ambig_128 = _mm_set1_epi16(this->w_ambig); // ambig penalty + + __m128i e_del128 = _mm_set1_epi16(this->e_del); + __m128i oe_del128 = _mm_set1_epi16(this->o_del + this->e_del); + __m128i e_ins128 = _mm_set1_epi16(this->e_ins); + __m128i oe_ins128 = _mm_set1_epi16(this->o_ins + this->e_ins); + + int16_t *F = F16_ + tid * SIMD_WIDTH16 * MAX_SEQ_LEN16; + int16_t *H_h = H16_ + tid * SIMD_WIDTH16 * MAX_SEQ_LEN16; + int16_t *H_v = H16__ + tid * SIMD_WIDTH16 * MAX_SEQ_LEN16; + + int16_t i, j; + + uint16_t tlen[SIMD_WIDTH16] __attribute((aligned(256))); + uint16_t tail[SIMD_WIDTH16] __attribute((aligned(256))); + uint16_t head[SIMD_WIDTH16] __attribute((aligned(256))); + + int32_t minq = 10000000; + for (int l=0; l i + w + 1) end = i + w + 1; + if (end > ncol) end = ncol; + + h10 = zero128; + if (beg == 0) + h10 = _mm_load_si128((__m128i *)(H_v + (i+1) * SIMD_WIDTH16)); + + __m128i j128 = zero128; + __m128i maxRS1 = zero128; + + __m128i i1_128 = _mm_set1_epi16(i+1); + __m128i y1_128 = zero128; + +#if RDT + uint64_t tim1 = __rdtsc(); +#endif + + __m128i i128, cache128; + __m128i phead128 = head128, ptail128 = tail128; + i128 = _mm_set1_epi16(i); + cache128 = _mm_sub_epi16(i128, myband128); + head128 = _mm_max_epi16(head128, cache128); + cache128 = _mm_add_epi16(i1_128, myband128); + tail128 = _mm_min_epu16(tail128, cache128); + tail128 = _mm_min_epu16(tail128, qlen128); + + // NEW, trimming. + svbool_t cmph = _mm_cmpeq_epi16(head128, phead128); + svbool_t cmpt = _mm_cmpeq_epi16(tail128, ptail128); + cmph = _mm_and_si128(cmph, cmpt); + + //for (int l=beg; l= minq) + { + svbool_t cmp = _mm_cmpeq_epi16(j128, qlen128); + __m128i max_gh = _mm_max_epi16(gscore, h11); + svbool_t cmp_gh = _mm_cmpgt_epi16(gscore, h11); + __m128i tmp128_1 = _mm_blend_epi16(i1_128, max_ie128, cmp_gh); + + __m128i tmp128_t = _mm_blend_epi16(max_ie128, tmp128_1, cmp); + tmp128_1 = _mm_blend_epi16(max_ie128, tmp128_t, exit0); + + max_gh = _mm_blend_epi16(gscore, max_gh, exit0); + max_gh = _mm_blend_epi16(gscore, max_gh, cmp); + + cmp = _mm_cmpgt_epi16(j128, tail128); + max_gh = _mm_blend_epi16(max_gh, gscore, cmp); + max_ie128 = _mm_blend_epi16(tmp128_1, max_ie128, cmp); + gscore = max_gh; + } + } + _mm_store_si128((__m128i *)(H_h + j * SIMD_WIDTH16), h10); + _mm_store_si128((__m128i *)(F + j * SIMD_WIDTH16), zero128); + + /* exit due to zero score by a row */ + __m128i bmaxScore128 = maxScore128; + svbool_t tmp_1 = _mm_cmpeq_epi16(maxRS1, zero128); + //uint16_t cval = _mm_movemask_epi8(tmp) & dmask16; + //if (cval == dmask16) break; + if (!svptest_any(svptrue_b16(),svnot_z(svptrue_b16(),tmp_1))) break; + + //exit0 = _mm_blend_epi16(exit0, zero128, tmp_1); + exit0 = _mm_andnot_si128(tmp_1, exit0); + + __m128i score128 = _mm_max_epi16(maxScore128, maxRS1); + maxScore128 = _mm_blend_epi16(maxScore128, score128, exit0); + + svbool_t cmp = _mm_cmpgt_epi16(maxScore128, bmaxScore128); + y128 = _mm_blend_epi16(y128, y1_128, cmp); + x128 = _mm_blend_epi16(x128, i1_128, cmp); + // max_off calculations + __m128i ab = _mm_subs_epu16(y1_128, i1_128); + __m128i ba = _mm_subs_epu16(i1_128, y1_128); + __m128i tmp = _mm_or_si128(ab, ba); + __m128i bmax_off128 = max_off128; + tmp = _mm_max_epi16(max_off128, tmp); + max_off128 = _mm_blend_epi16(bmax_off128, tmp, cmp); + + // Z-score + __m128i tmpi = _mm_sub_epi16(i1_128, x128); + __m128i tmpj = _mm_sub_epi16(y1_128, y128); + cmp = _mm_cmpgt_epi16(tmpi, tmpj); + score128 = _mm_sub_epi16(maxScore128, maxRS1); + __m128i insdel = _mm_blend_epi16(e_ins128, e_del128, cmp); + __m128i sub_a128 = _mm_sub_epi16(tmpi, tmpj); + __m128i sub_b128 = _mm_sub_epi16(tmpj, tmpi); + tmp = _mm_blend_epi16(sub_b128, sub_a128, cmp); + tmp = _mm_sub_epi16(score128, tmp); + cmp = _mm_cmpgt_epi16(tmp, zdrop128); + exit0 = _mm_andnot_si128(cmp, exit0); + + +#if RDT + prof[DP1][0] += __rdtsc() - tim1; + tim1 = __rdtsc(); +#endif + + /* Narrowing of the band */ + /* From beg */ + int l; + for (l = beg; l < end; l++) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH16)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH16)); + __m128i tmp = _mm_or_si128(f128, h128); + svbool_t tmp_1 = _mm_cmpeq_epi16(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp) & dmask16; + //if (val == dmask16) nbeg = l; + //else break; + //if (!svptest_any(svptrue_b16(),svnot_z(svptrue_b16(),tmp_1))) nbeg = l; + if (svcntp_b16(svptrue_b16(),tmp_1) == SIMD_WIDTH16) nbeg = l; + else break; + } + + /* From end */ + for (l = end; l >= beg; l--) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH16)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH16)); + __m128i tmp = _mm_or_si128(f128, h128); + tmp_1 = _mm_cmpeq_epi16(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp_1) & dmask16; + //if (val != dmask16) break; + //if (svptest_any(svptrue_b16(),svnot_z(svptrue_b16(),tmp_1))) break; + if (svcntp_b16(svptrue_b16(),tmp_1) != SIMD_WIDTH16) break; + } + nend = l + 2 < ncol? l + 2: ncol; + + svbool_t tmpb = ff128_b; + + //__m128i exit1 = _mm_xor_si128(exit0, ff128); + __m128i exit1 = svreinterpret_s64(svdup_s16_z(svnot_z(svptrue_b16(),exit0),0xFFFF)); + __m128i l128 = _mm_set1_epi16(beg); + for (l = beg; l < end; l++) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH16)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH16)); + + __m128i tmp = _mm_or_si128(f128, h128); + tmp = _mm_or_si128(tmp, exit1); + //tmp = _mm_cmpeq_epi16(tmp, zero128); + svbool_t tmp_1 = _mm_cmpeq_epi8(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp_1) & dmask16; + //if (val == 0x00) break; + if (!svptest_any(svptrue_b16(),tmp_1)) break; + + tmp_1 = _mm_and_si128(tmp_1,tmpb); + l128 = _mm_add_epi16(l128, one128); + head128 = _mm_blend_epi16(head128, l128, tmp_1); + + tmpb = tmp_1; + } + // _mm_store_si128((__m128i *) head, head128); + + __m128i index128 = tail128; + tmpb = ff128_b; + + l128 = _mm_set1_epi16(end); + for (l = end; l >= beg; l--) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH16)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH16)); + + __m128i tmp = _mm_or_si128(f128, h128); + tmp = _mm_or_si128(tmp, exit1); + svbool_t tmp_1 = _mm_cmpeq_epi16(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp_1) & dmask16; + //if (val == 0x00) break; + if (!svptest_any(svptrue_b16(),tmp_1)) break; + tmp_1 = _mm_and_si128(tmp_1,tmpb); + l128 = _mm_sub_epi16(l128, one128); + // NEW + index128 = _mm_blend_epi8(index128, l128, tmp_1); + + tmpb = tmp_1; + } + index128 = _mm_add_epi16(index128, two128); + tail128 = _mm_min_epi16(index128, qlen128); + +#if RDT + prof[DP2][0] += __rdtsc() - tim1; +#endif + } + +#if RDT + prof[DP][0] += __rdtsc() - tim; +#endif + + int16_t score[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) score, maxScore128); + + int16_t maxi[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxi, x128); + + int16_t maxj[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxj, y128); + + int16_t max_off_ar[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) max_off_ar, max_off128); + + int16_t gscore_ar[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) gscore_ar, gscore); + + int16_t maxie_ar[SIMD_WIDTH16] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxie_ar, max_ie128); + + for(i = 0; i < SIMD_WIDTH16; i++) + { + p[i].score = score[i]; + p[i].tle = maxi[i]; + p[i].qle = maxj[i]; + p[i].max_off = max_off_ar[i]; + p[i].gscore = gscore_ar[i]; + p[i].gtle = maxie_ar[i]; + } + + return; +} + +/********************************************************************************/ + +// #define PFD 2 // SSE2 +void BandedPairWiseSW::getScores8(SeqPair *pairArray, + uint8_t *seqBufRef, + uint8_t *seqBufQer, + int32_t numPairs, + uint16_t numThreads, + int32_t w) +{ + //assert(SIMD_WIDTH8 == 16 && SIMD_WIDTH16 == 8); + smithWatermanBatchWrapper8(pairArray, seqBufRef, seqBufQer, numPairs, numThreads, w); + + +#if MAXI + printf("Vecor code: Writing output..\n"); + for (int l=0; l roundNumPairs) last = roundNumPairs; + sortPairsLen(pairArray + first, last - first, myTempArray, myHist); + } + } + _mm_free(hist); +#endif + +#if RDT + st3 = ___rdtsc(); +#endif + + int eb = end_bonus; +// #pragma omp parallel num_threads(numThreads) + { + int32_t i; + uint16_t tid = 0; + uint8_t *mySeq1SoA = seq1SoA + tid * MAX_SEQ_LEN8 * SIMD_WIDTH8; + uint8_t *mySeq2SoA = seq2SoA + tid * MAX_SEQ_LEN8 * SIMD_WIDTH8; + assert(mySeq1SoA != NULL && mySeq2SoA != NULL); + uint8_t *seq1; + uint8_t *seq2; + uint8_t h0[SIMD_WIDTH8] __attribute__((aligned(256))); + uint8_t qlen[SIMD_WIDTH8] __attribute__((aligned(256))); + int32_t bsize = 0; + + int8_t *H1 = H8_ + tid * SIMD_WIDTH8 * MAX_SEQ_LEN8; + int8_t *H2 = H8__ + tid * SIMD_WIDTH8 * MAX_SEQ_LEN8; + + __m128i zero128 = _mm_setzero_si128(); + __m128i e_ins128 = _mm_set1_epi8(e_ins); + __m128i oe_ins128 = _mm_set1_epi8(o_ins + e_ins); + __m128i o_del128 = _mm_set1_epi8(o_del); + __m128i e_del128 = _mm_set1_epi8(e_del); + __m128i eb_ins128 = _mm_set1_epi8(eb - o_ins); + __m128i eb_del128 = _mm_set1_epi8(eb - o_del); + + int8_t max = 0; + if (max < w_match) max = w_match; + if (max < w_mismatch) max = w_mismatch; + if (max < w_ambig) max = w_ambig; + + int nstart = 0, nend = numPairs; + +// #pragma omp for schedule(dynamic, 128) + for(i = nstart; i < nend; i+=SIMD_WIDTH8) + { + int32_t j, k; + uint8_t maxLen1 = 0; + uint8_t maxLen2 = 0; + //bsize = 100; + bsize = w; + + for(j = 0; j < SIMD_WIDTH8; j++) + { + SeqPair sp = pairArray[i + j]; + h0[j] = sp.h0; + seq1 = seqBufRef + (int64_t)sp.idr; + + for(k = 0; k < sp.len1; k++) + { + mySeq1SoA[k * SIMD_WIDTH8 + j] = (seq1[k] == AMBIG?0xFF:seq1[k]); + H2[k * SIMD_WIDTH8 + j] = 0; + } + qlen[j] = sp.len2 * max; + if(maxLen1 < sp.len1) maxLen1 = sp.len1; + } + + for(j = 0; j < SIMD_WIDTH8; j++) + { + SeqPair sp = pairArray[i + j]; + for(k = sp.len1; k <= maxLen1; k++) //removed "=" + { + mySeq1SoA[k * SIMD_WIDTH8 + j] = DUMMY1; + H2[k * SIMD_WIDTH8 + j] = DUMMY1; + } + } +//-------------------- + __m128i h0_128 = _mm_load_si128((__m128i*) h0); + _mm_store_si128((__m128i *) H2, h0_128); + __m128i tmp128 = _mm_subs_epu8(h0_128, o_del128); + + for(k = 1; k < maxLen1; k++) + { + tmp128 = _mm_subs_epu8(tmp128, e_del128); + //__m128i tmp128_ = _mm_max_epi8(tmp128, zero128); //epi is not present in SSE2 + _mm_store_si128((__m128i *)(H2 + k* SIMD_WIDTH8), tmp128); + } +//------------------- + + for(j = 0; j < SIMD_WIDTH8; j++) + { + SeqPair sp = pairArray[i + j]; + // seq2 = seqBuf + (2 * (int64_t)sp.id + 1) * MAX_SEQ_LEN; + seq2 = seqBufQer + (int64_t)sp.idq; + + for(k = 0; k < sp.len2; k++) + { + mySeq2SoA[k * SIMD_WIDTH8 + j] = (seq2[k]==AMBIG?0xFF:seq2[k]); + H1[k * SIMD_WIDTH8 + j] = 0; + } + if(maxLen2 < sp.len2) maxLen2 = sp.len2; + } + + //maxLen2 = ((maxLen2 + 3) >> 2) * 4; + + for(j = 0; j < SIMD_WIDTH8; j++) + { + SeqPair sp = pairArray[i + j]; + for(k = sp.len2; k <= maxLen2; k++) + { + mySeq2SoA[k * SIMD_WIDTH8 + j] = DUMMY2; + H1[k * SIMD_WIDTH8 + j] = 0; + } + } +//------------------------ + _mm_store_si128((__m128i *) H1, h0_128); + svbool_t cmp128 = _mm_cmpgt_epi8(h0_128, oe_ins128); + tmp128 = _mm_sub_epi8(h0_128, oe_ins128); + + tmp128 = _mm_blend_epi8(zero128, tmp128, cmp128); + _mm_store_si128((__m128i *) (H1 + SIMD_WIDTH8), tmp128); + for(k = 2; k < maxLen2; k++) + { + // __m128i h1_128 = _mm_load_si128((__m128i *) (H1 + (k-1) * SIMD_WIDTH8)); + __m128i h1_128 = tmp128; + tmp128 = _mm_subs_epu8(h1_128, e_ins128); // modif + // tmp128 = _mm_max_epi8(tmp128, zero128); + _mm_store_si128((__m128i *)(H1 + k*SIMD_WIDTH8), tmp128); + } +//------------------------ + uint8_t myband[SIMD_WIDTH8] __attribute__((aligned(256))); + uint8_t temp[SIMD_WIDTH8] __attribute__((aligned(256))); + { + __m128i qlen128 = _mm_load_si128((__m128i *) qlen); + __m128i sum128 = _mm_add_epi8(qlen128, eb_ins128); + _mm_store_si128((__m128i *) temp, sum128); + for (int l=0; l 1? max_ins : 1; + myband[l] = min_(bsize, max_ins); + } + sum128 = _mm_add_epi8(qlen128, eb_del128); + _mm_store_si128((__m128i *) temp, sum128); + for (int l=0; l 1? max_ins : 1; + myband[l] = min_(myband[l], max_ins); + bsize = bsize < myband[l] ? myband[l] : bsize; + } + } + + smithWaterman128_8(mySeq1SoA, + mySeq2SoA, + maxLen1, + maxLen2, + pairArray + i, + h0, + tid, + numPairs, + zdrop, + bsize, + qlen, + myband); + } + } +#if RDT + st4 = ___rdtsc(); +#endif + +#if SORT_PAIRS // disbaled in bwa-mem2 (only used in separate benchmark bsw code) + { + // Sort the sequences according to increasing order of id +#pragma omp parallel num_threads(numThreads) + { + int32_t tid = omp_get_thread_num(); + SeqPair *myTempArray = tempArray + tid * SORT_BLOCK_SIZE; + +#pragma omp for + for(ii = 0; ii < roundNumPairs; ii+=SORT_BLOCK_SIZE) + { + int32_t first, last; + first = ii; + last = ii + SORT_BLOCK_SIZE; + if(last > roundNumPairs) last = roundNumPairs; + sortPairsId(pairArray + first, first, last - first, myTempArray); + } + } + _mm_free(tempArray); + } +#endif + +#if RDT + st5 = ___rdtsc(); + setupTicks = st2 - st1; + sort1Ticks = st3 - st2; + swTicks = st4 - st3; + sort2Ticks = st5 - st4; +#endif + + // free mem + _mm_free(seq1SoA); + _mm_free(seq2SoA); + + return; +} + +void BandedPairWiseSW::smithWaterman128_8(uint8_t seq1SoA[], + uint8_t seq2SoA[], + uint8_t nrow, + uint8_t ncol, + SeqPair *p, + uint8_t h0[], + uint16_t tid, + int32_t numPairs, + int zdrop, + int32_t w, + uint8_t qlen[], + uint8_t myband[]) +{ + + __m128i match128 = _mm_set1_epi8(this->w_match); + __m128i mismatch128 = _mm_set1_epi8(this->w_mismatch); + __m128i w_ambig_128 = _mm_set1_epi8(this->w_ambig); // ambig penalty + + __m128i e_del128 = _mm_set1_epi8(this->e_del); + __m128i oe_del128 = _mm_set1_epi8(this->o_del + this->e_del); + __m128i e_ins128 = _mm_set1_epi8(this->e_ins); + __m128i oe_ins128 = _mm_set1_epi8(this->o_ins + this->e_ins); + + int8_t *F = F8_ + tid * SIMD_WIDTH8 * MAX_SEQ_LEN8; + int8_t *H_h = H8_ + tid * SIMD_WIDTH8 * MAX_SEQ_LEN8; + int8_t *H_v = H8__ + tid * SIMD_WIDTH8 * MAX_SEQ_LEN8; + + int8_t i, j; + + uint8_t tlen[SIMD_WIDTH8]; + uint8_t tail[SIMD_WIDTH8] __attribute((aligned(256))); + uint8_t head[SIMD_WIDTH8] __attribute((aligned(256))); + + int32_t minq = 10000000; + for (int l=0; l i + w + 1) end = i + w + 1; + if (end > ncol) end = ncol; + + h10 = zero128; + if (beg == 0) + h10 = _mm_load_si128((__m128i *)(H_v + (i+1) * SIMD_WIDTH8)); + + __m128i j128 = zero128; + __m128i maxRS1 = zero128; + + __m128i i1_128 = _mm_set1_epi8(i+1); + __m128i y1_128 = zero128; + +#if RDT + uint64_t tim1 = __rdtsc(); +#endif + + // Banding + __m128i i128, cache128; + __m128i phead128 = head128, ptail128 = tail128; + i128 = _mm_set1_epi8(i); + cache128 = _mm_subs_epu8(i128, myband128); // modif + head128 = _mm_max_epu8(head128, cache128); // epi8 not present + cache128 = _mm_add_epi8(i1_128, myband128); + tail128 = _mm_min_epu8(tail128, cache128); + tail128 = _mm_min_epu8(tail128, qlen128); + + // NEW, trimming. + svbool_t cmph = _mm_cmpeq_epi8(head128, phead128); + svbool_t cmpt = _mm_cmpeq_epi8(tail128, ptail128); + cmph = _mm_and_si128(cmph, cmpt); + + //for (int l=beg; l= minq) + { + svbool_t cmp = _mm_cmpeq_epi8(j128, qlen128); + svbool_t cmp_gh = _mm_cmpgt_epi8(gscore, h11); + __m128i tmp128_1 = _mm_blend_epi8(i1_128, max_ie128, cmp_gh); + __m128i max_gh = _mm_blend_epi8(h11, gscore, cmp_gh); + + tmp128_1 = _mm_blend_epi8(max_ie128, tmp128_1, cmp); + tmp128_1 = _mm_blend_epi8(max_ie128, tmp128_1, exit0); + + max_gh = _mm_blend_epi8(gscore, max_gh, exit0); + max_gh = _mm_blend_epi8(gscore, max_gh, cmp); + + cmp = _mm_cmpgt_epi8(j128, tail128); + max_gh = _mm_blend_epi8(max_gh, gscore, cmp); + max_ie128 = _mm_blend_epi8(tmp128_1, max_ie128, cmp); + gscore = max_gh; + } + } + _mm_store_si128((__m128i *)(H_h + j * SIMD_WIDTH8), h10); + _mm_store_si128((__m128i *)(F + j * SIMD_WIDTH8), zero128); + + + /* exit due to zero score by a row */ + //uint16_t cval = 0; + __m128i bmaxScore128 = maxScore128; + svbool_t tmp_1 = _mm_cmpeq_epi8(maxRS1, zero128); + //cval = _mm_movemask_epi8(tmp_1); + //if (cval == 0xFFFF) break; + if (!svptest_any(svptrue_b8(),svnot_z(svptrue_b8(),tmp_1))) break; + + //exit0 = _mm_blend_epi8(exit0, zero128, tmp_1); + exit0 = _mm_andnot_si128(tmp_1, exit0); + + __m128i score128 = _mm_max_epu8(maxScore128, maxRS1); // epi8 not present, modif + maxScore128 = _mm_blend_epi8(maxScore128, score128, exit0); + + svbool_t cmp = _mm_cmpgt_epi8(maxScore128, bmaxScore128); + y128 = _mm_blend_epi8(y128, y1_128, cmp); + x128 = _mm_blend_epi8(x128, i1_128, cmp); + + // max_off calculations + __m128i ab = _mm_subs_epu8(y1_128, i1_128); + __m128i ba = _mm_subs_epu8(i1_128, y1_128); + __m128i tmp = _mm_or_si128(ab, ba); + + __m128i bmax_off128 = max_off128; + tmp = _mm_max_epu8(max_off128, tmp); // modif + max_off128 = _mm_blend_epi8(bmax_off128, tmp, cmp); + + // Z-score + __m128i tmpi = _mm_sub_epi8(i1_128, x128); + __m128i tmpj = _mm_sub_epi8(y1_128, y128); + cmp = _mm_cmpgt_epi8(tmpi, tmpj); + score128 = _mm_sub_epi8(maxScore128, maxRS1); + __m128i insdel = _mm_blend_epi8(e_ins128, e_del128, cmp); + __m128i sub_a128 = _mm_sub_epi8(tmpi, tmpj); + __m128i sub_b128 = _mm_sub_epi8(tmpj, tmpi); + tmp = _mm_blend_epi8(sub_b128, sub_a128, cmp); + tmp = _mm_sub_epi8(score128, tmp); + cmp = _mm_cmpgt_epi8(tmp, zdrop128); + exit0 = _mm_andnot_si128(cmp, exit0); + + +#if RDT + prof[DP1][0] += __rdtsc() - tim1; + tim1 = __rdtsc(); +#endif + + /* Narrowing of the band */ + /* From beg */ + int l; + for (l = beg; l < end; l++) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH8)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH8)); + __m128i tmp = _mm_or_si128(f128, h128); + svbool_t tmp_1 = _mm_cmpeq_epi8(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp_1); + //if (val == 0xFFFF) nbeg = l; + //else break; + //if (!svptest_any(svptrue_b8(),svnot_z(svptrue_b8(),tmp_1))) nbeg = l; + if (svcntp_b8(svptrue_b8(),tmp_1) == SIMD_WIDTH8) nbeg = l; + else break; + } + + /* From end */ + for (l = end; l >= beg; l--) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH8)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH8)); + __m128i tmp = _mm_or_si128(f128, h128); + svbool_t tmp_1 = _mm_cmpeq_epi8(tmp, zero128); + //uint16_t val = _mm_movemask_epi8(tmp_1); + //if (val != 0xFFFF) break; + //if (svptest_any(svptrue_b8(),svnot_z(svptrue_b8(),tmp_1))) break; + if (svcntp_b8(svptrue_b8(),tmp_1) != SIMD_WIDTH8) break; + } + // int pnend =nend; + nend = l + 2 < ncol? l + 2: ncol; + svbool_t tmpb = ff128_b; + + //__m128i exit1 = _mm_xor_si128(exit0, ff128); + __m128i exit1 = svreinterpret_s64(svdup_u8_z(svnot_z(svptrue_b8(),exit0),0xFF)); + __m128i l128 = _mm_set1_epi8(beg); + + for (l = beg; l < end; l++) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH8)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH8)); + + __m128i tmp = _mm_or_si128(f128, h128); + //tmp = _mm_or_si128(tmp, _mm_xor_si128(exit0, ff128)); + tmp = _mm_or_si128(tmp, exit1); + svbool_t tmp_1 = _mm_cmpeq_epi8(tmp, zero128); + //uint32_t val = _mm_movemask_epi8(tmp_1); + //if (val == 0x00) break; + if (!svptest_any(svptrue_b8(),tmp_1)) break; + + tmp_1 = _mm_and_si128(tmp_1,tmpb); + + l128 = _mm_add_epi8(l128, one128); + head128 = _mm_blend_epi8(head128, l128, tmp_1); + + tmpb = tmp_1; + } + + __m128i index128 = tail128; + tmpb = ff128_b; + + l128 = _mm_set1_epi8(end); + for (l = end; l >= beg; l--) + { + __m128i f128 = _mm_load_si128((__m128i *)(F + l * SIMD_WIDTH8)); + __m128i h128 = _mm_load_si128((__m128i *)(H_h + l * SIMD_WIDTH8)); + + __m128i tmp = _mm_or_si128(f128, h128); + tmp = _mm_or_si128(tmp, exit1); + svbool_t tmp_1 = _mm_cmpeq_epi8(tmp, zero128); + //uint32_t val = _mm_movemask_epi8(tmp_1); + //if (val == 0x00) break; + if (!svptest_any(svptrue_b8(),tmp_1)) break; + + tmp_1 = _mm_and_si128(tmp_1,tmpb); + l128 = _mm_sub_epi8(l128, one128); + // NEW + index128 = _mm_blend_epi8(index128, l128, tmp_1); + + tmpb = tmp_1; + } + index128 = _mm_add_epi8(index128, two128); + tail128 = _mm_min_epu8(index128, qlen128); // epi8 not present, modif + +#if RDT + prof[DP2][0] += __rdtsc() - tim1; +#endif + } + +#if RDT + prof[DP][0] += __rdtsc() - tim; +#endif + + int8_t score[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) score, maxScore128); + + int8_t maxi[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxi, x128); + + int8_t maxj[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxj, y128); + + int8_t max_off_ar[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) max_off_ar, max_off128); + + int8_t gscore_ar[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) gscore_ar, gscore); + + int8_t maxie_ar[SIMD_WIDTH8] __attribute((aligned(256))); + _mm_store_si128((__m128i *) maxie_ar, max_ie128); + + for(i = 0; i < SIMD_WIDTH8; i++) + { + p[i].score = score[i]; + p[i].tle = maxi[i]; + p[i].qle = maxj[i]; + p[i].max_off = max_off_ar[i]; + p[i].gscore = gscore_ar[i]; + p[i].gtle = maxie_ar[i]; + } + + return; +} + +#endif diff --git a/benchmarks/bsw/bandedSWA.h b/benchmarks/bsw/bandedSWA.h index 9678196..1823ba3 100644 --- a/benchmarks/bsw/bandedSWA.h +++ b/benchmarks/bsw/bandedSWA.h @@ -39,10 +39,13 @@ Authors: Vasimuddin Md ; Sanchit Misra -#else +#elif ((!__AVX512BW__) & (!__AVX2__) & (__SSE2__)) #include // for SSE4.1 #define __mmask8 uint8_t #define __mmask16 uint16_t +#elif (__ARM_FEATURE_SVE) +#include "sse2sve.h" +#include #endif #define MAX_SEQ_LEN_REF 256 @@ -73,8 +76,13 @@ Authors: Vasimuddin Md ; Sanchit Misra + +#if defined(__GNUC__) || defined(__clang__) + +#pragma push_macro("FORCE_INLINE") +#pragma push_macro("ALIGN_STRUCT") +#define FORCE_INLINE static inline __attribute__((always_inline)) +#define ALIGN_STRUCT(x) __attribute__((aligned(x))) + +#else + +#error "Macro name collisions may happens with unknown compiler" +#ifdef FORCE_INLINE +#undef FORCE_INLINE +#endif +#define FORCE_INLINE static inline +#ifndef ALIGN_STRUCT +#define ALIGN_STRUCT(x) __declspec(align(x)) +#endif + +#endif + +#include +#include +#include + +typedef svint64_t __m128i; + +#define _MM_HINT_NTA 0 +#define _MM_HINT_T0 1 + +FORCE_INLINE void* _mm_malloc(size_t size, size_t align) +{ + //return aligned_alloc(align, size); + return aligned_alloc(256, size); +} + +FORCE_INLINE void _mm_free(void *ptr) +{ + free(ptr); +} + +FORCE_INLINE uint64_t __rdtsc() +{ + uint64_t virtual_timer_value; + asm volatile ("mrs %0, cntvct_el0":"=r" (virtual_timer_value)); + //virtual_timer_value = (unsigned long) time(0); + return virtual_timer_value; +} + +FORCE_INLINE void _mm_prefetch(const char* rseq, int i) { + __builtin_prefetch((void*) rseq); +} + +FORCE_INLINE void _mm_prefetch(const int8_t* rseq, int i) { + __builtin_prefetch((void*) rseq); +} + +FORCE_INLINE void _mm_prefetch(const uint32_t* rseq, int i) { + __builtin_prefetch((void*) rseq); +} + +FORCE_INLINE svint64_t _mm_setzero_si128() { + return svdup_s64(0); +} + +FORCE_INLINE svint64_t _mm_set1_epi8(int8_t data) { + return svreinterpret_s64(svdup_s8(data)); +} + +FORCE_INLINE svint64_t _mm_set1_epi16(int16_t data) { + return svreinterpret_s64(svdup_s16(data)); +} + +FORCE_INLINE svint64_t _mm_blend_epi8(svint64_t a, svint64_t b, svbool_t mask) { + svint8_t a_aux = svreinterpret_s8(a); + svint8_t b_aux = svreinterpret_s8(b); + svint8_t r_aux = svsel(mask,b_aux,a_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_blend_epi16(svint64_t a, svint64_t b, svbool_t mask) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svsel(mask,b_aux,a_aux); + return svreinterpret_s64(r_aux); +} + +// ----- ARITHMETIC OPS ----- + +FORCE_INLINE svint64_t _mm_add_epi8(svint64_t a, svint64_t b) { + svint8_t a_aux = svreinterpret_s8(a); + svint8_t b_aux = svreinterpret_s8(b); + svint8_t r_aux = svadd_x(svptrue_b8(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_add_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svadd_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_adds_epu8(svint64_t a, svint64_t b) { + svuint8_t a_aux = svreinterpret_u8(a); + svuint8_t b_aux = svreinterpret_u8(b); + svuint8_t r_aux = svqadd(a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_adds_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svqadd(a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_sub_epi8(svint64_t a, svint64_t b) { + svint8_t a_aux = svreinterpret_s8(a); + svint8_t b_aux = svreinterpret_s8(b); + svint8_t r_aux = svsub_x(svptrue_b8(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_sub_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svsub_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_subs_epu8(svint64_t a, svint64_t b) { + svuint8_t a_aux = svreinterpret_u8(a); + svuint8_t b_aux = svreinterpret_u8(b); + svuint8_t r_aux = svqsub(a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_subs_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svqsub(a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_subs_epu16(svint64_t a, svint64_t b) { + svuint16_t a_aux = svreinterpret_u16(a); + svuint16_t b_aux = svreinterpret_u16(b); + svuint16_t r_aux = svqsub(a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_max_epu8(svint64_t a, svint64_t b) { + svuint8_t a_aux = svreinterpret_u8(a); + svuint8_t b_aux = svreinterpret_u8(b); + svuint8_t r_aux = svmax_x(svptrue_b8(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_max_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svmax_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_max_epu16(svint64_t a, svint64_t b) { + svuint16_t a_aux = svreinterpret_u16(a); + svuint16_t b_aux = svreinterpret_u16(b); + svuint16_t r_aux = svmax_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_min_epu8(svint64_t a, svint64_t b) { + svuint8_t a_aux = svreinterpret_u8(a); + svuint8_t b_aux = svreinterpret_u8(b); + svuint8_t r_aux = svmin_x(svptrue_b8(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_min_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + svint16_t r_aux = svmin_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +FORCE_INLINE svint64_t _mm_min_epu16(svint64_t a, svint64_t b) { + svuint16_t a_aux = svreinterpret_u16(a); + svuint16_t b_aux = svreinterpret_u16(b); + svuint16_t r_aux = svmin_x(svptrue_b16(),a_aux,b_aux); + return svreinterpret_s64(r_aux); +} + +// ----- BIT WISE OPS ----- + +FORCE_INLINE svbool_t _mm_and_si128(svbool_t a, svbool_t b) { + return svand_z(svptrue_b8(),a,b); +} + +FORCE_INLINE svint64_t _mm_and_si128(svint64_t a, svint64_t b) { + return svand_x(svptrue_b64(),a,b); +} + +FORCE_INLINE svbool_t _mm_or_si128(svbool_t a, svbool_t b) { + return svorr_z(svptrue_b8(),a,b); +} + +FORCE_INLINE svint64_t _mm_or_si128(svint64_t a, svint64_t b) { + return svorr_x(svptrue_b64(),a,b); +} + +FORCE_INLINE svint64_t _mm_xor_si128(svint64_t a, svint64_t b) { + return sveor_x(svptrue_b64(),a,b); +} + +FORCE_INLINE svint64_t _mm_andnot_si128(svint64_t a, svint64_t b) { + return svbic_x(svptrue_b64(),b,a); +} + +FORCE_INLINE svbool_t _mm_andnot_si128(svbool_t a, svbool_t b) { + return svbic_z(svptrue_b8(),b,a); +} + +// ----- CMP OPS ----- + +FORCE_INLINE svbool_t _mm_cmpeq_epi8(svint64_t a, svint64_t b) { + svint8_t a_aux = svreinterpret_s8(a); + svint8_t b_aux = svreinterpret_s8(b); + return svcmpeq(svptrue_b8(),a_aux,b_aux); +} + +FORCE_INLINE svbool_t _mm_cmpeq_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + return svcmpeq(svptrue_b16(),a_aux,b_aux); +} + +FORCE_INLINE svbool_t _mm_cmpgt_epi8(svint64_t a, svint64_t b) { + svint8_t a_aux = svreinterpret_s8(a); + svint8_t b_aux = svreinterpret_s8(b); + return svcmpgt(svptrue_b8(),a_aux,b_aux); +} + +FORCE_INLINE svbool_t _mm_cmpgt_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + return svcmpgt(svptrue_b16(),a_aux,b_aux); +} + +FORCE_INLINE svbool_t _mm_cmpge_epi16(svint64_t a, svint64_t b) { + svint16_t a_aux = svreinterpret_s16(a); + svint16_t b_aux = svreinterpret_s16(b); + return svcmpge(svptrue_b16(),a_aux,b_aux); +} + +// ----- MEM OPS ----- + +FORCE_INLINE svint64_t _mm_load_si128(svint64_t * dir) { + return svld1(svptrue_b64(),(int64_t*)dir); +} + +FORCE_INLINE void _mm_store_si128(svint64_t * dir, svint64_t reg) { + svst1(svptrue_b64(),(int64_t*)dir,reg); +} + +#if defined(__GNUC__) || defined(__clang__) +#pragma pop_macro("ALIGN_STRUCT") +#pragma pop_macro("FORCE_INLINE") +#endif + +#endif + +#endif diff --git a/benchmarks/chain/Makefile b/benchmarks/chain/Makefile index aa50562..bf7a985 100644 --- a/benchmarks/chain/Makefile +++ b/benchmarks/chain/Makefile @@ -50,9 +50,9 @@ endif #COMPILE_FLAGS = -std=c++11 -Wall -Wextra -g -O3 -fopenmp -xAVX2 -axAVX2 #VTUNE_HOME= /opt/intel/vtune_profiler COMPILE_FLAGS = -std=c++11 -Wall -Wextra -g -O3 -fopenmp $(ARCH_FLAGS) -INCLUDES = -I../../tools/minimap2 +# INCLUDES = -I../../tools/minimap2 # Space-separated pkg-config libraries used by this project -LIBS = -L../../tools/minimap2 -lminimap2 -ldl +# LIBS = -L../../tools/minimap2 -lminimap2 -ldl ifneq ($(VTUNE_HOME),) diff --git a/benchmarks/chain/src/host_kernel.cpp b/benchmarks/chain/src/host_kernel.cpp index 654ba85..01850c8 100644 --- a/benchmarks/chain/src/host_kernel.cpp +++ b/benchmarks/chain/src/host_kernel.cpp @@ -5,11 +5,13 @@ #include "omp.h" #include "host_kernel.h" #include "common.h" +#if 0 #include "minimap.h" #include "mmpriv.h" #include "kalloc.h" +#endif -static const char LogTable256[256] = { +static const signed char LogTable256[256] = { #define LT(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, LT(4), LT(5), LT(5), LT(6), LT(6), LT(6), LT(6), diff --git a/benchmarks/fast-chain/Makefile b/benchmarks/fast-chain/Makefile new file mode 100644 index 0000000..15416cd --- /dev/null +++ b/benchmarks/fast-chain/Makefile @@ -0,0 +1,98 @@ +CC = gcc +CXX = g++ + +# path # +SRC_PATH = src +BUILD_PATH = build +BIN_PATH = $(BUILD_PATH)/bin + +# executable # +BIN_NAME = chain + +# extensions # +SRC_EXT = cpp + +# code lists # +# Find all source files in the source directory, sorted by +# most recently modified +SOURCES = $(shell find $(SRC_PATH) -name '*.$(SRC_EXT)' | sort -k 1nr | cut -f2-) +# Set the object file names, with the source directory stripped +# from the path, and the build path prepended in its place +OBJECTS = $(SOURCES:$(SRC_PATH)/%.$(SRC_EXT)=$(BUILD_PATH)/%.o) +# Set the dependency files that will be used to add header dependencies +DEPS = $(OBJECTS:.o=.d) + +ifeq ($(arch),sse41) + ARCH_FLAGS=-msse4.1 +else ifeq ($(arch),avx2) + ifeq ($(CXX), icpc) + ARCH_FLAGS=-march=core-avx2 #-xCORE-AVX2 + else + ARCH_FLAGS=-mavx2 + endif +else ifeq ($(arch),avx512) + ifeq ($(CXX), icpc) + ARCH_FLAGS=-xCORE-AVX512 + else + ARCH_FLAGS=-mavx512bw + endif +else ifeq ($(arch),native) + ARCH_FLAGS=-march=native +else ifneq ($(arch),) + ## To provide a different architecture flag like -march=core-avx2. + ARCH_FLAGS=$(arch) +endif + +# Intel VTune Profiler +VTUNE_ANALYSIS=0 + +ifeq ($(VTUNE_ANALYSIS),1) + INCLUDES+=$(VTUNE_INCLUDES) + LIBS+=$(VTUNE_LDFLAGS) +endif + +COMPILE_FLAGS=-std=c++11 -Wall -Wextra -O3 -fopenmp -fno-strict-aliasing $(ARCH_FLAGS) -DVTUNE_ANALYSIS=$(VTUNE_ANALYSIS) # -g + +.PHONY: default_target +default_target: release + +.PHONY: release +release: export CXXFLAGS := $(CXXFLAGS) $(COMPILE_FLAGS) +release: dirs all + +.PHONY: dirs +dirs: + @echo "Creating directories" + @mkdir -p $(dir $(OBJECTS)) + @mkdir -p $(BIN_PATH) + +.PHONY: clean +clean: + @echo "Deleting $(BIN_NAME) symlink" + @$(RM) $(BIN_NAME) + @echo "Deleting directories" + @$(RM) -r $(BUILD_PATH) + @$(RM) -r $(BIN_PATH) + +# checks the executable and symlinks to the output +.PHONY: all +all: $(BIN_PATH)/$(BIN_NAME) + @echo "Making symlink: $(BIN_NAME) -> $<" + @$(RM) $(BIN_NAME) + @ln -s $(BIN_PATH)/$(BIN_NAME) $(BIN_NAME) + +# Creation of the executable +$(BIN_PATH)/$(BIN_NAME): $(OBJECTS) + @echo "Linking: $@" + #$(CXX) -O3 -fopenmp -xAVX2 -axAVX2 $(OBJECTS) -o $@ + $(CXX) -O3 -fopenmp $(ARCH_FLAGS) $(OBJECTS) $(INCLUDES) $(LIBS) -o $@ + +# Add dependency files, if they exist +-include $(DEPS) + +# Source file rules +# After the first compilation they will be joined with the rules from the +# dependency files to provide header dependencies +$(BUILD_PATH)/%.o: $(SRC_PATH)/%.$(SRC_EXT) + @echo "Compiling: $< -> $@" + $(CXX) $(CXXFLAGS) $(INCLUDES) $(LIBS) -MP -MMD -c $< -o $@ diff --git a/benchmarks/fast-chain/README.md b/benchmarks/fast-chain/README.md new file mode 100644 index 0000000..695a3db --- /dev/null +++ b/benchmarks/fast-chain/README.md @@ -0,0 +1,15 @@ +`fast-chain` uses the same license as [Minimap2](https://github.com/lh3/minimap2/tree/fast-contrib). + +If you find `fast-chain` useful, please cite: + +``` +@article{Kalikar2021, + doi = {10.1101/2021.07.21.453294}, + url = {https://doi.org/10.1101/2021.07.21.453294}, + year = {2021}, + month = jul, + publisher = {Cold Spring Harbor Laboratory}, + author = {Saurabh Kalikar and Chirag Jain and Vasimuddin Md and Sanchit Misra}, + title = {Accelerating long-read analysis on modern {CPUs}} +} +``` diff --git a/benchmarks/fast-chain/src/common.cpp b/benchmarks/fast-chain/src/common.cpp new file mode 100644 index 0000000..0b1a81d --- /dev/null +++ b/benchmarks/fast-chain/src/common.cpp @@ -0,0 +1,4 @@ +#include "common.h" + +const score_t NEG_INF_SCORE = -0x3FFFFFFF; + diff --git a/benchmarks/fast-chain/src/common.h b/benchmarks/fast-chain/src/common.h new file mode 100644 index 0000000..91d95f6 --- /dev/null +++ b/benchmarks/fast-chain/src/common.h @@ -0,0 +1,11 @@ +#ifndef COMMON_H +#define COMMON_H + +#include +#include "host_data.h" + +#define BACK_SEARCH_COUNT 65 +extern const score_t NEG_INF_SCORE; + + +#endif // COMMON_H diff --git a/benchmarks/fast-chain/src/host_data.h b/benchmarks/fast-chain/src/host_data.h new file mode 100644 index 0000000..0ebdca1 --- /dev/null +++ b/benchmarks/fast-chain/src/host_data.h @@ -0,0 +1,51 @@ +#ifndef HOST_INPUT_H +#define HOST_INPUT_H + +#include +#include + +typedef int64_t anchor_idx_t; +typedef uint32_t tag_t; +typedef int32_t loc_t; +typedef int32_t loc_dist_t; +typedef int32_t score_t; +typedef int32_t parent_t; +typedef int32_t target_t; +typedef int32_t peak_score_t; + +#define ANCHOR_NULL (anchor_idx_t)(-1) + +// struct anchor_t { +// uint64_t x; +// uint64_t y; +// }; + +// struct call_t { +// anchor_idx_t n; +// float avg_qspan; +// int max_dist_x, max_dist_y, bw, n_segs; +// std::vector anchors; +// }; + +struct call_t { + anchor_idx_t n; + float avg_qspan; + int max_dist_x, max_dist_y, bw, n_segs; + std::vector anchors_x; + std::vector anchors_x32; // WRONG in some cases. + + std::vector anchors_y; + std::vector anchors_y32; // WRONG in some cases. + + std::vector q_spans; +}; + +struct return_t { + anchor_idx_t n; + std::vector scores; + std::vector parents; + std::vector targets; + std::vector peak_scores; +}; + +#endif // HOST_INPUT_H diff --git a/benchmarks/fast-chain/src/host_data_io.cpp b/benchmarks/fast-chain/src/host_data_io.cpp new file mode 100644 index 0000000..6ad4b74 --- /dev/null +++ b/benchmarks/fast-chain/src/host_data_io.cpp @@ -0,0 +1,84 @@ +#include "host_data_io.h" +#include "host_data.h" + +void skip_to_EOR(FILE *fp) { + const char *loc = "EOR"; + while (*loc != '\0') { + if (fgetc(fp) == *loc) { + loc++; + } + } +} + +call_t read_call(FILE *fp) { + call_t call; + + long long n; + float avg_qspan; + int max_dist_x, max_dist_y, bw, n_segs; + + int t = fscanf(fp, "%lld%f%d%d%d%d", + &n, &avg_qspan, &max_dist_x, &max_dist_y, &bw, &n_segs); + // fprintf(stderr, "read %d arguments\n", t); + if (t != 6) { + call.n = ANCHOR_NULL; + call.avg_qspan = .0; + return call; + } + + call.n = n; + call.avg_qspan = avg_qspan; + call.max_dist_x = max_dist_x; + call.max_dist_y = max_dist_y; + call.bw = bw; + call.n_segs = n_segs; + // fprintf(stderr, "%lld\t%f\t%d\t%d\t%d\t%d\n", n, avg_qspan, max_dist_x, max_dist_y, bw, n_segs); + + // call.anchors.resize(call.n); + + // for (anchor_idx_t i = 0; i < call.n; i++) { + // uint64_t x, y; + // fscanf(fp, "%llu%llu", &x, &y); + + // anchor_t t; + // t.x = x; t.y = y; + + // call.anchors[i] = t; + // } + + // Some extra space for vectorization with intrinsics. + call.anchors_x.resize(call.n + 64); + call.anchors_x32.resize(call.n + 64); + + call.anchors_y.resize(call.n + 64); + call.anchors_y32.resize(call.n + 64); + + call.q_spans.resize(call.n + 64); + + // Add padding before and after the data. + for (anchor_idx_t i = 32; i < call.n + 32; i++) { + uint64_t x, y; + fscanf(fp, "%llu%llu", &x, &y); + + call.anchors_x[i] = x; + call.anchors_x32[i] = x; + + call.anchors_y[i] = y; + call.anchors_y32[i] = y; + + call.q_spans[i] = y >> 32 & 0xff; + } + + skip_to_EOR(fp); + return call; +} + +void print_return(FILE *fp, const return_t &data) +{ + fprintf(fp, "%lld\n", (long long)data.n); + // Keep padding in mind (32 before and 32 after) + for (anchor_idx_t i = 32; i < data.n + 32; i++) { + fprintf(fp, "%d\t%d\n", (int)data.scores[i], (int)data.parents[i]); + } + fprintf(fp, "EOR\n"); +} diff --git a/benchmarks/fast-chain/src/host_data_io.h b/benchmarks/fast-chain/src/host_data_io.h new file mode 100644 index 0000000..4768007 --- /dev/null +++ b/benchmarks/fast-chain/src/host_data_io.h @@ -0,0 +1,10 @@ +#ifndef HOST_KERNEL_IO_H +#define HOST_KERNEL_IO_H + +#include +#include "host_data.h" + +call_t read_call(FILE *fp); +void print_return(FILE *fp, const return_t &data); + +#endif // HOST_KERNEL_IO_H diff --git a/benchmarks/fast-chain/src/host_kernel.cpp b/benchmarks/fast-chain/src/host_kernel.cpp new file mode 100644 index 0000000..e824d85 --- /dev/null +++ b/benchmarks/fast-chain/src/host_kernel.cpp @@ -0,0 +1,878 @@ +#include +#include +#include +#include +#include +#include +#include "omp.h" +#include "host_kernel.h" +#include "common.h" +// #include "minimap.h" +// #include "mmpriv.h" +// #include "kalloc.h" + +#ifdef __AVX2__ + #include +#endif +#ifdef __AVX512BW__ + // #include +#endif +#ifdef __ARM_FEATURE_SVE + #include + #define VL svcntw() +#endif + +static const signed char LogTable256_dp_lib[256] = { +#define LT(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n + -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + LT(4), LT(5), LT(5), LT(6), LT(6), LT(6), LT(6), + LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7) +}; + +static inline int ilog2_32_dp_lib(uint32_t v) { + uint32_t t, tt; + if ((tt = v >> 16)) { + return (t = tt >> 8) ? 24 + LogTable256_dp_lib[t] : 16 + LogTable256_dp_lib[tt]; + } + return (t = v >> 8) ? 8 + LogTable256_dp_lib[t] : LogTable256_dp_lib[v]; +} + +static inline int32_t ilog2_32(uint32_t v) { + constexpr uint32_t base = 31; + const uint32_t leading_zeros = __builtin_clz(v); + + return base - leading_zeros; +} + +const int BACKSEARCH = 65; +#define MM_SEED_SEG_SHIFT 48 +#define MM_SEED_SEG_MASK (0xffULL<<(MM_SEED_SEG_SHIFT)) + +#if 0 +void print_vector(svint64_t v) { + int64_t _v[VL]; + svst1_s64(svptrue_b64(),_v,v); + for (uint64_t i = 0; i < VL; i++) { + printf("%ld,",_v[i]); + } + printf("\n"); +} + +void print_vector_f(svfloat64_t v) { + float64_t _v[VL]; + svst1_f64(svptrue_b64(),_v,v); + for (uint64_t i = 0; i < VL; i++) { + printf("%f,",_v[i]); + } + printf("\n"); +} +#endif + +#ifdef __AVX512BW__ + inline __m512i get_gap_cost_vectorized_int32(__m512i dd_v, float avg_qspan, float gap_scale) { + //Vectorized log2 + uint32_t base = 31; + __m512i vbase = _mm512_set1_epi32(base); + __mmask16 msk = 0xFFFF; + __m512i vout = _mm512_mask_lzcnt_epi32(dd_v, msk, dd_v); + //int res = base - temp[0]; + __m512i r_v = _mm512_sub_epi32(vbase, vout); + + // log_dd = dd?ilog2:0; log_dd>>1 + __m512i zero_v = _mm512_setzero_si512(); + __mmask16 neg_mask = _mm512_cmpneq_epi32_mask(dd_v, zero_v); + __m512i log_dd_v = _mm512_srli_epi32(_mm512_maskz_or_epi32(neg_mask, r_v, zero_v), 1); + + //dd * 0.01*avg_qspan + float avg_qspan_val = 0.01 * avg_qspan; + __m512 avg_qspan_v = _mm512_set1_ps(avg_qspan_val); + __m512i cost = _mm512_cvt_roundps_epi32(_mm512_mul_ps(_mm512_cvtepi32_ps(dd_v), avg_qspan_v), _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC); + + // gap_cost = (dd * 0.01*avg_qspan) + (log_dd>>1 + 0.499) + __m512i gap_cost_v = (_mm512_add_epi32(cost, log_dd_v)); + + return gap_cost_v; + } +#elif __AVX2__ +//#endif + inline __m256i get_gap_cost_vectorized_int32(__m256i dd_v, float avg_qspan, float gap_scale) { + + /* + //Vectorized log2 + uint32_t base = 31; + __m256i vbase = _mm256_set1_epi32(base); + __mmask8 msk = 0xFF; + __m256i vout = _mm256_mask_lzcnt_epi32(dd_v, msk, dd_v); + //int res = base - temp[0]; + __m256i r_v = _mm256_sub_epi32(vbase, vout); + */ + uint32_t lg[8], dd_array[8]; + _mm256_storeu_si256((__m256i *)dd_array, dd_v); + for (int i = 0; i < 8; i++) { + lg[i] = ilog2_32_dp_lib(dd_array[i]); + } + __m256i r_v = _mm256_loadu_si256((__m256i *)lg); + + // log_dd = dd?ilog2:0; log_dd>>1 + //__mmask8 neg_mask = _mm256_cmpneq_epi32_mask(dd_v, zero_avx2_v); + //__m256i log_dd_v = _mm256_srli_epi32(_mm256_maskz_or_epi32(neg_mask, r_v, zero_avx2_v), 1); + __m256i zero_avx2_v = _mm256_setzero_si256(); + __m256i neg_mask = _mm256_cmpeq_epi32(dd_v, zero_avx2_v); + __m256i log_dd_v = _mm256_srli_epi32(_mm256_andnot_si256(neg_mask, r_v), 1); + + //dd * 0.01*avg_qspan + float avg_qspan_val = 0.01 * avg_qspan; + __m256 avg_qspan_v = _mm256_set1_ps(avg_qspan_val); + __m256i cost = _mm256_cvtps_epi32(_mm256_round_ps((_mm256_mul_ps(_mm256_cvtepi32_ps(dd_v), avg_qspan_v)), _MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC)); + + // gap_cost = (dd * 0.01*avg_qspan) + (log_dd>>1 + 0.499) + __m256i gap_cost_v = (_mm256_add_epi32(cost, log_dd_v)); + + return gap_cost_v; + } +#endif + +static void chain_dp(call_t *a, return_t *ret) { + constexpr float gap_scale = 1.0f; + // constexpr int max_skip = 25; + // constexpr int is_cdna = 0; + constexpr int max_iter = 5000; + + const auto max_dist_x = a->max_dist_x; + const auto max_dist_y = a->max_dist_y; + const auto bw = a->bw; + + const auto avg_qspan = a->avg_qspan; + const float avg_qspan001 = 0.01f * avg_qspan; + + // const auto n_segs = a->n_segs; + const auto n = a->n; + + auto *anchors_x = a->anchors_x.data() + 32; + auto *anchors_x32 = a->anchors_x32.data() + 32; + + auto *anchors_y = a->anchors_y.data() + 32; + auto *anchors_y32 = a->anchors_y32.data() + 32; + + auto *q_spans = a->q_spans.data() + 32; + + ret->n = n; + + // Some extra space for vectorization with intrinsics. + ret->scores.resize(n + 64); + ret->parents.resize(n + 64); + ret->targets.resize(n + 64); + ret->peak_scores.resize(n + 64); + + // Add padding before and after the data. + auto *scores = ret->scores.data() + 32; + auto *parents = ret->parents.data() + 32; + auto *targets = ret->targets.data() + 32; + auto *peak_scores = ret->peak_scores.data() + 32; + + int32_t st = 0; + +#ifdef __AVX512BW__ + #pragma message("Using AVX512 version") + //Vector code with SoA function parameters 32-bit number representation - avx512 + + __m512i zero_v = _mm512_setzero_si512(); + + int32_t dr = max_dist_x; + int32_t dq = max_dist_y; + + __m512i dr_v = _mm512_set1_epi32((int64_t)dr); + __m512i dq_v = _mm512_set1_epi32((int64_t)dq); + __m512i j_idx_base = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + int32_t maxfVector_v[16]; + int32_t maxjVector_v[16]; + __m512i neg_one_v = _mm512_set1_epi32((int32_t) -1); + + for (int i = 0; i < n; i++) { + + + int32_t max_j = -1; + int32_t max_f = q_spans[i]; + + + + //uint64_t ri = anchors_x[i]; + while (st < i && !(anchors_x[i] - anchors_x[st] <= (uint32_t)dr)) { + ++st; //predecessor's position is too far + } + + // TODO: Minimap specific max_iter parameter + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + + int j = i - 1; + if (!(j - st <= 5)) { + //broadcast ri and qi + __m512i ri_v = _mm512_set1_epi32(anchors_x32[i]); + __m512i qi_v = _mm512_set1_epi32(anchors_y32[i]); + + + __m512i maxj_v = neg_one_v; + __m512i maxf_v = _mm512_set1_epi32((int32_t)q_spans[i]); + __m512i li_v = maxf_v; + + // 16-way vectorized + //_mm_prefetch((const char *)(&anchors_x32[j - 30]), _MM_HINT_T2); + //_mm_prefetch((const char *)(&anchors_y32[j - 30]), _MM_HINT_T2); + + for (j = i - 1; (j - 15) >= st; j = j - 16) { + + _mm_prefetch((const char *)(&anchors_x32[j - 60]), _MM_HINT_T0); + _mm_prefetch((const char *)(&anchors_y32[j - 60]), _MM_HINT_T0); + + uint32_t *rj, *qj; + rj = &anchors_x32[j - 15]; + qj = &anchors_y32[j - 15]; + + + // Load rj and qj + __m512i rj_v = _mm512_loadu_si512(rj); + __m512i qj_v = _mm512_loadu_si512(qj); + + + __m512i ddr_v = _mm512_sub_epi32(ri_v, rj_v); + __m512i ddq_v = _mm512_sub_epi32(qi_v, qj_v); + + //TODO: Minimap2 specific continue condition + __m512i dd_v = _mm512_abs_epi32(_mm512_sub_epi32(ddr_v, ddq_v)); + __m512i bw_v = _mm512_set1_epi32((int32_t) bw); + __mmask16 bw_gt = _mm512_cmpgt_epi32_mask(dd_v, bw_v); + __mmask16 mask_eq = _mm512_cmpeq_epi32_mask(ddr_v, zero_v); + __mmask16 mask_leq = _mm512_cmple_epi32_mask(ddq_v, zero_v); + __mmask16 mask_gt1 = _mm512_cmpgt_epi32_mask(ddq_v, dq_v); + __mmask16 mask_gt2 = _mm512_cmpgt_epi32_mask(ddq_v, dr_v); + + __mmask16 loopContinueMask = ~(bw_gt | mask_eq | mask_leq | mask_gt1 | mask_gt2); + + // Load scores[j-8, j] + __m512i fj_v = _mm512_loadu_si512(&scores[j - 15]); + + //Vectorized gap cost function + __m512i gc_v = get_gap_cost_vectorized_int32(dd_v, avg_qspan, gap_scale); + + //---------------- Inline get_overlap_cost function ------------------- + __m512i min1 = _mm512_min_epi32(ddr_v, ddq_v); + __m512i oc_v = _mm512_min_epi32(li_v, min1); + //---------------------------------------------------------------------- + __m512i f_plus_oc_v = _mm512_add_epi32(fj_v, oc_v); + __m512i sc_v = _mm512_maskz_sub_epi32(loopContinueMask, f_plus_oc_v, gc_v); + + // Update Maxf and Maxj + __mmask16 mask_max_sc = _mm512_cmpgt_epi32_mask(sc_v, maxf_v); + __m512i j_idx_v = _mm512_add_epi32(j_idx_base, _mm512_set1_epi32(j - 15)); + + maxf_v = _mm512_max_epi32(sc_v, maxf_v); + maxj_v = _mm512_mask_blend_epi32(mask_max_sc, maxj_v, j_idx_v); + } + + + if (j >= st) { + uint32_t *rj, *qj; + rj = &anchors_x32[j - 15]; + qj = &anchors_y32[j - 15]; + + + // Load rj and qj + __m512i rj_v = _mm512_loadu_si512(rj); + __m512i qj_v = _mm512_loadu_si512(qj); + + + __m512i ddr_v = _mm512_sub_epi32(ri_v, rj_v); + __m512i ddq_v = _mm512_sub_epi32(qi_v, qj_v); + + //TODO: Minimap2 specific continue condition + __m512i dd_v = _mm512_abs_epi32(_mm512_sub_epi32(ddr_v, ddq_v)); + __m512i bw_v = _mm512_set1_epi32((int32_t) bw); + __mmask16 bw_gt = _mm512_cmpgt_epi32_mask(dd_v, bw_v); + __mmask16 mask_eq = _mm512_cmpeq_epi32_mask(ddr_v, zero_v); + __mmask16 mask_leq = _mm512_cmple_epi32_mask(ddq_v, zero_v); + __mmask16 mask_gt1 = _mm512_cmpgt_epi32_mask(ddq_v, dq_v); + __mmask16 mask_gt2 = _mm512_cmpgt_epi32_mask(ddq_v, dr_v); + + + __mmask16 loopContinueMask = ~(bw_gt | mask_eq | mask_leq | mask_gt1 | mask_gt2); + + + //Last partial vector processing mask - To enable, change loop condition to -> j >= st + + int shift = st - (j - 15); + + loopContinueMask = loopContinueMask >> (shift); + loopContinueMask = loopContinueMask << (shift); + if (loopContinueMask != 0x00) { + + + + // Load scores[j-8, j] + __m512i fj_v = _mm512_loadu_si512(&scores[j - 15]); + + //Vectorized gap cost function + __m512i gc_v = get_gap_cost_vectorized_int32(dd_v, avg_qspan, gap_scale); + + //---------------- Inline get_overlap_cost function ------------------- + __m512i min1 = _mm512_min_epi32(ddr_v, ddq_v); + __m512i oc_v = _mm512_min_epi32(li_v, min1); + //---------------------------------------------------------------------- + __m512i f_plus_oc_v = _mm512_add_epi32(fj_v, oc_v); + __m512i sc_v = _mm512_maskz_sub_epi32(loopContinueMask, f_plus_oc_v, gc_v); + + + // Update Maxf and Maxj + __mmask16 mask_max_sc = _mm512_cmpgt_epi32_mask(sc_v, maxf_v); + __m512i j_idx_v = _mm512_add_epi32(j_idx_base, _mm512_set1_epi32(j - 15)); + + maxf_v = _mm512_max_epi32(sc_v, maxf_v); + maxj_v = _mm512_mask_blend_epi32(mask_max_sc, maxj_v, j_idx_v); + + } + } + + + _mm512_store_epi32(maxfVector_v, maxf_v); + _mm512_store_epi32(maxjVector_v, maxj_v); + + for (int iter = 15; iter >= 0; iter--) { + if (maxfVector_v[iter] > max_f) { + max_f = maxfVector_v[iter]; + max_j = maxjVector_v[iter]; + } + else if (maxfVector_v[iter] == max_f) { + max_j = std::max(max_j, maxjVector_v[iter]); + if ((uint32_t)max_f == q_spans[i]) { + max_j = -1; + } + } + } + + + } + else { + int32_t ri = anchors_x32[i], qi = anchors_y32[i]; + for (; j >= st; j--) { + + int32_t ddr, ddq; + + int32_t rj = anchors_x32[j], + qj = anchors_y32[j]; + ddr = ri - rj; + ddq = qi - qj; + + if (abs(ddr - ddq) > bw) { + continue; + } + if (ddr == 0 || ddq <= 0) { + continue; + } + + if (ddq > dq || ddq > dr) { + continue; + } + + + int32_t oc = 0; + + int32_t score = scores[j];//q_spans[i]; + oc = ddr < ddq ? ddr : ddq; + oc = oc < (int32_t)q_spans[i] ? oc : q_spans[i]; + score += oc; + + int32_t dr = ddr; + int32_t dq = ddq; + int32_t dd = abs(dr - dq);//dr > dq? dr - dq : dq - dr; //dd = |dr-dq|; + int32_t log_dd = dd ? ilog2_32_dp_lib(dd) : 0; + int32_t gap_cost = 0; + + gap_cost = (int)(dd * .01 * avg_qspan) + (log_dd >> 1) + 0.00; //TODO: Only multiplication should be casted to (int) + score -= gap_cost; + + if (score > max_f) { + max_f = score; + max_j = j; + } + + } + + } + + scores[i] = max_f; + parents[i] = max_j; + peak_scores[i] = max_j >= 0 && peak_scores[max_j] > max_f ? peak_scores[max_j] : max_f; // v[] keeps the peak score up to i; scores[] is the score ending at i, not always the peak + } +#elif __AVX2__ + #pragma message("Using AVX2 version") + + __m256i zero_avx2_v = _mm256_setzero_si256(); + + int32_t dr = max_dist_x; + int32_t dq = max_dist_y; + + __m256i dr_v = _mm256_set1_epi32(dr); + __m256i dq_v = _mm256_set1_epi32(dq); + __m256i j_idx_base = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0); + int32_t maxfVector_v[8]; + int32_t maxjVector_v[8]; + __m256i neg_one_v = _mm256_set1_epi32((int32_t) -1); + + for (int i = 0; i < n; i++) { + + + int32_t max_j = -1; + int32_t max_f = q_spans[i]; + + + + //uint64_t ri = anchors_x[i]; + while (st < i && !(anchors_x[i] - anchors_x[st] <= dr)) { + ++st; //predecessor's position is too far + } + + // TODO: Minimap specific max_iter parameter + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + + int j = i - 1; + if (!(j - st <= 5)) { + //broadcast ri and qi + __m256i ri_v = _mm256_set1_epi32(anchors_x32[i]); + __m256i qi_v = _mm256_set1_epi32(anchors_y32[i]); + + + __m256i maxj_v = neg_one_v; + __m256i maxf_v = _mm256_set1_epi32((int32_t)q_spans[i]); + __m256i li_v = maxf_v; + + // 8-way vectorized + for (j = i - 1; (j - 7) >= st; j = j - 8) { + + _mm_prefetch((const char *)(&anchors_x32[j - 60]), _MM_HINT_T0); + _mm_prefetch((const char *)(&anchors_y32[j - 60]), _MM_HINT_T0); + + uint32_t *rj, *qj; + int j_stride = j - 7; + rj = &anchors_x32[j_stride]; + qj = &anchors_y32[j_stride]; + + + // Load rj and qj + __m256i rj_v = _mm256_loadu_si256((__m256i *) rj); + __m256i qj_v = _mm256_loadu_si256((__m256i *) qj); + + + __m256i ddr_v = _mm256_sub_epi32(ri_v, rj_v); + __m256i ddq_v = _mm256_sub_epi32(qi_v, qj_v); + + //TODO: Minimap2 specific continue condition + __m256i dd_v = _mm256_abs_epi32(_mm256_sub_epi32(ddr_v, ddq_v)); + __m256i bw_v = _mm256_set1_epi32((int32_t) bw); + + /* + __mmask8 bw_gt = _mm256_cmpgt_epi32_mask(dd_v, bw_v); + __mmask8 mask_eq = _mm256_cmpeq_epi32_mask(ddr_v, zero_avx2_v); + __mmask8 mask_leq1 = _mm256_cmpgt_epi32_mask(zero_avx2_v, ddq_v); + __mmask8 mask_leq2 = _mm256_cmpeq_epi32_mask(ddq_v, zero_avx2_v); + __mmask8 mask_gt1 = _mm256_cmpgt_epi32_mask(ddq_v, dq_v); + __mmask8 mask_gt2 = _mm256_cmpgt_epi32_mask(ddq_v, dr_v); + + + __mmask8 loopContinueMask = ~(bw_gt | mask_eq | (mask_leq1 | mask_leq2 ) | mask_gt1 | mask_gt2); + */ + + __m256i bw_gt = _mm256_cmpgt_epi32(dd_v, bw_v); + __m256i mask_eq = _mm256_cmpeq_epi32(ddr_v, zero_avx2_v); + __m256i mask_leq1 = _mm256_cmpgt_epi32(zero_avx2_v, ddq_v); + __m256i mask_leq2 = _mm256_cmpeq_epi32(ddq_v, zero_avx2_v); + __m256i mask_gt1 = _mm256_cmpgt_epi32(ddq_v, dq_v); + __m256i mask_gt2 = _mm256_cmpgt_epi32(ddq_v, dr_v); + + __m256i tmp1 = _mm256_or_si256(bw_gt, mask_eq); + __m256i tmp2 = _mm256_or_si256(mask_leq1, mask_leq2); + __m256i tmp3 = _mm256_or_si256(mask_gt1, mask_gt2); + + __m256i loopContinueMask = _mm256_or_si256(_mm256_or_si256(tmp1, tmp2), tmp3); + + // Load scores[j-8, j] + __m256i fj_v = _mm256_loadu_si256((__m256i *) &scores[j_stride]); + + //Vectorized gap cost function + __m256i gc_v = get_gap_cost_vectorized_int32(dd_v, avg_qspan, gap_scale); + + //---------------- Inline get_overlap_cost function ------------------- + __m256i min1 = _mm256_min_epi32(ddr_v, ddq_v); + __m256i oc_v = _mm256_min_epi32(li_v, min1); + //---------------------------------------------------------------------- + __m256i f_plus_oc_v = _mm256_add_epi32(fj_v, oc_v); + //__m256i sc_v = _mm256_maskz_sub_epi32(loopContinueMask,f_plus_oc_v, gc_v); + __m256i sc_v = _mm256_andnot_si256(loopContinueMask, _mm256_sub_epi32(f_plus_oc_v, gc_v)); + + // Update Maxf and Maxj + __m256i mask_max_sc = _mm256_cmpgt_epi32(sc_v, maxf_v); + __m256i j_idx_v = _mm256_add_epi32(j_idx_base, _mm256_set1_epi32(j_stride)); + + maxf_v = _mm256_max_epi32(sc_v, maxf_v); + maxj_v = _mm256_or_si256(_mm256_andnot_si256(mask_max_sc, maxj_v), _mm256_and_si256(mask_max_sc, j_idx_v)); + } + + + if (j >= st) { + uint32_t *rj, *qj; + int j_stride = j - 7; + rj = &anchors_x32[j_stride]; + qj = &anchors_y32[j_stride]; + + + // Load rj and qj + __m256i rj_v = _mm256_loadu_si256((__m256i *) rj); + __m256i qj_v = _mm256_loadu_si256((__m256i *) qj); + + + __m256i ddr_v = _mm256_sub_epi32(ri_v, rj_v); + __m256i ddq_v = _mm256_sub_epi32(qi_v, qj_v); + + //TODO: Minimap2 specific continue condition + __m256i dd_v = _mm256_abs_epi32(_mm256_sub_epi32(ddr_v, ddq_v)); + __m256i bw_v = _mm256_set1_epi32((int32_t) bw); + /* + __mmask8 bw_gt = _mm256_cmpgt_epi32_mask(dd_v, bw_v); + __mmask8 mask_eq = _mm256_cmpeq_epi32_mask(ddr_v, zero_avx2_v); + __mmask8 mask_leq = _mm256_cmple_epi32_mask(ddq_v, zero_avx2_v); + __mmask8 mask_gt1 = _mm256_cmpgt_epi32_mask(ddq_v, dq_v); + __mmask8 mask_gt2 = _mm256_cmpgt_epi32_mask(ddq_v, dr_v); + + + __mmask8 loopContinueMask = ~(bw_gt | mask_eq | mask_leq | mask_gt1 | mask_gt2); + */ + __m256i bw_gt = _mm256_cmpgt_epi32(dd_v, bw_v); + __m256i mask_eq = _mm256_cmpeq_epi32(ddr_v, zero_avx2_v); + __m256i mask_leq1 = _mm256_cmpgt_epi32(zero_avx2_v, ddq_v); + __m256i mask_leq2 = _mm256_cmpeq_epi32(ddq_v, zero_avx2_v); + __m256i mask_gt1 = _mm256_cmpgt_epi32(ddq_v, dq_v); + __m256i mask_gt2 = _mm256_cmpgt_epi32(ddq_v, dr_v); + + __m256i tmp1 = _mm256_or_si256(bw_gt, mask_eq); + __m256i tmp2 = _mm256_or_si256(mask_leq1, mask_leq2); + __m256i tmp3 = _mm256_or_si256(mask_gt1, mask_gt2); + + __m256i loopContinueMask = _mm256_or_si256(_mm256_or_si256(tmp1, tmp2), tmp3); + + //Last partial vector processing mask - To enable, change loop condition to -> j >= st + + int shift = st - (j_stride); + + int32_t msk_ar[8]; + for (int it = 0; it < 8; it++) { + msk_ar[it] = (it < (shift)) ? 0xFFFFFFFF : 0; + } + loopContinueMask = _mm256_or_si256(loopContinueMask, _mm256_loadu_si256((__m256i *)msk_ar)); + //loopContinueMask = loopContinueMask>>(shift); + //loopContinueMask = loopContinueMask<<(shift); + //if(loopContinueMask != 0x0) + { + + + + // Load scores[j-8, j] + __m256i fj_v = _mm256_loadu_si256((__m256i *) &scores[j_stride]); + + //Vectorized gap cost function + __m256i gc_v = get_gap_cost_vectorized_int32(dd_v, avg_qspan, gap_scale); + + //---------------- Inline get_overlap_cost function ------------------- + __m256i min1 = _mm256_min_epi32(ddr_v, ddq_v); + __m256i oc_v = _mm256_min_epi32(li_v, min1); + //---------------------------------------------------------------------- + __m256i f_plus_oc_v = _mm256_add_epi32(fj_v, oc_v); + //__m256i sc_v = _mm256_maskz_sub_epi32(loopContinueMask,f_plus_oc_v, gc_v); + __m256i sc_v = _mm256_andnot_si256(loopContinueMask, _mm256_sub_epi32(f_plus_oc_v, gc_v)); + + + // Update Maxf and Maxj + __m256i mask_max_sc = _mm256_cmpgt_epi32(sc_v, maxf_v); + __m256i j_idx_v = _mm256_add_epi32(j_idx_base, _mm256_set1_epi32(j_stride)); + + maxf_v = _mm256_max_epi32(sc_v, maxf_v); + maxj_v = _mm256_or_si256(_mm256_andnot_si256(mask_max_sc, maxj_v), _mm256_and_si256(mask_max_sc, j_idx_v)); + + } + } + + + //_mm256_store_epi32(maxfVector_v, maxf_v); + //_mm256_store_epi32(maxjVector_v, maxj_v); + _mm256_store_si256((__m256i *) maxfVector_v, maxf_v); + _mm256_store_si256((__m256i *) maxjVector_v, maxj_v); + + for (int iter = 7; iter >= 0; iter--) { + if (maxfVector_v[iter] > max_f) { + max_f = maxfVector_v[iter]; + max_j = maxjVector_v[iter]; + } + else if (maxfVector_v[iter] == max_f) { + max_j = std::max(max_j, maxjVector_v[iter]); + if (max_f == q_spans[i]) { + max_j = -1; + } + } + } + + + } + else { + int32_t ri = anchors_x32[i], qi = anchors_y32[i]; + for (; j >= st; j--) { + + int32_t ddr, ddq; + + int32_t rj = anchors_x32[j], + qj = anchors_y32[j]; + ddr = ri - rj; + ddq = qi - qj; + + if (abs(ddr - ddq) > bw) { + continue; + } + if (ddr == 0 || ddq <= 0) { + continue; + } + + if (ddq > dq || ddq > dr) { + continue; + } + + + int32_t oc = 0; + int32_t lj = q_spans[j]; + int32_t ref_overlap = rj + lj - ri; + int32_t query_overlap = qj + lj - qi; + + int32_t score = scores[j];//q_spans[i]; + oc = ddr < ddq ? ddr : ddq; + oc = oc < q_spans[i] ? oc : q_spans[i]; + score += oc; + + int32_t dr = ddr; + int32_t dq = ddq; + int32_t dd = abs(dr - dq);//dr > dq? dr - dq : dq - dr; //dd = |dr-dq|; + int32_t log_dd = dd ? ilog2_32_dp_lib(dd) : 0; + int32_t gap_cost = 0; + + gap_cost = (int)(dd * .01 * avg_qspan) + (log_dd >> 1) + 0.00; //TODO: Only multiplication should be casted to (int) + score -= gap_cost; + + bool check = score > max_f; + if (score > max_f) { + max_f = score; + max_j = j; + } + + } + + } + + scores[i] = max_f; + parents[i] = max_j; + peak_scores[i] = max_j >= 0 && peak_scores[max_j] > max_f ? peak_scores[max_j] : max_f; // v[] keeps the peak score up to i; scores[] is the score ending at i, not always the peak + } +#elif __ARM_FEATURE_SVE + #pragma message("Using SVE version") + + // fill the score and backtrack arrays + for (int32_t i = 0; i < n; ++i) { + const int32_t ri_scalar = anchors_x32[i]; + svint32_t ri = svdup_n_s32(ri_scalar); + const int32_t qi_scalar = static_cast(anchors_y32[i]); + svint32_t qi = svdup_n_s32(qi_scalar); + const int32_t q_spani = q_spans[i]; + + int32_t max_j = -1; + int32_t max_f = q_spani; + + while (st < i && !(anchors_x[i] - anchors_x[st] <= max_dist_x)) { + ++st; //predecessor's position is too far + } + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + //for (int64_t j = i - 1; j >= st; --j) { + svbool_t ptrue = svptrue_b32(); + for (int32_t j = i - 1; j >= st; j-=VL) { + int32_t real_j = j-VL+1; + svbool_t valid_elements = svnot_b_z(ptrue,svwhilelt_b32_s32(real_j,st)); + //const auto rj = anchors_x[j]; + svint32_t rj = svld1_s32(valid_elements,(int32_t*)&anchors_x32[real_j]); + //const int32_t qj = static_cast(anchors_y[j]); + svint32_t qj = svld1_s32(valid_elements,(int32_t*)&anchors_y32[real_j]); + // qj = svextw_s64_x(valid_elements,qj); + + //const int64_t dr = ri - rj; + svint32_t dr = svsub_s32_x(valid_elements,ri,rj); + //const int32_t dq = qi - qj; + svint32_t dq = svsub_s32_x(valid_elements,qi,qj); + + //const int32_t dd = std::abs(dr - dq); + svint32_t dd = svabd_s32_x(valid_elements,dr,dq); + + //if ((dr == 0 || dq <= 0) || + // (dq > max_dist_y || dq > max_dist_x) || + // (dd > bw)) { + // continue; + //} + svbool_t skip_anchor = svcmpeq_n_s32(valid_elements,dr,0); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmple_n_s32(valid_elements,dq,0); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s32(valid_elements,dq,max_dist_y); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s32(valid_elements,dq,max_dist_x); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s32(valid_elements,dd,bw); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + if (!svptest_any(ptrue,valid_elements)) continue; + + //const int64_t dr_dq_min = (dr < dq) ? dr : dq; + svint32_t dr_dq_min = svmin_s32_x(valid_elements,dr,dq); + //const int32_t oc = (dr_dq_min < q_spani) ? dr_dq_min : q_spani; + svint32_t oc = svmin_n_s32_x(valid_elements,dr_dq_min,q_spani); + + //const int32_t log_dd = (dd) ? ilog2_32(dd) : 0; + svbool_t valid_log = svcmpne_n_s32(valid_elements,dd,0); + svint32_t log_dd = svreinterpret_s32(svclz_s32_x(valid_elements,dd)); + log_dd = svsubr_n_s32_z(valid_log,log_dd,31); + + //const int32_t gap_cost = static_cast(dd * 0.01f * avg_qspan) + (log_dd >> 1); + /* + svint64_t gap_cost = svmul_n_s64_x(valid_elements,dd,avg_qspan001); + log_dd = svlsr_n_s64_x(valid_elements,log_dd,1); + gap_cost = svadd_s64_x(valid_elements,gap_cost,log_dd); + */ + svfloat32_t gap_cost_f = svcvt_f32_s32_x(valid_elements,dd); + //gap_cost_f = svmul_n_f64_x(valid_elements,gap_cost_f,0.01f); + gap_cost_f = svmul_n_f32_x(valid_elements,gap_cost_f,avg_qspan001); + svint32_t gap_cost = svcvt_s32_f32_x(valid_elements,gap_cost_f); + log_dd = svreinterpret_s32(svlsr_n_u32_x(valid_elements,svreinterpret_u32(log_dd),1)); + gap_cost = svadd_s32_x(valid_elements,gap_cost,log_dd); + + //const int32_t score = scores[j] + oc - gap_cost; + svint32_t score = svld1_s32(valid_elements,&scores[real_j]); + score = svadd_s32_x(valid_elements,score,oc); + score = svsub_s32_x(valid_elements,score,gap_cost); + + // TODO: CAN'T VECTORIZE THIS. THE COMPILER IS ONLY ABLE TO PERFORM + // ONE REDUCTION. + // Ideas: + // 1. j to int32_t and: uint64_t max = (j << 32) & score. + // max = ((int32_t)(max & 0xffffffff) > score) ? max = (j << 32) & score : max; + // Reduction + + // TODO: reduction on vector register, no scalar register + //if (score > max_f) { + // max_f = score; + // max_j = j; + //} + + int32_t max_local = svmaxv_s32(valid_elements,score); + if (max_local > max_f) { + max_f = max_local; + // WARNING + svint32_t index = svindex_s32(real_j,1); + svbool_t max_index = svcmpeq_n_s32(valid_elements,score,max_local); + max_j = svlastb_s32(max_index,index); + } + } + scores[i] = max_f; + parents[i] = max_j; + //if (max_f == 36 && max_j == 18821) {printf("ee\n");exit(0);} + peak_scores[i] = max_j >= 0 && peak_scores[max_j] > max_f ? peak_scores[max_j] : max_f; + } +#else // SCALAR VERSION + #pragma message("Using SCALAR version") + for (int32_t i = 0; i < n; ++i) { + const uint32_t ri = anchors_x32[i]; + const int32_t qi = static_cast(anchors_y32[i]); + const int32_t q_spani = q_spans[i]; + + int32_t max_j = -1; + int32_t max_f = q_spani; + + while (st < i && !(anchors_x[i] - anchors_x[st] <= max_dist_x)) { + ++st; //predecessor's position is too far + } + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + // TODO: Iterate forward to vectorize the loop. + // for (int64_t j_inv = st; j_inv < i; ++j_inv) { + // const int64_t j = (i - 1) - j_inv + st; + for (int32_t j = i - 1; j >= st; --j) { + const uint32_t rj = anchors_x32[j]; + const uint32_t qj = anchors_y32[j]; + + const int32_t dr = ri - rj; + const int32_t dq = qi - qj; + + const int32_t dd = std::abs(dr - dq); + + if ((dr == 0 || dq <= 0) || + (dq > max_dist_y || dq > max_dist_x) || + (dd > bw)) { + continue; + } + // Can not vectorize "continue". Use a mask instead. + // const bool skip = ((dr == 0 || dq <= 0) || + // (dq > max_dist_y || dq > max_dist_x) || + // (dd > bw)); + + const int32_t dr_dq_min = (dr < dq) ? dr : dq; + const int32_t oc = (dr_dq_min < q_spani) ? dr_dq_min : q_spani; + + // TODO: CAN'T VECTORIZE __builtin_clz + const int32_t log_dd = (dd) ? ilog2_32(dd) : 0; + const int32_t gap_cost = static_cast(dd * 0.01f * avg_qspan) + (log_dd >> 1); + + const int32_t score = scores[j] + oc - gap_cost; + + // TODO: CAN'T VECTORIZE THIS. THE COMPILER IS ONLY ABLE TO PERFORM + // ONE REDUCTION. + // Ideas: + // 1. j to int32_t and: uint64_t max = (j << 32) & score. + // max = ((int32_t)(max & 0xffffffff) > score) ? max = (j << 32) & score : max; + // Reduction + if (score > max_f) { + max_f = score; + max_j = j; + } + } + scores[i] = max_f; + parents[i] = max_j; + peak_scores[i] = max_j >= 0 && peak_scores[max_j] > max_f ? peak_scores[max_j] : max_f; + } +#endif +} + +void host_chain_kernel(std::vector &args, std::vector &rets, int numThreads) { + #pragma omp parallel num_threads(numThreads) + { + #pragma omp for schedule(dynamic) + for (size_t batch = 0; batch < args.size(); batch++) { + call_t *arg = &args[batch]; + return_t *ret = &rets[batch]; + // fprintf(stderr, "%lld\t%f\t%d\t%d\t%d\t%d\n", arg->n, arg->avg_qspan, arg->max_dist_x, arg->max_dist_y, arg->bw, arg->n_segs); + chain_dp(arg, ret); + } + } +} diff --git a/benchmarks/fast-chain/src/host_kernel.cpp.uint64 b/benchmarks/fast-chain/src/host_kernel.cpp.uint64 new file mode 100644 index 0000000..7298d21 --- /dev/null +++ b/benchmarks/fast-chain/src/host_kernel.cpp.uint64 @@ -0,0 +1,292 @@ +#include +#include +#include +#include +#include +#include +#include "omp.h" +#include "host_kernel.h" +#include "common.h" +// #include "minimap.h" +// #include "mmpriv.h" +// #include "kalloc.h" + +#ifdef __AVX2__ + #include +#endif +#ifdef __AVX512BW__ + #include +#endif +#ifdef __ARM_FEATURE_SVE + #include + #define VL svcntd() +#endif + +static const char LogTable256_dp_lib[256] = { +#define LT(n) n, n, n, n, n, n, n, n, n, n, n, n, n, n, n, n + -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, + LT(4), LT(5), LT(5), LT(6), LT(6), LT(6), LT(6), + LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7), LT(7) + }; + +static inline int ilog2_32_dp_lib(uint32_t v) { + uint32_t t, tt; + if ((tt = v >> 16)) { + return (t = tt >> 8) ? 24 + LogTable256_dp_lib[t] : 16 + LogTable256_dp_lib[tt]; + } + return (t = v >> 8) ? 8 + LogTable256_dp_lib[t] : LogTable256_dp_lib[v]; +} + +static inline int32_t ilog2_32(uint32_t v) { + constexpr uint32_t base = 31; + const uint32_t leading_zeros = __builtin_clz(v); + + return base - leading_zeros; +} + +const int BACKSEARCH = 65; +#define MM_SEED_SEG_SHIFT 48 +#define MM_SEED_SEG_MASK (0xffULL<<(MM_SEED_SEG_SHIFT)) + +#if 0 +void print_vector(svint64_t v) { + int64_t _v[VL]; + svst1_s64(svptrue_b64(),_v,v); + for (uint64_t i = 0; i < VL; i++) { + printf("%ld,",_v[i]); + } + printf("\n"); +} + +void print_vector_f(svfloat64_t v) { + float64_t _v[VL]; + svst1_f64(svptrue_b64(),_v,v); + for (uint64_t i = 0; i < VL; i++) { + printf("%f,",_v[i]); + } + printf("\n"); +} +#endif + +static void chain_dp(call_t *a, return_t *ret) { + // constexpr float gap_scale = 1.0f; + // constexpr int max_skip = 25; + // constexpr int is_cdna = 0; + constexpr int max_iter = 5000; + + const auto max_dist_x = a->max_dist_x; + const auto max_dist_y = a->max_dist_y; + const auto bw = a->bw; + + const auto avg_qspan = a->avg_qspan; + const auto avg_qspan001 = 0.01f * avg_qspan; + + // const auto n_segs = a->n_segs; + const auto n = a->n; + + auto *anchors_x = a->anchors_x.data(); + auto *anchors_y = a->anchors_y.data(); + auto *anchors_l = a->anchors_l.data(); + + ret->n = n; + ret->scores.resize(n); + ret->parents.resize(n); + ret->targets.resize(n); + ret->peak_scores.resize(n); + + int64_t st = 0; + +#if __ARM_FEATURE_SVE + printf("Executing SVE version"); + + // fill the score and backtrack arrays + for (int64_t i = 0; i < n; ++i) { + const auto ri_scalar = anchors_x[i]; + svint64_t ri = svdup_n_s64(ri_scalar); + const int32_t qi_scalar = static_cast(anchors_y[i]); + svint64_t qi = svdup_n_s64(qi_scalar); + const int32_t q_spani = qspans[i]; + + int64_t max_j = -1; + int32_t max_f = q_spani; + + while (st < i && ri_scalar > anchors_x[st] + max_dist_x) { + ++st; //predecessor's position is too far + } + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + //for (int64_t j = i - 1; j >= st; --j) { + svbool_t ptrue = svptrue_b64(); + for (int64_t j = i - 1; j >= st; j-=VL) { + int64_t real_j = j-VL+1; + svbool_t valid_elements = svnot_b_z(ptrue,svwhilelt_b64_s64(real_j,st)); + //const auto rj = anchors_x[j]; + svint64_t rj = svld1_s64(valid_elements,(int64_t*)&anchors_x[real_j]); + //const int32_t qj = static_cast(anchors_y[j]); + svint64_t qj = svld1_s64(valid_elements,(int64_t*)&anchors_y[real_j]); + qj = svextw_s64_x(valid_elements,qj); + + //const int64_t dr = ri - rj; + svint64_t dr = svsub_s64_x(valid_elements,ri,rj); + //const int32_t dq = qi - qj; + svint64_t dq = svsub_s64_x(valid_elements,qi,qj); + + //const int32_t dd = std::abs(dr - dq); + svint64_t dd = svabd_s64_x(valid_elements,dr,dq); + + //if ((dr == 0 || dq <= 0) || + // (dq > max_dist_y || dq > max_dist_x) || + // (dd > bw)) { + // continue; + //} + svbool_t skip_anchor = svcmpeq_n_s64(valid_elements,dr,0); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmple_n_s64(valid_elements,dq,0); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s64(valid_elements,dq,max_dist_y); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s64(valid_elements,dq,max_dist_x); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + skip_anchor = svcmpgt_n_s64(valid_elements,dd,bw); + valid_elements = svbic_b_z(valid_elements,valid_elements,skip_anchor); + + if (!svptest_any(ptrue,valid_elements)) continue; + + //const int64_t dr_dq_min = (dr < dq) ? dr : dq; + svint64_t dr_dq_min = svmin_s64_x(valid_elements,dr,dq); + //const int32_t oc = (dr_dq_min < q_spani) ? dr_dq_min : q_spani; + svint64_t oc = svmin_n_s64_x(valid_elements,dr_dq_min,q_spani); + + //const int32_t log_dd = (dd) ? ilog2_32(dd) : 0; + svbool_t valid_log = svcmpne_n_s64(valid_elements,dd,0); + svint64_t log_dd = svreinterpret_s64(svclz_s64_x(valid_elements,dd)); + log_dd = svsubr_n_s64_z(valid_log,log_dd,63); + + //const int32_t gap_cost = static_cast(dd * 0.01f * avg_qspan) + (log_dd >> 1); + /* + svint64_t gap_cost = svmul_n_s64_x(valid_elements,dd,avg_qspan001); + log_dd = svlsr_n_s64_x(valid_elements,log_dd,1); + gap_cost = svadd_s64_x(valid_elements,gap_cost,log_dd); + */ + svfloat64_t gap_cost_f = svcvt_f64_s64_x(valid_elements,dd); + //gap_cost_f = svmul_n_f64_x(valid_elements,gap_cost_f,0.01f); + gap_cost_f = svmul_n_f64_x(valid_elements,gap_cost_f,avg_qspan001); + svint64_t gap_cost = svcvt_s64_f64_x(valid_elements,gap_cost_f); + log_dd = svlsr_n_s64_x(valid_elements,log_dd,1); + gap_cost = svadd_s64_x(valid_elements,gap_cost,log_dd); + + //const int32_t score = ret->scores[j] + oc - gap_cost; + svint64_t score = svld1sw_s64(valid_elements,&ret->scores[real_j]); + score = svadd_s64_x(valid_elements,score,oc); + score = svsub_s64_x(valid_elements,score,gap_cost); + + // TODO: CAN'T VECTORIZE THIS. THE COMPILER IS ONLY ABLE TO PERFORM + // ONE REDUCTION. + // Ideas: + // 1. j to int32_t and: uint64_t max = (j << 32) & score. + // max = ((int32_t)(max & 0xffffffff) > score) ? max = (j << 32) & score : max; + // Reduction + + // TODO: reduction on vector register, no scalar register + //if (score > max_f) { + // max_f = score; + // max_j = j; + //} + + int32_t max_local = svmaxv_s64(valid_elements,score); + if (max_local > max_f) { + max_f = max_local; + // WARNING + svint64_t index = svindex_s64(real_j,1); + svbool_t max_index = svcmpeq_n_s64(valid_elements,score,max_local); + max_j = svlastb_s64(max_index,index); + } + } + ret->scores[i] = max_f; + ret->parents[i] = max_j; + //if (max_f == 36 && max_j == 18821) {printf("ee\n");exit(0);} + ret->peak_scores[i] = max_j >= 0 && ret->peak_scores[max_j] > max_f ? ret->peak_scores[max_j] : max_f; + } +#else // SCALAR VERSION + printf("Executing SCALAR version"); + for (int64_t i = 0; i < n; ++i) { + const auto ri = anchors_x[i]; + const int32_t qi = static_cast(anchors_y[i]); + const int32_t q_spani = qspans[i]; + + int64_t max_j = -1; + int32_t max_f = q_spani; + + while (st < i && ri > anchors_x[st] + max_dist_x) { + ++st; //predecessor's position is too far + } + if (i - st > max_iter) { + st = i - max_iter; //predecessor's index is too far + } + + // TODO: Iterate forward to vectorize the loop. + // for (int64_t j_inv = st; j_inv < i; ++j_inv) { + // const int64_t j = (i - 1) - j_inv + st; + for (int64_t j = i - 1; j >= st; --j) { + const auto rj = anchors_x[j]; + const int32_t qj = static_cast(anchors_y[j]); + + const int64_t dr = ri - rj; + const int32_t dq = qi - qj; + + const int32_t dd = std::abs(dr - dq); + + if ((dr == 0 || dq <= 0) || + (dq > max_dist_y || dq > max_dist_x) || + (dd > bw)) { + continue; + } + // Can not vectorize "continue". Use a mask instead. + // const bool skip = ((dr == 0 || dq <= 0) || + // (dq > max_dist_y || dq > max_dist_x) || + // (dd > bw)); + + const int64_t dr_dq_min = (dr < dq) ? dr : dq; + const int32_t oc = (dr_dq_min < q_spani) ? dr_dq_min : q_spani; + + // TODO: CAN'T VECTORIZE __builtin_clz + const int32_t log_dd = (dd) ? ilog2_32(dd) : 0; + const int32_t gap_cost = static_cast(dd * 0.01f * avg_qspan) + (log_dd >> 1); + + const int32_t score = ret->scores[j] + oc - gap_cost; + + // TODO: CAN'T VECTORIZE THIS. THE COMPILER IS ONLY ABLE TO PERFORM + // ONE REDUCTION. + // Ideas: + // 1. j to int32_t and: uint64_t max = (j << 32) & score. + // max = ((int32_t)(max & 0xffffffff) > score) ? max = (j << 32) & score : max; + // Reduction + if (score > max_f) { + max_f = score; + max_j = j; + } + } + ret->scores[i] = max_f; + ret->parents[i] = max_j; + ret->peak_scores[i] = max_j >= 0 && ret->peak_scores[max_j] > max_f ? ret->peak_scores[max_j] : max_f; + } +#endif + +void host_chain_kernel(std::vector &args, std::vector &rets, int numThreads) { + #pragma omp parallel num_threads(numThreads) + { + #pragma omp for schedule(dynamic) + for (size_t batch = 0; batch < args.size(); batch++) { + call_t *arg = &args[batch]; + return_t *ret = &rets[batch]; + // fprintf(stderr, "%lld\t%f\t%d\t%d\t%d\t%d\n", arg->n, arg->avg_qspan, arg->max_dist_x, arg->max_dist_y, arg->bw, arg->n_segs); + chain_dp(arg, ret); + } + } +} diff --git a/benchmarks/fast-chain/src/host_kernel.h b/benchmarks/fast-chain/src/host_kernel.h new file mode 100644 index 0000000..1f011d4 --- /dev/null +++ b/benchmarks/fast-chain/src/host_kernel.h @@ -0,0 +1,8 @@ +#ifndef HOST_KERNEL_H +#define HOST_KERNEL_H + +#include "host_data.h" + +void host_chain_kernel(std::vector &arg, std::vector &ret, int numThreads); + +#endif // HOST_KERNEL_H diff --git a/benchmarks/fast-chain/src/main.cpp b/benchmarks/fast-chain/src/main.cpp new file mode 100644 index 0000000..f1103d4 --- /dev/null +++ b/benchmarks/fast-chain/src/main.cpp @@ -0,0 +1,116 @@ +#include +#include +#include +#include +#include +#include +#include +#include "omp.h" +#include "host_data_io.h" +#include "host_data.h" +#include "host_kernel.h" + +// #define PRINT_OUTPUT 1 + +// #define VTUNE_ANALYSIS 1 + +#if VTUNE_ANALYSIS + #include +#endif + +void help() { + std::cout << + "\n" + "usage: ./chain [options ...]\n" + "\n" + " options:\n" + " -i \n" + " default: NULL\n" + " the input anchor set\n" + " -o \n" + " default: NULL\n" + " the output scores, best predecessor set\n" + " -t \n" + " default: 1\n" + " number of CPU threads\n" + " -h \n" + " prints the usage\n"; +} + + +int main(int argc, char **argv) { +#if VTUNE_ANALYSIS + __itt_pause(); +#endif + FILE *in, *out; + std::string inputFileName, outputFileName; + + int opt, numThreads = 1; + while ((opt = getopt(argc, argv, ":i:o:t:h")) != -1) { + switch (opt) { + case 'i': inputFileName = optarg; break; + case 'o': outputFileName = optarg; break; + case 't': numThreads = atoi(optarg); break; + case 'h': help(); return 0; + default: help(); return 1; + } + } + + if (argc == 1 || argc != optind) { + help(); + exit(EXIT_FAILURE); + } + + fprintf(stderr, "Input file: %s\n", inputFileName.c_str()); + fprintf(stderr, "Output file: %s\n", outputFileName.c_str()); + + in = fopen(inputFileName.c_str(), "r"); + out = fopen(outputFileName.c_str(), "w"); + + std::vector calls; + std::vector rets; + + for (call_t call = read_call(in); + call.n != ANCHOR_NULL; + call = read_call(in)) { + calls.push_back(call); + } + + rets.resize(calls.size()); + +#pragma omp parallel num_threads(numThreads) +{ + int tid = omp_get_thread_num(); + if (tid == 0) { + fprintf(stderr, "Running with threads: %d\n", numThreads); + } +} + + struct timeval start_time, end_time; + double runtime = 0; + + gettimeofday(&start_time, NULL); +#if VTUNE_ANALYSIS + __itt_resume(); +#endif + host_chain_kernel(calls, rets, numThreads); +#if VTUNE_ANALYSIS + __itt_pause(); +#endif + gettimeofday(&end_time, NULL); + + runtime += (end_time.tv_sec - start_time.tv_sec) * 1e6 + (end_time.tv_usec - start_time.tv_usec); + +#if PRINT_OUTPUT + for (auto it = rets.begin(); it != rets.end(); it++) { + print_return(out, *it); + } +#endif + + fprintf(stderr, "Time in kernel: %.2f sec\n", runtime * 1e-6); + + fclose(in); + fclose(out); + + return 0; +} diff --git a/benchmarks/kmer-cnt/kmer_cnt.cpp b/benchmarks/kmer-cnt/kmer_cnt.cpp index 1f48be2..dda7cf2 100644 --- a/benchmarks/kmer-cnt/kmer_cnt.cpp +++ b/benchmarks/kmer-cnt/kmer_cnt.cpp @@ -29,7 +29,8 @@ bool parseArgs(int argc, char** argv, std::string& readsFasta, std::string& logFile, int& kmerSize, bool& debug, size_t& numThreads, int& minOverlap, - std::string& configPath, int& minReadLength, bool& unevenCov) + std::string& configPath, int& minReadLength, bool& unevenCov, + bool& highmem) { auto printUsage = []() { @@ -49,7 +50,9 @@ bool parseArgs(int argc, char** argv, std::string& readsFasta, << " --log log_file\toutput log to file " << "[default = not set] \n" << " --threads num_threads\tnumber of parallel threads " - << "[default = 1] \n"; + << "[default = 1] \n" + << " --highmem \t\tuse more memory for faster kmer counting " + << "[default = false] \n"; }; int optionIndex = 0; @@ -63,6 +66,7 @@ bool parseArgs(int argc, char** argv, std::string& readsFasta, {"kmer", required_argument, 0, 0}, {"min-ovlp", required_argument, 0, 0}, {"debug", no_argument, 0, 0}, + {"highmem", no_argument, 0, 0}, {0, 0, 0, 0} }; @@ -90,6 +94,8 @@ bool parseArgs(int argc, char** argv, std::string& readsFasta, readsFasta = optarg; else if (!strcmp(longOptions[optionIndex].name, "config")) configPath = optarg; + else if (!strcmp(longOptions[optionIndex].name, "highmem")) + highmem = true; break; case 'h': @@ -145,10 +151,11 @@ int main(int argc, char** argv) std::string readsFasta; std::string logFile; std::string configPath; + bool highmem = false; if (!parseArgs(argc, argv, readsFasta, logFile, kmerSize, debugging, numThreads, minOverlap, configPath, - minReadLength, unevenCov)) return 1; + minReadLength, unevenCov, highmem)) return 1; Logger::get().setDebugging(debugging); if (!logFile.empty()) Logger::get().setOutputFile(logFile); @@ -198,8 +205,9 @@ int main(int argc, char** argv) return 1; } readsContainer.buildPositionIndex(); - VertexIndex vertexIndex(readsContainer, - (int)Config::get("assemble_kmer_sample")); + VertexIndex vertexIndex(readsContainer, + (int)Config::get("assemble_kmer_sample"), + highmem); vertexIndex.outputProgress(true); /*int64_t sumLength = 0; diff --git a/benchmarks/kmer-cnt/parallel.h b/benchmarks/kmer-cnt/parallel.h index a06f52c..7d139fe 100644 --- a/benchmarks/kmer-cnt/parallel.h +++ b/benchmarks/kmer-cnt/parallel.h @@ -22,10 +22,11 @@ void processInParallel(const std::vector& scheduledTasks, ProgressPercent progress(scheduledTasks.size()); if (progressBar) progress.advance(0); - auto threadWorker = [&jobId, &scheduledTasks, &updateFun, - &progress, progressBar]() + #pragma omp parallel for + for (size_t i = 0; i < std::min(maxThreads, scheduledTasks.size()); ++i) { - while (true) + bool finished = false; + while (!finished) { size_t expected = 0; while(true) @@ -33,27 +34,19 @@ void processInParallel(const std::vector& scheduledTasks, expected = jobId; if (jobId == scheduledTasks.size()) { - return; + finished = true; + break; } if (jobId.compare_exchange_weak(expected, expected + 1)) { break; } } - updateFun(scheduledTasks[expected]); - if (progressBar) progress.advance(); + if (!finished) { + updateFun(scheduledTasks[expected]); + if (progressBar) progress.advance(); + } } - }; - - std::vector threads(std::min(maxThreads, - scheduledTasks.size())); - for (size_t i = 0; i < threads.size(); ++i) - { - threads[i] = std::thread(threadWorker); - } - for (size_t i = 0; i < threads.size(); ++i) - { - threads[i].join(); } } diff --git a/benchmarks/kmer-cnt/vertex_index.cpp b/benchmarks/kmer-cnt/vertex_index.cpp index 29159fb..ac70793 100644 --- a/benchmarks/kmer-cnt/vertex_index.cpp +++ b/benchmarks/kmer-cnt/vertex_index.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "vertex_index.h" #include "logger.h" @@ -15,7 +16,6 @@ #include "config.h" #include "memory_info.h" - void VertexIndex::countKmers() { _kmerCounter.count(/*use flat counter*/ true); @@ -509,21 +509,24 @@ void VertexIndex::clear() //_kmerCounts.reserve(0); } - void KmerCounter::count(bool useFlatCounter) { //Logger::get().debug() << "Before counter: " // << getPeakRSS() / 1024 / 1024 / 1024 << " Gb"; + _numKmers = 0; + if (useFlatCounter && Parameters::get().kmerSize > 17) { throw std::runtime_error("Can't use flat counter for k-mer size > 17"); } _useFlatCounter = useFlatCounter; - //flat array for all possible k-mers, 4 bits for each - //in case of k=17, takes 8Gb - static const size_t COUNTER_LEN = std::pow(4, Parameters::get().kmerSize) / 2; + //flat array for all possible k-mers, 8 bits for each + //in case of k=17, takes 16Gb + + static const size_t COUNTER_LEN = _highmem ? std::pow(4, Parameters::get().kmerSize) : + std::pow(4, Parameters::get().kmerSize) / 2; if (useFlatCounter) { _flatCounter = new std::atomic[COUNTER_LEN]; @@ -535,31 +538,52 @@ void KmerCounter::count(bool useFlatCounter) [this] (const FastaRecord::Id& readId) { if (!readId.strand()) return; + + size_t pnumKmers = 0; for (auto kmerPos : IterKmers(_seqContainer.getSeq(readId))) { kmerPos.kmer.standardForm(); bool addOne = true; + if (_useFlatCounter) { - size_t arrayPos = kmerPos.kmer.numRepr() / 2; - bool highBits = kmerPos.kmer.numRepr() % 2; - - while (true) + if (_highmem) { - uint8_t expected = _flatCounter[arrayPos]; - uint8_t count = highBits ? (expected >> 4) : (expected & 15); - if (count == 15) + addOne = false; + + size_t arrayPos = kmerPos.kmer.numRepr(); + uint8_t old = _flatCounter[arrayPos].fetch_add(1, std::memory_order_relaxed); + + if (old == 0 && !_hashCounter.contains(kmerPos.kmer)) + { + ++pnumKmers; + } + else if (old == 255) { - break; + addOne = true; } + } + else { + size_t arrayPos = kmerPos.kmer.numRepr() / 2; + bool highBits = kmerPos.kmer.numRepr() % 2; - uint8_t updated = highBits ? (expected + 16) : (expected + 1); - if (_flatCounter[arrayPos].compare_exchange_weak(expected, updated)) + while (true) { - if (count == 0) ++_numKmers; - addOne = false; //not saturated yet, don't update hash counter - break; + uint8_t expected = _flatCounter[arrayPos]; + uint8_t count = highBits ? (expected >> 4) : (expected & 15); + if (count == 15) + { + break; + } + + uint8_t updated = highBits ? (expected + 16) : (expected + 1); + if (_flatCounter[arrayPos].compare_exchange_weak(expected, updated)) + { + if (count == 0) ++pnumKmers; + addOne = false; //not saturated yet, don't update hash counter + break; + } } } } @@ -569,6 +593,8 @@ void KmerCounter::count(bool useFlatCounter) _hashCounter.upsert(kmerPos.kmer, [](size_t& num){++num;}, 1); } } + + _numKmers += pnumKmers; }; std::vector allReads; for (const auto& seq : _seqContainer.iterSeqs()) @@ -583,58 +609,48 @@ void KmerCounter::count(bool useFlatCounter) else { processInParallel(allReads, readUpdate, Parameters::get().numThreads, _outputProgress); } - /* - Logger::get().debug() << "Updating k-mer histogram"; - if (_useFlatCounter) - { - for (size_t kmerId = 0; kmerId < COUNTER_LEN * 2; ++kmerId) - { - Kmer kmer(kmerId); - size_t freq = this->getFreq(kmer); - if (freq > 0) _kmerDistribution[freq] += 1; - // if (kmerId % 1000000 == 0) Logger::get().debug() << kmerId << " " << freq; - } - } - else - { - for (const auto& kmer : _hashCounter.lock_table()) - { - _kmerDistribution[kmer.second] += 1; - } - } - */ - //Logger::get().debug() << "After counter: " - // << getPeakRSS() / 1024 / 1024 / 1024 << " Gb"; + if (!_useFlatCounter) { + _numKmers = _hashCounter.size(); + } Logger::get().debug() << "Hash size: " << _hashCounter.size(); Logger::get().debug() << "Total k-mers " << _numKmers; } - size_t KmerCounter::getFreq(Kmer kmer) const { - //kmer.standardForm(); + size_t freq; + if (!_hashCounter.find(kmer, freq)) + { + freq = 0; + } - size_t addCount = 0; if (_useFlatCounter) { - size_t arrayPos = kmer.numRepr() / 2; - bool highBits = kmer.numRepr() % 2; - uint8_t count = highBits ? (_flatCounter[arrayPos]) >> 4 : (_flatCounter[arrayPos] & 15); - if (count < 15) + if (_highmem) { - return count; + size_t arrayPos = kmer.numRepr(); + freq = freq * 255 + _flatCounter[arrayPos]; } else { - addCount = count; + size_t arrayPos = kmer.numRepr() / 2; + bool highBits = kmer.numRepr() % 2; + uint8_t count = highBits ? (_flatCounter[arrayPos]) >> 4 : (_flatCounter[arrayPos] & 15); + + if (count < 15) + { + freq = count; + } + else + { + freq += count; + } } } - size_t freq = 0; - _hashCounter.find(kmer, freq); - return freq + addCount; + return freq; } void KmerCounter::clear() @@ -650,6 +666,5 @@ void KmerCounter::clear() size_t KmerCounter::getKmerNum() const { - if (!_useFlatCounter) return _hashCounter.size(); return _numKmers; } diff --git a/benchmarks/kmer-cnt/vertex_index.h b/benchmarks/kmer-cnt/vertex_index.h index 1209308..ec0d324 100644 --- a/benchmarks/kmer-cnt/vertex_index.h +++ b/benchmarks/kmer-cnt/vertex_index.h @@ -25,8 +25,8 @@ typedef std::map KmerDistribution; class KmerCounter { public: - KmerCounter(const SequenceContainer& seqContainer): - _seqContainer(seqContainer), + KmerCounter(const SequenceContainer& seqContainer, bool highmem): + _seqContainer(seqContainer), _highmem(highmem), _flatCounter(nullptr), _numKmers(0) {} @@ -54,6 +54,7 @@ class KmerCounter const SequenceContainer& _seqContainer; bool _outputProgress; bool _useFlatCounter; + bool _highmem; std::atomic* _flatCounter; //std::vector> _flatCounter; @@ -70,10 +71,10 @@ class VertexIndex { this->clear(); } - VertexIndex(const SequenceContainer& seqContainer, float sampleRate): + VertexIndex(const SequenceContainer& seqContainer, float sampleRate, bool highmem): _seqContainer(seqContainer), _outputProgress(false), _sampleRate(sampleRate), _repetitiveFrequency(0), - _kmerCounter(seqContainer) + _kmerCounter(seqContainer, highmem) //_solidMultiplier(1) //_flankRepeatSize(flankRepeatSize) {} diff --git a/benchmarks/nn-variant/Clair3/Dockerfile b/benchmarks/nn-variant/Clair3/Dockerfile new file mode 100644 index 0000000..c3e2447 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/Dockerfile @@ -0,0 +1,51 @@ +FROM ubuntu:16.04 + +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 PATH=/opt/bin:/opt/conda/bin:$PATH + +# update ubuntu packages +RUN apt-get update --fix-missing && \ + yes|apt-get upgrade && \ + apt-get install -y \ + wget \ + bzip2 \ + make \ + g++ \ + libboost-graph-dev && \ + rm -rf /bar/lib/apt/lists/* + +WORKDIR /opt/bin + +# install anaconda +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ + bash Miniconda3-latest-Linux-x86_64.sh -b -p /opt/conda && \ + rm Miniconda3-latest-Linux-x86_64.sh && \ + conda config --add channels defaults && \ + conda config --add channels bioconda && \ + conda config --add channels conda-forge && \ + conda create -n clair3 python=3.6.10 -y + +ENV PATH /opt/conda/envs/clair3/bin:$PATH +ENV CONDA_DEFAULT_ENV clair3 + +RUN /bin/bash -c "source activate clair3" && \ + conda install -c conda-forge pypy3.6 -y && \ + pypy3 -m ensurepip && \ + pypy3 -m pip install mpmath==1.2.1 && \ + pip install tensorflow-cpu==2.2.0 && \ + pip install tensorflow-addons==0.11.2 tables==3.6.1 && \ + conda install -c anaconda pigz==2.4 -y && \ + conda install -c conda-forge parallel=20191122 zstd=1.4.4 -y && \ + conda install -c conda-forge -c bioconda samtools=1.10 -y && \ + conda install -c conda-forge -c bioconda whatshap=1.0 -y && \ + rm -rf /opt/conda/pkgs/* && \ + rm -rf /root/.cache/pip + +COPY . . + +RUN cd /opt/bin/preprocess/realign && \ + g++ -std=c++14 -O1 -shared -fPIC -o realigner ssw_cpp.cpp ssw.c realigner.cpp && \ + g++ -std=c++11 -shared -fPIC -o debruijn_graph -O3 debruijn_graph.cpp && \ + wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz -P /opt/models && \ + tar -zxvf /opt/models/clair3_models.tar.gz -C /opt/models && \ + rm /opt/models/clair3_models.tar.gz && \ + echo "source activate clair3" > ~/.bashrc \ No newline at end of file diff --git a/benchmarks/nn-variant/Clair3/LICENSE.md b/benchmarks/nn-variant/Clair3/LICENSE.md new file mode 100644 index 0000000..0c41a5a --- /dev/null +++ b/benchmarks/nn-variant/Clair3/LICENSE.md @@ -0,0 +1,26 @@ +Copyright 2021 The University of Hong Kong, Department of Computer Science + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/benchmarks/nn-variant/Clair3/README.md b/benchmarks/nn-variant/Clair3/README.md new file mode 100644 index 0000000..e27da3b --- /dev/null +++ b/benchmarks/nn-variant/Clair3/README.md @@ -0,0 +1,542 @@ + + +# Clair3 - Symphonizing pileup and full-alignment for high-performance long-read variant calling + +[![License](https://img.shields.io/badge/License-BSD%203--Clause-blue.svg)](https://opensource.org/licenses/BSD-3-Clause) [![install with bioconda](https://img.shields.io/badge/install%20with-bioconda-brightgreen.svg?style=flat)](http://bioconda.github.io/recipes/clair3/README.html) + +Contact: Ruibang Luo, Zhenxian Zheng + +Email: rbluo@cs.hku.hk, zxzheng@cs.hku.hk + +---- + +## Introduction + +Clair3 is a germline small variant caller for long-reads. Clair3 makes the best of two major method categories: pileup calling handles most variant candidates with speed, and full-alignment tackles complicated candidates to maximize precision and recall. Clair3 runs fast and has superior performance, especially at lower coverage. Clair3 is simple and modular for easy deployment and integration. + +Clair3 is the 3rd generation of [Clair](https://github.com/HKU-BAL/Clair) (the 2nd) and [Clairvoyante](https://github.com/aquaskyline/Clairvoyante) (the 1st). + +A short preprint describing Clair3's algorithms and results is at [bioRxiv](https://www.biorxiv.org/content/10.1101/2021.12.29.474431v1). + +---- + +## Contents + +* [Introduction](#introduction) +* [Latest Updates](#latest-updates) +* [Pre-trained Models](#pre-trained-models) + * [Guppy5,6 Model](docs/guppy5_20220113.md) + * [R10.4 with the Kit 12 chemistry (Q20) Models](#ont-provided-models) + * [Guppy3,4 Model](#pre-trained-models) + * [Guppy2 Model](docs/guppy2.md) +* [What's New in Clair3](#whats-new-in-clair3) +* [Installation](#installation) + + [Option 1. Docker pre-built image](#option-1--docker-pre-built-image) + + [Option 2. Singularity](#option-2-singularity) + + [Option 3. Bioconda](#option-3--bioconda) + + [Option 4. Build an anaconda virtual environment](#option-4-build-an-anaconda-virtual-environment) + + [Option 5. Docker Dockerfile](#option-5-docker-dockerfile) +* [Quick Demo](#quick-demo) +* [Usage](#usage) +* [Folder Structure and Submodule Descriptions](#folder-structure-and-submodule-descriptions) +* [Training Data](docs/training_data.md) +* [VCF/GVCF Output Formats](#vcfgvcf-output-formats) +* [Pileup Model Training](docs/pileup_training.md) +* [Full-Alignment Model Training](docs/full_alignment_training_r1.md) +* [Representation Unification](docs/representation_unification.md) +* [Visualization](docs) + * [Model Input](docs/model_input_visualization.md) + * [Representation Unification](docs/representation_unification_visualization.md) + +---- + +## Latest Updates + +*v0.1-r10 (Jan 13)* : 1. Added a new ONT Guppy5 model (`r941_prom_sup_g5014`). Click [here](docs/guppy5_20220113.md) for some benchmarking results. This `sup` model is also applicable to reads called using the `hac` and `fast` mode. The old `r941_prom_sup_g506` model that was fine-tuned from the Guppy3,4 model is obsoleted. 2. Added `--var_pct_phasing` option to control the percentage of top ranked heterozygous pile-up variants used for WhatsHap phasing. + +*v0.1-r9 (Dec 1)* : Added the `--enable_long_indel` option to output indel variant calls >50bp ([#64](https://github.com/HKU-BAL/Clair3/issues/64)), Click [here](https://github.com/HKU-BAL/Clair3/blob/main/docs/indel_gt50_performance.md) to see more benchmarking results. + +*v0.1-r8 (Nov 11)* : 1. Added the `--enable_phasing` option that adds a step after Clair3 calling to output variants phased by WhatsHap ([#63](https://github.com/HKU-BAL/Clair3/issues/63)). 2. Fixed unexpected program termination on successful runs. + +*v0.1-r7 (Oct 18)* : 1. Increased `var_pct_full` in ONT mode from 0.3 to 0.7. Indel F1-score increased ~0.2%, but took ~30 minutes longer to finish calling a ~50x ONT dataset. 2. Expand fall through to next most likely variant if network prediction has insufficient read coverage ([#53](https://github.com/HKU-BAL/Clair3/pull/53) commit 09a7d185, contributor @[ftostevin-ont](https://github.com/ftostevin-ont)), accuracy improved on complex Indels. 3. Streamized pileup and full-alignment training workflows. Reduce diskspace demand in model training ([#55](https://github.com/HKU-BAL/Clair3/pull/55) commit 09a7d185, contributor @[ftostevin-ont](https://github.com/ftostevin-ont)). 4. Added `mini_epochs` option in Train.py, performance slightly improved in training a model for ONT Q20 data using mini-epochs([#60](https://github.com/HKU-BAL/Clair3/pull/60), contributor @[ftostevin-ont](https://github.com/ftostevin-ont)). 5. Massively reduced disk space demand when outputting GVCF. Now compressing GVCF intermediate files with lz4, five times smaller with little speed penalty. 6. Added `--remove_intermediate_dir`to remove intermediate files as soon as no longer needed ([#48](https://github.com/HKU-BAL/Clair3/issues/48)). 7. Renamed ONT pre-trained models with [Medaka](https://github.com/nanoporetech/medaka/blob/master/medaka/options.py#L22)'s naming convention. 8. Fixed training data spilling over to validation data ([#57](https://github.com/HKU-BAL/Clair3/issues/57)). + +*ONT-provided Models (Sep 23)*: ONT also provides Clair3 models for specific chemistries and basecallers through [Rerio](https://github.com/nanoporetech/rerio). + +*v0.1-r6 (Sep 4)* : 1. Reduced memory footprint at the `SortVcf` stage([#45](https://github.com/HKU-BAL/Clair3/issues/45)). 2. Reduced `ulimit -n` (number of files simultaneously opened) requirement ([#45](https://github.com/HKU-BAL/Clair3/issues/45), [#47](https://github.com/HKU-BAL/Clair3/issues/47)). 3. Added Clair3-Illumina package in bioconda([#42](https://github.com/HKU-BAL/Clair3/issues/42)). + +*v0.1-r5 (July 19)* : 1. Modified data generator in model training to avoid memory exhaustion and unexpected segmentation fault by Tensorflow (contributor @[ftostevin-ont](https://github.com/ftostevin-ont) ). 2. Simplified dockerfile workflow to reuse container caching (contributor @[amblina](https://github.com/amblina)). 3. Fixed ALT output for reference calls (contributor @[wdecoster](https://github.com/wdecoster)). 4. Fixed a bug in multi-allelic AF computation (AF of [ACGT]Del variants was wrong before r5). 5. Added AD tag to the GVCF output. 6. Added the `--call_snp_only` option to only call SNP only ([#40](https://github.com/HKU-BAL/Clair3/issues/40)). 7. Added pileup and full-alignment output validity check to avoid workflow crashing ([#32](https://github.com/HKU-BAL/Clair3/issues/32), [#38](https://github.com/HKU-BAL/Clair3/issues/38)). + +*v0.1-r4 (June 28)* : 1. Install via [bioconda](https://github.com/HKU-BAL/Clair3#option-3--bioconda). 2. Added an ONT Guppy2 model to the images (`ont_guppy2`). Click [here](https://github.com/HKU-BAL/Clair3/blob/main/docs/guppy2.md) for more benchmarking results. **The results show you have to use the Guppy2 model for Guppy2 or earlier data**. 3. Added [google colab notebooks](https://github.com/HKU-BAL/Clair3/blob/main/colab) for quick demo. 4. Fixed a bug when there are too few variant candidates ([#28](https://github.com/HKU-BAL/Clair3/issues/28)). + +*v0.1-r3 (June 9)* : 1. Added `ulimit -u` (max user processes) check (lowers the `THREADS` if the resource is insufficient) and automatic retries on failed jobs ([#20](https://github.com/HKU-BAL/Clair3/issues/20), [#23](https://github.com/HKU-BAL/Clair3/issues/23), [#24](https://github.com/HKU-BAL/Clair3/issues/24)). 2. Added an ONT Guppy5 model to the images (`ont_guppy5`). Click [here](docs/guppy5.md) for more benchmarks on the Guppy5 model and data. + +*v0.1-r2 (May 23)* : 1. Fixed BED file out of range error ([#12](https://github.com/HKU-BAL/Clair3/issues/12)). 2. Added support for both `.bam.bai` and `.bai` BAM index filename ([#10](https://github.com/HKU-BAL/Clair3/issues/10)). 3. Added some boundary checks on inputs. 4. Added version checks on required packages and utilities. 5. Increased pipeline robusity. + +*v0.1-r1 (May 18)* : 1. Support relative path in Conda, but Docker and Singularity still require absolute path ([#5](https://github.com/HKU-BAL/Clair3/issues/5)). 2. Fix `taskset` CPU-core visibility and provide a Singularity image ([#6](https://github.com/HKU-BAL/Clair3/issues/6)). + +*v0.1 (May 17)*: Initial release. + +--- + +## Pre-trained Models + +### HKU-provided Models + +Download models from [here](http://www.bio8.cs.hku.hk/clair3/clair3_models/) or click on the links below. + +In a docker installation, models are in `/opt/models/`. In a bioconda installation, models are in `{CONDA_PREFIX}/bin/models/`. + +| Model name | Platform | Training samples | Included in the bioconda package | Included in the docker image | Date | Basecaller | File | Link | +| :----------------------------: | :---------: | :----------------------------------------------------------: | -------------------------------- | :--------------------------: | :------: | :----------: | ----------------------------------- | :----------------------------------------------------------: | +| r941_prom_sup_g5014 | ONT | HG002,4,5 (Guppy5_sup) | Yes | Yes | 20220112 | Guppy5 sup | r941_prom_sup_g5014.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_sup_g5014.tar.gz) | +| r941_prom_hac_g360+g422 | ONT | HG001,2,4,5 | Yes | Yes | 20210517 | Guppy3,4 hac | r941_prom_hac_g360+g422.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422.tar.gz) | +| r941_prom_hac_g360+g422_1235 | ONT | HG001,2,3,5 | | | 20210517 | Guppy3,4 hac | r941_prom_hac_g360+g422_1235.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g360+g422_1235.tar.gz) | +| r941_prom_hac_g238 | ONT | HG001,2,3,4 | | Yes | 20210627 | Guppy2 | r941_prom_hac_g238.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_hac_g238.tar.gz) | +| ~~r941_prom_sup_g506~~ | ONT | Base model: HG001,2,4,5 (Guppy3,4)
Fine-tuning data: HG002 (Guppy5_sup) | | | 20210609 | Guppy5 sup | r941_prom_sup_g506.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/r941_prom_sup_g506.tar.gz) | +| hifi | PacBio HiFi | HG001,2,4,5 | Yes | Yes | 20210517 | NA | hifi.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/hifi.tar.gz) | +| ilmn | Illumina | HG001,2,4,5 | Yes | Yes | 20210517 | NA | ilmn.tar.gz | [Download](http://www.bio8.cs.hku.hk/clair3/clair3_models/ilmn.tar.gz) | + +### ONT-provided Models + +ONT provides models for some latest or specific chemistries and basecallers through [Rerio](https://github.com/nanoporetech/rerio). These models are tested and supported by the ONT developers. Avaiable model in Rerio including: + +| Config | Chemistry | Guppy basecaller | +| :----------------: | :-------: | :--------------: | +| r104_e81_sup_g5015 | R10.4 E8.1 | v5.0.15 SUP | +| r104_e81_hac_g5015 | R10.4 E8.1 | v5.0.15 HAC | + +---- + +## What's New in Clair3 + +* **New Architecture.** Clair3 integrates both pileup (summarized alignment statistics) model and full-alignment model for variant calling. While a pileup model determines the result of a majority of variant candidates, candidates with uncertain results are further processed with a more computational-intensive haplotype-resolved full-alignment model. +* **Improved Performance.** Using HG003 85-fold coverage ONT data from PrecisionFDA for benchmarking, Clair3 achieved 99.69% SNP F1-score and 80.58% Indel F1-score. Compare to Clair, Clair3 reduced SNP errors by **~78%**, and Indel errors by **~48%**. +* **High Efficiency.** Using 36 CPU cores, + * Clair3 takes ~8 hours to process 50-fold WGS ONT data (~4x faster than PEPPER (r0.4) and ~14x faster than Medaka (v1.3.2)). Memory consumption of Clair3 is capped at 1 GB per CPU thread, which is roughly five times lower than Clair. + * Clair3 takes ~2 hours to process 35-fold WGS PacBio HiFi data (13x faster than DeepVariant (v1.1.0)). +* **Using data from newer basecallers.** Clair3 models were trained using data from Guppy version 3.6.0 and 4.2.2, please check [Training Data](docs/training_data.md) for details and links. +* **GVCF Support.** Clair3 can output GVCF using the ```--gvcf``` option, enabling downstream joint-sample genotyping and cohort merging. + +---- + +## Quick Demo + +* Oxford Nanopore (ONT) data, see [ONT Quick Demo](docs/quick_demo/ont_quick_demo.md). +* PacBio HiFi data, see [PaBio HiFi Quick Demo](docs/quick_demo/pacbio_hifi_quick_demo.md). +* Illumina NGS data, see [Illumina Quick Demo](docs/quick_demo/illumina_quick_demo.md). + +**Run Clair3 ONT quick demo**: + +- **(Option 1) using Google Colab notebook:** + + [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HKU-BAL/Clair3/blob/main/colab/clair3_ont_quick_demo.ipynb) + +- **(Option 2) using pre-built docker image:** + +```bash +cd ${HOME} +wget "http://www.bio8.cs.hku.hk/clair3/demo/clair3_ont_quick_demo.sh" +chmod +x clair3_ont_quick_demo.sh +./clair3_ont_quick_demo.sh +``` + +Check the results using `less ${HOME}/clair3_ont_quickDemo/output/merge_output.vcf.gz` + +---- + +## Installation + +### Option 1. Docker pre-built image + +A pre-built docker image is available [here](https://hub.docker.com/r/hkubal/clair3). With it you can run Clair3 using a single command. + +**Caution**: Absolute path is needed for both `INPUT_DIR` and `OUTPUT_DIR`. + +```bash +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} ## absolute output path prefix +``` + +Check [Usage](#Usage) for more options. + +### Option 2. Singularity + +**Caution**: Absolute path is needed for both `INPUT_DIR` and `OUTPUT_DIR`. + +```bash +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +conda config --add channels defaults +conda create -n singularity-env -c conda-forge singularity -y +conda activate singularity-env + +# singularity pull docker pre-built image +singularity pull docker://hkubal/clair3:latest + +# run clair3 like this afterward +singularity exec clair3_latest.sif \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} ## absolute output path prefix +``` + +### Option 3. Bioconda + +*For using Clair3 with Illumina data, install [clair3-illumina](https://anaconda.org/bioconda/clair3-illumina) package in bioconda channel instead.* + +```bash +# make sure channels are added in conda +conda config --add channels defaults +conda config --add channels bioconda +conda config --add channels conda-forge + +# create conda environment named "clair3" +# replace clair3 by clair3-illumina for using illumina data +conda create -n clair3 -c bioconda clair3 python=3.6.10 -y +conda activate clair3 + +# run clair3 like this afterward +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +run_clair3.sh \ + --bam_fn=input.bam \ ## change your bam file name here + --ref_fn=ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="${CONDA_PREFIX}/bin/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} ## output path prefix +``` + +Check [Usage](#Usage) for more options. [Pre-trained models](#pre-trained-models) are already included in the bioconda package. + +### Option 4. Build an anaconda virtual environment + +**Anaconda install**: + +Please install anaconda using the official [guide](https://docs.anaconda.com/anaconda/install) or using the commands below: + +```bash +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +chmod +x ./Miniconda3-latest-Linux-x86_64.sh +./Miniconda3-latest-Linux-x86_64.sh +``` + +**Install Clair3 using anaconda step by step:** + +*For using Clair3 on Illumina data, additional installation steps after the following steps are mandatory. Please follow this [guide](https://github.com/HKU-BAL/Clair3/blob/main/docs/quick_demo/illumina_quick_demo.md#step-2-install-boost-graph-library-for-illumina-realignment-process) for the additional steps.* + +```bash +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. ./input +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. ./output +THREADS="[MAXIMUM_THREADS]" # e.g. 8 + +# create and activate an environment named clair3 +conda create -n clair3 python=3.6.10 -y +source activate clair3 + +# install pypy and packages in the environemnt +conda install -c conda-forge pypy3.6 -y +pypy3 -m ensurepip +pypy3 -m pip install mpmath==1.2.1 + +# install python packages in environment +pip3 install tensorflow==2.2.0 +pip3 install tensorflow-addons==0.11.2 tables==3.6.1 +conda install -c anaconda pigz==2.4 -y +conda install -c conda-forge parallel=20191122 zstd=1.4.4 -y +conda install -c conda-forge -c bioconda samtools=1.10 -y +conda install -c conda-forge -c bioconda whatshap=1.0 -y + +# clone Clair3 +git clone https://github.com/HKU-BAL/Clair3.git +cd Clair3 + +# download pre-trained models +mkdir models +wget http://www.bio8.cs.hku.hk/clair3/clair3_models/clair3_models.tar.gz +tar -zxvf clair3_models.tar.gz -C ./models + +# run clair3 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 +./run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path=`pwd`"/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} ## output path prefix +``` + +### Option 5. Docker Dockerfile + +This is the same as option 1 except that you are building a docker image yourself. Please refer to option 1 for usage. + +```bash +# clone Clair3 +git clone https://github.com/hku-bal/Clair3.git +cd Clair3 + +# build a docker image named hkubal/clair3:latest +# might require docker authentication to build docker image +docker build -f ./Dockerfile -t hkubal/clair3:latest . + +# run clair3 docker image like option 1 +docker run -it hkubal/clair3:latest /opt/bin/run_clair3.sh --help +``` + +---- + +## Usage + +### General Usage + +**Caution**: Use `=value` for optional parameters, e.g. `--bed_fn=fn.bed` instead of `--bed_fn fn.bed`. + +```bash +./run_clair3.sh \ + --bam_fn=${BAM} \ + --ref_fn=${REF} \ + --threads=${THREADS} \ + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path=${MODEL_PREFIX} \ ## absolute model path prefix + --output=${OUTPUT_DIR} ## absolute output path prefix +## pileup output file: ${OUTPUT_DIR}/pileup.vcf.gz +## full-alignment output file: ${OUTPUT_DIR}/full_alignment.vcf.gz +## Clair3 final output file: ${OUTPUT_DIR}/merge_output.vcf.gz +``` + +### Options + +**Required parameters:** + +```bash + -b, --bam_fn=FILE BAM file input. The input file must be samtools indexed. + -f, --ref_fn=FILE FASTA reference file input. The input file must be samtools indexed. + -m, --model_path=STR The folder path containing a Clair3 model (requiring six files in the folder, including pileup.data-00000-of-00002, pileup.data-00001-of-00002 pileup.index, full_alignment.data-00000-of-00002, full_alignment.data-00001-of-00002 and full_alignment.index). + -t, --threads=INT Max threads to be used. The full genome will be divided into small chunks for parallel processing. Each chunk will use 4 threads. The chunks being processed simultaneously is ceil($threads/4)*3. 3 is the overloading factor. + -p, --platform=STR Select the sequencing platform of the input. Possible options: {ont,hifi,ilmn}. + -o, --output=PATH VCF/GVCF output directory. +``` + +**Other parameters:** + + **Caution**: Use `=value` for optional parameters, e.g., `--bed_fn=fn.bed` instead of `--bed_fn fn.bed` + +```bash + --bed_fn=FILE Call variants only in the provided bed regions. + --vcf_fn=FILE Candidate sites VCF file input, variants will only be called at the sites in the VCF file if provided. + --ctg_name=STR The name of the sequence to be processed. + --sample_name=STR Define the sample name to be shown in the VCF file. + --qual=INT If set, variants with >$qual will be marked PASS, or LowQual otherwise. + --samtools=STR Path of samtools, samtools version >= 1.10 is required. + --python=STR Path of python, python3 >= 3.6 is required. + --pypy=STR Path of pypy3, pypy3 >= 3.6 is required. + --parallel=STR Path of parallel, parallel >= 20191122 is required. + --whatshap=STR Path of whatshap, whatshap >= 1.0 is required. + --chunk_size=INT The size of each chuck for parallel processing, default: 5Mbp. + --pileup_only Use the pileup model only when calling, default: disable. + --print_ref_calls Show reference calls (0/0) in vcf file, default: disable. + --include_all_ctgs Call variants on all contigs, otherwise call in chr{1..22,X,Y} and {1..22,X,Y}, default: disable. + --gvcf Enable GVCF output, default: disable. + --enable_phasing Output phased variants using whatshap, default: disable. + --remove_intermediate_dir Remove intermediate directory, including intermediate phased BAM, pileup and full-alignment results. default: disable. + --snp_min_af=FLOAT Minimum SNP AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.08,hifi:0.08,ilmn:0.08. + --indel_min_af=FLOAT Minimum INDEL AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.15,hifi:0.08,ilmn:0.08. + --var_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3. + --ref_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/0 variants called in the pileup mode for full-alignment mode calling, default: 0.3 for ilmn and hifi, 0.1 for ont. + --var_pct_phasing=FLOAT EXPERIMENTAL: Specify an expected percentage of high quality 0/1 variants used in WhatsHap phasing, default: 0.8 for ont guppy5 and 0.7 for other platforms. + --pileup_model_prefix=STR EXPERIMENTAL: Model prefix in pileup calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index. default: pileup. + --fa_model_prefix=STR EXPERIMENTAL: Model prefix in full-alignment calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index, default: full_alignment. + --fast_mode EXPERIMENTAL: Skip variant candidates with AF <= 0.15, default: disable. + --haploid_precise EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant, default: disable. + --haploid_sensitive EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant, default: disable. + --no_phasing_for_fa EXPERIMENTAL: Call variants without whatshap phasing in full alignment calling, default: disable. + --call_snp_only EXPERIMENTAL: Call candidates pass SNP minimum AF only, ignore Indel candidates, default: disable. + --enable_long_indel EXPERIMENTAL: Call long Indel variants(>50 bp), default: disable. +``` + +#### Call variants in a chromosome + +```bash +CONTIGS_LIST="[YOUR_CONTIGS_LIST]" # e.g "chr21" or "chr21,chr22" +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} \ ## absolute output path prefix + --ctg_name=${CONTIGS_LIST} +``` + +#### Call variants at known variant sites + +```bash +KNOWN_VARIANTS_VCF="[YOUR_VCF_PATH]" # e.g. /home/user1/known_variants.vcf.gz (absolute path needed) +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} \ ## absolute output path prefix + --vcf_fn=${KNOWN_VARIANTS_VCF} +``` + +#### Call variants at specific sites or bed regions + +We highly recommended using BED file to define the regions of interest like: + +```shell +# define 0-based "ctg start end" if at specific sites +CONTIGS="[YOUR_CONTIGS_NAME]" # e.g. chr22 +START_POS="[YOUR_START_POS]" # e.g. 0 +END_POS="[YOUR_END_POS]" # e.g 10000 +echo -e "${CONTIGS}\t${START_POS}\t${END_POS}" > /home/user1/tmp.bed ## change directory accordingly +``` + +Then run Clair3 like this: + +```bash +BED_FILE_PATH="[YOUR_BED_FILE]" # e.g. /home/user1/tmp.bed (absolute path needed) +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} \ ## absolute output path prefix + --bed_fn=${BED_FILE_PATH} +``` + +#### Call variants in non-diploid organisms (Haploid calling) + +```bash +INPUT_DIR="[YOUR_INPUT_FOLDER]" # e.g. /home/user1/input (absolute path needed) +OUTPUT_DIR="[YOUR_OUTPUT_FOLDER]" # e.g. /home/user1/output (absolute path needed) +THREADS="[MAXIMUM_THREADS]" # e.g. 8 +MODEL_NAME="[YOUR_MODEL_NAME]" # e.g. r941_prom_hac_g360+g422 + +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/input.bam \ ## change your bam file name here + --ref_fn=${INPUT_DIR}/ref.fa \ ## change your reference file name here + --threads=${THREADS} \ ## maximum threads to be used + --platform="ont" \ ## options: {ont,hifi,ilmn} + --model_path="/opt/models/${MODEL_NAME}" \ + --output=${OUTPUT_DIR} \ + --no_phasing_for_fa \ ## disable phasing for full-alignment + --include_all_ctgs \ ## call variants on all contigs in the reference fasta + --haploid_precise ## optional(enable --haploid_precise or --haploid_sensitive) for haploid calling +``` + +---- + +## Folder Structure and Submodule Descriptions + +Submodules in __`clair3/`__ are for variant calling and model training. Submodules in __`preprocess`__ are for data preparation. + +*For all the submodules listed below, you can use `-h` or `--help` for available options.* + +`clair3/` | Note: submodules under this folder are pypy incompatible, please run using python +---: | --- +`CallVariants` | Call variants using a trained model and tensors of candidate variants. +`CallVarBam` | Call variants using a trained model and a BAM file. +`Train` | Training a model using the `RectifiedAdam` optimizer. We also use the `Lookahead` optimizer to adjust the `RectifiedAdam` parameters dynamically. The initial learning rate is `1e-3` with `0.1` learning rate warm-up. Input a binary containing tensors created by `Tensor2Bin`. + + + +`preprocess/` | Note: submodules under this folder is Pypy compatible unless specified. +---: | --- +`CheckEnvs`| Check the environment and validity of the input variables, preprocess the BED input if necessary, `--chunk_size` sets the chuck size to be processed per parallel job. +`CreateTensorPileup`| Generate variant candidate tensors in pileup format for training or calling. +`CreateTensorFullAlignment`| Generate variant candidate tensors in phased full-alignment format for training or calling. +`GetTruth`| Extract the variants from a truth VCF. Input: VCF; Reference FASTA if the VCF contains asterisks in ALT field. +`MergeVcf` | Merge pileup and full-alignment VCF/GVCF. +`RealignReads` | Reads local realignment for Illumina platform. +`SelectCandidates`| Select pileup candidates for full-alignment calling. +`SelectHetSnp` | Select heterozygous SNP candidates for whatshap phasing. +`SelectQual` | Select a quality cutoff using the pileup calling results. Variants below the cutoff are included in phasing and full-alignment calling. +`SortVcf` | Sort VCF file. +`SplitExtendBed` | Split BED file regions according to the contig names and extend bed region by 33bp by default for variant calling. +`UnifyRepresentation` | Representation unification between candidate sites and true variants. +`MergeBin` | Combine tensor binaries into a single file. +`CreateTrainingTensor` | Create tensor binaries for pileup or full-alignment training. +`Tensor2Bin` | Combine the variant and non-variant tensors and convert them to a binary, using `blosc:lz4hc` meta-compressor, the overall training memory is 10~15G (pypy incompatible). + +---- + +## Training Data + +Clair3 trained both its pileup and full-alignment models using four GIAB samples (HG001, HG002, HG004 and HG005), excluded HG003. On ONT, we also trained a model using HG001, 2, 3, and 5, excluded HG004. All models were trained with chr20 excluded (including only chr1-19, 21, 22). + +| Platform | Reference | Aligner | Training samples | +| :---------: | :-----------: | :---------------: | :--------------: | +| ONT | GRCh38_no_alt | minimap2 | HG001,2,(3\|4),5 | +| PacBio HiFi | GRCh38_no_alt | pbmm2 | HG001,2,4,5 | +| Illumina | GRCh38 | BWA-MEM/NovoAlign | HG001,2,4,5 | + +Please find more details about the training data and links at [Training Data](docs/training_data.md). + +---- + +## VCF/GVCF Output Formats + +Clair3 supports both VCF and GVCF output formats. Clair3 uses VCF version 4.2 specifications. Specifically, Clair3 adds a `P` INFO tag to the results called using a pileup model, and a `F` INFO tag to the results called using a full-alignment model. + +Clair3 outputs a GATK-compatible GVCF format that passes GATK's `ValidateVariants` module. Different from DeepVariant that uses `<*>` to represent any possible alternative allele, Clair3 uses ``, the same as GATK. diff --git a/benchmarks/nn-variant/clair/__init__.py b/benchmarks/nn-variant/Clair3/__init__.py similarity index 100% rename from benchmarks/nn-variant/clair/__init__.py rename to benchmarks/nn-variant/Clair3/__init__.py diff --git a/benchmarks/nn-variant/Clair3/callVar.sh b/benchmarks/nn-variant/Clair3/callVar.sh new file mode 100755 index 0000000..95dceb6 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/callVar.sh @@ -0,0 +1,335 @@ +#!/bin/bash +SCRIPT_NAME=$(basename "$0") +SCRIPT_PATH=`dirname "$0"` +VERSION='v0.1-r10' +Usage="Usage: ./${SCRIPT_NAME} --bam_fn=BAM --ref_fn=REF --output=OUTPUT_DIR --threads=THREADS --platform=PLATFORM --model_path=MODEL_PREFIX [--bed_fn=BED] [options]" + +set -e +#./run_clair3.sh -b tmp.bam -f ref.fasta -t 32 -o tmp -p ont -m model_path +print_help_messages() +{ + echo $'' + echo ${Usage} + echo $'' + echo $'Required parameters:' + echo $' -b, --bam_fn=FILE BAM file input. The input file must be samtools indexed.' + echo $' -f, --ref_fn=FILE FASTA reference file input. The input file must be samtools indexed.' + echo $' -m, --model_path=STR The folder path containing a Clair3 model (requiring six files in the folder, including pileup.data-00000-of-00002, pileup.data-00001-of-00002 pileup.index, full_alignment.data-00000-of-00002, full_alignment.data-00001-of-00002 and full_alignment.index).' + echo $' -t, --threads=INT Max #threads to be used. The full genome will be divided into small chunks for parallel processing. Each chunk will use 4 threads. The #chunks being processed simultaneously is ceil(#threads/4)*3. 3 is the overloading factor.' + echo $' -p, --platform=STR Select the sequencing platform of the input. Possible options: {ont,hifi,ilmn}.' + echo $' -o, --output=PATH VCF/GVCF output directory.' + echo $'' + echo $'' + echo $"Optional parameters (Use \"=value\" instead of \" value\". E.g., \"--bed_fn=fn.bed\" instead of \"--bed_fn fn.bed\".):" + echo $' --bed_fn=FILE Call variants only in the provided bed regions.' + echo $' --vcf_fn=FILE Candidate sites VCF file input, variants will only be called at the sites in the VCF file if provided.' + echo $' --ctg_name=STR The name of the sequence to be processed.' + echo $' --sample_name=STR Define the sample name to be shown in the VCF file.' + echo $' --qual=INT If set, variants with >$qual will be marked PASS, or LowQual otherwise.' + echo $' --samtools=STR Path of samtools, samtools version >= 1.10 is required.' + echo $' --python=STR Path of python, python3 >= 3.6 is required.' + echo $' --pypy=STR Path of pypy3, pypy3 >= 3.6 is required.' + echo $' --parallel=STR Path of parallel, parallel >= 20191122 is required.' + echo $' --whatshap=STR Path of whatshap, whatshap >= 1.0 is required.' + echo $' --chunk_size=INT The size of each chuck for parallel processing, default: 5000000.' + echo $' --pileup_only Use the pileup model only when calling, default: disable.' + echo $' --print_ref_calls Show reference calls (0/0) in VCF file, default: disable.' + echo $' --include_all_ctgs Call variants on all contigs, otherwise call in chr{1..22,X,Y} and {1..22,X,Y}, default: disable.' + echo $' --gvcf Enable GVCF output, default: disable.' + echo $' --enable_phasing Output phased variants using whatshap, default: disable.' + echo $' --remove_intermediate_dir Remove intermediate directory, including intermediate phased BAM, pileup and full-alignment results. default: disable.' + echo $' --snp_min_af=FLOAT Minimum SNP AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.08,hifi:0.08,ilmn:0.08.' + echo $' --indel_min_af=FLOAT Minimum Indel AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.15,hifi:0.08,ilmn:0.08.' + echo $' --var_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3.' + echo $' --ref_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/0 variants called in the pileup mode for full-alignment mode calling, default: 0.3 for ilmn and hifi, 0.1 for ont.' + echo $' --var_pct_phasing=FLOAT EXPERIMENTAL: Specify an expected percentage of high quality 0/1 variants used in WhatsHap phasing, default: 0.8 for ont guppy5 and 0.7 for other platforms.' + echo $' --pileup_model_prefix=STR EXPERIMENTAL: Model prefix in pileup calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index. default: pileup.' + echo $' --fa_model_prefix=STR EXPERIMENTAL: Model prefix in full-alignment calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index, default: full_alignment.' + echo $' --fast_mode EXPERIMENTAL: Skip variant candidates with AF <= 0.15, default: disable.' + echo $' --haploid_precise EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant, default: disable.' + echo $' --haploid_sensitive EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant, default: disable.' + echo $' --no_phasing_for_fa EXPERIMENTAL: Call variants without whatshap phasing in full alignment calling, default: disable.' + echo $' --call_snp_only EXPERIMENTAL: Call candidates pass SNP minimum AF only, ignore Indel candidates, default: disable.' + echo $' --enable_long_indel EXPERIMENTAL: Call long Indel variants(>50 bp), default: disable.' + echo $'' +} + +print_version() +{ + echo "Clair3 ${VERSION}" + exit 0 +} + +ERROR="\\033[31m[ERROR]" +WARNING="\\033[33m[WARNING]" +NC="\\033[0m" + +ARGS=`getopt -o b:f:t:m:p:o:hv \ +-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ +bed_fn::,vcf_fn::,ctg_name::,sample_name::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,ref_pct_full::,var_pct_phasing::,\ +snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,\ +remove_intermediate_dir,no_phasing_for_fa,call_snp_only,enable_phasing,enable_long_indel,help,version -n 'run_clair3.sh' -- "$@"` + +if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi +eval set -- "${ARGS}" + +# default options +SAMPLE="SAMPLE" +BED_FILE_PATH="EMPTY" +VCF_FILE_PATH='EMPTY' +CONTIGS="EMPTY" +SAMTOOLS="samtools" +PYPY="pypy3" +PYTHON='python3' +PARALLEL='parallel' +WHATSHAP='whatshap' +CHUNK_NUM=0 +CHUNK_SIZE=5000000 +QUAL=2 +PHASING_PCT="0" +PRO="0" +REF_PRO="0" +GVCF=False +PILEUP_ONLY=False +FAST_MODE=False +SHOW_REF=False +SNP_AF="0" +INDEL_AF="0" +HAP_PRE=False +HAP_SEN=False +SNP_ONLY=False +INCLUDE_ALL_CTGS=False +NO_PHASING=False +RM_TMP_DIR=False +ENABLE_PHASING=False +ENABLE_LONG_INDEL=False +PILEUP_PREFIX="pileup" +FA_PREFIX="full_alignment" + +while true; do + case "$1" in + -b|--bam_fn ) BAM_FILE_PATH="$2"; shift 2 ;; + -f|--ref_fn ) REFERENCE_FILE_PATH="$2"; shift 2 ;; + -t|--threads ) THREADS="$2"; shift 2 ;; + -m|--model_path ) MODEL_PATH="$2"; shift 2 ;; + -p|--platform ) PLATFORM="$2"; shift 2 ;; + -o|--output ) OUTPUT_FOLDER="$2"; shift 2 ;; + --bed_fn ) BED_FILE_PATH="$2"; shift 2 ;; + --vcf_fn ) VCF_FILE_PATH="$2"; shift 2 ;; + --ctg_name ) CONTIGS="$2"; shift 2 ;; + --sample_name ) SAMPLE="$2"; shift 2 ;; + --chunk_num ) CHUNK_NUM="$2"; shift 2 ;; + --chunk_size ) CHUNK_SIZE="$2"; shift 2 ;; + --qual ) QUAL="$2"; shift 2 ;; + --samtools ) SAMTOOLS="$2"; shift 2 ;; + --python ) PYTHON="$2"; shift 2 ;; + --pypy ) PYPY="$2"; shift 2 ;; + --parallel ) PARALLEL="$2"; shift 2 ;; + --whatshap ) WHATSHAP="$2"; shift 2 ;; + --var_pct_full ) PRO="$2"; shift 2 ;; + --ref_pct_full ) REF_PRO="$2"; shift 2 ;; + --var_pct_phasing ) PHASING_PCT="$2"; shift 2 ;; + --snp_min_af ) SNP_AF="$2"; shift 2 ;; + --indel_min_af ) INDEL_AF="$2"; shift 2 ;; + --pileup_model_prefix ) PILEUP_PREFIX="$2"; shift 2 ;; + --fa_model_prefix ) FA_PREFIX="$2"; shift 2 ;; + --gvcf ) GVCF=True; shift 1 ;; + --pileup_only ) PILEUP_ONLY=True; shift 1 ;; + --fast_mode ) FAST_MODE=True; shift 1 ;; + --call_snp_only ) SNP_ONLY=True; shift 1 ;; + --print_ref_calls ) SHOW_REF=True; shift 1 ;; + --haploid_precise ) HAP_PRE=True; shift 1 ;; + --haploid_sensitive ) HAP_SEN=True; shift 1 ;; + --include_all_ctgs ) INCLUDE_ALL_CTGS=True; shift 1 ;; + --no_phasing_for_fa ) NO_PHASING=True; shift 1 ;; + --remove_intermediate_dir ) RM_TMP_DIR=True; shift 1 ;; + --enable_phasing ) ENABLE_PHASING=True; shift 1 ;; + --enable_long_indel ) ENABLE_LONG_INDEL=True; shift 1 ;; + + -- ) shift; break; ;; + -h|--help ) print_help_messages; exit 0 ;; + -v|--version ) print_version; exit 0 ;; + * ) print_help_messages; break ;; + esac +done + +if [ -z ${BAM_FILE_PATH} ] || [ -z ${REFERENCE_FILE_PATH} ] || [ -z ${THREADS} ] || [ -z ${OUTPUT_FOLDER} ] || [ -z ${PLATFORM} ] || [ -z ${MODEL_PATH} ]; then + if [ -z ${BAM_FILE_PATH} ] && [ -z ${REFERENCE_FILE_PATH} ] && [ -z ${THREADS} ] && [ -z ${OUTPUT_FOLDER} ] && [ -z ${PLATFORM} ] && [ -z ${MODEL_PATH} ]; then print_help_messages; exit 0; fi + if [ -z ${BAM_FILE_PATH} ]; then echo -e "${ERROR} Require to define index BAM input by --bam_fn=BAM${NC}"; fi + if [ -z ${REFERENCE_FILE_PATH} ]; then echo -e "${ERROR} Require to define FASTA reference file input by --ref_fn=REF${NC}"; fi + if [ -z ${THREADS} ]; then echo -e "${ERROR} Require to define max threads to be used by --threads=THREADS${NC}"; fi + if [ -z ${OUTPUT_FOLDER} ]; then echo -e "${ERROR} Require to define output folder by --output=OUTPUT_DIR${NC}"; fi + if [ -z ${PLATFORM} ]; then echo -e "${ERROR} Require to define platform by --platform={ont,hifi,ilmn}${NC}"; fi + if [ -z ${MODEL_PATH} ]; then echo -e "${ERROR} Require to define model path by --model_path=MODEL_PREFIX${NC}"; fi + exit 1; +fi + +# force to use absolute path when in docker or singularity environment +if [ `pwd` = "/opt/bin" ]; then + if [[ ! "${BAM_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --bam_fn=FILE${NC}"; exit 1; fi + if [[ ! "${REFERENCE_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --ref_fn=FILE${NC}"; exit 1; fi + if [[ ! "${MODEL_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --model_path=PATH${NC}"; exit 1; fi + if [[ ! "${OUTPUT_FOLDER}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --output=PATH${NC}"; exit 1; fi + if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [[ ! "${BED_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --bef_fn=FILE${NC}"; exit 1; fi + if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [[ ! "${VCF_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --vcf_fn=FILE${NC}"; exit 1; fi +fi + +# relative path support +if [[ ! "${BAM_FILE_PATH}" = /* ]] && [ -f ${BAM_FILE_PATH} ]; then BAM_FILE_PATH=`pwd`/${BAM_FILE_PATH}; fi +if [[ ! "${REFERENCE_FILE_PATH}" = /* ]] && [ -f ${REFERENCE_FILE_PATH} ]; then REFERENCE_FILE_PATH=`pwd`/${REFERENCE_FILE_PATH}; fi +if [[ ! "${MODEL_PATH}" = /* ]] && [ -d ${MODEL_PATH} ]; then MODEL_PATH=`pwd`/${MODEL_PATH}; fi +if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [[ ! "${BED_FILE_PATH}" = /* ]] && [ -f ${BED_FILE_PATH} ]; then BED_FILE_PATH=`pwd`/${BED_FILE_PATH}; fi +if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [[ ! "${VCF_FILE_PATH}" = /* ]] && [ -f ${VCF_FILE_PATH} ]; then VCF_FILE_PATH=`pwd`/${VCF_FILE_PATH}; fi +if [[ ! "${OUTPUT_FOLDER}" = /* ]]; then echo -e "${WARNING} No absolute output path provided, using current directory as prefix${NC}"; OUTPUT_FOLDER=`pwd`/${OUTPUT_FOLDER}; fi + +mkdir -p ${OUTPUT_FOLDER} +if [ ! -d ${OUTPUT_FOLDER} ]; then echo -e "${ERROR} Cannot create output folder ${OUTPUT_FOLDER}${NC}"; exit 1; fi + +# show default reference proportion 0.3 for ilmn and hifi, 0.1 for ont +if [ "${PLATFORM}" = "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.1; fi +if [ "${PLATFORM}" != "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.3; fi + +# show default variant proportion 0.3 for ilmn and hifi, 0.7 for ont +if [ "${PLATFORM}" = "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.7; fi +if [ "${PLATFORM}" != "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.3; fi + +# show default high quality hete variant proportion for whatshap phasing, 0.8 for ont guppy5 and 0.7 for others +if [ "${PHASING_PCT}" = "0" ]; then PHASING_PCT=0.7; fi +BASE_MODEL=$(basename ${MODEL_PATH}) +if [ "${BASE_MODEL}" = "r941_prom_sup_g5014" ] || [ "${BASE_MODEL}" = "r941_prom_hac_g5014" ] || [ "${BASE_MODEL}" = "ont_guppy5" ]; then PHASING_PCT=0.8; fi + +# remove the last '/' character in directory input +OUTPUT_FOLDER=$(echo ${OUTPUT_FOLDER%*/}) +MODEL_PATH=$(echo ${MODEL_PATH%*/}) + +# optional parameters should use "=" +(time ( +echo "[INFO] CLAIR3 VERSION: ${VERSION}" +echo "[INFO] BAM FILE PATH: ${BAM_FILE_PATH}" +echo "[INFO] REFERENCE FILE PATH: ${REFERENCE_FILE_PATH}" +echo "[INFO] MODEL PATH: ${MODEL_PATH}" +echo "[INFO] OUTPUT FOLDER: ${OUTPUT_FOLDER}" +echo "[INFO] PLATFORM: ${PLATFORM}" +echo "[INFO] THREADS: ${THREADS}" +echo "[INFO] BED FILE PATH: ${BED_FILE_PATH}" +echo "[INFO] VCF FILE PATH: ${VCF_FILE_PATH}" +echo "[INFO] CONTIGS: ${CONTIGS}" +echo "[INFO] CONDA PREFIX: ${CONDA_PREFIX}" +echo "[INFO] SAMTOOLS PATH: ${SAMTOOLS}" +echo "[INFO] PYTHON PATH: ${PYTHON}" +echo "[INFO] PYPY PATH: ${PYPY}" +echo "[INFO] PARALLEL PATH: ${PARALLEL}" +echo "[INFO] WHATSHAP PATH: ${WHATSHAP}" +echo "[INFO] CHUNK SIZE: ${CHUNK_SIZE}" +if [ ${CHUNK_NUM} -gt 0 ]; then echo "[INFO] CHUNK NUM: ${CHUNK_NUM}"; fi +echo "[INFO] FULL ALIGN PROPORTION: ${PRO}" +echo "[INFO] FULL ALIGN REFERENCE PROPORTION: ${REF_PRO}" +echo "[INFO] PHASING PROPORTION: ${PHASING_PCT}" +if [ "${SNP_AF}" != "0" ]; then echo "[INFO] USER DEFINED SNP THRESHOLD: ${SNP_AF}"; fi +if [ "${INDEL_AF}" != "0" ]; then echo "[INFO] USER DEFINED INDEL THRESHOLD: ${INDEL_AF}"; fi +echo "[INFO] ENABLE FILEUP ONLY CALLING: ${PILEUP_ONLY}" +echo "[INFO] ENABLE FAST MODE CALLING: ${FAST_MODE}" +echo "[INFO] ENABLE CALLING SNP CANDIDATES ONLY: ${SNP_ONLY}" +echo "[INFO] ENABLE PRINTING REFERENCE CALLS: ${SHOW_REF}" +echo "[INFO] ENABLE OUTPUT GVCF: ${GVCF}" +echo "[INFO] ENABLE HAPLOID PRECISE MODE: ${HAP_PRE}" +echo "[INFO] ENABLE HAPLOID SENSITIVE MODE: ${HAP_SEN}" +echo "[INFO] ENABLE INCLUDE ALL CTGS CALLING: ${INCLUDE_ALL_CTGS}" +echo "[INFO] ENABLE NO PHASING FOR FULL ALIGNMENT: ${NO_PHASING}" +echo "[INFO] ENABLE REMOVING INTERMEDIATE FILES: ${RM_TMP_DIR}" +echo "[INFO] ENABLE PHASING VCF OUTPUT: ${ENABLE_PHASING}" +echo "[INFO] ENABLE LONG INDEL CALLING: ${ENABLE_LONG_INDEL}" +echo $'' + +# file check +if [ ! -f ${BAM_FILE_PATH} ]; then echo -e "${ERROR} BAM file ${BAM_FILE_PATH} not found${NC}"; exit 1; fi +if [ ! -f ${BAM_FILE_PATH}.bai ] && [ ! -f ${BAM_FILE_PATH%.*}.bai ]; then echo -e "${ERROR} BAM index bai file not found, please use 'samtools index \$BAM' first${NC}"; exit 1; fi +if [ ! -f ${REFERENCE_FILE_PATH} ]; then echo -e "${ERROR} Reference file ${REFERENCE_FILE_PATH} not found${NC}"; exit 1; fi +if [ ! -f ${REFERENCE_FILE_PATH}.fai ] && [ ! -f ${REFERENCE_FILE_PATH%.*}.fai ]; then echo -e "${ERROR} Reference index fai file not found, please use 'samtools faidx \$REF' first${NC}"; exit 1; fi + +if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [ ! -f ${BED_FILE_PATH} ]; then echo -e "${ERROR} BED file ${BED_FILE_PATH} provides but not found${NC}"; exit 1; fi +if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [ ! -f ${VCF_FILE_PATH} ]; then echo -e "${ERROR} VCF file ${VCF_FILE_PATH} provides but not found${NC}"; exit 1; fi +if [ ! -d ${MODEL_PATH} ] && [ -z ${CONDA_PREFIX} ]; then echo -e "${ERROR} Conda prefix not found, please activate clair3 conda environment first, model path: ${MODEL_PATH}${NC}"; exit 1; fi +if [ ! -d ${MODEL_PATH} ]; then echo -e "${ERROR} Model path not found${NC}"; exit 1; fi + +# max threads detection +MAX_THREADS=$(nproc) +if [[ ! ${THREADS} =~ ^[\-0-9]+$ ]] || (( ${THREADS} <= 0)); then echo -e "${ERROR} Invalid threads input --threads=INT ${NC}"; exit 1; fi +if [[ ${THREADS} -gt ${MAX_THREADS} ]]; then echo -e "${WARNING} Threads setting exceeds maximum available threads ${MAX_THREADS}, set threads=${MAX_THREADS}${NC}"; THREADS=${MAX_THREADS}; fi + +# max user ulimit threads detection +MAX_ULIMIT_THREADS=`ulimit -u` +if [ ! -z ${MAX_ULIMIT_THREADS} ]; then PER_ULIMIT_THREADS=$((${MAX_ULIMIT_THREADS}/30)); else MAX_ULIMIT_THREADS="unlimited"; PER_ULIMIT_THREADS=${THREADS}; fi +if [[ ${PER_ULIMIT_THREADS} < 1 ]]; then PER_ULIMIT_THREADS=1; fi +if [ "${MAX_ULIMIT_THREADS}" != "unlimited" ] && [[ ${THREADS} -gt ${PER_ULIMIT_THREADS} ]]; then echo -e "${WARNING} Threads setting exceeds maximum ulimit threads ${THREADS} * 30 > ${MAX_ULIMIT_THREADS} (ulimit -u), set threads=${PER_ULIMIT_THREADS}${NC}"; THREADS=${PER_ULIMIT_THREADS}; fi + +# platform check +if [ ! ${PLATFORM} = "ont" ] && [ ! ${PLATFORM} = "hifi" ] && [ ! ${PLATFORM} = "ilmn" ]; then echo -e "${ERROR} Invalid platform input, optional: {ont, hifi, ilmn}${NC}"; exit 1; fi + +# optional parameter detection +if [ -z ${BED_FILE_PATH} ]; then echo -e "${ERROR} Use '--bed_fn=FILE' instead of '--bed_fn FILE' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${VCF_FILE_PATH} ]; then echo -e "${ERROR} Use '--vcf_fn=FILE' instead of '--vcf_fn =FILE' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${CONTIGS} ]; then echo -e "${ERROR} Use '--ctg_name=STR' instead of '--ctg_name STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SAMPLE} ]; then echo -e "${ERROR} Use '--sample_name=STR' instead of '--sample_name STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${QUAL} ]; then echo -e "${ERROR} Use '--qual=INT' instead of '--qual INT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SAMTOOLS} ]; then echo -e "${ERROR} Use '--samtools=STR' instead of '--samtools STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PYTHON} ]; then echo -e "${ERROR} Use '--python=STR' instead of '--python STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PYPY} ]; then echo -e "${ERROR} Use '--pypy=STR' instead of '--pypy STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PARALLEL} ]; then echo -e "${ERROR} Use '--parallel=STR' instead of '--parallel STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${WHATSHAP} ]; then echo -e "${ERROR} Use '--whatshap=STR' instead of '--whatshap STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${CHUNK_SIZE} ]; then echo -e "${ERROR} Use '--chunk_size=INT' instead of '--chunk_size INT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SNP_AF} ]; then echo -e "${ERROR} Use '--snp_min_af=FLOAT' instead of '--snp_min_af FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${INDEL_AF} ]; then echo -e "${ERROR} Use '--indel_min_af=FLOAT' instead of '--indel_min_af FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PRO} ]; then echo -e "${ERROR} Use '--var_pct_full=FLOAT' instead of '--var_pct_full FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${REF_PRO} ]; then echo -e "${ERROR} Use '--ref_pct_full=FLOAT' instead of '--ref_pct_full FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PHASING_PCT} ]; then echo -e "${ERROR} Use '--var_pct_phasing=FLOAT' instead of '--var_pct_phasing FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PILEUP_PREFIX} ]; then echo -e "${ERROR} Use '--pileup_model_prefix=STR' instead of '--pileup_model_prefix STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${FA_PREFIX} ]; then echo -e "${ERROR} Use '--fa_model_prefix=STR' instead of '--fa_model_prefix STR' for optional parameters${NC}"; exit 1 ; fi + +# model prefix detection +if [ ! -f ${MODEL_PATH}/${PILEUP_PREFIX}.index ]; then echo -e "${ERROR} No pileup model found in provided model path and model prefix ${MODEL_PATH}/${PILEUP_PREFIX} ${NC}"; exit 1; fi +if [ ! -f ${MODEL_PATH}/${FA_PREFIX}.index ]; then echo -e "${ERROR} No full-alignment model found in provided model path and model prefix ${MODEL_PATH}/${FA_PREFIX} ${NC}"; exit 1; fi + + +set -x +${SCRIPT_PATH}/scripts/clair3_CallVar.sh \ + --bam_fn ${BAM_FILE_PATH} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --threads ${THREADS} \ + --model_path ${MODEL_PATH} \ + --platform ${PLATFORM} \ + --output ${OUTPUT_FOLDER} \ + --bed_fn=${BED_FILE_PATH} \ + --vcf_fn=${VCF_FILE_PATH} \ + --ctg_name=${CONTIGS} \ + --sample_name=${SAMPLE} \ + --chunk_num=${CHUNK_NUM} \ + --chunk_size=${CHUNK_SIZE} \ + --samtools=${SAMTOOLS} \ + --python=${PYTHON} \ + --pypy=${PYPY} \ + --parallel=${PARALLEL} \ + --whatshap=${WHATSHAP} \ + --qual=${QUAL} \ + --var_pct_full=${PRO} \ + --ref_pct_full=${REF_PRO} \ + --var_pct_phasing=${PHASING_PCT} \ + --snp_min_af=${SNP_AF} \ + --indel_min_af=${INDEL_AF} \ + --pileup_only=${PILEUP_ONLY} \ + --gvcf=${GVCF} \ + --fast_mode=${FAST_MODE} \ + --call_snp_only=${SNP_ONLY} \ + --print_ref_calls=${SHOW_REF} \ + --haploid_precise=${HAP_PRE} \ + --haploid_sensitive=${HAP_SEN} \ + --include_all_ctgs=${INCLUDE_ALL_CTGS} \ + --no_phasing_for_fa=${NO_PHASING} \ + --pileup_model_prefix=${PILEUP_PREFIX} \ + --fa_model_prefix=${FA_PREFIX} \ + --remove_intermediate_dir=${RM_TMP_DIR} \ + --enable_phasing=${ENABLE_PHASING} \ + --enable_long_indel=${ENABLE_LONG_INDEL} + + +)) |& tee ${OUTPUT_FOLDER}/run_clair3.log diff --git a/benchmarks/nn-variant/Clair3/clair3.py b/benchmarks/nn-variant/Clair3/clair3.py new file mode 100644 index 0000000..b8a2a4c --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3.py @@ -0,0 +1,91 @@ +import sys +from importlib import import_module +from shared.param_p import REPO_NAME + +DATA_PREP_SCRIPTS_FOLDER="preprocess" +DEEP_LEARNING_FOLDER="clair3" +POST_PROCESS_SCRIPTS_FOLDER="clair3.metrics" + +deep_learning_folder = [ + "CallVarBam", + "CallVariants", + "Train", +] + +data_preprocess_folder = [ + "GetTruth", + "Tensor2Bin", + 'RealignReads', + 'CreateTensorPileup', + "CreateTensorFullAlignment", + 'CreateTrainingTensor', + 'SplitExtendBed', + 'MergeBin', + 'MergeVcf', + 'SelectHetSnp', + 'SelectCandidates', + 'UnifyRepresentation', + 'CheckEnvs', + 'SortVcf', + 'SelectQual' +] + +post_process_scripts_folder = [ + 'GetOverallMetrics', +] + +def directory_for(submodule_name): + if submodule_name in deep_learning_folder: + return DEEP_LEARNING_FOLDER + if submodule_name in data_preprocess_folder: + return DATA_PREP_SCRIPTS_FOLDER + if submodule_name in post_process_scripts_folder: + return POST_PROCESS_SCRIPTS_FOLDER + return "" + + +def print_help_messages(): + from textwrap import dedent + print(dedent("""\ + {0} submodule invocator: + Usage: python clair3.py [submodule] [options of the submodule] + Available data preparation submodules:\n{1} + Available clair submodules:\n{2} + Available post processing submodules:\n{3} + """.format( + REPO_NAME, + "\n".join(" - %s" % submodule_name for submodule_name in data_preprocess_folder), + "\n".join(" - %s" % submodule_name for submodule_name in deep_learning_folder), + "\n".join(" - %s" % submodule_name for submodule_name in post_process_scripts_folder), + ) + )) + + +def main(): + if len(sys.argv) <= 1 or sys.argv[1] == "-h" or sys.argv[1] == "--help": + print_help_messages() + sys.exit(0) + + submodule_name = sys.argv[1] + if ( + submodule_name not in deep_learning_folder and + submodule_name not in data_preprocess_folder and + submodule_name not in post_process_scripts_folder + ): + sys.exit("[ERROR] Submodule %s not found." % (submodule_name)) + + directory = directory_for(submodule_name) + submodule = import_module("%s.%s" % (directory, submodule_name)) + + # filter arguments (i.e. filter clair3.py) and add ".py" for that submodule + sys.argv = sys.argv[1:] + sys.argv[0] += (".py") + + # Note: need to make sure every submodule contains main() method + submodule.main() + + sys.exit(0) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/clair3/CallVarBam.py b/benchmarks/nn-variant/Clair3/clair3/CallVarBam.py new file mode 100644 index 0000000..edc5db4 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/CallVarBam.py @@ -0,0 +1,453 @@ +import sys +import shlex +import subprocess +import multiprocessing +import signal +import random +import os +from os.path import dirname +from time import sleep +from argparse import ArgumentParser, SUPPRESS +import logging + +logging.getLogger().setLevel(logging.INFO) + + +from shared.command_options import ( + CommandOption, + CommandOptionWithNoValue, + ExecuteCommand, + command_string_from, + command_option_from +) +from shared.utils import file_path_from, executable_command_string_from, subprocess_popen, str2bool, log_warning +import shared.param_p as param + + +class InstancesClass(object): + def __init__(self): + self.create_tensor = None + self.call_variant = None + + def poll(self): + self.create_tensor.poll() + self.call_variant.poll() + + +c = InstancesClass() + + +def check_return_code(signum, frame): + c.poll() + if c.create_tensor.returncode != None and c.create_tensor.returncode != 0: + c.call_variant.kill() + sys.exit("CreateTensor.py exited with exceptions. Exiting...") + + if c.call_variant.returncode != None and c.call_variant.returncode != 0: + c.create_tensor.kill() + sys.exit("call_variant.py exited with exceptions. Exiting...") + + if ( + c.create_tensor.returncode == None or + c.call_variant.returncode == None + ): + signal.alarm(5) + + +def Run(args): + basedir = dirname(__file__) + + CTP_Bin = basedir + "/../clair3.py CreateTensorPileup" + CTFA_Bin = basedir + "/../clair3.py CreateTensorFullAlignment" + RR_Bin = basedir + "/../clair3.py RealignReads" + CVBin = basedir + "/../clair3.py CallVariants" + + if args.delay > 0: + delay = random.randrange(0, args.delay) + print("[INFO] Delay %d seconds before starting variant calling ..." % (delay)) + sleep(delay) + + pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True) + pythonBin = executable_command_string_from(args.python, exit_on_not_found=True) + samtoolsBin = executable_command_string_from(args.samtools, exit_on_not_found=True) + + chkpnt_fn = args.chkpnt_fn + if args.pileup: + bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) + else: + bam_fn = file_path_from(args.bam_fn) + if bam_fn is None or bam_fn == "": + print(log_warning( + "[WARNING] Skip full-alignment variant calling for empty full-alignment regions")) + return + ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) + bed_fn = file_path_from(args.bed_fn) + vcf_fn = file_path_from(args.vcf_fn) + extend_bed = file_path_from(args.extend_bed) + full_aln_regions = file_path_from(args.full_aln_regions) + + platform = args.platform + if not platform or platform not in param.support_platform: + sys.exit("[ERROR] Provided platform are not in support platform list [ont, hifi, ilmn]") + + pileup = args.pileup + call_fn = args.call_fn + sampleName = args.sampleName + ctgName = args.ctgName + need_realignment = args.need_realignment and platform == 'ilmn' and not pileup + min_af = args.min_af if args.min_af else param.min_af_dict[platform] + snp_min_af = args.snp_min_af + indel_min_af = args.indel_min_af + + if ctgName is None: + sys.exit("--ctgName must be specified. You can call variants on multiple chromosomes simultaneously.") + + haploid_precise_mode = command_option_from(args.haploid_precise, 'haploid_precise') + haploid_sensitive_mode = command_option_from(args.haploid_sensitive, 'haploid_sensitive') + output_for_ensemble = command_option_from(args.output_for_ensemble, 'output_for_ensemble') + showRef_mode = command_option_from(args.showRef, 'showRef') + qual = command_option_from(args.qual, 'qual', option_value=args.qual) + + add_indel_length_mode = CommandOption('add_indel_length', args.add_indel_length) + phasing_info_in_bam_mode = command_option_from(args.phasing_info_in_bam, 'phasing_info_in_bam') + need_phasing_mode = command_option_from(args.need_phasing, 'need_phasing') + is_from_tables_mode = command_option_from(args.is_from_tables, 'is_from_tables') + pileup_mode = command_option_from(args.pileup, 'pileup') + gvcf_mode = CommandOption('gvcf', args.gvcf) + fast_mode = CommandOption('fast_mode', args.fast_mode) + call_snp_only_mode = CommandOption('call_snp_only', args.call_snp_only) + enable_long_indel_mode = CommandOption('enable_long_indel', args.enable_long_indel) + + ctgStart = None + ctgEnd = None + chunk_id = None + chunk_num = None + if args.ctgStart is not None and args.ctgEnd is not None and int(args.ctgStart) <= int(args.ctgEnd): + ctgStart = CommandOption('ctgStart', args.ctgStart) + ctgEnd = CommandOption('ctgEnd', args.ctgEnd) + + if args.chunk_id is not None and args.chunk_num is not None and int(args.chunk_id) <= int(args.chunk_num): + chunk_id = CommandOption('chunk_id', args.chunk_id) + chunk_num = CommandOption('chunk_num', args.chunk_num) + + sched_getaffinity_list = list(os.sched_getaffinity(0)) + maxCpus = len(sched_getaffinity_list) + if args.tensorflow_threads is None: + numCpus = maxCpus + else: + numCpus = args.tensorflow_threads if args.tensorflow_threads < maxCpus else maxCpus + + _cpuSet = ",".join(str(x) for x in random.sample(sched_getaffinity_list, numCpus)) + + taskSet = "taskset -c %s" % (_cpuSet) + try: + subprocess.check_output("which %s" % ("taskset"), shell=True) + except: + taskSet = "" + + if need_realignment: + realign_reads_command_options = [ + pypyBin, + RR_Bin, + CommandOption('bam_fn', bam_fn), + CommandOption('ref_fn', ref_fn), + CommandOption('ctgName', ctgName), + ctgStart, + ctgEnd, + chunk_id, + chunk_num, + CommandOption('samtools', samtoolsBin), + CommandOption('extend_bed', extend_bed), + CommandOption('full_aln_regions', full_aln_regions), + ] + bam_fn = "PIPE" + CT_Bin = CTP_Bin if pileup else CTFA_Bin + + create_tensor_command_options = [ + pypyBin, + CT_Bin, + CommandOption('bam_fn', bam_fn), + CommandOption('ref_fn', ref_fn), + CommandOption('vcf_fn', vcf_fn), + CommandOption('ctgName', ctgName), + CommandOption('min_af', min_af), + CommandOption('platform', platform), + CommandOption('samtools', samtoolsBin), + CommandOption('bed_fn', bed_fn), + CommandOption('extend_bed', extend_bed), + CommandOption('sampleName', args.sampleName), + ctgStart, + ctgEnd, + chunk_id, + chunk_num, + gvcf_mode, + ] + + if not pileup: + create_tensor_command_options.append(phasing_info_in_bam_mode) + create_tensor_command_options.append(need_phasing_mode) + create_tensor_command_options.append(CommandOption('full_aln_regions', full_aln_regions)) + else: + create_tensor_command_options.append(CommandOption('snp_min_af', snp_min_af)) + create_tensor_command_options.append(CommandOption('indel_min_af', indel_min_af)) + create_tensor_command_options.append(fast_mode) + create_tensor_command_options.append(call_snp_only_mode) + + if (args.gvcf): + create_tensor_command_options.append(CommandOption('base_err', args.base_err)) + create_tensor_command_options.append(CommandOption('gq_bin_size', args.gq_bin_size)) + create_tensor_command_options.append(CommandOption('temp_file_dir', args.temp_file_dir)) + if args.bp_resolution: + create_tensor_command_options.append(CommandOptionWithNoValue('bp_resolution')) + + call_variant_command_options = [ + taskSet, + pythonBin, + CVBin, + CommandOption('chkpnt_fn', chkpnt_fn), + CommandOption('call_fn', call_fn), + CommandOption('sampleName', sampleName), + CommandOption('ref_fn', ref_fn), + CommandOption('platform', platform), + CommandOption('ctgName', ctgName), + CommandOption('temp_file_dir', args.temp_file_dir), + haploid_precise_mode, + haploid_sensitive_mode, + output_for_ensemble, + qual, + add_indel_length_mode, + showRef_mode, + is_from_tables_mode, + pileup_mode, + chunk_id, + chunk_num, + gvcf_mode, + enable_long_indel_mode + ] + + try: + if need_realignment: + c.realign_reads = subprocess_popen( + shlex.split(command_string_from(realign_reads_command_options)), + ) + c.create_tensor = subprocess_popen( + shlex.split(command_string_from(create_tensor_command_options)), + stdin=c.realign_reads.stdout) + else: + c.create_tensor = subprocess_popen( + shlex.split(command_string_from(create_tensor_command_options)), + ) + + c.call_variant = subprocess_popen( + shlex.split(command_string_from(call_variant_command_options)), + stdin=c.create_tensor.stdout, stdout=sys.stderr + ) + except Exception as e: + print(e, file=sys.stderr) + sys.exit("Failed to start required processes. Exiting...") + + signal.signal(signal.SIGALRM, check_return_code) + signal.alarm(2) + + try: + c.call_variant.wait() + signal.alarm(0) + c.create_tensor.stdout.close() + c.create_tensor.wait() + if need_realignment: + c.realign_reads.stdout.close() + c.realign_reads.wait() + except KeyboardInterrupt as e: + print("KeyboardInterrupt received when waiting at CallVarBam, terminating all scripts.") + try: + c.call_variant.terminate() + c.create_tensor.terminate() + if need_realignment: + c.realign_reads.terminate() + except Exception as e: + print(e) + + raise KeyboardInterrupt + except Exception as e: + print("Exception received when waiting at CallVarBam, terminating all scripts.") + print(e) + try: + c.call_variant.terminate() + c.create_tensor.terminate() + if need_realignment: + c.realign_reads.terminate() + except Exception as e: + print(e) + + raise e + + +def main(): + parser = ArgumentParser(description="Call variants using a trained model and a BAM file") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--bam_fn', type=str, default="bam.bam", required=True, + help="BAM file input, required") + + parser.add_argument('--chkpnt_fn', type=str, default=None, required=True, + help="Input a trained model for variant calling, required") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", required=True, + help="Reference fasta file input, required") + + parser.add_argument('--call_fn', type=str, default=None, + help="VCF output filename, or stdout if not set") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, required if --bed_fn is not defined") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--bed_fn', type=str, nargs='?', action="store", default=None, + help="Call variant only in the provided regions. Will take an intersection if --ctgName and/or (--ctgStart, --ctgEnd) are set") + + parser.add_argument('--sampleName', type=str, nargs='?', action="store", default="SAMPLE", + help="Define the sample name to be shown in the VCF file, optional") + + parser.add_argument('--min_af', type=float, default=None, + help="Minimum allele frequency for both SNP and Indel for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum SNP allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.08, + help="Minimum Indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--gvcf', type=str2bool, default=False, + help="Enable GVCF output, default: disabled") + + parser.add_argument('--qual', type=int, default=None, + help="If set, variants with >=$qual will be marked 'PASS', or 'LowQual' otherwise, optional") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + parser.add_argument('--pypy', type=str, default="pypy3", + help="Path to the 'pypy', pypy3 version >= 3.6 is required, default: %(default)s") + + parser.add_argument('--python', type=str, default="python3", + help="Path to the 'python3', default: %(default)s") + + # options for advanced users + parser.add_argument('--fast_mode', type=str2bool, default=False, + help="EXPERIMENTAL: Skip variant candidates with AF <= 0.15, default: %(default)s") + + parser.add_argument('--minCoverage', type=float, default=param.min_coverage, + help="EXPERIMENTAL: Minimum coverage required to call a variant, default: %(default)f") + + parser.add_argument('--minMQ', type=int, default=param.min_mq, + help="EXPERIMENTAL: If set, reads with mapping quality with <$minMQ are filtered, default: %(default)d") + + parser.add_argument('--minBQ', type=int, default=param.min_bq, + help="EXPERIMENTAL: If set, bases with base quality with <$minBQ are filtered, default: %(default)d") + + parser.add_argument('--bp_resolution', action='store_true', + help="EXPERIMENTAL: Enable bp resolution GVCF output, default: disabled") + + parser.add_argument('--haploid_precise', action='store_true', + help="EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant") + + parser.add_argument('--haploid_sensitive', action='store_true', + help="EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant") + + parser.add_argument('--call_snp_only', type=str2bool, default=False, + help="EXPERIMENTAL: Call candidates pass snp minimum AF only, ignore Indel candidates") + + parser.add_argument('--enable_long_indel', type=str2bool, default=False, + help="EXPERIMENTAL: Enable long Indel variants(>50 bp) calling") + + # options for debug purpose + parser.add_argument('--phasing_info_in_bam', action='store_true', + help="DEBUG: Skip phasing and use the phasing info provided in the input BAM (HP tag), default: False") + + parser.add_argument('--base_err', default=param.base_err, type=float, + help='DEBUG: Base error rate prior for GVCF output, default: %(default)f') + + parser.add_argument('--gq_bin_size', default=param.gq_bin_size, type=int, + help='DEBUG: Default GQ bin size for merging non-variant block for GVCF output, default: %(default)d') + + parser.add_argument('--temp_file_dir', type=str, default='./', + help="DEBUG: The cache directory for storing temporary non-variant information if --gvcf is enabled, default: %(default)s") + + parser.add_argument('--use_gpu', type=str2bool, default=False, + help="DEBUG: Use GPU for calling. Speed up is mostly insignificant. Only use this for building your own pipeline") + + parser.add_argument('--tensorflow_threads', type=int, default=param.tensorflow_threads, + help="DEBUG: Number of threads per tensorflow job. Tune if you are building your own pipeline") + + parser.add_argument('--extend_bed', nargs='?', action="store", type=str, default=None, + help="DEBUG: Extend the regions in the --bed_fn by a few bp for tensor creation, default extend 16bp") + + # options for internal process control, don't use any of them unless you are sure about the consequences + ## In pileup mode or not + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + ## Output for ensemble model calling + parser.add_argument('--output_for_ensemble', action='store_true', + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Use Clair3's own phasing module for read level phasing when creating tensor, compared to using Whatshap, speed is faster but has higher memory footprint, default: False + parser.add_argument('--need_phasing', action='store_true', + help=SUPPRESS) + + ## Apply read realignment for illumina platform. Greatly boost indel performance in trade of running time, default true for illumina platform + parser.add_argument('--need_realignment', action='store_false', + help=SUPPRESS) + + ## Use bin file from pytables to speed up calling. + parser.add_argument('--is_from_tables', action='store_true', + help=SUPPRESS) + + ## Wait a short while for no more than a few seconds to start the job. This is to avoid starting multiple jobs simultaneously + ## that might use up the maximum number of threads allowed, because Tensorflow will create more threads than needed at the beginning of running the program + ## Obseleted after adding --tensorflow_threads defaulted at 4 + parser.add_argument('--delay', type=int, default=5, + help=SUPPRESS) + + ## Provide the regions to be included in full-alignment based calling + parser.add_argument('--full_aln_regions', type=str, nargs='?', action="store", default=None, + help=SUPPRESS) + + ## Include indel length in training and calling, false for pileup and true for raw alignment + parser.add_argument('--add_indel_length', action='store_true', + help=SUPPRESS) + + ## Output reference calls + parser.add_argument('--showRef', action='store_false', + help=SUPPRESS) + + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Run(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/clair3/CallVariants.py b/benchmarks/nn-variant/Clair3/clair3/CallVariants.py new file mode 100644 index 0000000..c919880 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/CallVariants.py @@ -0,0 +1,1866 @@ +import sys +import os +import math +import tables +import tensorflow as tf +import numpy as np +import logging +from time import time +from argparse import ArgumentParser, SUPPRESS +from threading import Thread +from math import log, e +from collections import namedtuple +from imp import reload + +from clair3.task.gt21 import ( + GT21_Type, gt21_enum_from_label, + HOMO_SNP_GT21, HOMO_SNP_LABELS, + HETERO_SNP_GT21, HETERO_SNP_LABELS, GT21_LABELS, partial_label_from, mix_two_partial_labels +) +import clair3.utils as utils +from clair3.task.genotype import Genotype, genotype_string_from, genotype_enum_from, genotype_enum_for_task +from shared.utils import IUPAC_base_to_ACGT_base_dict as BASE2ACGT, BASIC_BASES, str2bool, file_path_from, log_error, log_warning +from clair3.task.variant_length import VariantLength +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" +reload(logging) +logging.basicConfig(format='%(message)s', level=logging.INFO) +minimum_variant_length_that_need_infer = VariantLength.max +ACGT = 'ACGT' +Phred_Trans = (-10 * log(e, 10)) + +OutputConfig = namedtuple('OutputConfig', [ + 'is_show_reference', + 'is_debug', + 'is_haploid_precise_mode_enabled', + 'is_haploid_sensitive_mode_enabled', + 'is_output_for_ensemble', + 'quality_score_for_pass', + 'tensor_fn', + 'input_probabilities', + 'add_indel_length', + 'gvcf', + 'pileup', + 'enable_long_indel', + 'maximum_variant_length_that_need_infer' +]) +OutputUtilities = namedtuple('OutputUtilities', [ + 'print_debug_message', + 'output', + 'output_header', + 'close_opened_files', + 'gen_output_file' +]) + + +def homo_SNP_bases_from(homo_SNP_probabilities): + output_bases = HOMO_SNP_LABELS[np.argmax(homo_SNP_probabilities)] + return output_bases[0], output_bases[1] + + +def hetero_SNP_bases_from(hetero_SNP_probabilities): + output_bases = HETERO_SNP_LABELS[np.argmax(hetero_SNP_probabilities)] + return output_bases[0], output_bases[1] + + +def filtration_value_from(quality_score_for_pass, quality_score, is_reference=False): + """ + filter qual if set quality score, variant quliaty lower than specific quality socre will be + marked ass LowQual otherwise PASS. Default there is no quality score cut off. + + """ + if is_reference: + return 'RefCall' + + if quality_score_for_pass is None: + return "PASS" + if quality_score >= quality_score_for_pass: + return "PASS" + + return "LowQual" + + +def insertion_bases_using_alt_info_from( + alt_info_dict, + propose_insertion_length=None, + minimum_insertion_length=1, + maximum_insertion_length=50, + insertion_bases_to_ignore="", + return_multi=False +): + """ + get insertion base using altnertive information in bam alignment file. + alt_info_dict: dictionary (XID: count), include snp, insertion, deletion type and read count. + propose_insertion_length: if set, only return insertion length which match propose insertion length. + minimum_insertion_length: if set, only return insertion length which is larger than specific insertion length. + maximum_insertion_length: if set, only return insertion which is shorter than specific insertion length, + we will always only return insertion length shorter than 50bp by default. + insertion_bases_to_ignore: for multi alleic insertion variants, set the insertion bases to be ignored. + """ + + if propose_insertion_length: + propose_insertion_length += 1 # include reference base + if not len(alt_info_dict): return "" + insertion_bases_dict = {} + propose_insertion_bases_dict = {} + for raw_key, items in alt_info_dict.items(): + if raw_key[0] != 'I': continue + key = raw_key[1:] # remove first cigar +-X and reference_base + if propose_insertion_length and len(key) == propose_insertion_length and key != insertion_bases_to_ignore: + propose_insertion_bases_dict[key] = items + elif minimum_insertion_length <= len(key) <= maximum_insertion_length and key != insertion_bases_to_ignore: + insertion_bases_dict[key] = items + + if propose_insertion_length and len(propose_insertion_bases_dict): + return max(propose_insertion_bases_dict, key=propose_insertion_bases_dict.get) if len( + propose_insertion_bases_dict) > 0 else "" + if return_multi: + insertion_bases_list = list(insertion_bases_dict.items()) + insertion_bases_list = [item[0] for item in sorted(insertion_bases_list, key=lambda x: x[1])[::-1]] + return insertion_bases_list[:2] if len(insertion_bases_list) else "" + + return max(insertion_bases_dict, key=insertion_bases_dict.get) if len(insertion_bases_dict) > 0 else "" + + +def deletion_bases_using_alt_info_from( + alt_info_dict, + propose_deletion_length=None, + minimum_deletion_length=1, + maximum_deletion_length=50, + deletion_bases_to_ignore="", + return_multi=False, + +): + """ + get deletion base using altnertive information in bam alignment file. + alt_info_dict: dictionary (XID: count), include snp, insertion, deletion type and read count. + propose_deletion_length: if set, only return deletion length which match propose deletion length. + minimum_deletion_length: if set, only return deletion length which is larger than specific deletion length. + maximum_deletion_length: if set, only return deletion which is shorter than specific deletion length, + we will always only return deletion length shorter than 50bp by default. + deletion_bases_to_ignore: for multi alleic variants, set the deletion bases to be ignored. + """ + + if not len(alt_info_dict): return "" + + deletion_bases_dict = {} + propose_deletion_bases_dict = {} + for raw_key, items in alt_info_dict.items(): + if raw_key[0] != 'D': continue + key = raw_key[1:] # remove first cigar +-X + if propose_deletion_length and len(key) == propose_deletion_length and key != deletion_bases_to_ignore: + propose_deletion_bases_dict[key] = items + + elif minimum_deletion_length <= len(key) <= maximum_deletion_length and key != deletion_bases_to_ignore: + deletion_bases_dict[key] = items + + if propose_deletion_length and len(propose_deletion_bases_dict): + return max(propose_deletion_bases_dict, key=propose_deletion_bases_dict.get) if len( + propose_deletion_bases_dict) > 0 else "" + + if return_multi: + deletion_bases_list = list(deletion_bases_dict.items()) + deletion_bases_list = [item[0] for item in sorted(deletion_bases_list, key=lambda x: x[1])[::-1]] + if len(deletion_bases_list) <= 1: return "" + return [deletion_bases_list[0], deletion_bases_list[1]] if len(deletion_bases_list[0]) > len( + deletion_bases_list[1]) else [deletion_bases_list[1], deletion_bases_list[0]] + return max(deletion_bases_dict, key=deletion_bases_dict.get) if len(deletion_bases_dict) > 0 else "" + + +def Run(args): + os.environ["OMP_NUM_THREADS"] = "1" + os.environ["OPENBLAS_NUM_THREADS"] = "1" + os.environ["MKL_NUM_THREADS"] = "1" + os.environ["NUMEXPR_NUM_THREADS"] = "1" + + tf.config.threading.set_intra_op_parallelism_threads(1) + tf.config.threading.set_inter_op_parallelism_threads(1) + + global test_pos + test_pos = None + global param + if args.pileup: + import shared.param_p as param + else: + import shared.param_f as param + + if args.enable_long_indel: + maximum_variant_length_that_need_infer = param.maximum_variant_length_that_need_infer_include_long_indel + else: + maximum_variant_length_that_need_infer = param.maximum_variant_length_that_need_infer + + output_config = OutputConfig( + is_show_reference=args.showRef, + is_debug=args.debug, + is_haploid_precise_mode_enabled=args.haploid_precise, + is_haploid_sensitive_mode_enabled=args.haploid_sensitive, + is_output_for_ensemble=args.output_for_ensemble, + quality_score_for_pass=args.qual, + tensor_fn=args.tensor_fn, + input_probabilities=args.input_probabilities, + add_indel_length=args.add_indel_length, + gvcf=args.gvcf, + pileup=args.pileup, + enable_long_indel=args.enable_long_indel, + maximum_variant_length_that_need_infer=maximum_variant_length_that_need_infer + ) + output_utilities = output_utilties_from( + sample_name=args.sampleName, + is_debug=args.debug, + is_output_for_ensemble=args.output_for_ensemble, + reference_file_path=args.ref_fn, + output_file_path=args.call_fn, + output_probabilities=args.output_probabilities + ) + if args.input_probabilities: + call_variants_with_probabilities_input(args=args, output_config=output_config, + output_utilities=output_utilities) + elif args.output_probabilities: + predict(args=args, output_config=output_config, output_utilities=output_utilities) + else: + call_variants(args=args, output_config=output_config, output_utilities=output_utilities) + + +def output_utilties_from( + sample_name, + is_debug, + is_output_for_ensemble, + reference_file_path, + output_file_path, + output_probabilities +): + def gen_output_file(): + global output_file + if not output_probabilities: + output_file = open(output_file_path, "w") + + def output(string_value): + global output_file + print(string_value, file=output_file) + + def print_debug_message( + chromosome, + position, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + extra_infomation_string="" + ): + output("{}\t{}\t{}\t{}\t{}\t{}\t{}".format( + chromosome, + position, + ["{:0.8f}".format(x) for x in gt21_probabilities], + ["{:0.8f}".format(x) for x in genotype_probabilities], + ["{:0.8f}".format(x) for x in variant_length_probabilities_1], + ["{:0.8f}".format(x) for x in variant_length_probabilities_2], + extra_infomation_string + )) + + def close_opened_files(): + output_file.close() + + def output_header(): + if is_output_for_ensemble: + return + + from textwrap import dedent + output(dedent("""\ + ##fileformat=VCFv4.2 + ##FILTER= + ##FILTER= + ##FILTER= + ##INFO= + ##INFO= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT=""" + )) + + if reference_file_path is not None: + reference_index_file_path = file_path_from(reference_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + with open(reference_index_file_path, "r") as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + contig_name, contig_size = columns[0], columns[1] + output("##contig=" % (contig_name, contig_size)) + + output('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (sample_name)) + + return OutputUtilities( + print_debug_message, + output, + output_header, + close_opened_files, + gen_output_file + ) + + +def homo_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): + return [( + i, + variant_length_probabilities_1[i + VariantLength.index_offset] * + variant_length_probabilities_2[i + VariantLength.index_offset] * extra_probability + ) for i in range(1, VariantLength.max + 1)] + + +def hetero_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + return [( + i, + variant_length_probabilities_1[0 + VariantLength.index_offset] * + variant_length_probabilities_2[i + VariantLength.index_offset], + + ) for i in range(1, VariantLength.max + 1)] + + +def hetero_InsIns_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): + probabilities = [] + for i in range(1, VariantLength.max + 1): + for j in range(i, VariantLength.max + 1): + # note: one kind of InsIns is same # of insertion bases but different kind of ACGT + probabilities.append(( + (i, j), + variant_length_probabilities_1[i + VariantLength.index_offset] * + variant_length_probabilities_2[j + VariantLength.index_offset] * extra_probability + )) + return probabilities + + +def homo_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): + return [( + i, + variant_length_probabilities_1[-i + VariantLength.index_offset] * + variant_length_probabilities_2[-i + VariantLength.index_offset] * extra_probability + ) for i in range(1, VariantLength.max + 1)] + + +def hetero_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + return [( + i, + variant_length_probabilities_1[-i + VariantLength.index_offset] * + variant_length_probabilities_2[0 + VariantLength.index_offset], + ) for i in range(1, VariantLength.max + 1)] + + +def hetero_DelDel_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): + probabilities = [] + for i in range(1, VariantLength.max + 1): + for j in range(1, VariantLength.max + 1): + if i == j and i != VariantLength.index_offset and j != VariantLength.index_offset: + continue + probabilities.append(( + (i, j) if i < j else (j, i), + variant_length_probabilities_1[-i + VariantLength.index_offset] * + variant_length_probabilities_2[-j + VariantLength.index_offset] * extra_probability + )) + return probabilities + + +def hetero_InsDel_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): + probabilities = [] + for i in range(1, VariantLength.max + 1): + for j in range(1, VariantLength.max + 1): + probabilities.append(( + (i, j), + variant_length_probabilities_1[-i + VariantLength.index_offset] * + variant_length_probabilities_2[j + VariantLength.index_offset] * extra_probability + )) + return probabilities + + +def quality_score_from(probability): + """ + make a modification for quality score calculation. did not apply quality square for computation. + """ + p = probability + tmp = max(Phred_Trans * log(((1.0 - p) + 1e-10) / (p + 1e-10)) + 10, 0) + return float(round(tmp, 2)) + + +def get_long_indel_read_count(alt_info, proposed_ins_base="", propose_del_base_length=0, is_del=False): + """ + https://github.com/HKU-BAL/Clair3/blob/main/docs/indel_gt50_performance.md + for long indel variant calls, we also calculate all flanking indel signals with proposed indel + alternative bases (default under 10% flanking distance) + """ + long_indel_read_count = 0 + maximum_variant_length_that_need_infer = param.maximum_variant_length_that_need_infer + if not param.cal_precise_long_indel_af and (len(proposed_ins_base) > maximum_variant_length_that_need_infer or propose_del_base_length > maximum_variant_length_that_need_infer): + propose_indel_base_length = propose_del_base_length if is_del else len(proposed_ins_base) - 1 + min_long_indel_length_considered = max(propose_indel_base_length * (1.0 - param.long_indel_distance_proportion), maximum_variant_length_that_need_infer) + max_long_indel_length_considered = propose_indel_base_length * (1.0 + param.long_indel_distance_proportion) + for alt_base, count in alt_info.items(): + if is_del and len(alt_base) == propose_del_base_length: # del + continue + if alt_base == proposed_ins_base: # ins + continue + if len(alt_base) >= min_long_indel_length_considered and len(alt_base) <= max_long_indel_length_considered: + long_indel_read_count += count + return long_indel_read_count + + +def possible_outcome_probabilites_with_indel_length_from( + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + reference_base, +): + homo_reference_probability = genotype_probabilities[Genotype.homo_reference] + homo_variant_probability = genotype_probabilities[Genotype.homo_variant] + hetero_variant_probability = genotype_probabilities[Genotype.hetero_variant] + variant_length_0_probability = ( + variant_length_probabilities_1[0 + VariantLength.index_offset] * + variant_length_probabilities_2[0 + VariantLength.index_offset] + ) + + reference_gt21 = gt21_enum_from_label(reference_base + reference_base) + homo_Ref_probability = ( + variant_length_0_probability * homo_reference_probability * gt21_probabilities[reference_gt21] + ) + + homo_SNP_probabilities = [( + variant_length_0_probability * homo_variant_probability * gt21_probabilities[gt21] + ) for gt21 in HOMO_SNP_GT21] + hetero_SNP_probabilities = [( + variant_length_0_probability * hetero_variant_probability * gt21_probabilities[gt21] + ) for gt21 in HETERO_SNP_GT21] + + # Insertion + homo_Ins_lengths, homo_Ins_probabilities = zip(*homo_Ins_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + homo_variant_probability * gt21_probabilities[GT21_Type.InsIns] + )) + homo_Ins_lengths, homo_Ins_probabilities = list(homo_Ins_lengths), list(homo_Ins_probabilities) + hetero_InsIns_length_tuples, hetero_InsIns_probabilities = zip(*hetero_InsIns_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.InsIns] + )) + hetero_InsIns_length_tuples, hetero_InsIns_probabilities = ( + list(hetero_InsIns_length_tuples), list(hetero_InsIns_probabilities) + ) + hetero_ACGT_Ins_tuples = [] + gt21_base_tuples = [(GT21_Type.AIns, "A"), (GT21_Type.CIns, "C"), (GT21_Type.GIns, "G"), (GT21_Type.TIns, "T")] + for length_tuples, p in hetero_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Ins_tuples.append(( + hetero_base, + length_tuples, + p * gt21_probabilities[gt21] * hetero_variant_probability + )) + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = zip(*hetero_ACGT_Ins_tuples) + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = ( + list(hetero_ACGT_Ins_bases), list(hetero_ACGT_Ins_lengths), list(hetero_ACGT_Ins_probabilities) + ) + + # Deletion + homo_Del_lengths, homo_Del_probabilities = zip(*homo_Del_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + homo_variant_probability * gt21_probabilities[GT21_Type.DelDel] + )) + homo_Del_lengths, homo_Del_probabilities = list(homo_Del_lengths), list(homo_Del_probabilities) + hetero_DelDel_length_tuples, hetero_DelDel_probabilities = zip(*hetero_DelDel_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.DelDel] + )) + hetero_DelDel_length_tuples, hetero_DelDel_probabilities = ( + list(hetero_DelDel_length_tuples), list(hetero_DelDel_probabilities) + ) + hetero_ACGT_Del_tuples = [] + gt21_base_tuples = [(GT21_Type.ADel, "A"), (GT21_Type.CDel, "C"), (GT21_Type.GDel, "G"), (GT21_Type.TDel, "T")] + for length_tuples, p in hetero_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Del_tuples.append(( + hetero_base, + length_tuples, + p * gt21_probabilities[gt21] * hetero_variant_probability + )) + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = zip(*hetero_ACGT_Del_tuples) + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = ( + list(hetero_ACGT_Del_bases), list(hetero_ACGT_Del_lengths), list(hetero_ACGT_Del_probabilities) + ) + + # InsDel + hetero_InsDel_length_tuples, hetero_InsDel_probabilities = zip(*hetero_InsDel_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.InsDel] + )) + hetero_InsDel_length_tuples, hetero_InsDel_probabilities = ( + list(hetero_InsDel_length_tuples), list(hetero_InsDel_probabilities) + ) + + return ( + homo_Ref_probability, + homo_SNP_probabilities, + hetero_SNP_probabilities, + homo_Ins_lengths, homo_Ins_probabilities, + hetero_InsIns_length_tuples, hetero_InsIns_probabilities, + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities, + homo_Del_lengths, homo_Del_probabilities, + hetero_DelDel_length_tuples, hetero_DelDel_probabilities, + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities, + hetero_InsDel_length_tuples, hetero_InsDel_probabilities, + ) + + +def possible_outcome_probabilites_from( + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + reference_base, + alt_info_dict, + add_indel_length=False, +): + homo_reference_probability = genotype_probabilities[Genotype.homo_reference] + homo_variant_probability = genotype_probabilities[Genotype.homo_variant] + hetero_variant_probability = genotype_probabilities[Genotype.hetero_variant] + + reference_gt21 = gt21_enum_from_label(reference_base + reference_base) + + if not add_indel_length: + homo_Ref_probability = (homo_reference_probability * gt21_probabilities[reference_gt21] + ) + homo_SNP_probabilities = [homo_variant_probability * gt21_probabilities[gt21] + for gt21 in HOMO_SNP_GT21] + hetero_SNP_probabilities = [hetero_variant_probability * gt21_probabilities[gt21] + for gt21 in HETERO_SNP_GT21] + if homo_reference_probability >= 0.5 and gt21_probabilities[ + reference_gt21] >= 0.5: + return [homo_Ref_probability] + # Insertion + homo_Ins_probabilities = [homo_variant_probability * gt21_probabilities[GT21_Type.InsIns]] + homo_Ins_lengths = [] + hetero_InsIns_probabilities = [hetero_variant_probability * gt21_probabilities[GT21_Type.InsIns]] + hetero_InsIns_length_tuples = [] + hetero_ACGT_Ins_probabilities = [] + gt21_base_tuples = [(GT21_Type.AIns, "A"), (GT21_Type.CIns, "C"), (GT21_Type.GIns, "G"), (GT21_Type.TIns, "T")] + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Ins_probabilities.append(gt21_probabilities[gt21] * hetero_variant_probability) + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths = [], [] + + # Deletion + homo_Del_probabilities = [homo_variant_probability * gt21_probabilities[GT21_Type.DelDel]] + homo_Del_lengths = [] + + hetero_DelDel_probabilities = [hetero_variant_probability * gt21_probabilities[GT21_Type.DelDel]] + hetero_DelDel_length_tuples = [] + + hetero_ACGT_Del_probabilities = [] + gt21_base_tuples = [(GT21_Type.ADel, "A"), (GT21_Type.CDel, "C"), (GT21_Type.GDel, "G"), (GT21_Type.TDel, "T")] + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Del_probabilities.append(gt21_probabilities[gt21] * hetero_variant_probability + ) + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths = [], [] + + # InsDel + hetero_InsDel_probabilities = [hetero_variant_probability * gt21_probabilities[GT21_Type.InsDel]] + hetero_InsDel_length_tuples = [] + + else: + variant_length_0_probability_1 = variant_length_probabilities_1[0 + VariantLength.index_offset] + variant_length_0_probability_2 = variant_length_probabilities_2[0 + VariantLength.index_offset] + variant_length_0_probability = (variant_length_0_probability_1 * variant_length_0_probability_2) + + reference_gt21 = gt21_enum_from_label(reference_base + reference_base) + homo_Ref_probability = ( + variant_length_0_probability * homo_reference_probability * gt21_probabilities[reference_gt21] + ) + if variant_length_0_probability_1 >= 0.5 and variant_length_0_probability_2 >= 0.5 and homo_reference_probability >= 0.5 and \ + gt21_probabilities[ + reference_gt21] >= 0.5: + return [homo_Ref_probability] + + homo_SNP_probabilities = [( + variant_length_0_probability * homo_variant_probability * gt21_probabilities[gt21] + ) for gt21 in HOMO_SNP_GT21] + hetero_SNP_probabilities = [( + variant_length_0_probability * hetero_variant_probability * gt21_probabilities[gt21] + ) for gt21 in HETERO_SNP_GT21] + + # Insertion + homo_Ins_lengths, homo_Ins_probabilities = zip(*homo_Ins_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + homo_variant_probability * gt21_probabilities[GT21_Type.InsIns] + )) + homo_Ins_lengths, homo_Ins_probabilities = list(homo_Ins_lengths), list(homo_Ins_probabilities) + hetero_InsIns_length_tuples, hetero_InsIns_probabilities = zip(*hetero_InsIns_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.InsIns] + )) + hetero_InsIns_length_tuples, hetero_InsIns_probabilities = ( + list(hetero_InsIns_length_tuples), list(hetero_InsIns_probabilities) + ) + hetero_ACGT_Ins_tuples = [] + gt21_base_tuples = [(GT21_Type.AIns, "A"), (GT21_Type.CIns, "C"), (GT21_Type.GIns, "G"), (GT21_Type.TIns, "T")] + for length_tuples, p in hetero_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Ins_tuples.append(( + hetero_base, + length_tuples, + p * gt21_probabilities[gt21] * hetero_variant_probability + )) + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = zip(*hetero_ACGT_Ins_tuples) + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = ( + list(hetero_ACGT_Ins_bases), list(hetero_ACGT_Ins_lengths), list(hetero_ACGT_Ins_probabilities) + ) + + # Deletion + homo_Del_lengths, homo_Del_probabilities = zip(*homo_Del_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + homo_variant_probability * gt21_probabilities[GT21_Type.DelDel] + )) + homo_Del_lengths, homo_Del_probabilities = list(homo_Del_lengths), list(homo_Del_probabilities) + hetero_DelDel_length_tuples, hetero_DelDel_probabilities = zip(*hetero_DelDel_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.DelDel] + )) + hetero_DelDel_length_tuples, hetero_DelDel_probabilities = ( + list(hetero_DelDel_length_tuples), list(hetero_DelDel_probabilities) + ) + hetero_ACGT_Del_tuples = [] + gt21_base_tuples = [(GT21_Type.ADel, "A"), (GT21_Type.CDel, "C"), (GT21_Type.GDel, "G"), (GT21_Type.TDel, "T")] + for length_tuples, p in hetero_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): + for gt21, hetero_base in gt21_base_tuples: + hetero_ACGT_Del_tuples.append(( + hetero_base, + length_tuples, + p * gt21_probabilities[gt21] * hetero_variant_probability + )) + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = zip(*hetero_ACGT_Del_tuples) + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = ( + list(hetero_ACGT_Del_bases), list(hetero_ACGT_Del_lengths), list(hetero_ACGT_Del_probabilities) + ) + + # InsDel + hetero_InsDel_length_tuples, hetero_InsDel_probabilities = zip(*hetero_InsDel_tuples_from( + variant_length_probabilities_1, variant_length_probabilities_2, + hetero_variant_probability * gt21_probabilities[GT21_Type.InsDel] + )) + hetero_InsDel_length_tuples, hetero_InsDel_probabilities = ( + list(hetero_InsDel_length_tuples), list(hetero_InsDel_probabilities) + ) + + return ( + homo_Ref_probability, + homo_SNP_probabilities, + hetero_SNP_probabilities, + homo_Ins_lengths, homo_Ins_probabilities, + hetero_InsIns_length_tuples, hetero_InsIns_probabilities, + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities, + homo_Del_lengths, homo_Del_probabilities, + hetero_DelDel_length_tuples, hetero_DelDel_probabilities, + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities, + hetero_InsDel_length_tuples, hetero_InsDel_probabilities, + ) + + +def find_alt_base(alt_info_dict, alternate_base=None): + # double check whether alternate base exists, depth gap may happen when candidate depth is extreme high to infer. + max_depth_gap = 9 + sorted_alt_bases = sorted([(alt_base[1], count) for alt_base, count in alt_info_dict.items() if alt_base[0] == 'X'], + key=lambda x: x[1], reverse=True) + alt_count = [item[1] for item in sorted_alt_bases if item[0] == alternate_base] + if not len(sorted_alt_bases): + return [], None + if not len(alt_count) or sorted_alt_bases[0][1] - alt_count[0] >= max_depth_gap: + alternate_base = sorted_alt_bases[0][0] # + sorted_alt_bases = [item[0] for item in sorted_alt_bases] + return sorted_alt_bases, alternate_base + + +def output_from( + reference_sequence, + contig, + position, + tensor_position_center, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + output_config, + output_utilities, + alt_info_dict +): + add_indel_length = output_config.add_indel_length + reference_base_ACGT = BASE2ACGT[reference_sequence[tensor_position_center]] + + all_pro = possible_outcome_probabilites_from( + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + reference_base=reference_base_ACGT, + alt_info_dict=alt_info_dict, + add_indel_length=add_indel_length, + ) + + # return if a candidate is homo reference + if len(all_pro) == 1: + return ( + (True, False, False, False, False, False, False, False, False, False), + (reference_base_ACGT, reference_base_ACGT), (all_pro[0]) + ) + ( + homo_Ref_probability, + homo_SNP_probabilities, + hetero_SNP_probabilities, + homo_Ins_lengths, homo_Ins_probabilities, + hetero_InsIns_length_tuples, hetero_InsIns_probabilities, + hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities, + homo_Del_lengths, homo_Del_probabilities, + hetero_DelDel_length_tuples, hetero_DelDel_probabilities, + hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities, + hetero_InsDel_length_tuples, hetero_InsDel_probabilities, + ) = all_pro + maximum_probability = 0.0 + reference_base, alternate_base = None, None + while (reference_base is None or alternate_base is None): + maximum_probability = max( + homo_Ref_probability, + max(homo_SNP_probabilities), + max(hetero_SNP_probabilities), + max(homo_Ins_probabilities) if len(homo_Ins_probabilities) else 0, + max(homo_Del_probabilities) if len(homo_Del_probabilities) else 0, + max(hetero_ACGT_Ins_probabilities) if len(hetero_ACGT_Ins_probabilities) else 0, + max(hetero_InsIns_probabilities) if len(hetero_InsIns_probabilities) else 0, + max(hetero_ACGT_Del_probabilities) if len(hetero_ACGT_Del_probabilities) else 0, + max(hetero_DelDel_probabilities) if len(hetero_DelDel_probabilities) else 0, + max(hetero_InsDel_probabilities) if len(hetero_InsDel_probabilities) else 0, + ) + + is_reference = maximum_probability == homo_Ref_probability + if is_reference: + return ( + (True, False, False, False, False, False, False, False, False, False), + (reference_base_ACGT, reference_base_ACGT), (maximum_probability) + ) + + is_homo_SNP = maximum_probability in homo_SNP_probabilities + is_hetero_SNP = maximum_probability in hetero_SNP_probabilities + is_homo_insertion = maximum_probability in homo_Ins_probabilities + is_hetero_ACGT_Ins = maximum_probability in hetero_ACGT_Ins_probabilities + is_hetero_InsIns = maximum_probability in hetero_InsIns_probabilities + is_homo_deletion = maximum_probability in homo_Del_probabilities + is_hetero_ACGT_Del = maximum_probability in hetero_ACGT_Del_probabilities + is_hetero_DelDel = maximum_probability in hetero_DelDel_probabilities + is_insertion_and_deletion = maximum_probability in hetero_InsDel_probabilities + + if is_homo_SNP: + reference_base = reference_sequence[tensor_position_center] + idx = homo_SNP_probabilities.index(maximum_probability) + base1, base2 = homo_SNP_bases_from(homo_SNP_probabilities) + alternate_base = base1 if base1 != reference_base else base2 + sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base) + if alternate_base is None or alternate_base == reference_base: + homo_SNP_probabilities[idx] = 0 + continue + + elif is_hetero_SNP: + base1, base2 = hetero_SNP_bases_from(hetero_SNP_probabilities) + idx = hetero_SNP_probabilities.index(maximum_probability) + reference_base = reference_sequence[tensor_position_center] + is_multi = base1 != reference_base and base2 != reference_base + if is_multi: + sorted_alt_bases, _ = find_alt_base(alt_info_dict) + if len(sorted_alt_bases) < 2: + hetero_SNP_probabilities[idx] = 0 + continue + alternate_base = ','.join(sorted_alt_bases[:2]) + else: + alternate_base = base1 if base1 != reference_base else base2 + sorted_alt_bases, alternate_base = find_alt_base(alt_info_dict, alternate_base) + if alternate_base is None or alternate_base == reference_base: + hetero_SNP_probabilities[idx] = 0 + continue + + + elif is_homo_insertion: + variant_length = None + idx = homo_Ins_probabilities.index(maximum_probability) + if add_indel_length: + variant_length = homo_Ins_lengths[idx] + insertion_bases = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_insertion_length=variant_length if variant_length and variant_length < VariantLength.max else None, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + + insertion_length = len(insertion_bases) + if insertion_length == 0: + homo_Ins_probabilities[idx] = 0 + continue + reference_base = reference_sequence[tensor_position_center] + alternate_base = insertion_bases + + elif is_hetero_ACGT_Ins: + idx = hetero_ACGT_Ins_probabilities.index(maximum_probability) + variant_length = None + if add_indel_length: + hetero_Ins_base = hetero_ACGT_Ins_bases[idx] + variant_length = hetero_ACGT_Ins_lengths[idx] + else: + hetero_Ins_base = ACGT[idx] + insertion_bases = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_insertion_length=variant_length if variant_length and variant_length < VariantLength.max else None, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + + ) + insertion_length = len(insertion_bases) + if insertion_length == 0: + hetero_ACGT_Ins_probabilities[idx] = 0 + continue + reference_base = reference_sequence[tensor_position_center] + alternate_base = insertion_bases + + is_SNP_Ins_multi = hetero_Ins_base != reference_base + if is_SNP_Ins_multi: + sorted_alt_bases, _ = find_alt_base(alt_info_dict) + if len(sorted_alt_bases) == 0: + hetero_ACGT_Ins_probabilities[idx] = 0 + continue + else: + alternate_base = "{},{}".format(sorted_alt_bases[0], alternate_base) + + elif is_hetero_InsIns: + insertion_bases_list = [] + idx = hetero_InsIns_probabilities.index(maximum_probability) + if add_indel_length: + variant_length_1, variant_length_2 = hetero_InsIns_length_tuples[idx] + # del hetero_InsIns_probabilities[idx] + # del hetero_InsIns_length_tuples[idx] + + insertion_bases1 = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_insertion_length=variant_length_1 if variant_length_1 and variant_length_1 < VariantLength.max else None, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + if len(insertion_bases1): + insertion_bases2 = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_insertion_length=variant_length_2 if variant_length_2 and variant_length_2 < VariantLength.max else None, + insertion_bases_to_ignore=insertion_bases1, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + if len(insertion_bases2): + insertion_bases_list = [insertion_bases1, insertion_bases2] + if len(insertion_bases_list) < 2: + insertion_bases_list = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + return_multi=True, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + else: + insertion_bases_list = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + return_multi=True, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + if len(insertion_bases_list) < 2: + hetero_InsIns_probabilities[idx] = 0 + continue + insertion_bases, another_insertion_bases = insertion_bases_list + + reference_base = reference_sequence[tensor_position_center] + alternate_base = insertion_bases + + alternate_base_1 = another_insertion_bases + alternate_base_2 = alternate_base + if alternate_base_1 != alternate_base_2: + alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) + else: + hetero_InsIns_probabilities[idx] = 0 + continue + + elif is_homo_deletion: + variant_length = None + idx = homo_Del_probabilities.index(maximum_probability) + if add_indel_length: + variant_length = homo_Del_lengths[idx] + + deletion_bases = deletion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_deletion_length=variant_length if variant_length and variant_length < VariantLength.max else None, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + deletion_length = len(deletion_bases) + if deletion_length == 0: + homo_Del_probabilities[idx] = 0 + continue + reference_base = reference_sequence[tensor_position_center] + deletion_bases + alternate_base = reference_base[0] + + elif is_hetero_ACGT_Del: + variant_length = None + idx = hetero_ACGT_Del_probabilities.index(maximum_probability) + if add_indel_length: + variant_length = hetero_ACGT_Del_lengths[idx] + hetero_Del_base = hetero_ACGT_Del_bases[idx] + else: + hetero_Del_base = ACGT[idx] + deletion_bases = deletion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_deletion_length=variant_length if variant_length and variant_length < VariantLength.max else None, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + deletion_length = len(deletion_bases) + if deletion_length == 0: + hetero_ACGT_Del_probabilities[idx] = 0 + continue + reference_base = reference_sequence[tensor_position_center] + deletion_bases + alternate_base = reference_base[0] + + is_SNP_Del_multi = hetero_Del_base != reference_base[0] + if is_SNP_Del_multi: + alternate_base_1 = alternate_base + alternate_base_2 = hetero_Del_base + reference_base[1:] + alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) + + elif is_hetero_DelDel: + deletion_bases_list = [] + idx = hetero_DelDel_probabilities.index(maximum_probability) + if add_indel_length: + variant_length_1, variant_length_2 = sorted(hetero_DelDel_length_tuples[idx], + reverse=True) # longer deletion should be in first position + deletion_base1 = deletion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_deletion_length=variant_length_1 if variant_length_1 and variant_length_1 < VariantLength.max else None, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + if len(deletion_base1) > 0: + deletion_base2 = deletion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_deletion_length=variant_length_2 if variant_length_2 and variant_length_2 < VariantLength.max else None, + deletion_bases_to_ignore=deletion_base1, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + if len(deletion_base2) > 0: + deletion_bases_list = [deletion_base1, deletion_base2] if len(deletion_base1) > len( + deletion_base2) else [deletion_base2, deletion_base1] + if len(deletion_bases_list) < 2: + deletion_bases_list = deletion_bases_using_alt_info_from( + return_multi=True, + alt_info_dict=alt_info_dict, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + else: + deletion_bases_list = deletion_bases_using_alt_info_from( + return_multi=True, + alt_info_dict=alt_info_dict, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + + if len(deletion_bases_list) < 2: + hetero_DelDel_probabilities[idx] = 0 + continue + + deletion_bases, deletion_bases1 = deletion_bases_list + + reference_base = reference_sequence[tensor_position_center] + deletion_bases + alternate_base = reference_base[0] + + alternate_base_1 = alternate_base + alternate_base_2 = reference_base[0] + reference_base[len(deletion_bases1) + 1:] + if ( + alternate_base_1 != alternate_base_2 and + reference_base != alternate_base_1 and reference_base != alternate_base_2 + ): + alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) + else: + hetero_DelDel_probabilities[idx] = 0 + continue + + elif is_insertion_and_deletion: + variant_length_1, variant_length_2 = None, None + idx = hetero_InsDel_probabilities.index(maximum_probability) + if add_indel_length: + variant_length_1, variant_length_2 = hetero_InsDel_length_tuples[idx] + + insertion_bases = insertion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_insertion_length=variant_length_2 if variant_length_2 and variant_length_2 < VariantLength.max else None, + maximum_insertion_length=output_config.maximum_variant_length_that_need_infer + ) + insertion_length = len(insertion_bases) + + deletion_bases = deletion_bases_using_alt_info_from( + alt_info_dict=alt_info_dict, + propose_deletion_length=variant_length_1 if variant_length_1 and variant_length_1 < VariantLength.max else None, + maximum_deletion_length=output_config.maximum_variant_length_that_need_infer + ) + deletion_length = len(deletion_bases) + + if insertion_length == 0 or deletion_length == 0: + hetero_InsDel_probabilities[idx] = 0 + continue + reference_base = reference_sequence[tensor_position_center] + deletion_bases + alternate_base = "{},{}".format( + reference_base[0], + insertion_bases + reference_base[1:] + ) + + return ( + is_reference, is_homo_SNP, is_hetero_SNP, + is_homo_insertion, is_hetero_ACGT_Ins, is_hetero_InsIns, + is_homo_deletion, is_hetero_ACGT_Del, is_hetero_DelDel, + is_insertion_and_deletion + ), (reference_base, alternate_base), (maximum_probability) + + +def batch_output_for_ensemble(X, batch_chr_pos_seq, alt_info_list, batch_Y, output_config, output_utilities): + batch_size = len(batch_chr_pos_seq) + batch_gt21_probabilities, batch_genotype_probabilities, = batch_Y + + if len(batch_gt21_probabilities) != batch_size: + sys.exit( + "Inconsistent shape between input tensor and output predictions %d/%d" % + (batch_size, len(batch_gt21_probabilities)) + ) + + tensor_position_center = param.flankingBaseNum + + for ( + x, + chr_pos_seq, + gt21_probabilities, + genotype_probabilities, + alt_info + ) in zip( + X, + batch_chr_pos_seq, + batch_gt21_probabilities, + batch_genotype_probabilities, + alt_info_list + ): + if output_config.tensor_fn != 'PIPE': + chromosome, position, reference_sequence = chr_pos_seq.decode().rstrip().split(":") + else: + chromosome, position, reference_sequence = chr_pos_seq + + position = int(position) + + if reference_sequence[tensor_position_center] not in BASIC_BASES: + continue + + output_utilities.output( + "\t".join( + [ + chromosome, + str(position), + reference_sequence, + alt_info.decode(), + ' '.join(["{:0.6f}".format(p) for p in list(gt21_probabilities)]), + ' '.join(["{:0.6f}".format(p) for p in list(genotype_probabilities)])] + ) + ) + + +def batch_output(batch_chr_pos_seq, alt_info_list, batch_Y, output_config, output_utilities): + batch_size = len(batch_chr_pos_seq) + + batch_gt21_probabilities, batch_genotype_probabilities = batch_Y[:,:param.label_shape_cum[0]], batch_Y[:,param.label_shape_cum[0]:param.label_shape_cum[1]] + if len(batch_gt21_probabilities) != batch_size: + sys.exit( + "Inconsistent shape between input tensor and output predictions %d/%d" % + (batch_size, len(batch_gt21_probabilities)) + ) + batch_variant_length_probabilities_1, batch_variant_length_probabilities_2 = [0] * batch_size, [0] * batch_size + + if output_config.add_indel_length: + batch_variant_length_probabilities_1, batch_variant_length_probabilities_2 = batch_Y[:,param.label_shape_cum[1]:param.label_shape_cum[2]], batch_Y[:,param.label_shape_cum[2]:param.label_shape_cum[3]] + for ( + chr_pos_seq, + alt_info, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2 + ) in zip( + batch_chr_pos_seq, + alt_info_list, + batch_gt21_probabilities, + batch_genotype_probabilities, + batch_variant_length_probabilities_1, + batch_variant_length_probabilities_2 + ): + output_with( + chr_pos_seq, + alt_info, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + output_config, + output_utilities, + ) + + +def output_with( + chr_pos_seq, + alt_info, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + output_config, + output_utilities, +): + if type(chr_pos_seq) == np.memmap: + chr_pos_seq = chr_pos_seq[0].decode() + elif type(chr_pos_seq) == np.bytes_ or type(chr_pos_seq) == bytes: + chr_pos_seq = chr_pos_seq.decode() + + chromosome, position, reference_sequence = chr_pos_seq.rstrip().split(':') + position = int(position) + + tensor_position_center = param.flankingBaseNum + information_string = "P" if output_config.pileup else 'F' + + if type(alt_info) == np.memmap: + alt_info = alt_info[0].decode() + elif type(alt_info) == np.bytes_ or type(alt_info) == bytes: + alt_info = alt_info.decode() + + alt_info = alt_info.rstrip().split('-') + read_depth = int(alt_info[0]) # alt_info + indel_str = alt_info[1] if len(alt_info) > 1 else '' + seqs = indel_str.split(' ') + alt_info_dict = dict(zip(seqs[::2], [int(item) for item in seqs[1::2]])) if len(seqs) else {} + + output_info = output_from( + reference_sequence, + chromosome, + position, + tensor_position_center, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + output_config, + output_utilities, + alt_info_dict + ) + if output_info is None: + return + + ( + is_reference, is_homo_SNP, is_hetero_SNP, + is_homo_insertion, is_hetero_ACGT_Ins, is_hetero_InsIns, + is_homo_deletion, is_hetero_ACGT_Del, is_hetero_DelDel, + is_insertion_and_deletion + ), (reference_base, alternate_base), (maximum_probability) = output_info + + if not output_config.is_debug and ( + (not output_config.is_show_reference and is_reference) or + (not is_reference and reference_base == alternate_base) + ): + return + + if reference_base is None or alternate_base is None: + return + + is_multi = "," in str(alternate_base) + + # haploid (precise mode) + if output_config.is_haploid_precise_mode_enabled and ( + is_hetero_SNP or is_hetero_ACGT_Ins or is_hetero_InsIns or + is_hetero_ACGT_Del or is_hetero_DelDel or is_insertion_and_deletion + ): + return + # haploid (sensitive mode) + elif output_config.is_haploid_sensitive_mode_enabled and is_multi: + return + + # geno type string + if is_reference: + genotype_string = genotype_string_from(Genotype.homo_reference) + elif is_homo_SNP or is_homo_insertion or is_homo_deletion: + genotype_string = genotype_string_from(Genotype.homo_variant) + elif is_hetero_SNP or is_hetero_ACGT_Ins or is_hetero_InsIns or is_hetero_ACGT_Del or is_hetero_DelDel: + genotype_string = genotype_string_from(Genotype.hetero_variant) + if is_multi: + genotype_string = genotype_string_from(Genotype.hetero_variant_multi) + + # allele frequency / supported reads + + def decode_alt_info(alt_info_dict): + alt_type_list = [{}, {}, {}] # SNP I D + for alt_type, count in alt_info_dict.items(): + count = int(count) + if alt_type[0] == 'X': + alt_type_list[0][alt_type[1]] = count + elif alt_type[0] == 'I': + alt_type_list[1][alt_type[1:]] = count + elif alt_type[0] == 'D': + alt_type_list[2][alt_type[1:]] = count + return alt_type_list + + alt_type_list = decode_alt_info(alt_info_dict) + supported_reads_count = 0 + ref_count, alt_list_count = 0, [] + snp_num = sum([item for item in alt_type_list[0].values()]) if len(alt_type_list[0]) else 0 + insert_num = sum([item for item in alt_type_list[1].values()]) if len(alt_type_list[1]) else 0 + del_num = sum([item for item in alt_type_list[2].values()]) if len(alt_type_list[2]) else 0 + ref_count = max(0, read_depth - snp_num - insert_num - del_num) + if is_reference: + supported_reads_count = ref_count + alternate_base = "." + + elif is_homo_SNP or is_hetero_SNP: + for base in str(alternate_base): + if base == ',': + continue + supported_reads_count += alt_type_list[0][base] if base in alt_type_list[0] else 0 + alt_list_count.append(supported_reads_count) + elif is_homo_insertion or is_hetero_InsIns: + base_list = alternate_base.split(',') + for ins_bases in base_list: + supported_reads_for_long_ins = get_long_indel_read_count(alt_info=alt_type_list[1], + proposed_ins_base=ins_bases, + is_del=False) if output_config.enable_long_indel else 0 + insertion_type_reads_count = (alt_type_list[1][ins_bases] if ins_bases in alt_type_list[1] else 0) + supported_reads_for_long_ins + supported_reads_count += insertion_type_reads_count + alt_list_count.append(insertion_type_reads_count) + elif is_hetero_ACGT_Ins: + is_SNP_Ins_multi = is_multi + SNP_base = alternate_base.split(",")[0][0] if is_SNP_Ins_multi else None + ins_bases = alternate_base.split(",")[1] if is_SNP_Ins_multi else alternate_base + + supported_reads_for_SNP = ( + alt_type_list[0][SNP_base] if SNP_base in alt_type_list[0] else 0) if is_SNP_Ins_multi else 0 + + supported_reads_for_long_ins = get_long_indel_read_count(alt_info=alt_type_list[1], + proposed_ins_base=ins_bases, is_del=False) if output_config.enable_long_indel else 0 + + supported_reads_for_ins = (alt_type_list[1][ins_bases] if ins_bases in alt_type_list[1] else 0) + supported_reads_for_long_ins + supported_reads_count = supported_reads_for_ins + supported_reads_for_SNP + if SNP_base: + alt_list_count.append(supported_reads_for_SNP) + alt_list_count.append(supported_reads_for_ins) + elif is_homo_deletion or is_hetero_DelDel: + if len(alt_type_list[2]) > 0: + if is_homo_deletion: + del_bases = reference_base[1:] if len(reference_base) > 1 else None + supported_reads_for_long_del = get_long_indel_read_count(alt_info=alt_type_list[2], + propose_del_base_length=len(del_bases)) if output_config.enable_long_indel else 0 + supported_reads_count = (alt_type_list[2][del_bases] if del_bases in alt_type_list[2] else 0) + supported_reads_for_long_del + alt_list_count.append(supported_reads_count) + elif is_hetero_DelDel and len(alt_type_list[2]) > 1: + for _bases in alternate_base.split(','): + _alt_len = len(reference_base) - len(_bases) + _tmp_cnt = [alt_type_list[2][_i] for _i in alt_type_list[2] if len(_i) == _alt_len] + supported_reads_for_long_del = get_long_indel_read_count(alt_info=alt_type_list[2], + propose_del_base_length=_alt_len) if output_config.enable_long_indel else 0 + _read_count = (_tmp_cnt[0] if len(_tmp_cnt) > 0 else 0) + supported_reads_for_long_del + alt_list_count.append(_read_count) + supported_reads_count += _read_count + elif is_hetero_ACGT_Del: + alt_list = alternate_base.split(",") + is_SNP_Del_multi = False if len(alt_list) == 0 else is_multi + SNP_base = (alt_list[1][0] if len(alt_list) > 1 else None) if is_SNP_Del_multi else None + supported_reads_for_SNP = ( + alt_type_list[0][SNP_base] if SNP_base in alt_type_list[0] else 0) if is_SNP_Del_multi else 0 + + del_bases = reference_base[1:] if len(reference_base) > 1 else None + supported_reads_for_long_del = get_long_indel_read_count(alt_info=alt_type_list[2], + propose_del_base_length=len( + del_bases)) if output_config.enable_long_indel else 0 + supported_reads_for_del = (alt_type_list[2][del_bases] if del_bases in alt_type_list[2] else 0) + supported_reads_for_long_del + supported_reads_count = supported_reads_for_del + supported_reads_for_SNP + if SNP_base: + alt_list_count.append(supported_reads_for_SNP) + alt_list_count.append(supported_reads_for_del) + elif is_insertion_and_deletion: + for _bases in alternate_base.split(','): + _alt_len = len(reference_base) - len(_bases) + if _alt_len < 0: #ins + ins_bases = _bases[:-(len(reference_base) - 1)] if len(reference_base) > 1 else _bases + supported_reads_for_long_ins = get_long_indel_read_count(alt_info=alt_type_list[1], + proposed_ins_base=ins_bases, + is_del=False) if output_config.enable_long_indel else 0 + _read_count = (alt_type_list[1][ins_bases] if ins_bases in alt_type_list[1] else 0) + supported_reads_for_long_ins + else: # del + _tmp_cnt = [alt_type_list[2][_i] for _i in alt_type_list[2] if len(_i) == _alt_len] + supported_reads_for_long_del = get_long_indel_read_count(alt_info=alt_type_list[2], + propose_del_base_length=_alt_len) if output_config.enable_long_indel else 0 + _read_count = (_tmp_cnt[0] if len(_tmp_cnt) > 0 else 0) + supported_reads_for_long_del + alt_list_count.append(_read_count) + supported_reads_count += _read_count + + allele_frequency = ((supported_reads_count + 0.0) / read_depth) if read_depth != 0 else 0.0 + if allele_frequency > 1: + allele_frequency = 1 + + # quality score + quality_score = quality_score_from(maximum_probability) + + # replace genotype string if any haploid mode enabled + if output_config.is_haploid_precise_mode_enabled or output_config.is_haploid_sensitive_mode_enabled: + genotype_string = "1" if "1" in genotype_string else "0" + + # filtration value + filtration_value = filtration_value_from( + quality_score_for_pass=output_config.quality_score_for_pass, + quality_score=quality_score, + is_reference=is_reference + ) + + if output_config.is_debug: + output_utilities.print_debug_message( + chromosome, + position, + gt21_probabilities, + genotype_probabilities, + variant_length_probabilities_1, + variant_length_probabilities_2, + "Normal output" if not is_reference else "Reference" + ) + else: + if output_config.gvcf: + + # allele depth + ad_alt = ',' + ','.join([str(item) for item in alt_list_count]) + allele_depth = str(ref_count) + (ad_alt if len(alt_list_count) else "") + + PLs = compute_PL(genotype_string, genotype_probabilities, gt21_probabilities, reference_base, + alternate_base) + + PLs = ','.join([str(x) for x in PLs]) + + output_utilities.output("%s\t%d\t.\t%s\t%s\t%.2f\t%s\t%s\tGT:GQ:DP:AD:AF:PL\t%s:%d:%d:%s:%.4f:%s" % ( + chromosome, + position, + reference_base, + alternate_base, + quality_score, + filtration_value, + information_string, + genotype_string, + quality_score, + read_depth, + allele_depth, + allele_frequency, + PLs + )) + else: + output_utilities.output("%s\t%d\t.\t%s\t%s\t%.2f\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( + chromosome, + position, + reference_base, + alternate_base, + quality_score, + filtration_value, + information_string, + genotype_string, + quality_score, + read_depth, + allele_frequency + )) + + +def compute_PL(genotype_string, genotype_probabilities, gt21_probabilities, reference_base, alternate_base): + ''' + PL computation + for bi-allelic: AA(00), AB(01), BB(11) + for tri-allielic: AA(00),AB(01), BB(11), AC(02), BC(12), CC(22) + ''' + alt_array = alternate_base.split(',') + alt_num = len(alt_array) + + genotypes = {1: [[0, 0], [0, 1], [1, 1]], 2: [[0, 0], [0, 1], [1, 1], [0, 2], [1, 2], [2, 2]]} + likelihoods = [] + reference_base = BASE2ACGT[reference_base] if len(reference_base) == 1 else reference_base + all_base = [reference_base] + all_base.extend(alt_array) + for encoded_genotype in genotypes[alt_num]: + # obtain the genotype probability from the 21 gt + + partial_label_1 = partial_label_from(reference_base, all_base[encoded_genotype[0]]) + partial_label_2 = partial_label_from(reference_base, all_base[encoded_genotype[1]]) + gt21_label = mix_two_partial_labels(partial_label_1, partial_label_2) + try: + gt21_prob_index = gt21_enum_from_label(gt21_label) + except: + #skip N positions + return [990 * len(genotypes[alt_num])] + genotype_prob_21 = gt21_probabilities[gt21_prob_index] + + # obtain the genotype probability from 3 zygosity + _genotype = genotype_enum_for_task(genotype_enum_from(encoded_genotype[0], encoded_genotype[1])) + genotype_prob_zygosity = genotype_probabilities[_genotype] + + # chain probability + _p = genotype_prob_21 * genotype_prob_zygosity + # _p = genotype_prob_21 + likelihoods.append(_p) + pass + + # genotype likelihood normalization + # p/sum(p) + + sum_p = sum(likelihoods) + LOG_10 = math.log(10.0) + likelihoods = [x / sum_p for x in likelihoods] + # phred transformation + + # avoid domain error + add_val = 1e-8 + likelihoods = [x+add_val for x in likelihoods] + # -10*log10(x/sum_p) = -10*(log10(x) - log10(sum_p)) + + PLs = [-10 * (log(x) / LOG_10) for x in likelihoods] + min_PL = min(PLs) + + PLs = [int(math.ceil(x - min_PL)) for x in PLs] + return PLs + +def call_variants(args, output_config, output_utilities): + use_gpu = args.use_gpu + if use_gpu: + gpus = tf.config.experimental.list_physical_devices('GPU') + tf.config.experimental.set_virtual_device_configuration(gpus[0], [ + tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + global param + if args.pileup: + import shared.param_p as param + from clair3.model import Clair3_P + m = Clair3_P(add_indel_length=args.add_indel_length, predict=True) + else: + import shared.param_f as param + from clair3.model import Clair3_F + m = Clair3_F(add_indel_length=args.add_indel_length, predict=True) + + m.load_weights(args.chkpnt_fn) + + output_utilities.gen_output_file() + output_utilities.output_header() + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + full_alignment_mode = not args.pileup + + tensor_generator = utils.tensor_generator_from( + args.tensor_fn, param.predictBatchSize, args.pileup, args.platform) + logging.info("Calling variants ...") + variant_call_start_time = time() + + is_finish_loaded_all_mini_batches = False + batch_output_method = batch_output_for_ensemble if output_config.is_output_for_ensemble else batch_output + mini_batches_loaded = [] + mini_batches_to_output = [] + + def load_mini_batch(): + try: + mini_batches_loaded.append(next(tensor_generator)) + except StopIteration: + return + + total = 0 #start, end + if not args.is_from_tables: + apply_threading = False + if apply_threading: + while True: + thread_pool = [] + + if len(mini_batches_to_output) > 0: + mini_batch = mini_batches_to_output.pop(0) + X, position, alt_info_list = mini_batch + prediction = m.predict_on_batch(X) + total += len(X) + thread_pool.append(Thread( + target=batch_output_method, + args=(position, alt_info_list, prediction, output_config, output_utilities) + )) + + if not is_finish_loaded_all_mini_batches: + thread_pool.append(Thread(target=load_mini_batch)) + + for t in thread_pool: + t.start() + for t in thread_pool: + t.join() + + is_finish_loaded_all_mini_batches = len(mini_batches_loaded) == 0 + while len(mini_batches_loaded) > 0: + mini_batch = mini_batches_loaded.pop(0) + mini_batches_to_output.append(mini_batch) + + is_nothing_to_predict_and_output = ( + len(thread_pool) <= 0 and len(mini_batches_to_output) <= 0 + ) + if is_finish_loaded_all_mini_batches and is_nothing_to_predict_and_output: + break + else: + while True: + if len(mini_batches_to_output) > 0: + mini_batch = mini_batches_to_output.pop(0) + X, position, alt_info_list = mini_batch + prediction = m.predict_on_batch(X) + total += len(X) + batch_output_method(position, alt_info_list, prediction, output_config, output_utilities) + + if not is_finish_loaded_all_mini_batches: + load_mini_batch() + + is_finish_loaded_all_mini_batches = len(mini_batches_loaded) == 0 + while len(mini_batches_loaded) > 0: + mini_batch = mini_batches_loaded.pop(0) + mini_batches_to_output.append(mini_batch) + + is_nothing_to_predict_and_output = len(mini_batches_to_output) <= 0 + if is_finish_loaded_all_mini_batches and is_nothing_to_predict_and_output: + break + + if chunk_id is not None: + logging.info("Total processed positions in {} (chunk {}/{}) : {}".format(args.ctgName, chunk_id+1, chunk_num, total)) + elif full_alignment_mode: + try: + chunk_infos = args.call_fn.split('.')[-2] + c_id, c_num = chunk_infos.split('_') + c_id = int(c_id) + 1 # 0-index to 1-index + logging.info("Total processed positions in {} (chunk {}/{}) : {}".format(args.ctgName, c_id, c_num, total)) + except: + logging.info("Total processed positions in {} : {}".format(args.ctgName, total)) + else: + logging.info("Total processed positions in {} : {}".format(args.ctgName, total)) + if full_alignment_mode and total == 0: + logging.info(log_error("[ERROR] No full-alignment output for file {}/{}".format(args.ctgName, args.call_fn))) + else: + dataset = tables.open_file(args.tensor_fn, 'r').root + batch_size = param.predictBatchSize + dataset_size = len(dataset.label) + chunk_start_pos = 0 + # process by chunk windows + if chunk_id is not None and chunk_num is not None: + chunk_dataset_size = dataset_size // chunk_num if dataset_size % chunk_num == 0 else dataset_size // chunk_num + 1 + chunk_start_pos = chunk_id * chunk_dataset_size + dataset_size = chunk_dataset_size + num_epoch = dataset_size // batch_size if dataset_size % batch_size == 0 else dataset_size // batch_size + 1 + + for idx in range(num_epoch): + position_matrix = dataset.position_matrix[ + chunk_start_pos + idx * batch_size:chunk_start_pos + (idx + 1) * batch_size] + position = list( + dataset.position[chunk_start_pos + idx * batch_size:chunk_start_pos + (idx + 1) * batch_size].flatten()) + alt_info_list = list( + dataset.alt_info[chunk_start_pos + idx * batch_size:chunk_start_pos + (idx + 1) * batch_size].flatten()) + + prediction = m.predict_on_batch(position_matrix) + batch_output_method(position, alt_info_list, prediction, output_config, output_utilities) + total += len(position_matrix) + + logging.info("Total time elapsed: %.2f s" % (time() - variant_call_start_time)) + + output_utilities.close_opened_files() + # remove file if on variant in output + if os.path.exists(args.call_fn): + for row in open(args.call_fn, 'r'): + if row[0] != '#': + return + logging.info("[INFO] No vcf output for file {}, remove empty file".format(args.call_fn)) + os.remove(args.call_fn) + + +def call_variants_with_probabilities_input(args, output_config, output_utilities): + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + logging.info("Calling variants ...") + variant_call_start_time = time() + + batch_output_method = batch_output_for_ensemble if output_config.is_output_for_ensemble else batch_output + + prediction_path = args.tensor_fn + '.prediction' + if not os.path.exists(prediction_path): + return + + output_utilities.gen_output_file() + prediction = np.load(prediction_path, mmap_mode='r') + position = np.load(args.tensor_fn + '.position', mmap_mode='r') + alt_info = np.load(args.tensor_fn + '.alt_info', mmap_mode='r') + + global param + if args.pileup: + import shared.param_p as param + else: + import shared.param_f as param + + output_utilities.output_header() + batch_size = param.predictBatchSize + dataset_size = len(prediction) + chunk_start_pos = 0 + # process by chunk windows + if chunk_id is not None and chunk_num is not None: + chunk_dataset_size = dataset_size // chunk_num if dataset_size % chunk_num == 0 else dataset_size // chunk_num + 1 + chunk_start_pos = chunk_id * chunk_dataset_size + dataset_size = chunk_dataset_size + num_epoch = dataset_size // batch_size if dataset_size % batch_size == 0 else dataset_size // batch_size + 1 + + for idx in range(num_epoch): + start_pos = chunk_start_pos + idx * batch_size + end_pos = min(chunk_start_pos + (idx + 1) * batch_size, len(prediction)) + batch_prediction = prediction[start_pos:end_pos] + batch_position = position[start_pos:end_pos] + batch_alt_info = alt_info[start_pos:end_pos] + if test_pos: + find_test_pos = False + for item in batch_position: + if str(test_pos) == item[0].decode().split(':')[1]: + find_test_pos = True + if not find_test_pos: + continue + batch_output_method(batch_position, batch_alt_info, batch_prediction, output_config, output_utilities) + + logging.info("Total time elapsed: %.2f s" % (time() - variant_call_start_time)) + + output_utilities.close_opened_files() + # remove file if on variant in output + if os.path.exists(args.call_fn): + vcf_file = open(args.call_fn, 'r').readlines() + if not len(vcf_file): + os.remove(args.call_fn) + for row in vcf_file: + if row[0] != '#': + return + logging.info("[INFO] No vcf output for file {}, remove empty file".format(args.call_fn)) + os.remove(args.call_fn) + + +def DataGenerator(dataset, num_epoch, batch_size, chunk_start_pos, chunk_end_pos): + for idx in range(num_epoch): + start_pos = chunk_start_pos + idx * batch_size + end_pos = min(chunk_start_pos + (idx + 1) * batch_size, chunk_end_pos) + position_matrix = dataset.position_matrix[start_pos:end_pos] + position = dataset.position[start_pos:end_pos] # .flatten() + alt_info_list = dataset.alt_info[start_pos:end_pos] # .flatten() + yield position_matrix, position, alt_info_list + + +def predict(args, output_config, output_utilities): + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + predict_fn = args.predict_fn + use_gpu = args.use_gpu + logging.info("[INFO] Make prediction ...") + variant_call_start_time = time() + add_indel_length = args.add_indel_length + + if use_gpu: + gpus = tf.config.experimental.list_physical_devices('GPU') + tf.config.experimental.set_virtual_device_configuration(gpus[0], [ + tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)]) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = "" + + if args.pileup: + from clair3.model import Clair3_P + m = Clair3_P(add_indel_length=args.add_indel_length, predict=True) + else: + from clair3.model import Clair3_F + m = Clair3_F(add_indel_length=args.add_indel_length, predict=True) + + batch_output_method = batch_output_for_ensemble if output_config.is_output_for_ensemble else batch_output + m.load_weights(args.chkpnt_fn) + + total = 0 + if not args.is_from_tables: + output_utilities.output_header() + is_finish_loaded_all_mini_batches = False + mini_batches_loaded = [] + mini_batches_to_output = [] + + def load_mini_batch(): + try: + mini_batches_loaded.append(next(tensor_generator)) + except StopIteration: + return + + tensor_generator = utils.tensor_generator_from(args.tensor_fn, param.predictBatchSize, args.pileup, + args.platform) + while True: + thread_pool = [] + if len(mini_batches_to_output) > 0: + mini_batch = mini_batches_to_output.pop(0) + X, position, alt_info_list = mini_batch + prediction = m.predict_on_batch(X) + total += len(X) + thread_pool.append(Thread( + target=batch_output_method, + args=(position, alt_info_list, prediction, output_config, output_utilities) + )) + + if not is_finish_loaded_all_mini_batches: + thread_pool.append(Thread(target=load_mini_batch)) + + for t in thread_pool: + t.start() + for t in thread_pool: + t.join() + + is_finish_loaded_all_mini_batches = len(mini_batches_loaded) == 0 + while len(mini_batches_loaded) > 0: + mini_batch = mini_batches_loaded.pop(0) + mini_batches_to_output.append(mini_batch) + + is_nothing_to_predict_and_output = ( + len(thread_pool) <= 0 and len(mini_batches_to_output) <= 0 + ) + if is_finish_loaded_all_mini_batches and is_nothing_to_predict_and_output: + break + logging.info("Total process positions: {}".format(total)) + + else: + if not os.path.exists(args.tensor_fn): + logging.info("skip {}, not existing chunk_id".format(args.tensor_fn)) + return + dataset = tables.open_file(args.tensor_fn, 'r').root + batch_size = param.predictBatchSize + dataset_size = len(dataset.label) + chunk_start_pos, chunk_end_pos = 0, dataset_size + tensor_shape = param.ont_input_shape if args.platform == 'ont' else param.input_shape + # process by chunk windows + if chunk_id is not None and chunk_num is not None: + chunk_dataset_size = dataset_size // chunk_num if dataset_size % chunk_num == 0 else dataset_size // chunk_num + 1 + chunk_start_pos = chunk_id * chunk_dataset_size + dataset_size = min(chunk_dataset_size, dataset_size - chunk_start_pos) + chunk_end_pos = min(chunk_start_pos + dataset_size, chunk_end_pos) + num_epoch = dataset_size // batch_size if dataset_size % batch_size == 0 else dataset_size // batch_size + 1 + label_size = sum(param.label_shape) if add_indel_length else sum(param.label_shape[:2]) + prediction_memmap = np.lib.format.open_memmap(predict_fn + '.prediction', dtype=np.float, mode='w+', + shape=(dataset_size, label_size)) + position_memmap = np.lib.format.open_memmap(predict_fn + '.position', dtype='S100', mode='w+', + shape=(dataset_size, 1)) + alt_info_memmap = np.lib.format.open_memmap(predict_fn + '.alt_info', dtype='S2000', mode='w+', + shape=(dataset_size, 1)) + TensorShape = (tf.TensorShape([None] + tensor_shape), tf.TensorShape([None, 1]), tf.TensorShape([None, 1])) + TensorDtype = (tf.int32, tf.string, tf.string) + + predict_dataset = tf.data.Dataset.from_generator( + lambda: DataGenerator(dataset, num_epoch, batch_size, chunk_start_pos, chunk_end_pos), TensorDtype, + TensorShape).prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + dataset_iter = iter(predict_dataset) + for idx in range(num_epoch): + position_matrix, position, alt_info_list = next(dataset_iter) + prediction = m.predict_on_batch(position_matrix) + start_pos = idx * batch_size + end_pos = min((idx + 1) * batch_size, dataset_size) + prediction_memmap[start_pos:end_pos] = prediction + position_memmap[start_pos:end_pos] = position.numpy() + alt_info_memmap[start_pos:end_pos] = alt_info_list.numpy() + + total += len(position_matrix) + logging.info("Total processed positions/bin file size: {}/{}".format(total, len(dataset.label))) + logging.info("Total time elapsed: %.2f s" % (time() - variant_call_start_time)) + + +def main(): + parser = ArgumentParser(description="Call variants using a trained model and tensors of candidate variants") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--tensor_fn', type=str, default="PIPE", + help="Tensor input filename, or stdin if not set") + + parser.add_argument('--chkpnt_fn', type=str, default=None, required=True, + help="Input a trained model for variant calling, required") + + parser.add_argument('--call_fn', type=str, default="clair3", + help="VCF output filename, or stdout if not set") + + parser.add_argument('--gvcf', type=str2bool, default=False, + help="Enable GVCF output, default: disabled") + + parser.add_argument('--ref_fn', type=str, default=None, + help="Reference fasta file input, required if --gvcf is enabled") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of the sequence to be processed") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file, optional") + + parser.add_argument('--qual', type=int, default=2, + help="If set, variants with >=$qual will be marked 'PASS', or 'LowQual' otherwise, optional") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + # options for advanced users + parser.add_argument('--temp_file_dir', type=str, default='./', + help="EXPERIMENTAL: The cache directory for storing temporary non-variant information if --gvcf is enabled, default: %(default)s") + + parser.add_argument('--haploid_precise', action='store_true', + help="EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant") + + parser.add_argument('--haploid_sensitive', action='store_true', + help="EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant") + + parser.add_argument('--enable_long_indel', type=str2bool, default=False, + help="EXPERIMENTAL: Enable long Indel variants(>50 bp) calling") + + # options for debug purpose + parser.add_argument('--use_gpu', type=str2bool, default=False, + help="DEBUG: Use GPU for calling. Speed up is mostly insignficiant. Only use this for building your own pipeline") + + parser.add_argument('--predict_fn', type=str, default=None, + help="DEBUG: Output network output probabilities for further analysis") + + parser.add_argument('--input_probabilities', action='store_true', + help="DEBUG: Use network probability outputs as input and generate variants from them") + + parser.add_argument('--output_probabilities', action='store_true', + help="DEBUG: Output the network probabilities of gt21, genotype, indel_length_1 and indel_length_2") + + # options for internal process control + ## In pileup mode or not (full alignment mode), default: False + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + ## Include indel length in training and calling, false for pileup and true for raw alignment + parser.add_argument('--add_indel_length', type=str2bool, default=False, + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Enable debug mode, default: False, optional + parser.add_argument('--debug', action='store_true', + help=SUPPRESS) + + ## Generating outputs for ensemble model calling + parser.add_argument('--output_for_ensemble', action='store_true', + help=SUPPRESS) + + ## Use bin file from pytables to speed up calling. + parser.add_argument('--is_from_tables', action='store_true', + help=SUPPRESS) + + ## Output reference calls + parser.add_argument('--showRef', action='store_true', + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Run(args) + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/clair3/Train.py b/benchmarks/nn-variant/Clair3/clair3/Train.py new file mode 100644 index 0000000..01f596f --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/Train.py @@ -0,0 +1,352 @@ +import logging +import random +import numpy as np +from argparse import ArgumentParser, SUPPRESS +import tensorflow_addons as tfa +import tensorflow as tf +import tables +import os +import sys +from itertools import accumulate + +import clair3.model as model_path +from shared.utils import str2bool + +logging.basicConfig(format='%(message)s', level=logging.INFO) +tables.set_blosc_max_threads(512) +os.environ['NUMEXPR_MAX_THREADS'] = '64' +os.environ['NUMEXPR_NUM_THREADS'] = '8' + + +def get_label_task(label, label_shape_cum, task): + if task == 0: + return label[:label_shape_cum[task]] + elif task == len(label_shape_cum) - 1: + return label[label_shape_cum[task - 1]:] + else: + return label[label_shape_cum[task - 1]:label_shape_cum[task]] + + +def cal_class_weight(samples_per_cls, no_of_classes, beta=0.999): + effective_num = 1.0 - np.power(beta, samples_per_cls) + cls_weights = (1.0 - beta) / np.array(effective_num) + cls_weights = cls_weights / np.sum(cls_weights) * no_of_classes + return cls_weights + + +class FocalLoss(tf.keras.losses.Loss): + """ + updated version of focal loss function, for multi class classification, we remove alpha parameter, which the loss + more stable, and add gradient clipping to avoid gradient explosion and precision overflow. + """ + + def __init__(self, label_shape_cum, task, effective_label_num=None, gamma=2): + super(FocalLoss, self).__init__() + self.gamma = gamma + self.cls_weights = None + if effective_label_num is not None: + task_label_num = get_label_task(effective_label_num, label_shape_cum, task) + cls_weights = cal_class_weight(task_label_num, len(task_label_num)) + cls_weights = tf.constant(cls_weights, dtype=tf.float32) + cls_weights = tf.expand_dims(cls_weights, axis=0) + self.cls_weights = cls_weights + + def call(self, y_true, y_pred): + y_pred = tf.clip_by_value(y_pred, clip_value_min=1e-9, clip_value_max=1 - 1e-9) + cross_entropy = -y_true * tf.math.log(y_pred) + weight = ((1 - y_pred) ** self.gamma) * y_true + FCLoss = cross_entropy * weight + if self.cls_weights is not None: + FCLoss = FCLoss * self.cls_weights + reduce_fl = tf.reduce_sum(FCLoss, axis=-1) + return reduce_fl + + +class DataSequence(tf.keras.utils.Sequence): + def __init__(self, data, chunk_list, param, tensor_shape, mini_epochs=1, add_indel_length=False, validation=False): + self.data = data + self.chunk_list = chunk_list + self.batch_size = param.trainBatchSize + self.chunk_size = param.chunk_size + self.chunks_per_batch = self.batch_size // self.chunk_size + self.label_shape_cum = param.label_shape_cum[0:4 if add_indel_length else 2] + self.mini_epochs = mini_epochs + self.mini_epochs_count = -1 + self.validation = validation + self.position_matrix = np.empty([self.batch_size] + tensor_shape, np.int32) + self.label = np.empty((self.batch_size, param.label_size), np.float32) + self.random_offset = 0 + self.on_epoch_end() + + def __len__(self): + return int((len(self.chunk_list) // self.chunks_per_batch) // self.mini_epochs) + + def __getitem__(self, index): + mini_epoch_offset = self.mini_epochs_count * self.__len__() + chunk_batch_list = self.chunk_list[(mini_epoch_offset + index) * self.chunks_per_batch:(mini_epoch_offset + index + 1) * self.chunks_per_batch] + for chunk_idx, (bin_id, chunk_id) in enumerate(chunk_batch_list): + start_pos = self.random_offset + chunk_id * self.chunk_size + self.position_matrix[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \ + self.data[bin_id].root.position_matrix[start_pos:start_pos + self.chunk_size] + self.label[chunk_idx * self.chunk_size:(chunk_idx + 1) * self.chunk_size] = \ + self.data[bin_id].root.label[start_pos:start_pos + self.chunk_size] + + return self.position_matrix, tuple( + np.split(self.label, self.label_shape_cum, axis=1)[:len(self.label_shape_cum)] + ) + + def on_epoch_end(self): + self.mini_epochs_count += 1 + if (self.mini_epochs_count % self.mini_epochs) == 0: + self.mini_epochs_count = 0 + if not self.validation: + self.random_offset = np.random.randint(0, self.chunk_size) + np.random.shuffle(self.chunk_list) + + +def get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch=10, training_dataset_percentage=None): + """ + get chunk list for training and validation data. we will randomly split training and validation dataset, + all training data is directly acquired from various tensor bin files. + + """ + need_split_validation_data = training_dataset_percentage is not None + all_shuffle_chunk_list = [] + training_chunk_list, validation_chunk_list = [], [] + for bin_idx, chunk_num in enumerate(chunk_offset): + current_chunk_list = [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)] + all_shuffle_chunk_list += current_chunk_list + if need_split_validation_data: + buffer_chunk_num = chunks_per_batch + if chunk_num < buffer_chunk_num: + training_chunk_list += [(bin_idx, chunk_idx) for chunk_idx in range(chunk_num)] + continue + + training_chunk_num = int((chunk_num - buffer_chunk_num) * training_dataset_percentage) + validation_chunk_num = int(chunk_num - buffer_chunk_num - training_chunk_num) + if training_chunk_num > 0: + training_chunk_list += current_chunk_list[:training_chunk_num] + if validation_chunk_num > 0: + validation_chunk_list += current_chunk_list[-validation_chunk_num:] + + if need_split_validation_data: + return np.array(training_chunk_list), np.array(validation_chunk_list) + + return np.array(all_shuffle_chunk_list[:train_chunk_num]), np.array(all_shuffle_chunk_list[train_chunk_num:]) + + +def exist_file_prefix(exclude_training_samples, f): + for prefix in exclude_training_samples: + if prefix in f: + return True + return False + + +def train_model(args): + platform = args.platform + pileup = args.pileup + add_indel_length = args.add_indel_length + exclude_training_samples = args.exclude_training_samples + exclude_training_samples = set(exclude_training_samples.split(',')) if exclude_training_samples else set() + add_validation_dataset = args.random_validation or (args.validation_fn is not None) + validation_fn = args.validation_fn + ochk_prefix = args.ochk_prefix if args.ochk_prefix is not None else "" + if pileup: + import shared.param_p as param + model = model_path.Clair3_P() + else: + import shared.param_f as param + model = model_path.Clair3_F(add_indel_length=add_indel_length) + + tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape + label_shape = param.label_shape + label_shape_cum = param.label_shape_cum + batch_size, chunk_size = param.trainBatchSize, param.chunk_size + assert batch_size % chunk_size == 0 + chunks_per_batch = batch_size // chunk_size + random.seed(param.RANDOM_SEED) + np.random.seed(param.RANDOM_SEED) + learning_rate = args.learning_rate if args.learning_rate else param.initialLearningRate + max_epoch = args.maxEpoch if args.maxEpoch else param.maxEpoch + task_num = 4 if add_indel_length else 2 + mini_epochs = args.mini_epochs + + def populate_dataset_table(file_list, file_path): + chunk_offset = np.zeros(len(file_list), dtype=int) + table_dataset_list = [] + for bin_idx, bin_file in enumerate(file_list): + table_dataset = tables.open_file(os.path.join(file_path, bin_file), 'r') + table_dataset_list.append(table_dataset) + chunk_num = (len(table_dataset.root.label) - batch_size) // chunk_size + chunk_offset[bin_idx] = chunk_num + return table_dataset_list, chunk_offset + + bin_list = os.listdir(args.bin_fn) + # default we exclude sample hg003 and all chr20 for training + bin_list = [f for f in bin_list if '_20_' not in f and not exist_file_prefix(exclude_training_samples, f)] + logging.info("[INFO] total {} training bin files: {}".format(len(bin_list), ','.join(bin_list))) + + effective_label_num = None + + table_dataset_list, chunk_offset = populate_dataset_table(bin_list, args.bin_fn) + + if validation_fn: + val_list = os.listdir(validation_fn) + logging.info("[INFO] total {} validation bin files: {}".format(len(val_list), ','.join(val_list))) + validate_table_dataset_list, validate_chunk_offset = populate_dataset_table(val_list, args.validation_fn) + + train_chunk_num = int(sum(chunk_offset)) + train_shuffle_chunk_list, _ = get_chunk_list(chunk_offset, train_chunk_num) + + validate_chunk_num = int(sum(validate_chunk_offset)) + validate_shuffle_chunk_list, _ = get_chunk_list(validate_chunk_offset, validate_chunk_num) + total_chunks = train_chunk_num + validate_chunk_num + else: + total_chunks = int(sum(chunk_offset)) + training_dataset_percentage = param.trainingDatasetPercentage if add_validation_dataset else None + if add_validation_dataset: + total_batches = total_chunks // chunks_per_batch + validate_chunk_num = int(max(1., np.floor(total_batches * (1 - training_dataset_percentage))) * chunks_per_batch) + # +++++++++++++**---- + # +:training *:buffer -:validation + # distribute one batch data as buffer for each bin file, avoiding shifting training data to validation data + train_chunk_num = int(total_chunks - validate_chunk_num) + else: + train_chunk_num = total_chunks + train_shuffle_chunk_list, validate_shuffle_chunk_list = get_chunk_list(chunk_offset, train_chunk_num, chunks_per_batch, training_dataset_percentage) + train_chunk_num = len(train_shuffle_chunk_list) + validate_chunk_num = len(validate_shuffle_chunk_list) + + train_seq = DataSequence(table_dataset_list, train_shuffle_chunk_list, param, tensor_shape, + mini_epochs=mini_epochs, add_indel_length=add_indel_length) + if add_validation_dataset: + val_seq = DataSequence(validate_table_dataset_list if validation_fn else table_dataset_list, validate_shuffle_chunk_list, param, tensor_shape, + mini_epochs=1, add_indel_length=add_indel_length, validation=True) + else: + val_seq = None + + total_steps = max_epoch * (train_chunk_num // chunks_per_batch) + + #RectifiedAdam with warmup start + optimizer = tfa.optimizers.Lookahead(tfa.optimizers.RectifiedAdam( + lr=learning_rate, + total_steps=total_steps, + warmup_proportion=0.1, + min_lr=learning_rate*0.75, + )) + + loss_func = [FocalLoss(label_shape_cum, task, effective_label_num) for task in range(task_num)] + loss_task = {"output_{}".format(task + 1): loss_func[task] for task in range(task_num)} + metrics = {"output_{}".format(task + 1): tfa.metrics.F1Score(num_classes=label_shape[task], average='micro') for + task in range(task_num)} + + model.compile( + loss=loss_task, + metrics=metrics, + optimizer=optimizer + ) + early_stop_callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10*mini_epochs, mode="min") + model_save_callback = tf.keras.callbacks.ModelCheckpoint(ochk_prefix + ".{epoch:02d}", period=1, save_weights_only=False) + model_best_callback = tf.keras.callbacks.ModelCheckpoint("best_val_loss", monitor='val_loss', save_best_only=True, mode="min") + train_log_callback = tf.keras.callbacks.CSVLogger("training.log", separator='\t') + + # Use first 20 element to initialize tensorflow model using graph mode + output = model(np.array(table_dataset_list[0].root.position_matrix[:20])) + logging.info(model.summary(print_fn=logging.info)) + + logging.info("[INFO] The size of dataset: {}".format(total_chunks * chunk_size)) + logging.info("[INFO] The training batch size: {}".format(batch_size)) + logging.info("[INFO] The training learning_rate: {}".format(learning_rate)) + logging.info("[INFO] Total training steps: {}".format(total_steps)) + logging.info("[INFO] Maximum training epoch: {}".format(max_epoch)) + logging.info("[INFO] Mini-epochs per epoch: {}".format(mini_epochs)) + logging.info("[INFO] Start training...") + + if args.chkpnt_fn is not None: + model.load_weights(args.chkpnt_fn) + logging.info("[INFO] Starting from model {}".format(args.chkpnt_fn)) + + train_history = model.fit(x=train_seq, + epochs=max_epoch * mini_epochs, + validation_data=val_seq, + callbacks=[early_stop_callback, + model_save_callback, + model_best_callback, + train_log_callback], + verbose=1, + shuffle=False) + + for table_dataset in table_dataset_list: + table_dataset.close() + + if validation_fn: + for table_dataset in validate_table_dataset_list: + table_dataset.close() + + # show the parameter set with the smallest validation loss + if 'val_loss' in train_history.history: + best_validation_epoch = np.argmin(np.array(train_history.history["val_loss"])) + 1 + logging.info("[INFO] Best validation loss at epoch: %d" % best_validation_epoch) + else: + best_train_epoch = np.argmin(np.array(train_history.history["loss"])) + 1 + logging.info("[INFO] Best train loss at epoch: %d" % best_train_epoch) + + +def main(): + parser = ArgumentParser(description="Train a Clair3 model") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--bin_fn', type=str, default="", required=True, + help="Binary tensor input generated by Tensor2Bin.py, support multiple bin readers using pytables") + + parser.add_argument('--chkpnt_fn', type=str, default=None, + help="Input a model to resume training or for fine-tuning") + + parser.add_argument('--ochk_prefix', type=str, default=None, required=True, + help="Prefix for model output after each epoch") + + # options for advanced users + parser.add_argument('--maxEpoch', type=int, default=None, + help="Maximum number of training epochs") + + parser.add_argument('--learning_rate', type=float, default=1e-3, + help="Set the initial learning rate, default: %(default)s") + + + parser.add_argument('--exclude_training_samples', type=str, default=None, + help="Define training samples to be excluded") + + parser.add_argument('--mini_epochs', type=int, default=1, + help="Number of mini-epochs per epoch") + + # Internal process control + ## In pileup training mode or not + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + ## Add indel length for training and calling, default true for full alignment + parser.add_argument('--add_indel_length', type=str2bool, default=False, + help=SUPPRESS) + + # mutually-incompatible validation options + vgrp = parser.add_mutually_exclusive_group() + vgrp.add_argument('--random_validation', action='store_true', + help="Use random sample of dataset for validation, default: %(default)s") + + vgrp.add_argument('--validation_fn', type=str, default=None, + help="Binary tensor input for use in validation: %(default)s") + + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + train_model(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/clair/post_processing/__init__.py b/benchmarks/nn-variant/Clair3/clair3/__init__.py similarity index 100% rename from benchmarks/nn-variant/clair/post_processing/__init__.py rename to benchmarks/nn-variant/Clair3/clair3/__init__.py diff --git a/benchmarks/nn-variant/Clair3/clair3/metrics/GetOverallMetrics.py b/benchmarks/nn-variant/Clair3/clair3/metrics/GetOverallMetrics.py new file mode 100644 index 0000000..0351429 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/metrics/GetOverallMetrics.py @@ -0,0 +1,214 @@ +import os +import sys +import shlex +import logging + +from sys import stderr +from subprocess import Popen +from argparse import ArgumentParser +from subprocess import PIPE + +logging.basicConfig(format='%(message)s', level=logging.INFO) + + +def subprocess_popen(args, stdin=None, stdout=PIPE, stderr=stderr, bufsize=8388608): + return Popen(args, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, universal_newlines=True) + + +def metrics(query_fp, query_tp, truth_fn, truth_tp): + # https://github.com/Illumina/hap.py/blob/master/doc/happy.md + precision = query_tp / (query_tp + query_fp) + recall = truth_tp / (truth_tp + truth_fn) + f1_score = 2 * precision * recall / (precision + recall) + return round(precision, 6), round(recall, 6), round(f1_score, 6) + + +def Cal(args): + happy_vcf_fn = args.happy_vcf_fn + contig_name = args.ctgName + output_fn = args.output_fn + + if output_fn: + output_file = open(output_fn, 'w') + else: + output_file = None + happy_vcf_unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (happy_vcf_fn))) + + truth_all_tp, query_all_tp, query_all_fp, truth_all_fn = 0, 0, 0, 0 + truth_snp_tp, query_snp_tp, query_snp_fp, truth_snp_fn = 0, 0, 0, 0 + truth_indel_tp, query_indel_tp, query_indel_fp, truth_indel_fn = 0, 0, 0, 0 + truth_ins_tp, query_ins_tp, query_ins_fp, truth_ins_fn = 0, 0, 0, 0 + truth_del_tp, query_del_tp, query_del_fp, truth_del_fn = 0, 0, 0, 0 + + for row in happy_vcf_unzip_process.stdout: + if row[0] == '#': + continue + columns = row.strip().split() + + ctg_name, pos = columns[0], int(columns[1]) + if contig_name is not None and ctg_name != contig_name: + continue + + FORMAT, TRUTH, QUERY = columns[8], columns[9], columns[10] + FORMAT = FORMAT.split(':') + TRUTH = TRUTH.split(':') + QUERY = QUERY.split(':') + + ft_dict = dict(zip(FORMAT, TRUTH)) + fq_dict = dict(zip(FORMAT, QUERY)) + + # hap.py vcf header + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##INFO= + ##FORMAT= + ##FORMAT= + + t_BD = ft_dict['BD'] if 'BD' in ft_dict else None + t_BI = ft_dict['BI'] if 'BI' in ft_dict else None + t_BVT = ft_dict['BVT'] if 'BVT' in ft_dict else None + q_BD = fq_dict['BD'] if 'BD' in fq_dict else None + q_BI = fq_dict['BI'] if 'BI' in fq_dict else None + q_BVT = fq_dict['BVT'] if 'BVT' in fq_dict else None + if not t_BD or not t_BI or not t_BVT or not q_BD or not q_BI or not q_BVT: + sys.exit("[ERROR] Happy format not match, exit!") + + query_fp = q_BD == 'FP' + query_tp = q_BD == 'TP' + truth_fn = t_BD == 'FN' + truth_tp = t_BD == 'TP' + + is_query_snp_fp = (q_BVT == 'SNP') and query_fp + is_query_snp_tp = (q_BVT == 'SNP') and query_tp + is_truth_snp_fn = (t_BVT == 'SNP') and truth_fn + is_truth_snp_tp = (t_BVT == 'SNP') and truth_tp + + is_query_indel_fp = (q_BVT == 'INDEL') and query_fp + is_query_indel_tp = (q_BVT == 'INDEL') and query_tp + is_truth_indel_fn = (t_BVT == 'INDEL') and truth_fn + is_truth_indel_tp = (t_BVT == 'INDEL') and truth_tp + + query_snp_fp = query_snp_fp + 1 if is_query_snp_fp else query_snp_fp + query_snp_tp = query_snp_tp + 1 if is_query_snp_tp else query_snp_tp + truth_snp_fn = truth_snp_fn + 1 if is_truth_snp_fn else truth_snp_fn + truth_snp_tp = truth_snp_tp + 1 if is_truth_snp_tp else truth_snp_tp + + query_indel_fp = query_indel_fp + 1 if is_query_indel_fp else query_indel_fp + query_indel_tp = query_indel_tp + 1 if is_query_indel_tp else query_indel_tp + truth_indel_fn = truth_indel_fn + 1 if is_truth_indel_fn else truth_indel_fn + truth_indel_tp = truth_indel_tp + 1 if is_truth_indel_tp else truth_indel_tp + + is_query_ins_fp = q_BI[0] == 'i' and is_query_indel_fp + is_query_ins_tp = q_BI[0] == 'i' and is_query_indel_tp + is_truth_ins_fn = t_BI[0] == 'i' and is_truth_indel_fn + is_truth_ins_tp = t_BI[0] == 'i' and is_truth_indel_tp + + is_query_del_fp = q_BI[0] == 'd' and is_query_indel_fp + is_query_del_tp = q_BI[0] == 'd' and is_query_indel_tp + is_truth_del_fn = t_BI[0] == 'd' and is_truth_indel_fn + is_truth_del_tp = t_BI[0] == 'd' and is_truth_indel_tp + + query_ins_fp = query_ins_fp + 1 if is_query_ins_fp else query_ins_fp + query_ins_tp = query_ins_tp + 1 if is_query_ins_tp else query_ins_tp + truth_ins_fn = truth_ins_fn + 1 if is_truth_ins_fn else truth_ins_fn + truth_ins_tp = truth_ins_tp + 1 if is_truth_ins_tp else truth_ins_tp + + query_del_fp = query_del_fp + 1 if is_query_del_fp else query_del_fp + query_del_tp = query_del_tp + 1 if is_query_del_tp else query_del_tp + truth_del_fn = truth_del_fn + 1 if is_truth_del_fn else truth_del_fn + truth_del_tp = truth_del_tp + 1 if is_truth_del_tp else truth_del_tp + + truth_all_tp = truth_snp_tp + truth_indel_tp + truth_all_fn = truth_snp_fn + truth_indel_fn + query_all_fp = query_snp_fp + query_indel_fp + query_all_tp = query_snp_tp + query_indel_tp + + # p->precision, r->recall, f1->f1_score + # a->overall, s->snp, id->indel, i->insertion, d->deletion + ap, ar, af1 = metrics(query_fp=query_all_fp, query_tp=query_all_tp, truth_fn=truth_all_fn, truth_tp=truth_all_tp) + sp, sr, sf1 = metrics(query_fp=query_snp_fp, query_tp=query_snp_tp, truth_fn=truth_snp_fn, truth_tp=truth_snp_tp) + idp, idr, idf1 = metrics(query_fp=query_indel_fp, query_tp=query_indel_tp, truth_fn=truth_indel_fn, truth_tp=truth_indel_tp) + ip, ir, if1 = metrics(query_fp=query_ins_fp, query_tp=query_ins_tp, truth_fn=truth_ins_fn, truth_tp=truth_ins_tp) + dp, dr, df1 = metrics(query_fp=query_del_fp, query_tp=query_del_tp, truth_fn=truth_del_fn, truth_tp=truth_del_tp) + + print (''.join([item.ljust(20) for item in ["VariantType", 'TRUTH.FP', 'TRUTH.FN', 'TRUTH.TP','QUERY.TP', 'METRIC.Precision', 'METRIC.Recall', 'METRIC.F1_Score']]), file=output_file) + print (''.join([str(item).ljust(20) for item in ["Overall", query_all_fp, truth_all_fn, truth_all_tp, query_all_tp, ap, ar, af1]]), file=output_file) + print (''.join([str(item).ljust(20) for item in ["SNP", query_snp_fp, truth_snp_fn, truth_snp_tp, query_snp_tp, sp, sr, sf1]]),file=output_file) + print (''.join([str(item).ljust(20) for item in ["INDEL", query_indel_fp, truth_indel_fn, truth_indel_tp, query_indel_tp, idp, idr, idf1]]), file=output_file) + print (''.join([str(item).ljust(20) for item in ["INS", query_ins_fp, truth_ins_fn, truth_ins_tp, query_ins_tp, ip, ir, if1]]), file=output_file) + print (''.join([str(item).ljust(20) for item in ["DEL", query_del_fp, truth_del_fn, truth_del_tp, query_del_tp, dp, dr, df1]]), file=output_file) + print('\n', file=output_file) + + # print log_happy output + pass_row = [] + snp_row = [] + indel_row = [] + if args.log_happy and os.path.exists(args.log_happy): + log_happy = open(args.log_happy) + for row in log_happy.readlines(): + if 'PASS' not in row: + continue + pass_row.append(row) + + for row in pass_row: + + if 'INDEL' in row: + row = row.split() + tp, fn, fp = row[3], row[4], row[6] + precision, recall, f1 = row[11], row[10], row[13] + indel_row = [fp, fn, tp, precision, recall, f1] + if 'SNP' in row: + row = row.split() + tp, fn, fp = row[3], row[4], row[6] + precision, recall, f1 = row[11], row[10], row[13] + snp_row = [fp, fn, tp, precision, recall, f1] + print('Double check with happy log:', file=output_file) + print(' '.join(['%.6f' % item for item in [sp, sr, sf1, idp, idr, idf1]] + [str(item) for item in + [query_snp_fp, truth_snp_fn, + truth_snp_tp, query_indel_fp, + truth_indel_fn, truth_indel_tp]]), + file=output_file) + print(' '.join(snp_row[3:]) + ' ' + ' '.join(indel_row[3:]) + ' ' + ' '.join(snp_row[:3]) + ' ' + ' '.join( + indel_row[:3]), file=output_file) + print('\n', file=output_file) + + print(' '.join([str(item) for item in [ap, ar, af1, sp, sr, sf1, idp, idr, idf1, ip, ir, if1, dp, dr, df1]]), + file=output_file) + print(' '.join([str(item) for item in + [query_all_fp, truth_all_fn, truth_all_tp, query_snp_fp, truth_snp_fn, truth_snp_tp, query_indel_fp, + truth_indel_fn, truth_indel_tp, query_ins_tp, truth_ins_tp, truth_ins_tp, query_del_fp, + truth_del_fn, truth_del_tp]]), file=output_file) + + if output_fn: + output_file.close() + + +def main(): + parser = ArgumentParser(description="Overall Metrics of hap.py output") + + parser.add_argument('--happy_vcf_fn', type=str, default=None, + help="Path to the happy vcf output file") + + parser.add_argument('--log_happy', type=str, default=None, + help="Path to the happy vcf output file") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed") + + parser.add_argument('--output_fn', type=str, default=None, + help="Filename of the metrics output") + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Cal(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/clair3/model.py b/benchmarks/nn-variant/Clair3/clair3/model.py new file mode 100644 index 0000000..0419304 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/model.py @@ -0,0 +1,431 @@ +import warnings +with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings("ignore", category=FutureWarning) + from tensorflow.python.util import deprecation + deprecation._PRINT_DEPRECATION_WARNINGS = False + import tensorflow as tf +import logging +import numpy as np +logging.basicConfig(format='%(message)s', level=logging.INFO) +tf.get_logger().setLevel(logging.ERROR) + +from clair3.task.main import GT21, GENOTYPE, VARIANT_LENGTH_1, VARIANT_LENGTH_2 +import shared.param_f as param +params = dict( + float_type=tf.float32, + task_loss_weights=[ + 1, # gt21 + 1, # genotype + 1, # variant/indel length 0 + 1, # variant/indel length 1 + 1 # l2 loss + ], + output_shape=GT21.output_label_count + \ + GENOTYPE.output_label_count + \ + VARIANT_LENGTH_1.output_label_count + \ + VARIANT_LENGTH_2.output_label_count, + output_gt21_shape=GT21.output_label_count, + output_genotype_shape=GENOTYPE.output_label_count, + output_indel_length_shape_1=VARIANT_LENGTH_1.output_label_count, + output_indel_length_shape_2=VARIANT_LENGTH_2.output_label_count, + output_gt21_entropy_weights=[1] * GT21.output_label_count, + output_genotype_entropy_weights=[1] * GENOTYPE.output_label_count, + output_indel_length_entropy_weights_1=[1] * VARIANT_LENGTH_1.output_label_count, + output_indel_length_entropy_weights_2=[1] * VARIANT_LENGTH_2.output_label_count, + L3_dropout_rate=0.2, + L4_num_units=256, + L4_pileup_num_units=128, + L4_dropout_rate=0.5, + L5_1_num_units=128, + L5_1_dropout_rate=0.2, + L5_2_num_units=128, + L5_2_dropout_rate=0.2, + L5_3_num_units=128, + L5_3_dropout_rate=0.2, + L5_4_num_units=128, + L5_4_dropout_rate=0.2, + LSTM1_num_units=128, + LSTM2_num_units=160, + LSTM1_dropout_rate=0, + LSTM2_dropout_rate=0.5, + l2_regularization_lambda=param.l2RegularizationLambda, + ) + +add_l2_regulation = True +L2_regularizers = tf.keras.regularizers.l2(params['l2_regularization_lambda']) if add_l2_regulation else None + +class Clair3_P(tf.keras.Model): + # Bi-lstm model for clair3 pileup input + def __init__(self, add_indel_length=False, predict=False): + super(Clair3_P, self).__init__() + + # output + self.output_gt21_shape = params['output_gt21_shape'] + self.output_genotype_shape = params['output_genotype_shape'] + self.output_indel_length_shape_1 = params['output_indel_length_shape_1'] + self.output_indel_length_shape_2 = params['output_indel_length_shape_2'] + + self.L3_dropout_rate = params['L3_dropout_rate'] + self.L4_num_units = params['L4_num_units'] + self.L4_pileup_num_units = params['L4_pileup_num_units'] + self.L4_dropout_rate = params['L4_dropout_rate'] + self.L5_1_num_units = params['L5_1_num_units'] + self.L5_1_dropout_rate = params['L5_1_dropout_rate'] + self.L5_2_num_units = params['L5_2_num_units'] + self.L5_2_dropout_rate = params['L5_2_dropout_rate'] + self.L5_3_num_units = params['L5_3_num_units'] + self.L5_3_dropout_rate = params['L5_3_dropout_rate'] + self.L5_4_num_units = params['L5_4_num_units'] + self.L5_4_dropout_rate = params['L5_4_dropout_rate'] + self.LSTM1_num_units = params['LSTM1_num_units'] + self.LSTM2_num_units = params['LSTM2_num_units'] + self.LSTM1_dropout_rate = params['LSTM1_dropout_rate'] + self.LSTM2_dropout_rate = params['LSTM2_dropout_rate'] + + self.output_label_split = [ + self.output_gt21_shape, + self.output_genotype_shape, + self.output_indel_length_shape_1, + self.output_indel_length_shape_2 + ] + + self.add_indel_length = add_indel_length + self.predict = predict + + self.LSTM1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM( + units=self.LSTM1_num_units, + return_sequences=True, + kernel_regularizer=L2_regularizers + )) + + self.LSTM2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM( + units=self.LSTM2_num_units, + return_sequences=True, + kernel_regularizer=L2_regularizers + )) + + self.L3_dropout = tf.keras.layers.Dropout(rate=self.L3_dropout_rate) + + self.L3_dropout_flatten = tf.keras.layers.Flatten() + + self.L4 = tf.keras.layers.Dense(units=self.L4_pileup_num_units, activation='selu',kernel_regularizer=L2_regularizers) + + self.L4_dropout = tf.keras.layers.Dropout(rate=self.LSTM2_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_1 = tf.keras.layers.Dense(units=self.L5_1_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_1_dropout = tf.keras.layers.Dropout(rate=self.L5_1_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_2 = tf.keras.layers.Dense(units=self.L5_2_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_2_dropout = tf.keras.layers.Dropout(rate=self.L5_2_dropout_rate, seed=param.OPERATION_SEED) + + self.Y_gt21_logits = tf.keras.layers.Dense(units=self.output_gt21_shape, activation='selu', kernel_regularizer=L2_regularizers) + + self.Y_genotype_logits = tf.keras.layers.Dense(units=self.output_genotype_shape, activation='selu', kernel_regularizer=L2_regularizers) + + if self.add_indel_length: + + self.L5_3 = tf.keras.layers.Dense(units=self.L5_3_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_3_dropout = tf.keras.layers.Dropout(rate=self.L5_3_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_4 = tf.keras.layers.Dense(units=self.L5_4_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_4_dropout = tf.keras.layers.Dropout(rate=self.L5_4_dropout_rate, seed=param.OPERATION_SEED) + + self.Y_indel_length_logits_1 = tf.keras.layers.Dense(units=self.output_indel_length_shape_1, activation='selu', kernel_regularizer=L2_regularizers) + + self.Y_indel_length_logits_2 = tf.keras.layers.Dense(units=self.output_indel_length_shape_2, activation='selu', kernel_regularizer=L2_regularizers) + + self.softmax = tf.keras.layers.Softmax() + + + def call(self, x,): + + x = tf.cast(x, tf.float32) + + x = self.LSTM1(x) # (batch_size, inp_seq_len, d_model) + + x = self.LSTM2(x) + + x = self.L3_dropout(x) + + x = self.L3_dropout_flatten(x) + + x = self.L4(x) + + x = self.L4_dropout(x) + + l5_1_dropout = self.L5_1_dropout(self.L5_1(x)) + + l5_2_dropout = self.L5_2_dropout(self.L5_2(x)) + + y_gt21_logits = self.softmax(self.Y_gt21_logits(l5_1_dropout)) + + y_genotype_logits = self.softmax(self.Y_genotype_logits(l5_2_dropout)) + + if self.add_indel_length: + l5_3_dropout = self.L5_3_dropout(self.L5_3(x)) + + l5_4_dropout = self.L5_4_dropout(self.L5_4(x)) + + y_indel_length_logits_1 = self.softmax(self.Y_indel_length_logits_1(l5_3_dropout)) + + y_indel_length_logits_2 = self.softmax(self.Y_indel_length_logits_2(l5_4_dropout)) + + if self.predict: + return tf.concat([y_gt21_logits, y_genotype_logits, y_indel_length_logits_1, y_indel_length_logits_2], axis=1) + + return [y_gt21_logits, y_genotype_logits, y_indel_length_logits_1, y_indel_length_logits_2] + + if self.predict: + return tf.concat([y_gt21_logits, y_genotype_logits],axis=1) + + return [y_gt21_logits, y_genotype_logits] + + +class BasicConv2D(tf.keras.layers.Layer): + def __init__(self, filters, kernel_size, strides, padding, SeparableConv=False): + super(BasicConv2D, self).__init__() + conv = tf.keras.layers.SeparableConv2D if SeparableConv else tf.keras.layers.Conv2D + self.conv = conv(filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + kernel_regularizer=L2_regularizers) + self.bn = tf.keras.layers.BatchNormalization() + self.relu = tf.keras.layers.ReLU() + + def call(self, inputs): + output = self.conv(inputs) + output = self.bn(output) + output = self.relu(output) + + return output + +class BasicBlock(tf.keras.layers.Layer): + + def __init__(self, filter_num, stride=1,SeparableConv=False): + super(BasicBlock, self).__init__() + conv = tf.keras.layers.SeparableConv2D if SeparableConv else tf.keras.layers.Conv2D + + self.conv1 = conv(filters=filter_num, + kernel_size=(3, 3), + strides=stride, + padding="same", + kernel_regularizer=L2_regularizers) + self.bn1 = tf.keras.layers.BatchNormalization() + self.conv2 = conv(filters=filter_num, + kernel_size=(3, 3), + strides=1, + padding="same", + kernel_regularizer=L2_regularizers) + self.bn2 = tf.keras.layers.BatchNormalization() + if stride != 1: + self.downsample = tf.keras.Sequential() + self.downsample.add(tf.keras.layers.Conv2D(filters=filter_num, + kernel_size=(1, 1), + strides=stride, + kernel_regularizer=L2_regularizers)) + self.downsample.add(tf.keras.layers.BatchNormalization()) + else: + self.downsample = lambda x: x + + def call(self, inputs): + residual = self.downsample(inputs) + + x = self.conv1(inputs) + x = self.bn1(x, ) + x = tf.nn.relu(x) + x = self.conv2(x) + x = self.bn2(x, ) + + output = tf.nn.relu(tf.keras.layers.add([residual, x])) + + return output + +def make_basic_block_layer(filter_num, blocks, stride=1, SeparableConv=False): + + res_block = tf.keras.Sequential() + + res_block.add(BasicBlock(filter_num, stride=stride, SeparableConv=SeparableConv)) + + for _ in range(1, blocks): + res_block.add(BasicBlock(filter_num, stride=1,SeparableConv=SeparableConv)) + + return res_block + +class PyramidPolling(tf.keras.layers.Layer): + def __init__(self, spatial_pool_size=(3, 2, 1)): + super(PyramidPolling, self).__init__() + + self.spatial_pool_size = spatial_pool_size + self.pool_len = len(self.spatial_pool_size) + self.window_h = np.empty(self.pool_len, dtype=int) + self.stride_h = np.empty(self.pool_len, dtype=int) + self.window_w = np.empty(self.pool_len, dtype=int) + self.stride_w = np.empty(self.pool_len, dtype=int) + + self.flatten = tf.keras.layers.Flatten() + + def build(self, input_shape): + height = int(input_shape[1]) + width = int(input_shape[2]) + + for i in range(self.pool_len): + self.window_h[i] = self.stride_h[i] = int(np.ceil(height / self.spatial_pool_size[i])) + self.window_w[i] = self.stride_w[i] = int(np.ceil(width / self.spatial_pool_size[i])) + + def call(self, x): + for i in range(self.pool_len): + max_pool = tf.nn.max_pool(x, + ksize=[1, self.window_h[i], self.window_w[i], 1], + strides=[1, self.stride_h[i], self.stride_w[i], 1], + padding='SAME') + if i == 0: + pp = self.flatten(max_pool) + + else: + pp = tf.concat([pp, self.flatten(max_pool)], axis=-1) + + return pp + +class Clair3_F(tf.keras.Model): + # Residual CNN model for clair3 full alignment input + def __init__(self, add_indel_length=False, predict=False): + super(Clair3_F, self).__init__() + self.output_gt21_shape = params['output_gt21_shape'] + self.output_genotype_shape = params['output_genotype_shape'] + self.output_indel_length_shape_1 = params['output_indel_length_shape_1'] + self.output_indel_length_shape_2 = params['output_indel_length_shape_2'] + + self.L3_dropout_rate = params['L3_dropout_rate'] + self.L4_num_units = params['L4_num_units'] + self.L4_dropout_rate = params['L4_dropout_rate'] + self.L5_1_num_units = params['L5_1_num_units'] + self.L5_1_dropout_rate = params['L5_1_dropout_rate'] + self.L5_2_num_units = params['L5_2_num_units'] + self.L5_2_dropout_rate = params['L5_2_dropout_rate'] + self.L5_3_num_units = params['L5_3_num_units'] + self.L5_3_dropout_rate = params['L5_3_dropout_rate'] + self.L5_4_num_units = params['L5_4_num_units'] + self.L5_4_dropout_rate = params['L5_4_dropout_rate'] + + self.output_label_split = [ + self.output_gt21_shape, + self.output_genotype_shape, + self.output_indel_length_shape_1, + self.output_indel_length_shape_2 + ] + + self.add_indel_length = add_indel_length + self.predict = predict + + self.conv1 = BasicConv2D(filters=64, + kernel_size=(3, 3), + strides=2, + padding="same",) + + self.res_block1 = make_basic_block_layer(filter_num=64, + blocks=1, stride=1, SeparableConv=False) + + self.conv3 = BasicConv2D(filters=128, + kernel_size=(3, 3), + strides=2, + padding="same") + + self.res_block2 = make_basic_block_layer(filter_num=128, + blocks=1, stride=1, SeparableConv=False) + + self.conv5 = BasicConv2D(filters=256, + kernel_size=(3, 3), + strides=2, + padding="same") + + self.res_block3 = make_basic_block_layer(filter_num=256, + blocks=1, stride=1) + + self.pyramidpolling = PyramidPolling() + + self.L3_dropout = tf.keras.layers.Dropout(rate=self.L3_dropout_rate) + + self.flatten = tf.keras.layers.Flatten() + + self.L4 = tf.keras.layers.Dense(units=self.L4_num_units, activation='selu',kernel_regularizer=L2_regularizers) + + self.L4_dropout = tf.keras.layers.Dropout(rate=self.L4_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_1 = tf.keras.layers.Dense(units=self.L5_1_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_1_dropout = tf.keras.layers.Dropout(rate=self.L5_1_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_2 = tf.keras.layers.Dense(units=self.L5_1_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_2_dropout = tf.keras.layers.Dropout(rate=self.L5_2_dropout_rate, seed=param.OPERATION_SEED) + + self.Y_gt21_logits = tf.keras.layers.Dense(units=self.output_gt21_shape, activation='selu', kernel_regularizer=L2_regularizers) + + self.Y_genotype_logits = tf.keras.layers.Dense(units=self.output_genotype_shape, activation='selu', kernel_regularizer=L2_regularizers) + + if self.add_indel_length: + self.L5_3 = tf.keras.layers.Dense(units=self.L5_3_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_3_dropout = tf.keras.layers.Dropout(rate=self.L5_3_dropout_rate, seed=param.OPERATION_SEED) + + self.L5_4 = tf.keras.layers.Dense(units=self.L5_4_num_units, activation='selu', kernel_regularizer=L2_regularizers) + + self.L5_4_dropout = tf.keras.layers.Dropout(rate=self.L5_4_dropout_rate, seed=param.OPERATION_SEED) + + self.Y_indel_length_logits_1 = tf.keras.layers.Dense(units=self.output_indel_length_shape_1, activation='selu',kernel_regularizer=L2_regularizers) + + self.Y_indel_length_logits_2 = tf.keras.layers.Dense(units=self.output_indel_length_shape_2, activation='selu',kernel_regularizer=L2_regularizers) + + self.softmax = tf.keras.layers.Softmax() + + def call(self, inputs): + + x = tf.cast(inputs, tf.float32) / param.NORMALIZE_NUM + + x = self.conv1(x) + x = self.res_block1(x) + x = self.conv3(x) + x = self.res_block2(x) + x = self.conv5(x) + x = self.res_block3(x) + x = self.pyramidpolling(x) + x = self.flatten(self.L3_dropout(x)) + + x = self.L4(x) + x = self.L4_dropout(x) + + l5_1_dropout = self.L5_1_dropout(self.L5_1(x)) + + l5_2_dropout = self.L5_2_dropout(self.L5_2(x)) + + y_gt21_logits = self.softmax(self.Y_gt21_logits(l5_1_dropout)) + + y_genotype_logits = self.softmax(self.Y_genotype_logits(l5_2_dropout)) + + if self.add_indel_length: + + l5_3_dropout = self.L5_3_dropout(self.L5_3(x)) + + l5_4_dropout = self.L5_4_dropout(self.L5_4(x)) + + y_indel_length_logits_1 = self.softmax(self.Y_indel_length_logits_1(l5_3_dropout)) + + y_indel_length_logits_2 = self.softmax(self.Y_indel_length_logits_2(l5_4_dropout)) + + if self.predict: + + return tf.concat([y_gt21_logits, y_genotype_logits, y_indel_length_logits_1, y_indel_length_logits_2], axis=1) + + return [y_gt21_logits, y_genotype_logits, y_indel_length_logits_1, y_indel_length_logits_2] + + if self.predict: + + return tf.concat([y_gt21_logits, y_genotype_logits],axis=1) + + return [y_gt21_logits, y_genotype_logits] \ No newline at end of file diff --git a/benchmarks/nn-variant/clair/task/__init__.py b/benchmarks/nn-variant/Clair3/clair3/task/__init__.py similarity index 100% rename from benchmarks/nn-variant/clair/task/__init__.py rename to benchmarks/nn-variant/Clair3/clair3/task/__init__.py diff --git a/benchmarks/nn-variant/clair/task/genotype.py b/benchmarks/nn-variant/Clair3/clair3/task/genotype.py similarity index 100% rename from benchmarks/nn-variant/clair/task/genotype.py rename to benchmarks/nn-variant/Clair3/clair3/task/genotype.py diff --git a/benchmarks/nn-variant/clair/task/gt21.py b/benchmarks/nn-variant/Clair3/clair3/task/gt21.py similarity index 100% rename from benchmarks/nn-variant/clair/task/gt21.py rename to benchmarks/nn-variant/Clair3/clair3/task/gt21.py diff --git a/benchmarks/nn-variant/clair/task/main.py b/benchmarks/nn-variant/Clair3/clair3/task/main.py similarity index 54% rename from benchmarks/nn-variant/clair/task/main.py rename to benchmarks/nn-variant/Clair3/clair3/task/main.py index 0e995b2..0575c2b 100644 --- a/benchmarks/nn-variant/clair/task/main.py +++ b/benchmarks/nn-variant/Clair3/clair3/task/main.py @@ -1,8 +1,8 @@ from collections import namedtuple -from clair.task.genotype import Genotype, genotype_enum_from, genotype_enum_for_task -from clair.task.gt21 import gt21_enum_from_label, gt21_enum_from -from clair.task.variant_length import VariantLength +from clair3.task.genotype import Genotype, genotype_enum_from, genotype_enum_for_task +from clair3.task.gt21 import * +from clair3.task.variant_length import VariantLength OutputLabelNamedTuple = namedtuple( 'BasePredictNamedTuple', ['output_label_count', 'y_start_index', 'y_end_index'] @@ -48,7 +48,7 @@ def output_labels_from_reference(reference_base): return gt21_vec + genotype_vec + variant_length_vec_1 + variant_length_vec_2 -def output_labels_from_vcf_columns(columns): +def output_labels_from_vcf_columns(columns, homo_calling=False, haplotype=None): reference, alternate = columns[2], columns[3] genotype_1, genotype_2 = int(columns[4]), int(columns[5]) @@ -79,3 +79,54 @@ def output_labels_from_vcf_columns(columns): variant_length_vec_2[variant_lengths[1] + VariantLength.index_offset] = 1 return gt21_vec + genotype_vec + variant_length_vec_1 + variant_length_vec_2 + +def output_labels_from_reference_new(reference_base, base_idx): + gt21_vec = [0] * GT21.output_label_count + gt21_vec[gt21_enum_from_label(reference_base + reference_base)] = 1 + + genotype_vec = [0] * GENOTYPE.output_label_count + [0] + if base_idx == '2': + genotype_vec[Genotype.homo_reference] = 1 + elif base_idx == '1': + genotype_vec[3] = 1 + variant_length_vec_1 = [0] * VARIANT_LENGTH_1.output_label_count + variant_length_vec_2 = [0] * VARIANT_LENGTH_2.output_label_count + variant_length_vec_1[0 + VariantLength.index_offset] = 1 + variant_length_vec_2[0 + VariantLength.index_offset] = 1 + + return gt21_vec + genotype_vec + variant_length_vec_1 + variant_length_vec_2 + + +def output_labels_from_vcf_columns_new(columns): + reference, alternate = columns[2], columns[3] + genotype_1, genotype_2 = int(columns[4]), int(columns[5]) + + alternate_arr = alternate.split(',') + if len(alternate_arr) == 1: + alternate_arr = ( + [reference if genotype_1 == 0 or genotype_2 == 0 else alternate_arr[0]] + + alternate_arr + ) + + gt21 = gt21_enum_from(reference, alternate, genotype_1, genotype_2, alternate_arr) + gt21_vec = [0] * GT21.output_label_count + gt21_vec[gt21] = 1 + + genotype = genotype_enum_from(genotype_1, genotype_2) + genotype_for_task = genotype_enum_for_task(genotype) + genotype_vec = [0] * GENOTYPE.output_label_count + genotype_vec[genotype_for_task] = 1 + + genotype_vec += [0] + + variant_lengths = [ + min_max(len(alt) - len(reference), VariantLength.min, VariantLength.max) + for alt in alternate_arr + ] + variant_lengths.sort() + variant_length_vec_1 = [0] * VARIANT_LENGTH_1.output_label_count + variant_length_vec_2 = [0] * VARIANT_LENGTH_2.output_label_count + variant_length_vec_1[variant_lengths[0] + VariantLength.index_offset] = 1 + variant_length_vec_2[variant_lengths[1] + VariantLength.index_offset] = 1 + + return gt21_vec + genotype_vec + variant_length_vec_1 + variant_length_vec_2 diff --git a/benchmarks/nn-variant/clair/task/variant_length.py b/benchmarks/nn-variant/Clair3/clair3/task/variant_length.py similarity index 100% rename from benchmarks/nn-variant/clair/task/variant_length.py rename to benchmarks/nn-variant/Clair3/clair3/task/variant_length.py diff --git a/benchmarks/nn-variant/Clair3/clair3/utils.py b/benchmarks/nn-variant/Clair3/clair3/utils.py new file mode 100644 index 0000000..acfbc7e --- /dev/null +++ b/benchmarks/nn-variant/Clair3/clair3/utils.py @@ -0,0 +1,493 @@ +import sys +import gc +import copy +import shlex +import os +import tables +import numpy as np +from functools import partial + +from clair3.task.main import * +from shared.interval_tree import bed_tree_from, is_region_in +from shared.utils import subprocess_popen, IUPAC_base_to_ACGT_base_dict as BASE2BASE, IUPAC_base_to_num_dict as BASE2NUM + +FILTERS = tables.Filters(complib='blosc:lz4hc', complevel=5) +shuffle_bin_size = 50000 +PREFIX_CHAR_STR = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +def setup_environment(): + gc.enable() + + +def batches_from(iterable, item_from, batch_size=1): + iterable = iter(iterable) + while True: + chunk = [] + for _ in range(batch_size): + try: + chunk.append(item_from(next(iterable))) + except StopIteration: + yield chunk + return + yield chunk + + +def tensor_generator_from(tensor_file_path, batch_size, pileup, platform): + global param + float_type = 'int32' + if pileup: + import shared.param_p as param + else: + import shared.param_f as param + float_type = 'int8' + + if tensor_file_path != "PIPE": + f = subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_file_path))) + fo = f.stdout + else: + fo = sys.stdin + + processed_tensors = 0 + tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape + prod_tensor_shape = np.prod(tensor_shape) + + def item_from(row): + chrom, coord, seq, tensor, alt_info = row.split("\t") + if pileup: + tensor = np.array(tensor.split(), dtype=np.dtype(float_type)) + depth = int(alt_info.split('-', maxsplit=1)[0]) + max_depth = param.max_depth_dict[platform] + # for extreme high coverage data, make sure we could have a truncated coverage + if depth > 0 and depth > max_depth * 1.5: + scale_factor = depth / max_depth + tensor = tensor / scale_factor + else: + # need add padding if depth is lower than maximum depth. + tensor = [int(item) for item in tensor.split()] + tensor_depth = len(tensor) // tensor_shape[1] // tensor_shape[2] + padding_depth = tensor_shape[0] - tensor_depth + prefix_padding_depth = int(padding_depth / 2) + suffix_padding_depth = padding_depth - int(padding_depth / 2) + prefix_zero_padding = [0] * prefix_padding_depth * tensor_shape[1] * tensor_shape[2] + suffix_zero_padding = [0] * suffix_padding_depth * tensor_shape[1] * tensor_shape[2] + tensor = prefix_zero_padding + tensor + suffix_zero_padding + tensor = np.array(tensor, dtype=np.dtype(float_type)) + + pos = chrom + ":" + coord + ":" + seq + return tensor, pos, seq, alt_info + + for batch in batches_from(fo, item_from=item_from, batch_size=batch_size): + tensors = np.empty(([batch_size, prod_tensor_shape]), dtype=np.dtype(float_type)) + positions = [] + alt_info_list = [] + for tensor, pos, seq, alt_info in batch: + if seq[param.flankingBaseNum] not in BASE2NUM: + continue + tensors[len(positions)] = tensor + positions.append(pos) + alt_info_list.append(alt_info) + + current_batch_size = len(positions) + X = np.reshape(tensors, ([batch_size] + tensor_shape)) + + if processed_tensors > 0 and processed_tensors % 20000 == 0: + print("Processed %d tensors" % processed_tensors, file=sys.stderr) + + processed_tensors += current_batch_size + + if current_batch_size <= 0: + continue + yield X[:current_batch_size], positions[:current_batch_size], alt_info_list[:current_batch_size] + + if tensor_file_path != "PIPE": + fo.close() + f.wait() + + +def remove_common_suffix(ref_base, alt_base): + min_length = min(len(ref_base) - 1, min([len(item) - 1 for item in alt_base])) # keep at least one base + prefix = ref_base[::-1] + for string in alt_base: + string = string[::-1] + while string[:len(prefix)] != prefix and prefix: + prefix = prefix[:len(prefix) - 1] + if not prefix: + break + res_length = len(prefix) + if res_length > min_length: + return ref_base, alt_base + return ref_base[:len(ref_base) - res_length], [item[:len(item) - res_length] for item in alt_base] + + return ref_base[-min_length], [item[-min_length] for item in alt_base] + + +def decode_alt(ref_base, alt_base): + if ',' not in alt_base: + return [ref_base], [alt_base] + alt_base = alt_base.split(',') + ref_base_list, alt_base_list = [], [] + for ab in alt_base: + rb,ab = remove_common_suffix(ref_base, [ab]) + ref_base_list.append(rb) + alt_base_list.append(ab[0]) + return ref_base_list, alt_base_list + + +def variant_map_from(var_fn, tree, is_tree_empty): + Y = {} + truth_alt_dict = {} + miss_variant_set = set() + if var_fn is None: + return Y, miss_variant_set, truth_alt_dict + + f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn))) + for row in f.stdout: + if row[0] == "#": + continue + columns = row.strip().split() + ctg_name, position_str, ref_base, alt_base, genotype1, genotype2 = columns + key = ctg_name + ":" + position_str + if genotype1 == '-1' or genotype2 == '-1': + miss_variant_set.add(key) + continue + if not (is_tree_empty or is_region_in(tree, ctg_name, int(position_str))): + continue + + Y[key] = output_labels_from_vcf_columns(columns) + ref_base_list, alt_base_list = decode_alt(ref_base, alt_base) + truth_alt_dict[int(position_str)] = (ref_base_list, alt_base_list) + f.stdout.close() + f.wait() + return Y, miss_variant_set, truth_alt_dict + +def find_read_support(pos, truth_alt_dict, alt_info): + alt_info = alt_info.rstrip().split('-') + seqs = alt_info[1].split(' ') if len(alt_info) > 1 else '' + seq_alt_bases_dict = dict(zip(seqs[::2], [int(item) for item in seqs[1::2]])) if len(seqs) else {} + + pos = int(pos) + if pos not in truth_alt_dict: + # candidate position not in the truth vcf or unified truth vcf + return None + ref_base_list, alt_base_list = truth_alt_dict[pos] + found = 0 + for alt_type in seq_alt_bases_dict: + if '*' in alt_type or '#' in alt_type or 'R' in alt_type: + continue + if alt_type[0] == 'X': + if alt_type[1] in alt_base_list: + found += 1 + elif alt_type[0] == 'I': + if alt_type[1:] in alt_base_list: + found += 1 + elif alt_type[0] == 'D': + del_cigar = alt_type[1:] + for rb, ab in zip(ref_base_list, alt_base_list): + if rb[1:] == del_cigar and len(ab) == 1: + found += 1 + if found >= len(alt_base_list): + return True + # return False if we find any alternative bases missed in subsampled bam, then remove the position from training + return False + +def write_table_dict(table_dict, string, label, pos, total, alt_info, tensor_shape, pileup): + """ + Write pileup or full alignment tensor into a dictionary.compressed bin file. + table_dict: dictionary include all training information (tensor position, label, altnative bases). + string: input tensor string, need add padding to meet the depth requirement. + label: include gt21 genotype, indel length 1, indel length 2. + alt_info: altnative information for querying variant. + """ + + if len(string) == 1: + string = string[0] + position_matrix = string + position_matrix = position_matrix.split() + + if pileup: + table_dict['position_matrix'].append(position_matrix) + else: + tensor_depth = len(position_matrix) // tensor_shape[1] // tensor_shape[2] + padding_depth = tensor_shape[0] - tensor_depth + prefix_padding_depth = int(padding_depth / 2) + suffix_padding_depth = padding_depth - int(padding_depth / 2) + prefix_zero_padding = ['0'] * prefix_padding_depth * tensor_shape[1] * tensor_shape[2] + suffix_zero_padding = ['0'] * suffix_padding_depth * tensor_shape[1] * tensor_shape[2] + table_dict['position_matrix'].append(prefix_zero_padding + position_matrix + suffix_zero_padding) + + table_dict['position'].append(pos) + table_dict['label'].append(label) + table_dict['alt_info'].append(alt_info) + + return total + 1 + + +def update_table_dict(): + table_dict = {} + table_dict['position_matrix'] = [] + table_dict['alt_info'] = [] + table_dict['position'] = [] + table_dict['label'] = [] + return table_dict + + +def write_table_file(table_file, table_dict, tensor_shape, label_size, float_type): + """ + Write pileup or full alignment tensor into compressed bin file. + table_dict: dictionary include all training information (tensor position, label, altnative bases). + string: input tensor string, need add padding to meet the depth requirement. + tree: dictionary(contig name : intervaltree) for quick region querying. + miss_variant_set: sometimes there will have true variant missing after downsampling reads. + is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character. + non_variant_subsample_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training + time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate. + """ + + position_matrix = np.array(table_dict['position_matrix'], np.dtype(float_type)).reshape([-1] + tensor_shape) + table_file.root.position_matrix.append(position_matrix) + + table_file.root.alt_info.append(np.array(table_dict['alt_info']).reshape(-1, 1)) + table_file.root.position.append(np.array(table_dict['position']).reshape(-1, 1)) + table_file.root.label.append(np.array(table_dict['label'], np.dtype(float_type)).reshape(-1, label_size)) + table_dict = update_table_dict() + + return table_dict + + +def print_bin_size(path, prefix=None): + import tables + import os + total = 0 + for file_name in os.listdir(path): + if prefix and not file_name.startswith(prefix): + continue + table = tables.open_file(os.path.join(path, file_name), 'r') + print("[INFO] {} size is: {}".format(file_name, len(table.root.label))) + total += len(table.root.label) + print('[INFO] total: {}'.format(total)) + + +def bin_reader_generator_from(tensor_fn, Y_true_var, Y, is_tree_empty, tree, miss_variant_set, truth_alt_dict, is_allow_duplicate_chr_pos=False, maximum_non_variant_ratio=None): + + """ + Bin reader generator for bin file generation. + tensor_fn: tensor file. + Y_true_var: dictionary (contig name: label information) containing all true variant information (should not be changed). + Y: dictionary (contig name: label information) to store all variant and non variant information. + tree: dictionary(contig name : intervaltree) for quick region querying. + miss_variant_set: sometimes there will have true variant missing after downsampling reads. + truth_alt_dict: unified truth reference base and alternative bases to find read support. + is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character. + maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training + time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate. + """ + + X = {} + ref_list = [] + total = 0 + variant_set_with_read_support = set() + variants_without_read_support = 0 + for row_idx, row in enumerate(tensor_fn): + chrom, coord, seq, string, alt_info = row.split("\t") + alt_info = alt_info.rstrip() + if not (is_tree_empty or is_region_in(tree, chrom, int(coord))): + continue + seq = seq.upper() + if seq[param.flankingBaseNum] not in 'ACGT': + continue + key = chrom + ":" + coord + is_reference = key not in Y_true_var + + if key in miss_variant_set: + continue + + have_read_support = find_read_support(pos=coord, truth_alt_dict=truth_alt_dict, alt_info=alt_info) + if have_read_support is not None and not have_read_support: + miss_variant_set.add(key) + variants_without_read_support += 1 + continue + + variant_set_with_read_support.add(key) + if key not in X: + X[key] = (string, alt_info, seq) + if is_reference: + ref_list.append(key) + elif is_allow_duplicate_chr_pos: + new_key = "" + for character in PREFIX_CHAR_STR: + tmp_key = character + key + if tmp_key not in X: + new_key = tmp_key + break + if len(new_key) > 0: + X[new_key] = (string, alt_info, seq) + if is_reference: + ref_list.append(new_key) + + if is_reference and key not in Y: + Y[key] = output_labels_from_reference(BASE2BASE[seq[param.flankingBaseNum]]) + + if len(X) == shuffle_bin_size: + if maximum_non_variant_ratio is not None: + _filter_non_variants(X, ref_list, maximum_non_variant_ratio) + yield X, total, False + X = {} + ref_list = [] + total += 1 + if total % 100000 == 0: + print("[INFO] Processed %d tensors" % total, file=sys.stderr) + + print("[INFO] Variants with read support/variants without read support: {}/{}".format(len(variant_set_with_read_support), variants_without_read_support)) + if maximum_non_variant_ratio is not None: + _filter_non_variants(X, ref_list, maximum_non_variant_ratio) + yield X, total, True + + +def _filter_non_variants(X, ref_list, maximum_non_variant_ratio): + non_variant_num = len(ref_list) + variant_num = len(X) - non_variant_num + if non_variant_num > variant_num * maximum_non_variant_ratio: + non_variant_keep_fraction = maximum_non_variant_ratio * variant_num / (1. * non_variant_num) + probabilities = np.random.random_sample((non_variant_num,)) + for key, p in zip(ref_list, probabilities): + if p > non_variant_keep_fraction: + X.pop(key) + + +def get_training_array(tensor_fn, var_fn, bed_fn, bin_fn, shuffle=True, is_allow_duplicate_chr_pos=True, chunk_id=None, + chunk_num=None, platform='ont', pileup=False, maximum_non_variant_ratio=None, candidate_details_fn_prefix=None): + + """ + Generate training array for training. here pytables with blosc:lz4hc are used for extreme fast compression and decompression, + which can meet the requirement of gpu utilization. lz4hc decompression allows speed up training array decompression 4~5x compared + with tensorflow tfrecord file format, current gpu utilization could reach over 85% with only 10G memory. + tensor_fn: string format tensor acquired from CreateTensorPileup or CreateTensorFullAlign, include contig name position, tensor matrix, alternative information. + var_fn: simplified variant(vcf) format from GetTruths, which include contig name, position, reference base, alternative base, genotype. + bin_fn: pytables format output bin file name. + shuffle: whether apply index shuffling when generating training data, default True, which would promote robustness. + is_allow_duplicate_chr_pos: whether allow duplicate positions when training, if there exists downsampled data, lower depth will add a random prefix character. + chunk_id: specific chunk id works with total chunk_num for parallel execution. Here will merge all tensor file with sampe prefix. + chunk_num: total chunk number for parallel execution. Each chunk refer to a smaller reference regions. + platform: platform for tensor shape, ont give a larger maximum depth compared with pb and illumina. + pileup: whether in pileup mode. Define two calling mode, pileup or full alignment. + maximum_non_variant_ratio: define a maximum non variant ratio for training, we always expect use more non variant data, while it would greatly increase training + time, especially in ont data, here we usually use 1:1 or 1:2 for variant candidate: non variant candidate. + candidate_details_fn_prefix: a counter to calculate total variant and non variant from the information in alternative file. + """ + + tree = bed_tree_from(bed_file_path=bed_fn) + is_tree_empty = len(tree.keys()) == 0 + Y_true_var, miss_variant_set, truth_alt_dict = variant_map_from(var_fn, tree, is_tree_empty) + Y = copy.deepcopy(Y_true_var) + + global param + float_type = 'int32' + if pileup: + import shared.param_p as param + else: + import shared.param_f as param + float_type = 'int8' + + tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape + + subprocess_list = [] + if tensor_fn == 'PIPE': + subprocess_list.append(sys.stdin) + elif os.path.exists(tensor_fn): + subprocess_list.append(subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, tensor_fn))).stdout) + # select all match prefix if file path not exists + else: + tensor_fn = tensor_fn.split('/') + directry, file_prefix = '/'.join(tensor_fn[:-1]), tensor_fn[-1] + all_file_name = [] + for file_name in os.listdir(directry): + if file_name.startswith(file_prefix + '_') or file_name.startswith( + file_prefix + '.'): # add '_.' to avoid add other prefix chr + all_file_name.append(file_name) + all_file_name = sorted(all_file_name) + if chunk_id is not None: + chunk_size = len(all_file_name) // chunk_num if len(all_file_name) % chunk_num == 0 else len( + all_file_name) // chunk_num + 1 + chunk_start = chunk_size * chunk_id + chunk_end = chunk_start + chunk_size + all_file_name = all_file_name[chunk_start:chunk_end] + if not len(all_file_name): + print("[INFO] chunk_id exceed total file number, skip chunk", file=sys.stderr) + return 0 + for file_name in all_file_name: + subprocess_list.append( + subprocess_popen(shlex.split("{} -fdc {}".format(param.zstd, os.path.join(directry, file_name)))).stdout) + + tables.set_blosc_max_threads(64) + int_atom = tables.Atom.from_dtype(np.dtype(float_type)) + string_atom = tables.StringAtom(itemsize=param.no_of_positions + 50) + long_string_atom = tables.StringAtom(itemsize=5000) # max alt_info length + table_file = tables.open_file(bin_fn, mode='w', filters=FILTERS) + table_file.create_earray(where='/', name='position_matrix', atom=int_atom, shape=[0] + tensor_shape, + filters=FILTERS) + table_file.create_earray(where='/', name='position', atom=string_atom, shape=(0, 1), filters=FILTERS) + table_file.create_earray(where='/', name='label', atom=int_atom, shape=(0, param.label_size), filters=FILTERS) + table_file.create_earray(where='/', name='alt_info', atom=long_string_atom, shape=(0, 1), filters=FILTERS) + + table_dict = update_table_dict() + + # generator to avoid high memory occupy + bin_reader_generator = partial(bin_reader_generator_from, + Y_true_var=Y_true_var, + Y=Y, + is_tree_empty=is_tree_empty, + tree=tree, + miss_variant_set=miss_variant_set, + truth_alt_dict=truth_alt_dict, + is_allow_duplicate_chr_pos=is_allow_duplicate_chr_pos, + maximum_non_variant_ratio=maximum_non_variant_ratio) + + total_compressed = 0 + for fin in subprocess_list: + bin_g = bin_reader_generator(tensor_fn=fin) + completed = False + while not completed: + try: + X, total, completed = next(bin_g) + except StopIteration: + completed = True + + if X is None or not len(X): + break + all_chr_pos = sorted(X.keys()) + if shuffle == True: + np.random.shuffle(all_chr_pos) + for key in all_chr_pos: + + string, alt_info, seq = X[key] + del X[key] + label = None + if key in Y: + label = Y[key] + pos = key + ':' + seq + if not is_allow_duplicate_chr_pos: + del Y[key] + elif is_allow_duplicate_chr_pos: + tmp_key = key[1:] + label = Y[tmp_key] + pos = tmp_key + ':' + seq + if label is None: + print(key) + continue + total_compressed = write_table_dict(table_dict, string, label, pos, total_compressed, alt_info, + tensor_shape, pileup) + + if total_compressed % 500 == 0 and total_compressed > 0: + table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type) + + if total_compressed % 50000 == 0: + print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr) + fin.close() + + if total_compressed % 500 != 0 and total_compressed > 0: + table_dict = write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type) + + table_file.close() + print("[INFO] Compressed %d/%d tensor" % (total_compressed, total), file=sys.stderr) + diff --git a/benchmarks/nn-variant/Clair3/preprocess/CheckEnvs.py b/benchmarks/nn-variant/Clair3/preprocess/CheckEnvs.py new file mode 100644 index 0000000..f32ea5b --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/CheckEnvs.py @@ -0,0 +1,496 @@ +import os +import sys +import argparse +import shlex +import subprocess +import multiprocessing + +from collections import defaultdict +from argparse import SUPPRESS +from distutils.version import LooseVersion + +import shared.param_p as param +from shared.interval_tree import bed_tree_from +from shared.utils import file_path_from, folder_path_from, subprocess_popen, str2bool, \ + legal_range_from, log_error, log_warning + +MIN_CHUNK_LENGTH = 200000 +MAX_CHUNK_LENGTH = 20000000 +major_contigs = {"chr" + str(a) for a in list(range(1, 23)) + ["X", "Y"]}.union( + {str(a) for a in list(range(1, 23)) + ["X", "Y"]}) +major_contigs_order = ["chr" + str(a) for a in list(range(1, 23)) + ["X", "Y"]] + [str(a) for a in + list(range(1, 23)) + ["X", "Y"]] + +required_tool_version = { + 'python': LooseVersion('3.6.7'), + 'pypy': LooseVersion('3.6'), + 'samtools': LooseVersion('1.10'), +# 'whatshap': LooseVersion('1.0'), + 'parallel': LooseVersion('20191122'), +} + +def check_version(tool, pos=None, is_pypy=False): + try: + if is_pypy: + proc = subprocess.run("{} -c 'import sys; print (sys.version)'".format(tool), stdout=subprocess.PIPE, + shell=True) + else: + proc = subprocess.run([tool, "--version"], stdout=subprocess.PIPE) + if proc.returncode != 0: + return None + first_line = proc.stdout.decode().split("\n", 1)[0] + version = first_line.split()[pos] + version = LooseVersion(version) + except Exception: + return None + + return version + + +def check_python_path(): + python_path = subprocess.run("which python", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip() + sys.exit(log_error("[ERROR] Current python execution path: {}".format(python_path))) + + +def check_tools_version(tool_version, required_tool_version): + for tool, version in tool_version.items(): + required_version = required_tool_version[tool] + if version is None: + print(log_error("[ERROR] {} not found, please check you are in clair3 virtual environment".format(tool))) + check_python_path() + elif version < required_version: + print(log_error("[ERROR] Tool version not match, please check you are in clair3 virtual environment")) + print(' '.join([str(item).ljust(10) for item in ["Tool", "Version", "Required"]])) + error_info = ' '.join([str(item).ljust(10) for item in [tool, version, '>=' + str(required_version)]]) + print(error_info) + check_python_path() + return + + +def check_contig_in_bam(bam_fn, sorted_contig_list, samtools): + bai_process = subprocess_popen(shlex.split("{} idxstats {}".format(samtools, bam_fn))) + contig_with_read_support_set = set() + for row_id, row in enumerate(bai_process.stdout): + row = row.split('\t') + if len(row) != 4: + continue + contig_name, contig_length, mapped_reads, unmapped_reads = row + if contig_name not in sorted_contig_list: + continue + if int(mapped_reads) > 0: + contig_with_read_support_set.add(contig_name) + for contig_name in sorted_contig_list: + if contig_name not in contig_with_read_support_set: + print(log_warning( + "[WARNING] Contig name {} provided but no mapped reads in BAM, skip!".format(contig_name))) + filtered_sorted_contig_list = [item for item in sorted_contig_list if item in contig_with_read_support_set] + + found_contig = True + if len(filtered_sorted_contig_list) == 0: + found_contig = False + print(log_warning( + "[WARNING] No mapped reads support in BAM for provided contigs set {}".format( + ' '.join(sorted_contig_list)))) + return filtered_sorted_contig_list, found_contig + + +def split_extend_vcf(vcf_fn, output_fn): + expand_region_size = param.no_of_positions + output_ctg_dict = defaultdict(list) + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + + for row_id, row in enumerate(unzip_process.stdout): + if row[0] == '#': + continue + columns = row.strip().split(maxsplit=3) + ctg_name = columns[0] + + center_pos = int(columns[1]) + ctg_start, ctg_end = center_pos - 1, center_pos + if ctg_start < 0: + sys.exit( + log_error("[ERROR] Invalid VCF input in {}-th row {} {} {}".format(row_id + 1, ctg_name, center_pos))) + if ctg_start - expand_region_size < 0: + continue + expand_ctg_start = ctg_start - expand_region_size + expand_ctg_end = ctg_end + expand_region_size + + output_ctg_dict[ctg_name].append( + ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)])) + + for key, value in output_ctg_dict.items(): + ctg_output_fn = os.path.join(output_fn, key) + with open(ctg_output_fn, 'w') as output_file: + output_file.write('\n'.join(value)) + + unzip_process.stdout.close() + unzip_process.wait() + + know_vcf_contig_set = set(list(output_ctg_dict.keys())) + + return know_vcf_contig_set + + +def split_extend_bed(bed_fn, output_fn, contig_set=None): + expand_region_size = param.no_of_positions + output_ctg_dict = defaultdict(list) + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_fn))) + for row_id, row in enumerate(unzip_process.stdout): + if row[0] == '#': + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_set and ctg_name not in contig_set: + continue + + ctg_start, ctg_end = int(columns[1]), int(columns[2]) + + if ctg_end < ctg_start or ctg_start < 0 or ctg_end < 0: + sys.exit(log_error( + "[ERROR] Invalid BED input in {}-th row {} {} {}".format(row_id + 1, ctg_name, ctg_start, ctg_end))) + expand_ctg_start = max(0, ctg_start - expand_region_size) + expand_ctg_end = max(0, ctg_end + expand_region_size) + output_ctg_dict[ctg_name].append( + ' '.join([ctg_name, str(expand_ctg_start), str(expand_ctg_end)])) + + for key, value in output_ctg_dict.items(): + ctg_output_fn = os.path.join(output_fn, key) + with open(ctg_output_fn, 'w') as output_file: + output_file.write('\n'.join(value)) + + unzip_process.stdout.close() + unzip_process.wait() + + +def output_header(output_fn, reference_file_path, sample_name='SAMPLE'): + output_file = open(output_fn, "w") + from textwrap import dedent + output_file.write(dedent("""\ + ##fileformat=VCFv4.2 + ##FILTER= + ##FILTER= + ##FILTER= + ##INFO= + ##INFO= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT=""" + ) + '\n') + + if reference_file_path is not None: + reference_index_file_path = file_path_from(reference_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + with open(reference_index_file_path, "r") as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + contig_name, contig_size = columns[0], columns[1] + output_file.write(("##contig=" % (contig_name, contig_size) + '\n')) + + output_file.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (sample_name)) + output_file.close() + +def compress_index_vcf(input_vcf): + # use bgzip to compress vcf -> vcf.gz + # use tabix to index vcf.gz + proc = subprocess.run('bgzip -f {}'.format(input_vcf), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.run('tabix -f -p vcf {}.gz'.format(input_vcf), shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + +def CheckEnvs(args): + basedir = os.path.dirname(__file__) + bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) + ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) + fai_fn = file_path_from(args.ref_fn, suffix=".fai", exit_on_not_found=True, sep='.') + bai_fn = file_path_from(args.bam_fn, suffix=".bai", exit_on_not_found=True, sep='.') + bed_fn = file_path_from(args.bed_fn) + vcf_fn = file_path_from(args.vcf_fn) + tree = bed_tree_from(bed_file_path=bed_fn) + + # create temp file folder + output_fn_prefix = args.output_fn_prefix + output_fn_prefix = folder_path_from(output_fn_prefix, create_not_found=True) + log_path = folder_path_from(os.path.join(output_fn_prefix, 'log'), create_not_found=True) + tmp_file_path = folder_path_from(os.path.join(output_fn_prefix, 'tmp'), create_not_found=True) + split_bed_path = folder_path_from(os.path.join(tmp_file_path, 'split_beds'), + create_not_found=True) if bed_fn or vcf_fn else None + pileup_vcf_path = folder_path_from(os.path.join(tmp_file_path, 'pileup_output'), create_not_found=True) + merge_vcf_path = folder_path_from(os.path.join(tmp_file_path, 'merge_output'), create_not_found=True) + phase_output_path = folder_path_from(os.path.join(tmp_file_path, 'phase_output'), create_not_found=True) + gvcf_temp_output_path = folder_path_from(os.path.join(tmp_file_path, 'gvcf_tmp_output'), create_not_found=True) + full_alignment_output_path = folder_path_from(os.path.join(tmp_file_path, 'full_alignment_output'), + create_not_found=True) + phase_vcf_path = folder_path_from(os.path.join(phase_output_path, 'phase_vcf'), create_not_found=True) + phase_bam_path = folder_path_from(os.path.join(phase_output_path, 'phase_bam'), create_not_found=True) + candidate_bed_path = folder_path_from(os.path.join(full_alignment_output_path, 'candidate_bed'), + create_not_found=True) + + # environment parameters + pypy = args.pypy + samtools = args.samtools + whatshap = args.whatshap + parallel = args.parallel + qual = args.qual + var_pct_full = args.var_pct_full + ref_pct_full = args.ref_pct_full + snp_min_af = args.snp_min_af + indel_min_af = args.indel_min_af + sample_name = args.sampleName + contig_name_list = os.path.join(tmp_file_path, 'CONTIGS') + chunk_list = os.path.join(tmp_file_path, 'CHUNK_LIST') + + legal_range_from(param_name="qual", x=qual, min_num=0, exit_out_of_range=True) + legal_range_from(param_name="var_pct_full", x=var_pct_full, min_num=0, max_num=1, exit_out_of_range=True) + legal_range_from(param_name="ref_pct_full", x=ref_pct_full, min_num=0, max_num=1, exit_out_of_range=True) + legal_range_from(param_name="snp_min_af", x=snp_min_af, min_num=0, max_num=1, exit_out_of_range=True) + legal_range_from(param_name="indel_min_af", x=indel_min_af, min_num=0, max_num=1, exit_out_of_range=True) + if ref_pct_full > 0.3: + print(log_warning( + "[WARNING] For efficiency, we use a maximum 30% reference candidates for full-alignment calling")) + tool_version = { + 'python': LooseVersion(sys.version.split()[0]), + 'pypy': check_version(tool=pypy, pos=0, is_pypy=True), + 'samtools': check_version(tool=samtools, pos=1), +# 'whatshap': check_version(tool=whatshap, pos=1), + 'parallel': check_version(tool=parallel, pos=2), + } + check_tools_version(tool_version, required_tool_version) + + is_include_all_contigs = args.include_all_ctgs + is_bed_file_provided = bed_fn is not None + is_known_vcf_file_provided = vcf_fn is not None + + if is_known_vcf_file_provided and is_bed_file_provided: + sys.exit(log_error("[ERROR] Please provide either --vcf_fn or --bed_fn only")) + + if is_known_vcf_file_provided: + know_vcf_contig_set = split_extend_vcf(vcf_fn=vcf_fn, output_fn=split_bed_path) + + ctg_name_list = args.ctg_name + is_ctg_name_list_provided = ctg_name_list is not None and ctg_name_list != "EMPTY" + contig_set = set(ctg_name_list.split(',')) if is_ctg_name_list_provided else set() + + if is_ctg_name_list_provided and is_bed_file_provided: + print(log_warning("[WARNING] both --ctg_name and --bed_fn provided, will only proceed contigs in intersection")) + + if is_ctg_name_list_provided and is_known_vcf_file_provided: + print(log_warning("[WARNING] both --ctg_name and --vcf_fn provided, will only proceed contigs in intersection")) + + if is_ctg_name_list_provided: + + contig_set = contig_set.intersection( + set(tree.keys())) if is_bed_file_provided else contig_set + + contig_set = contig_set.intersection( + know_vcf_contig_set) if is_known_vcf_file_provided else contig_set + else: + contig_set = contig_set.union( + set(tree.keys())) if is_bed_file_provided else contig_set + + contig_set = contig_set.union( + know_vcf_contig_set) if is_known_vcf_file_provided else contig_set + + # if each split region is too small(long) for given default chunk num, will increase(decrease) the total chunk num + default_chunk_num = args.chunk_num + DEFAULT_CHUNK_SIZE = args.chunk_size + contig_length_list = [] + contig_chunk_num = {} + + threads = args.threads + ''' + # A64FX does not support os.sched_getaffinity, therefore we replaced it with multiprocessing + sched_getaffinity_list = list(os.sched_getaffinity(0)) + numCpus = len(sched_getaffinity_list) + ''' + numCpus = multiprocessing.cpu_count() + + if threads > numCpus: + print(log_warning( + '[WARNING] Current maximum threads {} is larger than support cpu count {}, You may set a smaller parallel threads by setting --threads=$ for better parallelism.'.format( + threads, numCpus))) + + ## for better parallelism for create tensor and call variants, we over commit the overall threads/4 for 3 times, which is 0.75 * overall threads. + threads_over_commit = max(4, int(threads * 0.75)) + + with open(fai_fn, 'r') as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + contig_name, contig_length = columns[0], int(columns[1]) + if not is_include_all_contigs and ( + not (is_bed_file_provided or is_ctg_name_list_provided or is_known_vcf_file_provided)) and str( + contig_name) not in major_contigs: + continue + + if is_bed_file_provided and contig_name not in tree: + continue + if is_ctg_name_list_provided and contig_name not in contig_set: + continue + if is_known_vcf_file_provided and contig_name not in contig_set: + continue + + contig_set.add(contig_name) + contig_length_list.append(contig_length) + chunk_num = int( + contig_length / float(DEFAULT_CHUNK_SIZE)) + 1 if contig_length % DEFAULT_CHUNK_SIZE else int( + contig_length / float(DEFAULT_CHUNK_SIZE)) + contig_chunk_num[contig_name] = max(chunk_num, 1) + + if default_chunk_num > 0: + min_chunk_length = min(contig_length_list) / float(default_chunk_num) + max_chunk_length = max(contig_length_list) / float(default_chunk_num) + + contigs_order = major_contigs_order + list(contig_set) + + sorted_contig_list = sorted(list(contig_set), key=lambda x: contigs_order.index(x)) + + found_contig = True + if not len(contig_set): + if is_bed_file_provided: + all_contig_in_bed = ' '.join(list(tree.keys())) + print(log_warning("[WARNING] No contig intersection found by --bed_fn, contigs in BED {}: {}".format(bed_fn, all_contig_in_bed))) + if is_known_vcf_file_provided: + all_contig_in_vcf = ' '.join(list(know_vcf_contig_set)) + print(log_warning("[WARNING] No contig intersection found by --vcf_fn, contigs in VCF {}: {}".format(vcf_fn, all_contig_in_vcf))) + if is_ctg_name_list_provided: + all_contig_in_ctg_name = ' '.join(ctg_name_list.split(',')) + print(log_warning("[WARNING] No contig intersection found by --ctg_name, contigs in contigs list: {}".format(all_contig_in_ctg_name))) + found_contig = False + else: + for c in sorted_contig_list: + if c not in contig_chunk_num: + print(log_warning(("[WARNING] Contig {} given but not found in reference fai file".format(c)))) + + # check contig in bam have support reads + sorted_contig_list, found_contig = check_contig_in_bam(bam_fn=bam_fn, sorted_contig_list=sorted_contig_list, + samtools=samtools) + + if not found_contig: + # output header only to merge_output.vcf.gz + output_fn = os.path.join(output_fn_prefix, "merge_output.vcf") + output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name) + compress_index_vcf(output_fn) + print(log_warning( + ("[WARNING] No contig intersection found, output header only in {}").format(output_fn + ".gz"))) + with open(contig_name_list, 'w') as output_file: + output_file.write("") + return + + print('[INFO] Call variant in contigs: {}'.format(' '.join(sorted_contig_list))) + print('[INFO] Chunk number for each contig: {}'.format( + ' '.join([str(contig_chunk_num[c]) for c in sorted_contig_list]))) + + if default_chunk_num > 0 and max_chunk_length > MAX_CHUNK_LENGTH: + print(log_warning( + '[WARNING] Current maximum chunk size {} is larger than default maximum chunk size {}, You may set a larger chunk_num by setting --chunk_num=$ for better parallelism.'.format( + min_chunk_length, MAX_CHUNK_LENGTH))) + + elif default_chunk_num > 0 and min_chunk_length < MIN_CHUNK_LENGTH: + print(log_warning( + '[WARNING] Current minimum chunk size {} is smaller than default minimum chunk size {}, You may set a smaller chunk_num by setting --chunk_num=$.'.format( + min_chunk_length, MIN_CHUNK_LENGTH))) + + if default_chunk_num == 0 and max(contig_length_list) < DEFAULT_CHUNK_SIZE / 5: + print(log_warning( + '[WARNING] Current maximum contig length {} is much smaller than default chunk size {}, You may set a smaller chunk size by setting --chunk_size=$ for better parallelism.'.format( + max(contig_length_list), DEFAULT_CHUNK_SIZE))) + + if is_bed_file_provided: + split_extend_bed(bed_fn=bed_fn, output_fn=split_bed_path, contig_set=contig_set) + + with open(contig_name_list, 'w') as output_file: + output_file.write('\n'.join(sorted_contig_list)) + + with open(chunk_list, 'w') as output_file: + for contig_name in sorted_contig_list: + chunk_num = contig_chunk_num[contig_name] + for chunk_id in range(1, chunk_num + 1): + output_file.write(contig_name + ' ' + str(chunk_id) + ' ' + str(chunk_num) + '\n') + + +def main(): + parser = argparse.ArgumentParser( + description="Check the environment and the validity of the input variables, preprocess the BED input if necessary") + + parser.add_argument('--bam_fn', type=str, default=None, + help="BAM file input, default: %(default)s") + + parser.add_argument('--output_fn_prefix', type=str, default=None, + help="Path to the output folder") + + parser.add_argument('--ctg_name', type=str, default='EMPTY', + help="The name of sequence to be processed, separated by comma") + + parser.add_argument('--bed_fn', type=str, nargs='?', action="store", default=None, + help="Call variant only in these regions. Will take an intersection if --ctg_name is set") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", + help="Reference fasta file input, default: %(default)s") + + parser.add_argument('--chunk_size', type=int, default=5000000, + help="The size of each chuck for parallel processing, default: 5Mbp") + + parser.add_argument('--include_all_ctgs', type=str2bool, default=False, + help="Call variants on all contigs, default: chr{1..22,X,Y,M,MT} and {1..22,X,Y,MT}") + + parser.add_argument('--threads', type=int, default=16, + help="Max #threads to be used. The full genome will be divided into small chucks for parallel processing") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + parser.add_argument('--pypy', type=str, default="pypy3", + help="Path to the 'pypy', pypy3 version >= 3.6 is required, default: %(default)s") + + parser.add_argument('--python', type=str, default="python3", + help="Path to the 'python3', default: %(default)s") + + parser.add_argument('--parallel', type=str, default="parallel", + help="Path to the 'parallel', default: %(default)s") + + parser.add_argument('--whatshap', type=str, default="whatshap", + help="Path to the 'whatshap', default: %(default)s") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file, optional") + + parser.add_argument('--qual', type=int, default=None, + help="If set, variants with >=$qual will be marked 'PASS', or 'LowQual' otherwise, optional") + + parser.add_argument('--var_pct_full', type=float, default=0.3, + help="Default variant call proportion for raw alignment or remove low quality proportion for whatshap phasing. (default: %(default)f)") + + parser.add_argument('--ref_pct_full', type=float, default=0.3, + help="Default reference call proportion for raw alignment or remove low quality proportion for whatshap phasing. (default: %(default)f)") + + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum SNP allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.08, + help="Minimum Indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + + # options for internal process control + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=0, + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + if not args.include_all_ctgs and args.ctg_name == 'EMPTY': + print("[INFO] --include_all_ctgs not enabled, use chr{1..22,X,Y} and {1..22,X,Y} by default") + elif args.include_all_ctgs: + print("[INFO] --include_all_ctgs enabled") + print(log_warning("[WARNING] Please enable --no_phasing_for_fa if calling variant in non-diploid organisms")) + + CheckEnvs(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/CreateTensorFullAlignment.py b/benchmarks/nn-variant/Clair3/preprocess/CreateTensorFullAlignment.py new file mode 100644 index 0000000..f6d8c28 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/CreateTensorFullAlignment.py @@ -0,0 +1,998 @@ +import sys +import shlex +import os +import json +import logging +import random +from subprocess import PIPE +from os.path import isfile +from argparse import ArgumentParser, SUPPRESS +from collections import Counter, defaultdict, OrderedDict + +import shared.param_f as param +from shared.utils import subprocess_popen, file_path_from, IUPAC_base_to_num_dict as BASE2NUM, region_from, \ + reference_sequence_from, str2bool, vcf_candidates_from +from shared.interval_tree import bed_tree_from, is_region_in +from shared.intervaltree.intervaltree import IntervalTree + +logging.basicConfig(format='%(message)s', level=logging.INFO) +BASES = set(list(BASE2NUM.keys()) + ["-"]) +no_of_positions = param.no_of_positions +flanking_base_num = param.flankingBaseNum +channel_size = param.channel_size +BASE2NUMBER = dict(zip("ACGTURYSWKMBDHVN-", (0, 1, 2, 3, 3, 0, 1, 1, 0, 2, 0, 1, 0, 0, 0, 0, 4))) +NORMALIZE_NUM = param.NORMALIZE_NUM +MAX_BQ = 40.0 +MAX_MQ = 60.0 +MAX_AF = 1.0 +STRAND_0 = 100 +STRAND_1 = 50 +HAP_TYPE = dict(zip((1, 0, 2), (30, 60, 90))) # hap1 UNKNOWN H2 +ACGT_NUM = dict(zip("ACGT+-*#N", (100, 25, 75, 50, -50, -100, 0, 0, 100))) + + +def _normalize_bq(x): + return int(NORMALIZE_NUM * min(x, MAX_BQ) / MAX_BQ) + + +def _normalize_mq(x): + return int(NORMALIZE_NUM * min(x, MAX_MQ) / MAX_MQ) + + +def _normalize_af(x): + return int(NORMALIZE_NUM * min(x, MAX_AF) / MAX_AF) + + +class Position(object): + def __init__(self, pos, ref_base=None, alt_base=None, read_name_list=None, base_list=None, raw_base_quality=None, + raw_mapping_quality=None, af=None, depth=None, genotype=None, phase_set=None): + self.pos = pos + self.ref_base = ref_base + self.alt_base = alt_base + self.read_name_list = read_name_list + self.base_list = base_list + self.raw_base_quality = raw_base_quality + self.raw_mapping_quality = raw_mapping_quality + self.af = af + self.depth = depth + self.read_channel = None + self.mapping_quality = None + self.update_info = False + self.read_info = defaultdict() + self.ref_seq = None + self.alt_seq = None + self.phase_set = phase_set + self.genotype = genotype + self.read_name_seq = defaultdict(str) + + def update_infos(self): + # only proceed when variant exists in candidate windows which greatly improves efficiency + self.update_info = True + self.read_name_dict = dict(zip(self.read_name_list, self.base_list)) + self.mapping_quality = [_normalize_mq(phredscore2raw_score(item)) for item in self.raw_mapping_quality] + self.base_quality = [_normalize_bq(phredscore2raw_score(item)) for item in self.raw_base_quality] + + for read_name, base_info, bq, mq in zip(self.read_name_list, self.base_list, self.base_quality, + self.mapping_quality): + read_channel, ins_base, query_base = get_tensor_info(base_info, bq, self.ref_base, mq) + self.read_info[read_name] = (read_channel, ins_base) + + +class PhasingRead(object): + def __init__(self): + self.read_seq = defaultdict(str) + self.read_start = None + self.read_end = None + + +def phredscore2raw_score(qual): + return ord(qual) - 33 + + +def evc_base_from(base): + if base == 'N': + return 'A' + elif base == 'n': + return 'a' + elif base in 'ACGTacgt': + return base + elif base.isupper(): + return 'A' + else: + return 'a' + + +def sorted_by_hap_read_name(center_pos, haplotag_dict, pileup_dict, hap_dict, platform): + """ + Sort by reads haplotype after haplotag reads otherwise sort by read start position. + center_pos: define the center candidate position for proccessing. + haplotag_dict: dictionary (read name : hap type) which keep the read name and haplotype mapping. + pileup_dict: dictionary (pos: pos info) which keep read information that cover specific position . + hap_dict: similar to haplotag_dict, dictionary (pos: pos info) which keep the read name and haplotype mapping, + while haplotype information directly acquire from BAM HP tag. + platform: select maximum depth for each platform. + """ + all_nearby_read_name = [] + start_pos, end_pos = center_pos - flanking_base_num, center_pos + flanking_base_num + 1 + for p in range(start_pos, end_pos): + if p in pileup_dict.keys(): + all_nearby_read_name += pileup_dict[p].read_name_list + all_nearby_read_name = list(OrderedDict.fromkeys(all_nearby_read_name)) # have sorted by order + matrix_depth = param.matrix_depth_dict[platform] + if len(all_nearby_read_name) > matrix_depth: + # set same seed for reproducibility + random.seed(0) + indices = random.sample(range(len(all_nearby_read_name)), matrix_depth) + all_nearby_read_name = [all_nearby_read_name[i] for i in sorted(indices)] + sorted_read_name_list = [] + for order, read_name in enumerate(all_nearby_read_name): + hap = max(haplotag_dict[read_name], hap_dict[read_name]) # no phasing is 0 + sorted_read_name_list.append((hap, order, read_name)) + + sorted_read_name_list = sorted(sorted_read_name_list) + return sorted_read_name_list + + +def get_tensor_info(base_info, bq, ref_base, read_mq): + """ + Create tensor information for each read level position. + base_info: base information include all alternative bases. + bq: normalized base quality. + ref_base: reference_base: upper reference base for cigar calculation. + read_mq: read mapping quality. + """ + + base, indel = base_info + ins_base = "" + query_base = "" + read_channel = [0] * channel_size + if base[0] in '*#': + return read_channel, ins_base, query_base + strand = STRAND_1 + if base[0] in 'ACGT': + strand = STRAND_0 + ALT_BASE = 0 + + base_upper = base.upper() + if indel != '': + ALT_BASE = ACGT_NUM[indel[0]] + elif (base_upper != ref_base and base_upper in 'ACGT'): + base_upper = evc_base_from(base_upper) + ALT_BASE = ACGT_NUM[base_upper] + + REF_BASE = ACGT_NUM[ref_base] + if len(indel) and indel[0] in '+-': + if indel[0] == "+": + ins_base = indel[1:].upper() + read_channel[:5] = REF_BASE, ALT_BASE, strand, read_mq, bq + query_base = "" if base_upper not in "ACGT" else base_upper + return read_channel, ins_base, query_base + + +def decode_pileup_bases(pileup_bases, reference_base, minimum_af_for_candidate, minimum_snp_af_for_candidate, minimum_indel_af_for_candidate, has_pileup_candidates, platform='ont'): + """ + Decode mpileup input string. + pileup_bases: pileup base string for each position, include all mapping information. + reference_base: upper reference base for cigar calculation. + pileup_dict: dictionary (pos: pos info) which keep read information that cover specific position. + ref_seq: chunked reference sequence in window, start: center pos - flankingBaseNum, end: center + flankingBaseNum + 1. + reference_sequence: reference sequence index by contig:start-end. 0-based. + minimum_af_for_candidate: default minimum alleic frequency for candidate filtering, filter if below specific thredshold. + has_pileup_candidates: if the candidate is directly obtained from pileup output, then no need to check the af filtering. + """ + + base_idx = 0 + base_list = [] + while base_idx < len(pileup_bases): + base = pileup_bases[base_idx] + if base == '+' or base == '-': + base_idx += 1 + advance = 0 + while True: + num = pileup_bases[base_idx] + if num.isdigit(): + advance = advance * 10 + int(num) + base_idx += 1 + else: + break + base_list[-1][1] = base + pileup_bases[base_idx: base_idx + advance] # add indel seq + base_idx += advance - 1 + + elif base in "ACGTNacgtn#*": + base_list.append([base, ""]) + elif base == '^': # start of read, next base is mq, update mq info + base_idx += 1 + # skip $, the end of read + base_idx += 1 + if has_pileup_candidates: + return base_list, None, None, None + + pileup_dict = defaultdict(int) + base_counter = Counter([''.join(item) for item in base_list]) + depth = 0 + for key, count in base_counter.items(): + if key[0].upper() in 'ACGT': + pileup_dict[key[0].upper()] += count + depth += count + if len(key) > 1 and key[1] == '+': + pileup_dict['I'] += count + elif len(key) > 1 and key[1] == '-': + pileup_dict['D'] += count + + minimum_snp_af_for_candidate = minimum_snp_af_for_candidate if minimum_snp_af_for_candidate > 0 else param.min_af + minimum_indel_af_for_candidate = minimum_indel_af_for_candidate if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform] + + denominator = depth if depth > 0 else 1 + pileup_list = sorted(list(pileup_dict.items()), key=lambda x: x[1], reverse=True) + + pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base) + pass_snp_af = False + pass_indel_af = False + + for item, count in pileup_list: + if item == reference_base: + continue + elif item[0] in 'ID': + pass_indel_af = (pass_indel_af or (float(count) / denominator >= minimum_indel_af_for_candidate)) + continue + pass_snp_af = pass_snp_af or (float(count) / denominator >= minimum_snp_af_for_candidate) + + af = (float(pileup_list[1][1]) / denominator) if len(pileup_list) > 1 else 0.0 + af = (float(pileup_list[0][1]) / denominator) if len(pileup_list) >= 1 and pileup_list[0][ + 0] != reference_base else af + + pass_af = pass_af or pass_snp_af or pass_indel_af + + return base_list, depth, pass_af, af + + +def get_alt_info(center_pos, pileup_dict, ref_seq, reference_sequence, reference_start, hap_dict): + """ + Get alternative information for representation unification, keep all read level alignment information including phasing info. + center_pos: center position for processing, default window size = no_of_positions = flankingBaseNum + 1 + flankingBaseNum + pileup_dict: dictionary (pos: pos info) which keep read information that cover specific position . + ref_seq: chunked reference sequence in window, start: center pos - flankingBaseNum, end: center + flankingBaseNum + 1. + reference_sequence: reference sequence index by contig:start-end. 0-based. + reference_base: upper reference base for cigar calculation. + reference_start: upper reference base for cigar calculation. + hap_dict: dictionary (pos: pos info) which keep the read name and haplotype mapping. + """ + + reference_base = ref_seq[flanking_base_num] + alt_read_name_dict = defaultdict(set) + depth = 0 + for (base, indel), read_name in zip(pileup_dict[center_pos].base_list, pileup_dict[center_pos].read_name_list): + if base in "#*": + alt_read_name_dict['*'].add(read_name) + depth += 1 + continue + depth += 1 + if base.upper() == reference_base and indel == '': + alt_read_name_dict['R'].add(read_name) + if indel != '': + if indel[0] == '+': + indel = 'I' + base.upper() + indel.upper()[1:] + else: + del_bases_num = len(indel[1:]) + del_ref_bases = reference_sequence[ + center_pos - reference_start + 1:center_pos - reference_start + del_bases_num + 1] + indel = 'D' + del_ref_bases + alt_read_name_dict[indel].add(read_name) + + if indel == '' and base.upper() != reference_base: + alt_read_name_dict['X' + base.upper()].add(read_name) + + for alt_type, read_name_set in list(alt_read_name_dict.items()): + alt_read_name_dict[alt_type] = ' '.join( + [read_name + '_' + str(hap_dict[read_name]) for read_name in list(read_name_set)]) + + alt_info = str(depth) + '\t' + json.dumps(alt_read_name_dict) + + return alt_info + + +def generate_tensor(ctg_name, center_pos, sorted_read_name_list, pileup_dict, ref_seq, reference_sequence, + reference_start, platform, confident_bed_tree, add_no_phasing_data_training): + """ + Generate full alignment input tensor + ctg_name: provided contig name. + center_pos: center position for full alignment generation, default window size = no_of_positions = + flankingBaseNum + 1 + flankingBaseNum + sorted_read_name_list: read name list which have been sorted by read start position and haplotype. + pileup_dict: dictionary (pos: pos info) which keep read information that cover specific position . + ref_seq: chunked reference sequence in window, start: center pos - flankingBaseNum, end: center + flankingBaseNum + 1. + reference_sequence: reference sequence index by contig:start-end. 0-based. + reference_base: upper reference base for cigar calculation. + reference_start: upper reference base for cigar calculation. + platform: platform for tensor shape, ont give a larger maximum depth compared with pb and illumina. + confident_bed_tree: dictionary (contig name : intervaltree) for fast region query. + add_no_phasing_data_training: boolean option to decide whether add no phasing data in training, we will + resort the read and remove haplotype info when using this option. + """ + + tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape + reference_base = ref_seq[flanking_base_num] + tensor_depth = len(sorted_read_name_list) + if tensor_depth == 0: + return None, None + tensor = [[[0] * tensor_shape[2] for _ in range(tensor_shape[1])] for _ in range(tensor_depth)] + start_pos, end_pos = center_pos - flanking_base_num, center_pos + flanking_base_num + 1 + insert_tuple = [] + + alt_dict = defaultdict(int) + depth, max_del_length = 0, 0 + for base, indel in pileup_dict[center_pos].base_list: + if base in "#*": + depth += 1 + continue + depth += 1 + base_upper = base.upper() + if indel != '': + if indel[0] == '+': + alt_dict['+' + base_upper + indel[1:].upper()] += 1 + else: # del + alt_dict[indel.upper()] += 1 + max_del_length = max(len(indel), max_del_length) + elif base.upper() != reference_base: + alt_dict[base.upper()] += 1 + + # match deletion cases and bed format + pass_confident_bed = not len(confident_bed_tree) or is_region_in(confident_bed_tree, ctg_name, + center_pos - 2, + center_pos + max_del_length + 1) + if not pass_confident_bed: + return None, None + + for p in range(start_pos, end_pos): + if p in pileup_dict and not pileup_dict[p].update_info: + pileup_dict[p].update_infos() + for read_idx, read_name_info in enumerate(sorted_read_name_list): + hap, read_order, read_name = read_name_info + offset = p - start_pos + if p in pileup_dict and read_name in pileup_dict[p].read_info: + read_channel, ins_base = pileup_dict[p].read_info[read_name] + tensor[read_idx][offset] = read_channel + if ins_base != '' and p < end_pos - 1: + insert_tuple.append((read_idx, offset, ins_base, p)) + + for read_idx, p, ins_base, cp in insert_tuple: + + for ins_idx in range(min(len(ins_base), no_of_positions - p)): + tensor[read_idx][ins_idx + p][6] = ACGT_NUM[ins_base[ins_idx]] + + for row_idx, (hap, _, read_name) in enumerate(sorted_read_name_list): + af_num = 0 + if read_name in pileup_dict[center_pos].read_name_dict: + base, indel = pileup_dict[center_pos].read_name_dict[read_name] + base_upper = base.upper() + if indel != '': + if indel[0] == '+': + insert_str = ('+' + base_upper + indel.upper()[1:]) + af_num = alt_dict[insert_str] / max(1, float(depth)) if insert_str in alt_dict else af_num + else: + af_num = alt_dict[indel.upper()] / max(1, float(depth)) if indel.upper() in alt_dict else af_num + elif base.upper() in alt_dict: + af_num = alt_dict[base_upper] / max(1, float(depth)) + af_num = _normalize_af(af_num) if af_num != 0 else af_num + hap_type = HAP_TYPE[hap] + for p in range(no_of_positions): + if tensor[row_idx][p][2] != 0: # skip all del #* + tensor[row_idx][p][5] = af_num + tensor[row_idx][p][7] = hap_type + + alt_info = [] + for alt_type, alt_count in alt_dict.items(): + if alt_type[0] == '+': + alt_info.append(['I' + alt_type[1:].upper(), str(alt_count)]) + elif alt_type[0] == '-': + del_bases_num = len(alt_type[1:]) + del_ref_bases = reference_sequence[ + center_pos - reference_start + 1:center_pos - reference_start + del_bases_num + 1] + alt_info.append(['D' + del_ref_bases, str(alt_count)]) + else: + alt_info.append(['X' + alt_type, str(alt_count)]) + + alt_info = str(depth) + '-' + ' '.join([' '.join([item[0], str(item[1])]) for item in alt_info]) + tensor_string_list = [" ".join((" ".join(" ".join(str(x) for x in innerlist) for innerlist in outerlist)) for outerlist in tensor)] + + if add_no_phasing_data_training: + all_hap = [item[0] for item in sorted_read_name_list] + # skip if no phased reads exist + if sum(all_hap) != 0: + raw_read_name_index_mapping = [item[1] for item in sorted( + [(item[1], read_idx) for read_idx, item in enumerate(sorted_read_name_list)])] + no_phasing_tensor = [tensor[read_idx] for read_idx in raw_read_name_index_mapping] + for row_idx in range(len(no_phasing_tensor)): + for p in range(no_of_positions): + if tensor[row_idx][p][7] > 0: + tensor[row_idx][p][7] = HAP_TYPE[0] + + no_phasing_tensor_string = " ".join( + (" ".join(" ".join(str(x) for x in innerlist) for innerlist in outerlist)) for outerlist in + no_phasing_tensor) + tensor_string_list.append(no_phasing_tensor_string) + return '\n'.join(["%s\t%d\t%s\t%s\t%s" % ( + ctg_name, + center_pos, + ref_seq, + tensor_string, + alt_info + ) for tensor_string in tensor_string_list]), alt_info + + +class TensorStdout(object): + def __init__(self, handle): + self.stdin = handle + + def __del__(self): + self.stdin.close() + + +def update_hete_ref(pos, reference_sequence, reference_start, extend_bp, alt_base): + # if need phasing option enables, will store reference squence near hete snp candidate. + ref_start = pos - extend_bp + ref_end = pos + extend_bp + 1 + ref_seq = reference_sequence[ref_start - reference_start: ref_end - reference_start] + alt_seq = ref_seq[:extend_bp] + alt_base + ref_seq[extend_bp + 1:] + return ref_seq, alt_seq + + +def CreateTensorFullAlignment(args): + ctg_start = args.ctgStart + ctg_end = args.ctgEnd + full_aln_regions = args.full_aln_regions + fasta_file_path = args.ref_fn + ctg_name = args.ctgName + need_phasing = args.need_phasing + samtools_execute_command = args.samtools + bam_file_path = args.bam_fn + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + tensor_can_output_path = args.tensor_can_fn + is_full_aln_regions_given = full_aln_regions is not None + phasing_info_in_bam = args.phasing_info_in_bam + phasing_window_size = args.phasing_window_size + extend_bp = param.extend_bp + unify_repre = args.unify_repre + minimum_af_for_candidate = args.min_af + minimum_snp_af_for_candidate = args.snp_min_af + minimum_indel_af_for_candidate = args.indel_min_af + min_coverage = args.minCoverage + platform = args.platform + confident_bed_fn = args.bed_fn + is_confident_bed_file_given = confident_bed_fn is not None + phased_vcf_fn = args.phased_vcf_fn + alt_fn = args.indel_fn + extend_bed = args.extend_bed + is_extend_bed_file_given = extend_bed is not None + min_mapping_quality = args.minMQ + min_base_quality = args.minBQ + unify_repre_fn = args.unify_repre_fn + add_no_phasing_data_training = args.add_no_phasing_data_training + vcf_fn = args.vcf_fn + is_known_vcf_file_provided = vcf_fn is not None + + global test_pos + test_pos = None + hete_snp_pos_dict = defaultdict() + hete_snp_tree = IntervalTree() + need_phasing_pos_set = set() + add_read_regions = True + if full_aln_regions: + + """ + If given full alignment bed regions, all candidate positions will be directly selected from each row, define as + 'ctg start end', where 0-based center position is the candidate for full alignment calling. + if 'need_phasing' option enables, full alignment bed regions will also include nearby heterozygous snp candidates for reads + haplotag, which is faster than whatshap haplotag with more memory occupation. + """ + + candidate_file_path_process = subprocess_popen(shlex.split("gzip -fdc %s" % (full_aln_regions))) + candidate_file_path_output = candidate_file_path_process.stdout + + ctg_start, ctg_end = float('inf'), 0 + for row in candidate_file_path_output: + row = row.rstrip().split('\t') + if row[0] != ctg_name: continue + position = int(row[1]) + 1 + end = int(row[2]) + 1 + ctg_start = min(position, ctg_start) + ctg_end = max(end, ctg_end) + + if platform == "ilmn": + continue + if len(row) > 3: # hete snp positions + center_pos = position + extend_bp + 1 + ref_base, alt_base, genotype, phase_set = row[3].split('-') + hete_snp_pos_dict[center_pos] = Position(pos=center_pos, ref_base=ref_base, alt_base=alt_base, + genotype=int(genotype), phase_set=phase_set) + hete_snp_tree.addi(begin=center_pos - extend_bp, end=center_pos + extend_bp + 1) + else: + center = position + (end - position) // 2 - 1 + need_phasing_pos_set.add(center) + candidate_file_path_output.close() + candidate_file_path_process.wait() + + # currently deprecate using ctgName.start_end as file name, which will run similar regions for several times when start and end has slight difference + # if '.' in full_aln_regions.split('/')[-1] and len(full_aln_regions.split('/')[-1].split('.')[-1].split('_')) > 0: + # ctg_start, ctg_end = full_aln_regions.split('/')[-1].split('.')[-1].split('_') + # ctg_start, ctg_end = int(ctg_start), int(ctg_end) + if platform == 'ilmn' and bam_file_path == "PIPE": + add_read_regions = False + + fai_fn = file_path_from(fasta_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + + if is_known_vcf_file_provided: + known_variants_list = vcf_candidates_from(vcf_fn=vcf_fn, contig_name=ctg_name) + known_variants_set = set(known_variants_list) + if not full_aln_regions and chunk_id is not None: + + """ + Whole genome calling option, acquire contig start end position from reference fasta index(.fai), then split the + reference accroding to chunk id and total chunk numbers. + """ + if is_confident_bed_file_given: + # consistent with pileup generation, faster to extract tensor using bed region + tree, bed_start, bed_end = bed_tree_from(bed_file_path=extend_bed, + contig_name=ctg_name, + return_bed_region=True) + + chunk_size = (bed_end - bed_start) // chunk_num + 1 if (bed_end - bed_start) % chunk_num else ( + bed_end - bed_start) // chunk_num + ctg_start = bed_start + 1 + chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + else: + contig_length = 0 + with open(fai_fn, 'r') as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + + contig_name = columns[0] + if contig_name != ctg_name: + continue + contig_length = int(columns[1]) + chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num + ctg_start = chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + + # for illumina platform, the reads alignment is acquired after reads realignment from ReadsRealign.py + if platform == 'ilmn' and bam_file_path != "PIPE": + bam_file_path += '.{}_{}'.format(ctg_start, ctg_end) + add_read_regions = False + if bam_file_path == "PIPE": + add_read_regions = False + + if need_phasing and phased_vcf_fn and os.path.exists(phased_vcf_fn): + # if need_phasing option enables, scan the phased vcf file and store the heterozygous snp candidates from each phase set + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (phased_vcf_fn))) + for row in unzip_process.stdout: + row = row.rstrip() + if row[0] == '#': + continue + columns = row.strip().split('\t') + contig_name = columns[0] + if ctg_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + if ctg_start and ctg_end: + if pos < ctg_start - phasing_window_size or pos > ctg_end + phasing_window_size: + continue + ref_base = columns[3] + alt_base = columns[4] + genotype_info = columns[9].split(':') + genotype, phase_set = genotype_info[0], genotype_info[-1] + if '|' not in genotype: # unphasable + continue + genotype = ('1' if genotype == '0|1' else '2') + # need in phasing_window + hete_snp_pos_dict[pos] = Position(pos=pos, ref_base=ref_base, alt_base=alt_base, + genotype=int(genotype), phase_set=phase_set) + hete_snp_tree.addi(begin=pos - extend_bp, end=pos + extend_bp + 1) + + # preparation for candidates near variants + need_phasing_pos_set = set([item for item in need_phasing_pos_set if item >= ctg_start and item <= ctg_end]) + # 1-based regions [start, end] (start and end inclusive) + ref_regions = [] + reads_regions = [] + + is_ctg_name_given = ctg_name is not None + is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None + extend_start, extend_end = None, None + if is_ctg_range_given: + extend_start = ctg_start - (phasing_window_size if need_phasing else no_of_positions) + extend_end = ctg_end + (phasing_window_size if need_phasing else no_of_positions) + reads_regions.append(region_from(ctg_name=ctg_name, ctg_start=extend_start, ctg_end=extend_end)) + reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion + reference_start = 1 if reference_start < 1 else reference_start + ref_regions.append(region_from(ctg_name=ctg_name, ctg_start=reference_start, ctg_end=reference_end)) + elif is_ctg_name_given: + reads_regions.append(region_from(ctg_name=ctg_name)) + ref_regions.append(region_from(ctg_name=ctg_name)) + reference_start = 1 + + reference_sequence = reference_sequence_from( + samtools_execute_command=samtools_execute_command, + fasta_file_path=fasta_file_path, + regions=ref_regions + ) + if reference_sequence is None or len(reference_sequence) == 0: + sys.exit("[ERROR] Failed to load reference sequence from file ({}).".format(fasta_file_path)) + + phasing_option = " --output-extra HP" if phasing_info_in_bam else " " + mq_option = ' --min-MQ {}'.format(min_mapping_quality) + bq_option = ' --min-BQ {}'.format(min_base_quality) + # pileup bed first + bed_option = ' -l {}'.format( + extend_bed) if is_extend_bed_file_given and platform != 'ilmn' else "" + bed_option = ' -l {}'.format(full_aln_regions) if is_full_aln_regions_given and platform != 'ilmn' else bed_option + flags_option = ' --excl-flags {}'.format(param.SAMTOOLS_VIEW_FILTER_FLAG) + max_depth_option = ' --max-depth {}'.format(args.max_depth) if args.max_depth > 0 else "" + reads_regions_option = ' -r {}'.format(" ".join(reads_regions)) if add_read_regions else "" + # print (add_read_regions, ctg_start, ctg_end, reference_start) + stdin = None if bam_file_path != "PIPE" else sys.stdin + bam_file_path = bam_file_path if bam_file_path != "PIPE" else "-" + samtools_command = "{} mpileup {} --reverse-del --output-QNAME --output-MQ".format(samtools_execute_command, + bam_file_path) + \ + reads_regions_option + phasing_option + mq_option + bq_option + bed_option + flags_option + max_depth_option + samtools_mpileup_process = subprocess_popen( + shlex.split(samtools_command), stdin=stdin) + + if not unify_repre: + if tensor_can_output_path != "PIPE": + tensor_can_fpo = open(tensor_can_output_path, "wb") + tensor_can_fp = subprocess_popen(shlex.split("{} -c".format(args.zstd)), stdin=PIPE, stdout=tensor_can_fpo) + else: + tensor_can_fp = TensorStdout(sys.stdout) + else: + if unify_repre_fn != "PIPE": + label_fp = open(unify_repre_fn, 'w') + else: + label_fp = sys.stdout + if alt_fn: + output_alt_fn = alt_fn + alt_fp = open(output_alt_fn, 'w') + + hap_dict = defaultdict(int) + haplotag_dict = defaultdict(int) + pileup_dict = defaultdict(str) + phasing_read_seq = defaultdict(PhasingRead) + extend_bp_distance = phasing_window_size if need_phasing else no_of_positions + param.extend_bp + confident_bed_tree = bed_tree_from(bed_file_path=confident_bed_fn, + contig_name=ctg_name, + bed_ctg_start=extend_start, + bed_ctg_end=extend_end) + + extend_bed_tree = bed_tree_from(bed_file_path=extend_bed, + contig_name=ctg_name, + bed_ctg_start=extend_start, + bed_ctg_end=extend_end) + + def samtools_pileup_generator_from(samtools_mpileup_process): + need_phasing_pos_list = sorted(list(need_phasing_pos_set)) + current_pos_index = 0 + has_pileup_candidates = len(need_phasing_pos_set) + for row in samtools_mpileup_process.stdout: # chr position N depth seq BQ read_name mapping_quality phasing_info + columns = row.strip().split('\t') + pos = int(columns[1]) + # pos that near bed region should include some indel cover in bed + pass_extend_bed = not is_extend_bed_file_given or is_region_in(extend_bed_tree, + ctg_name, pos - 1, + pos + 1) + pass_ctg_range = not ctg_start or (pos >= ctg_start and pos <= ctg_end) + if not has_pileup_candidates and not pass_extend_bed and pass_ctg_range: + continue + pileup_bases = columns[4] + raw_base_quality = columns[5] + read_name_list = columns[6].split(',') + raw_mapping_quality = columns[7] + reference_base = evc_base_from(reference_sequence[pos - reference_start].upper()) # ev + base_list, depth, pass_af, af = decode_pileup_bases(pileup_bases=pileup_bases, + reference_base=reference_base, + minimum_af_for_candidate=minimum_af_for_candidate, + minimum_snp_af_for_candidate=minimum_snp_af_for_candidate, + minimum_indel_af_for_candidate=minimum_indel_af_for_candidate, + has_pileup_candidates=has_pileup_candidates) + + if phasing_info_in_bam: + phasing_info = columns[8].split(',') + # https://github.com/HKU-BAL/Clair3/issues/32, skip adding phase info when BAM phase info lacks + # add read name list size check in following steps + if len(read_name_list) != len(phasing_info): + continue + else: + for hap_idx, hap in enumerate(phasing_info): + if hap in '12' and read_name_list[hap_idx] not in hap_dict: + hap_dict[read_name_list[hap_idx]] = int(hap) + + if len(read_name_list) != len(base_list): + continue + + if not is_known_vcf_file_provided and not has_pileup_candidates and reference_base in 'ACGT' and ( + pass_af and depth >= min_coverage): + need_phasing_pos_list.append(pos) + + if is_known_vcf_file_provided and not has_pileup_candidates and pos in known_variants_set: + need_phasing_pos_list.append(pos) + + pileup_dict[pos] = Position(pos=pos, + ref_base=reference_base, + read_name_list=read_name_list, + base_list=base_list, + raw_base_quality=raw_base_quality, + raw_mapping_quality=raw_mapping_quality, + af=af, + depth=depth) + + overlap_hete_region = hete_snp_tree.at(pos) + if need_phasing and len(overlap_hete_region): + for read_name, base_info, in zip(read_name_list, base_list): + query_base, indel_base = base_info + query_base = query_base.upper() + ins_base = "" if indel_base == "" or indel_base[0] != '+' else indel_base.upper() + for region in overlap_hete_region: + hete_center = region.begin + param.extend_bp + phasing_read_seq[read_name].read_seq[hete_center] += query_base + ins_base + + if current_pos_index < len(need_phasing_pos_list) and pos - need_phasing_pos_list[ + current_pos_index] > extend_bp_distance: + yield need_phasing_pos_list[current_pos_index] + for pre_pos in sorted(pileup_dict.keys()): + if need_phasing_pos_list[current_pos_index] - pre_pos > extend_bp_distance: + del pileup_dict[pre_pos] + else: + break + current_pos_index += 1 + while current_pos_index != len(need_phasing_pos_list): + yield need_phasing_pos_list[current_pos_index] + for pre_pos in sorted(pileup_dict.keys()): + if need_phasing_pos_list[current_pos_index] - pre_pos > extend_bp_distance: + del pileup_dict[pre_pos] + else: + break + current_pos_index += 1 + + yield None + + samtools_pileup_generator = samtools_pileup_generator_from(samtools_mpileup_process) + + for hete_pos in hete_snp_pos_dict: + if need_phasing and hete_snp_pos_dict[hete_pos].ref_seq is None: + hete_snp_pos_dict[hete_pos].ref_seq, hete_snp_pos_dict[hete_pos].alt_seq = update_hete_ref(pos=hete_pos, + reference_sequence=reference_sequence, + reference_start=reference_start, + extend_bp=extend_bp, + hete_snp_pos_dict=hete_snp_pos_dict[hete_pos].alt_base) + while True: + pos = next(samtools_pileup_generator) + if pos is None: + break + if pos not in pileup_dict: + continue + if need_phasing: + """ + Haplotag reads haplotype when create full alignment tensor, which is faster than whatshap haplotag while + occupy more memory. Whole haplotag logic follow whatshap haplotag function. + """ + + from Levenshtein import distance as edit_distance + need_phasing_read_set = set(pileup_dict[pos].read_name_list) + for read_name in need_phasing_read_set: + haplotype_costs = defaultdict(int) + if haplotag_dict[read_name] != 0 or len(phasing_read_seq[read_name].read_seq) == 0: + continue + for overlp_pos, query_seq in phasing_read_seq[read_name].read_seq.items(): + if query_seq == "": + continue + ref_seq = hete_snp_pos_dict[overlp_pos].ref_seq + alt_seq = hete_snp_pos_dict[overlp_pos].alt_seq + distance_ref = edit_distance(query_seq, ref_seq) + distance_alt = edit_distance(query_seq, alt_seq) + hap_match = 0 + if distance_alt > distance_ref: + hap_match = 1 + elif distance_alt < distance_ref: + hap_match = 2 + else: + # skip read with unkown hap type + continue + hete_hap_type = hete_snp_pos_dict[overlp_pos].genotype + if hap_match == hete_hap_type: + haplotype_costs[hete_snp_pos_dict[overlp_pos].phase_set] += 1 + else: + haplotype_costs[hete_snp_pos_dict[overlp_pos].phase_set] -= 1 + + haplotype_costs = sorted(list(haplotype_costs.items()), key=lambda x: -abs(x[1])) + if len(haplotype_costs) == 0 or haplotype_costs[0][1] == 0: + # no hap support or having same score + continue + # release memory resource + del phasing_read_seq[read_name] + phaseset, quality = haplotype_costs[0] + haplotype = 1 if quality > 0 else 2 + haplotag_dict[read_name] = haplotype + # skip if two scores are the same + + sorted_read_name_list = sorted_by_hap_read_name(pos, haplotag_dict, pileup_dict, hap_dict, platform) + ref_seq = reference_sequence[ + pos - reference_start - flanking_base_num: pos - reference_start + flanking_base_num + 1].upper() + + if not unify_repre: + tensor, alt_info = generate_tensor(ctg_name=ctg_name, + center_pos=pos, + sorted_read_name_list=sorted_read_name_list, + pileup_dict=pileup_dict, + ref_seq=ref_seq, + reference_sequence=reference_sequence, + reference_start=reference_start, + platform=platform, + confident_bed_tree=confident_bed_tree, + add_no_phasing_data_training=add_no_phasing_data_training) + if not tensor: + continue + + tensor_can_fp.stdin.write(tensor) + tensor_can_fp.stdin.write("\n") + if alt_fn: + alt_info = alt_info.replace('-', '\t') + alt_fp.write('\t'.join([ctg_name + ' ' + str(pos), alt_info]) + '\n') + + if unify_repre: + label_info = get_alt_info(center_pos=pos, + pileup_dict=pileup_dict, + ref_seq=ref_seq, + reference_sequence=reference_sequence, + reference_start=reference_start, + hap_dict=hap_dict) + label_fp.write('\t'.join([ctg_name + ' ' + str(pos), label_info]) + '\n') + + samtools_mpileup_process.stdout.close() + samtools_mpileup_process.wait() + + if not unify_repre and tensor_can_output_path != "PIPE": + tensor_can_fp.stdin.close() + tensor_can_fp.wait() + tensor_can_fpo.close() + + if alt_fn: + alt_fp.close() + + if unify_repre and unify_repre_fn != "PIPE": + label_fp.close() + + +def main(): + parser = ArgumentParser(description="Generate variant candidate tensors using phased full-alignment") + + parser.add_argument('--platform', type=str, default='ont', + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--bam_fn', type=str, default="input.bam", required=True, + help="Sorted BAM file input, required") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", required=True, + help="Reference fasta file input, required") + + parser.add_argument('--tensor_can_fn', type=str, default="PIPE", + help="Tensor output, stdout by default, default: %(default)s") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + parser.add_argument('--min_af', type=float, default=0.08, + help="Minimum allele frequency for both SNP and Indel for a site to be considered as a condidate site, default: %(default)f") + + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum snp allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.15, + help="Minimum indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, required if --bed_fn is not defined") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--bed_fn', type=str, default=None, + help="Call variant only in the provided regions. Will take an intersection if --ctgName and/or (--ctgStart, --ctgEnd) are set") + + parser.add_argument('--gvcf', type=str2bool, default=False, + help="Enable GVCF output, default: disabled") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the GVCF file") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required. default: %(default)s") + + # options for advanced users + parser.add_argument('--minCoverage', type=float, default=param.min_coverage, + help="EXPERIMENTAL: Minimum coverage required to call a variant, default: %(default)f") + + parser.add_argument('--minMQ', type=int, default=param.min_mq, + help="EXPERIMENTAL: If set, reads with mapping quality with <$minMQ are filtered, default: %(default)d") + + parser.add_argument('--minBQ', type=int, default=param.min_bq, + help="EXPERIMENTAL: If set, bases with base quality with <$minBQ are filtered, default: %(default)d") + + parser.add_argument('--max_depth', type=int, default=param.max_depth, + help="EXPERIMENTAL: Maximum full alignment depth to be processed. default: %(default)s") + + # options for debug purpose + parser.add_argument('--phasing_info_in_bam', action='store_true', + help="DEBUG: Skip phasing and use the phasing info provided in the input BAM (HP tag), default: False") + + parser.add_argument('--phasing_window_size', type=int, default=param.phasing_window_size, + help="DEBUG: The window size for read phasing") + + parser.add_argument('--extend_bed', nargs='?', action="store", type=str, default=None, + help="DEBUG: Extend the regions in the --bed_fn by a few bp for tensor creation, default extend 16bp") + + parser.add_argument('--indel_fn', type=str, default=None, + help="DEBUG: Output all alternative indel cigar for debug purpose") + + parser.add_argument('--base_err', default=0.001, type=float, + help='DEBUG: Estimated base error rate in gvcf option, default: %(default)f') + + parser.add_argument('--gq_bin_size', default=5, type=int, + help='DEBUG: Default gq bin size for merge non-variant block in gvcf option, default: %(default)d') + + parser.add_argument('--bp_resolution', action='store_true', + help="DEBUG: Enable bp resolution for GVCF, default: disabled") + + # options for internal process control + ## Path to the 'zstd' compression + parser.add_argument('--zstd', type=str, default=param.zstd, + help=SUPPRESS) + + ## Test in specific candidate position. Only for testing + parser.add_argument('--test_pos', type=int, default=0, + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Only call variant in phased vcf file + parser.add_argument('--phased_vcf_fn', type=str, default=None, + help=SUPPRESS) + + ## Apply no phased data in training. Only works in data training, default: False + parser.add_argument('--add_no_phasing_data_training', action='store_true', + help=SUPPRESS) + + ## Output representation unification infos, which refines training labels + parser.add_argument('--unify_repre', action='store_true', + help=SUPPRESS) + + ## Path of representation unification output + parser.add_argument('--unify_repre_fn', type=str, default=None, + help=SUPPRESS) + + ## Provide the regions to be included in full-alignment based calling + parser.add_argument('--full_aln_regions', type=str, default=None, + help=SUPPRESS) + + ## Use Clair3's own phasing module for read level phasing when creating tensor, compared to using Whatshap, speed is faster but has higher memory footprint, default: False + parser.add_argument('--need_phasing', action='store_true', + help=SUPPRESS) + + ## Apply read realignment for illumina platform. Greatly boost indel performance in trade of running time + parser.add_argument('--need_realignment', action='store_true', + help=SUPPRESS) + + + args = parser.parse_args() + + CreateTensorFullAlignment(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/CreateTensorPileup.py b/benchmarks/nn-variant/Clair3/preprocess/CreateTensorPileup.py new file mode 100644 index 0000000..63e3095 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/CreateTensorPileup.py @@ -0,0 +1,558 @@ +import sys +import shlex +import logging +from subprocess import PIPE +from os.path import isfile +from argparse import ArgumentParser, SUPPRESS +from collections import Counter, defaultdict + +import shared.param_p as param +from shared.interval_tree import bed_tree_from, is_region_in +from preprocess.utils import variantInfoCalculator +from shared.utils import subprocess_popen, file_path_from, IUPAC_base_to_num_dict as BASE2NUM, region_from, \ + reference_sequence_from, str2bool, vcf_candidates_from, log_error + + +logging.getLogger().setLevel(logging.INFO) +BASES = set(list(BASE2NUM.keys()) + ["-"]) +flanking_base_num = param.flankingBaseNum +sliding_window_size = no_of_positions = 2 * flanking_base_num + 1 + +BASE2NUMBER = dict(zip( + "ACGTURYSWKMBDHVN-", + (0, 1, 2, 3, 3, 0, 1, 1, 0, 2, 0, 1, 0, 0, 0, 0, 4) +)) +channel = param.channel +channel_size = len(channel) +BASE2INDEX = dict(zip(channel, tuple(range(channel_size)))) + + +def phredscore2raw_score(qual): + return ord(qual) - 33 + + +def evc_base_from(base): + if base == 'N': + return 'A' + elif base == 'n': + return 'a' + elif base in 'ACGTacgt': + return base + elif base.isupper(): + return 'A' + else: + return 'a' + + +class CandidateStdout(object): + def __init__(self, handle): + self.stdin = handle + + def __del__(self): + self.stdin.close() + + +def generate_tensor(pos, pileup_bases, reference_sequence, reference_start, reference_base, minimum_af_for_candidate, + minimum_snp_af_for_candidate, minimum_indel_af_for_candidate, platform, fast_mode, call_snp_only): + """ + Generate pileup input tensor + pos: center position for pileup generation, default no_of_positions = flankingBaseNum + 1 + flankingBaseNum + pileup_bases: pileup bases list of each read in specific candidate position from samtools mpileup 1.10 + reference_sequence: the whole reference sequence index by contig:start-end. 0-based. + reference_base: upper reference base for cigar calculation. + reference_start: 0 based reference start position for region querying. + minimum_af_for_candidate: default minimum alleic frequency for candidate filtering, filter if below specific thredshold. + """ + + reference_base = evc_base_from(reference_base) + pileup_tensor = [0] * channel_size + base_idx = 0 + base_list = [] + alt_dict = defaultdict(int) + pileup_dict = defaultdict(int) + while base_idx < len(pileup_bases): + base = pileup_bases[base_idx] + if base in "ACGTNacgtn#*": + base_list.append(base) + elif base == '+' or base == '-': + base_idx += 1 + advance = 0 + while True: + num = pileup_bases[base_idx] + if num.isdigit(): + advance = advance * 10 + int(num) + base_idx += 1 + else: + break + base_list.append(base + pileup_bases[base_idx: base_idx + advance]) + base_idx += advance - 1 + + elif base == '^': # start of a read, next character is mapping quality + base_idx += 1 + # elif base == '$': # end of read with '$' symbol + base_idx += 1 + base_counter = Counter(base_list) + depth, max_ins_0, max_del_0, max_ins_1, max_del_1 = 0, 0, 0, 0, 0 + max_del_length = 0 + for key, count in base_counter.items(): + if key[0] == '+': + alt_dict['I' + reference_base + key[1:].upper()] += count + pileup_dict['I'] += count + # two strand + if key[1] in 'ACGTN*': + pileup_tensor[BASE2INDEX["I"]] += count + max_ins_0 = max(max_ins_0, count) + else: + pileup_tensor[BASE2INDEX["i"]] += count + max_ins_1 = max(max_ins_1, count) + elif key[0] == '-': + del_base = reference_sequence[pos - reference_start + 1: pos - reference_start + len(key[1:]) + 1] + alt_dict['D' + del_base] += count + pileup_dict['D'] += count + max_del_length = max(max_del_length, len(del_base)) + # two strand + if key[1] in 'N*ACGT': + pileup_tensor[BASE2INDEX["D"]] += count + max_del_0 = max(max_del_0, count) + else: + pileup_tensor[BASE2INDEX["d"]] += count + max_del_1 = max(max_del_1, count) + else: + if key.upper() in 'ACGT': + pileup_dict[key.upper()] += count + depth += count + if key.upper() != reference_base: + alt_dict['X' + key.upper()] += count + pileup_tensor[BASE2INDEX[key]] += count + elif key in '#*': + pileup_tensor[BASE2INDEX[key]] += count + depth += count + pileup_tensor[BASE2INDEX['I1']] = max_ins_0 + pileup_tensor[BASE2INDEX['i1']] = max_ins_1 + pileup_tensor[BASE2INDEX['D1']] = max_del_0 + pileup_tensor[BASE2INDEX['d1']] = max_del_1 + denominator = depth if depth > 0 else 1 + pileup_list = sorted(list(pileup_dict.items()), key=lambda x: x[1], reverse=True) + + pass_snp_af = False + pass_indel_af = False + fast_mode = platform == 'ont' and fast_mode + + minimum_snp_af_for_candidate = minimum_snp_af_for_candidate if minimum_snp_af_for_candidate > 0 else param.min_af + minimum_snp_af_for_candidate = max(minimum_snp_af_for_candidate, param.min_af_dict[platform]) if fast_mode else minimum_snp_af_for_candidate + minimum_indel_af_for_candidate = minimum_indel_af_for_candidate if minimum_indel_af_for_candidate > 0 else param.min_af_dict[platform] + + # check whether first non reference candidate in the first position + pass_af = len(pileup_list) and (pileup_list[0][0] != reference_base) + + for item, count in pileup_list: + if item == reference_base: + continue + elif item[0] in 'ID': + pass_indel_af = (pass_indel_af or (float(count) / denominator >= minimum_indel_af_for_candidate)) + continue + if fast_mode: + pass_snp_af = pass_snp_af or (float(count) / denominator >= minimum_snp_af_for_candidate and count >= 4) + else: + pass_snp_af = pass_snp_af or (float(count) / denominator >= minimum_snp_af_for_candidate) + + af = (float(pileup_list[1][1]) / denominator) if len(pileup_list) > 1 else 0.0 + af = (float(pileup_list[0][1]) / denominator) if len(pileup_list) >= 1 and pileup_list[0][ + 0] != reference_base else af + + pileup_tensor[BASE2INDEX[reference_base]] = -1 * sum([pileup_tensor[BASE2INDEX[item]] for item in 'ACGT']) + pileup_tensor[BASE2INDEX[reference_base.lower()]] = -1 * sum([pileup_tensor[BASE2INDEX[item]] for item in 'acgt']) + + pass_af = pass_snp_af if call_snp_only else (pass_af or pass_snp_af or pass_indel_af) + + # add a return: base_counter for generating GVCF + return pileup_tensor, alt_dict, af, depth, pass_af, pileup_list, max_del_length + + +class TensorStdout(object): + def __init__(self, handle): + self.stdin = handle + + def __del__(self): + self.stdin.close() + + +def CreateTensorPileup(args): + """ + Create pileup tensor for pileup model training or calling. + Use slide window to scan the whole candidate regions, keep all candidates over specific minimum allelic frequency + and minimum depth, use samtools mpileup to store pileup info for pileup tensor generation. Only scan candidate + regions once, we could directly get all variant candidates directly. + """ + ctg_start = args.ctgStart + ctg_end = args.ctgEnd + fasta_file_path = args.ref_fn + ctg_name = args.ctgName + samtools_execute_command = args.samtools + bam_file_path = args.bam_fn + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + tensor_can_output_path = args.tensor_can_fn + minimum_af_for_candidate = args.min_af + minimum_snp_af_for_candidate = args.snp_min_af + minimum_indel_af_for_candidate = args.indel_min_af + min_coverage = args.minCoverage + platform = args.platform + confident_bed_fn = args.bed_fn + is_confident_bed_file_given = confident_bed_fn is not None + alt_fn = args.indel_fn + extend_bed = args.extend_bed + is_extend_bed_file_given = extend_bed is not None + min_mapping_quality = args.minMQ + min_base_quality = args.minBQ + fast_mode = args.fast_mode + vcf_fn = args.vcf_fn + is_known_vcf_file_provided = vcf_fn is not None + call_snp_only = args.call_snp_only + + global test_pos + test_pos = None + + # 1-based regions [start, end] (start and end inclusive) + ref_regions = [] + reads_regions = [] + known_variants_set = set() + tree, bed_start, bed_end = bed_tree_from(bed_file_path=extend_bed, + contig_name=ctg_name, + return_bed_region=True) + + fai_fn = file_path_from(fasta_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + if not is_confident_bed_file_given and chunk_id is not None: + contig_length = 0 + with open(fai_fn, 'r') as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + + contig_name = columns[0] + if contig_name != ctg_name: + continue + contig_length = int(columns[1]) + chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num + ctg_start = chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + + if is_confident_bed_file_given and chunk_id is not None: + chunk_size = (bed_end - bed_start) // chunk_num + 1 if (bed_end - bed_start) % chunk_num else (bed_end - bed_start) // chunk_num + ctg_start = bed_start + 1 + chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + + if is_known_vcf_file_provided and chunk_id is not None: + known_variants_list = vcf_candidates_from(vcf_fn=vcf_fn, contig_name=ctg_name) + total_variants_size = len(known_variants_list) + chunk_variants_size = total_variants_size // chunk_num if total_variants_size % chunk_num == 0 else total_variants_size // chunk_num + 1 + chunk_start_pos = chunk_id * chunk_variants_size + known_variants_set = set(known_variants_list[chunk_start_pos: chunk_start_pos + chunk_variants_size]) + if len(known_variants_set) == 0: + return + ctg_start, ctg_end = min(known_variants_set), max(known_variants_set) + + is_ctg_name_given = ctg_name is not None + is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None + if is_ctg_range_given: + extend_start = ctg_start - no_of_positions + extend_end = ctg_end + no_of_positions + reads_regions.append(region_from(ctg_name=ctg_name, ctg_start=extend_start, ctg_end=extend_end)) + reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion + reference_start = 1 if reference_start < 1 else reference_start + ref_regions.append(region_from(ctg_name=ctg_name, ctg_start=reference_start, ctg_end=reference_end)) + elif is_ctg_name_given: + reads_regions.append(region_from(ctg_name=ctg_name)) + ref_regions.append(region_from(ctg_name=ctg_name)) + reference_start = 1 + + reference_sequence = reference_sequence_from( + samtools_execute_command=samtools_execute_command, + fasta_file_path=fasta_file_path, + regions=ref_regions + ) + + if reference_sequence is None or len(reference_sequence) == 0: + sys.exit(log_error("[ERROR] Failed to load reference sequence from file ({}).".format(fasta_file_path))) + + if is_confident_bed_file_given and ctg_name not in tree: + sys.exit(log_error("[ERROR] ctg_name {} not exists in bed file({}).".format(ctg_name, confident_bed_fn))) + + # samtools mpileup options + # reverse-del: deletion in forward/reverse strand were marked as '*'/'#' + min_base_quality = 0 if args.gvcf else min_base_quality + max_depth = param.max_depth_dict[args.platform] if args.platform else args.max_depth + mq_option = ' --min-MQ {}'.format(min_mapping_quality) + bq_option = ' --min-BQ {}'.format(min_base_quality) + flags_option = ' --excl-flags {}'.format(param.SAMTOOLS_VIEW_FILTER_FLAG) + max_depth_option = ' --max-depth {}'.format(max_depth) + bed_option = ' -l {}'.format(extend_bed) if is_extend_bed_file_given else "" + gvcf_option = ' -a' if args.gvcf else "" + samtools_mpileup_process = subprocess_popen( + shlex.split( + "{} mpileup {} -r {} --reverse-del".format(samtools_execute_command, + bam_file_path, + " ".join(reads_regions), ) + + mq_option + bq_option + bed_option + flags_option + max_depth_option + gvcf_option)) + + if tensor_can_output_path != "PIPE": + tensor_can_fpo = open(tensor_can_output_path, "wb") + tensor_can_fp = subprocess_popen(shlex.split("{} -c".format(param.zstd)), stdin=PIPE, stdout=tensor_can_fpo) + else: + tensor_can_fp = TensorStdout(sys.stdout) + + # whether save all alternative information, only for debug mode + if alt_fn: + alt_fp = open(alt_fn, 'w') + + pos_offset = 0 + pre_pos = -1 + tensor = [[]] * sliding_window_size + candidate_position = [] + all_alt_dict = {} + depth_dict = {} + af_dict = {} + + # to generate gvcf, it is needed to record whole genome statistical information + if args.gvcf: + nonVariantCaller = variantInfoCalculator(gvcfWritePath=args.temp_file_dir, ref_path=args.ref_fn, + bp_resolution=args.bp_resolution, ctgName=ctg_name,sample_name='.'.join( + [args.sampleName, ctg_name, str(ctg_start), str(ctg_end)]), p_err=args.base_err, + gq_bin_size=args.gq_bin_size) + + confident_bed_tree = bed_tree_from(bed_file_path=confident_bed_fn, contig_name=ctg_name, bed_ctg_start=extend_start, + bed_ctg_end=extend_end) + + + empty_pileup_flag = True + for row in samtools_mpileup_process.stdout: + empty_pileup_flag = False + columns = row.strip().split('\t',maxsplit=5) + pos = int(columns[1]) + pileup_bases = columns[4] + reference_base = reference_sequence[pos - reference_start].upper() + valid_reference_flag = True + within_flag = True + if args.gvcf: + if not valid_reference_flag: + nonVariantCaller.make_gvcf_online({}, push_current=True) + if ctg_start != None and ctg_end != None: + within_flag = pos >= ctg_start and pos <= ctg_end + elif ctg_start != None and ctg_end == None: + within_flag = pos >= ctg_start + elif ctg_start == None and ctg_end != None: + within_flag = pos <= ctg_end + else: + within_flag = True + if columns[3] == '0' and within_flag and valid_reference_flag: + cur_site_info = {'chr': columns[0], 'pos': pos, 'ref': reference_base, 'n_total': 0, 'n_ref': 0} + nonVariantCaller.make_gvcf_online(cur_site_info) + continue + + # start with a new region, clear all sliding windows cache, avoid memory occupation + if pre_pos + 1 != pos: + pos_offset = 0 + tensor = [[]] * sliding_window_size + candidate_position = [] + pre_pos = pos + + # a condition to skip some positions creating tensor,but return allele summary + # allele count function + pileup_tensor, alt_dict, af, depth, pass_af, pileup_list, max_del_length = generate_tensor(pos=pos, + pileup_bases=pileup_bases, + reference_sequence=reference_sequence, + reference_start=reference_start, + reference_base=reference_base, + minimum_af_for_candidate=minimum_af_for_candidate, + minimum_snp_af_for_candidate=minimum_snp_af_for_candidate, + minimum_indel_af_for_candidate=minimum_indel_af_for_candidate, + platform=platform, + fast_mode=fast_mode, + call_snp_only=call_snp_only) + if args.gvcf and within_flag and valid_reference_flag: + cur_n_total = 0 + cur_n_ref = 0 + for _key, _value in pileup_list: + if (_key == reference_base): + cur_n_ref = _value + cur_n_total += _value + + cur_site_info = {'chr': columns[0], 'pos': pos, 'ref': reference_base, 'n_total': cur_n_total, + 'n_ref': cur_n_ref} + nonVariantCaller.make_gvcf_online(cur_site_info) + + pass_confident_bed = not is_confident_bed_file_given or is_region_in(tree=confident_bed_tree, + contig_name=ctg_name, + region_start=pos - 1, + region_end=pos + max_del_length + 1) # 0-based + if (pass_confident_bed and reference_base in 'ACGT' and (pass_af and depth >= min_coverage) and not is_known_vcf_file_provided) or ( + is_known_vcf_file_provided and pos in known_variants_set): + candidate_position.append(pos) + all_alt_dict[pos] = alt_dict + depth_dict[pos] = depth + af_dict[pos] = af + tensor[pos_offset] = pileup_tensor + + # save pileup tensor for each candidate position with nearby flanking_base_num bp distance + pos_offset = (pos_offset + 1) % sliding_window_size + if len(candidate_position) and pos - candidate_position[0] == flanking_base_num: + center = candidate_position.pop(0) + has_empty_tensor = sum([True for item in tensor if not len(item)]) + if not has_empty_tensor: + depth = depth_dict[center] + ref_seq = reference_sequence[center - ( + flanking_base_num) - reference_start: center + flanking_base_num + 1 - reference_start] + concat_tensor = tensor[pos_offset:] + tensor[0:pos_offset] + + alt_info = str(depth) + '-' + ' '.join( + [' '.join([item[0], str(item[1])]) for item in list(all_alt_dict[center].items())]) + l = "%s\t%d\t%s\t%s\t%s" % ( + ctg_name, + center, + ref_seq, + " ".join(" ".join("%d" % x for x in innerlist) for innerlist in concat_tensor), + alt_info + ) + tensor_can_fp.stdin.write(l) + tensor_can_fp.stdin.write("\n") + if alt_fn: + alt_info = ' '.join( + [' '.join([item[0], str(item[1])]) for item in list(all_alt_dict[center].items())]) + alt_fp.write( + '\t'.join([ctg_name + ' ' + str(center), str(depth), alt_info, str(af_dict[center])]) + '\n') + del all_alt_dict[center], depth_dict[center], af_dict[center] + + if args.gvcf and len(nonVariantCaller.current_block) != 0: + nonVariantCaller.write_to_gvcf_batch(nonVariantCaller.current_block, nonVariantCaller.cur_min_DP, + nonVariantCaller.cur_raw_gq) + + if args.gvcf and empty_pileup_flag: + nonVariantCaller.write_empty_pileup(ctg_name,ctg_start,ctg_end) + if args.gvcf: + nonVariantCaller.close_vcf_writer() + + samtools_mpileup_process.stdout.close() + samtools_mpileup_process.wait() + + if tensor_can_output_path != "PIPE": + tensor_can_fp.stdin.close() + tensor_can_fp.wait() + tensor_can_fpo.close() + + if alt_fn: + alt_fp.close() + + +def main(): + parser = ArgumentParser(description="Generate variant candidate tensors using pileup") + + parser.add_argument('--platform', type=str, default='ont', + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--bam_fn', type=str, default="input.bam", required=True, + help="Sorted BAM file input, required") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", required=True, + help="Reference fasta file input, required") + + parser.add_argument('--tensor_can_fn', type=str, default="PIPE", + help="Tensor output, stdout by default, default: %(default)s") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + parser.add_argument('--min_af', type=float, default=0.08, + help="Minimum allele frequency for both SNP and Indel for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum snp allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.15, + help="Minimum indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, required if --bed_fn is not defined") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--bed_fn', type=str, default=None, + help="Call variant only in the provided regions. Will take an intersection if --ctgName and/or (--ctgStart, --ctgEnd) are set") + + parser.add_argument('--gvcf', type=str2bool, default=False, + help="Enable GVCF output, default: disabled") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file, default: %(default)s") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required. default: %(default)s") + + # options for advanced users + parser.add_argument('--fast_mode', type=str2bool, default=False, + help="EXPERIMENTAL: Skip variant candidates with AF <= 0.15, default: %(default)s") + + parser.add_argument('--minCoverage', type=float, default=2, + help="EXPERIMENTAL: Minimum coverage required to call a variant, default: %(default)f") + + parser.add_argument('--minMQ', type=int, default=param.min_mq, + help="EXPERIMENTAL: If set, reads with mapping quality with <$minMQ are filtered, default: %(default)d") + + parser.add_argument('--minBQ', type=int, default=param.min_bq, + help="EXPERIMENTAL: If set, bases with base quality with <$minBQ are filtered, default: %(default)d") + + parser.add_argument('--max_depth', type=int, default=param.max_depth, + help="EXPERIMENTAL: Maximum pileup depth to be processed. default: %(default)s") + + parser.add_argument('--call_snp_only', type=str2bool, default=False, + help="EXPERIMENTAL: Call candidates pass snp minimum AF only, ignore Indel candidates") + + # options for debug purpose + parser.add_argument('--extend_bed', type=str, default=None, + help="DEBUG: Extend the regions in the --bed_fn by a few bp for tensor creation, default extend 16bp") + + parser.add_argument('--temp_file_dir', type=str, default="./", + help="EXPERIMENTAL: The cache directory for storing temporary non-variant information if --gvcf is enabled, default: %(default)s") + + parser.add_argument('--indel_fn', type=str, default=None, + help="DEBUG: Output all alternative indel cigar for debug purpose") + + parser.add_argument('--base_err', default=param.base_err, type=float, + help='DEBUG: Estimated base error rate in gvcf option, default: %(default)f') + + parser.add_argument('--gq_bin_size', default=param.gq_bin_size, type=int, + help='DEBUG: Default gq bin size for merge non-variant block in gvcf option, default: %(default)d') + + parser.add_argument('--bp_resolution', action='store_true', + help="DEBUG: Enable bp resolution for GVCF, default: disabled") + + # options for internal process control + ## Path to the 'zstd' compression + parser.add_argument('--zstd', type=str, default=param.zstd, + help=SUPPRESS) + + ## Test in specific candidate position. Only for testing + parser.add_argument('--test_pos', type=int, default=0, + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + CreateTensorPileup(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/CreateTrainingTensor.py b/benchmarks/nn-variant/Clair3/preprocess/CreateTrainingTensor.py new file mode 100644 index 0000000..2f917cd --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/CreateTrainingTensor.py @@ -0,0 +1,307 @@ +import sys +import shlex +import subprocess +import signal +import random +import os +from os.path import dirname +from time import sleep +from argparse import ArgumentParser, SUPPRESS +import logging + +logging.getLogger().setLevel(logging.INFO) + + +from shared.command_options import ( + CommandOption, + CommandOptionWithNoValue, + ExecuteCommand, + command_string_from, + command_option_from +) +from shared.utils import file_path_from, executable_command_string_from, subprocess_popen, str2bool, log_warning +import shared.param_p as param + + +class InstancesClass(object): + def __init__(self): + self.create_tensor = None + self.compress_tensor = None + + def poll(self): + self.create_tensor.poll() + self.compress_tensor.poll() + + +c = InstancesClass() + + +def check_return_code(signum, frame): + c.poll() + if c.create_tensor.returncode != None and c.create_tensor.returncode != 0: + c.compress_tensor.kill() + sys.exit("CreateTensor.py exited with exceptions. Exiting...") + + if c.compress_tensor.returncode != None and c.compress_tensor.returncode != 0: + c.create_tensor.kill() + sys.exit("Tensor2Bin.py exited with exceptions. Exiting...") + + if ( + c.create_tensor.returncode == None or + c.compress_tensor.returncode == None + ): + signal.alarm(5) + + +def Run(args): + basedir = dirname(__file__) + + CTP_Bin = basedir + "/../clair3.py CreateTensorPileup" + CTFA_Bin = basedir + "/../clair3.py CreateTensorFullAlignment" + T2B_Bin = basedir + "/../clair3.py Tensor2Bin" + + if args.delay > 0: + delay = random.randrange(0, args.delay) + print("[INFO] Delay %d seconds before starting tensor creation ..." % (delay)) + sleep(delay) + + pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True) + pythonBin = executable_command_string_from(args.python, exit_on_not_found=True) + samtoolsBin = executable_command_string_from(args.samtools, exit_on_not_found=True) + + if args.pileup: + bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) + else: + bam_fn = file_path_from(args.bam_fn) + if bam_fn is None or bam_fn == "": + print(log_warning( + "[WARNING] Skip full-alignment variant calling for empty full-alignment regions")) + return + ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) + bed_fn = file_path_from(args.bed_fn) + vcf_fn = file_path_from(args.vcf_fn) + var_fn = file_path_from(args.var_fn, exit_on_not_found=True) + bin_fn = args.bin_fn + extend_bed = file_path_from(args.extend_bed) + full_aln_regions = file_path_from(args.full_aln_regions) + + platform = args.platform + if not platform or platform not in param.support_platform: + sys.exit("[ERROR] Provided platform are not in support platform list [ont, hifi, ilmn]") + + pileup = args.pileup + ctgName = args.ctgName + min_af = args.min_af if args.min_af else param.min_af_dict[platform] + snp_min_af = args.snp_min_af + indel_min_af = args.indel_min_af + + if ctgName is None: + sys.exit("--ctgName must be specified. You can call variants on multiple chromosomes simultaneously.") + + pileup_mode = command_option_from(args.pileup, 'pileup') + phasing_info_mode = command_option_from(args.phasing_info_in_bam, 'phasing_info_in_bam') + add_no_phasing_mode = command_option_from(args.add_no_phasing_data_training, 'add_no_phasing_data_training') + allow_duplicate_mode = command_option_from(args.allow_duplicate_chr_pos, 'allow_duplicate_chr_pos') + maximum_non_variant_ratio = CommandOption('maximum_non_variant_ratio', args.maximum_non_variant_ratio) + shuffle_mode = command_option_from(args.shuffle, 'shuffle') + + ctgStart = None + ctgEnd = None + chunk_id = None + chunk_num = None + if args.ctgStart is not None and args.ctgEnd is not None and int(args.ctgStart) <= int(args.ctgEnd): + ctgStart = CommandOption('ctgStart', args.ctgStart) + ctgEnd = CommandOption('ctgEnd', args.ctgEnd) + + if args.chunk_id is not None and args.chunk_num is not None and int(args.chunk_id) <= int(args.chunk_num): + chunk_id = CommandOption('chunk_id', args.chunk_id) + chunk_num = CommandOption('chunk_num', args.chunk_num) + + CT_Bin = CTP_Bin if pileup else CTFA_Bin + create_tensor_command_options = [ + pypyBin, + CT_Bin, + CommandOption('bam_fn', bam_fn), + CommandOption('ref_fn', ref_fn), + CommandOption('vcf_fn', vcf_fn), + CommandOption('ctgName', ctgName), + CommandOption('platform', platform), + CommandOption('samtools', samtoolsBin), + CommandOption('bed_fn', bed_fn), + CommandOption('extend_bed', extend_bed), + CommandOption('min_af', min_af), + CommandOption('snp_min_af', snp_min_af), + CommandOption('indel_min_af', indel_min_af), + ctgStart, + ctgEnd, + chunk_id, + chunk_num, + ] + + if not pileup: + create_tensor_command_options.append(phasing_info_mode) + create_tensor_command_options.append(add_no_phasing_mode) + create_tensor_command_options.append(CommandOption('full_aln_regions', full_aln_regions)) + + compress_tensor_command_options = [ + pythonBin, + T2B_Bin, + CommandOption('platform', platform), + CommandOption('var_fn', var_fn), + CommandOption('bin_fn', bin_fn), + CommandOption('bed_fn', bed_fn), + chunk_id, + chunk_num, + allow_duplicate_mode, + maximum_non_variant_ratio, + shuffle_mode, + ] + if pileup: + compress_tensor_command_options.append(pileup_mode) + + try: + c.create_tensor = subprocess_popen( + shlex.split(command_string_from(create_tensor_command_options)), + ) + + c.compress_tensor = subprocess_popen( + shlex.split(command_string_from(compress_tensor_command_options)), + stdin=c.create_tensor.stdout, stdout=sys.stderr + ) + except Exception as e: + print(e, file=sys.stderr) + sys.exit("Failed to start required processes. Exiting...") + + signal.signal(signal.SIGALRM, check_return_code) + signal.alarm(2) + + try: + c.compress_tensor.wait() + signal.alarm(0) + c.create_tensor.stdout.close() + c.create_tensor.wait() + except KeyboardInterrupt as e: + print("KeyboardInterrupt received when waiting at Tensor2Bin, terminating all scripts.") + try: + c.compress_tensor.terminate() + c.create_tensor.terminate() + except Exception as e: + print(e) + + raise KeyboardInterrupt + except Exception as e: + print("Exception received when waiting at CreateTensor, terminating all scripts.") + print(e) + try: + c.compress_tensor.terminate() + c.create_tensor.terminate() + except Exception as e: + print(e) + + raise e + + +def main(): + parser = ArgumentParser(description="Create tensor binaries for pileup or full-alignment training") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--bam_fn', type=str, default="bam.bam", required=True, + help="BAM file input, required") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", required=True, + help="Reference fasta file input, required") + + parser.add_argument('--var_fn', type=str, default=None, required=True, + help="Unified VCF input filename, required") + + parser.add_argument('--bin_fn', type=str, default=None, required=True, + help="Compressed binary output filename, required") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, required if --bed_fn is not defined") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--bed_fn', type=str, nargs='?', action="store", default=None, + help="Call variant only in the provided regions. Will take an intersection if --ctgName and/or (--ctgStart, --ctgEnd) are set") + + parser.add_argument('--min_af', type=float, default=None, + help="Minimum allele frequency for both SNP and Indel for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--snp_min_af', type=float, default=0.08, + help="Minimum SNP allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--indel_min_af', type=float, default=0.08, + help="Minimum Indel allele frequency for a site to be considered as a candidate site, default: %(default)f") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + parser.add_argument('--pypy', type=str, default="pypy3", + help="Path to the 'pypy', pypy3 version >= 3.6 is required, default: %(default)s") + + parser.add_argument('--python', type=str, default="python3", + help="Path to the 'python3', default: %(default)s") + + # options for advanced users + parser.add_argument('--maximum_non_variant_ratio', default=None, type=float, + help='Maximum ratio of non-variants to variants, default: %(default)f') + + parser.add_argument('--extend_bed', nargs='?', action="store", type=str, default=None, + help="DEBUG: Extend the regions in the --bed_fn by a few bp for tensor creation, default extend 16bp") + + parser.add_argument('--phasing_info_in_bam', action='store_true', + help="DEBUG: Skip phasing and use the phasing info provided in the input BAM (HP tag), default: False") + + # options for internal process control, don't use any of them unless you are sure about the consequences + ## In pileup mode or not + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + ## Provide the regions to be included in full-alignment based calling + parser.add_argument('--full_aln_regions', type=str, default=None, + help=SUPPRESS) + + parser.add_argument('--add_no_phasing_data_training', action='store_true', + help=SUPPRESS) + + parser.add_argument('--allow_duplicate_chr_pos', action='store_true', + help=SUPPRESS) + + parser.add_argument('--shuffle', action='store_true', + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Wait a short while for no more than a few seconds to start the job. This is to avoid starting multiple jobs simultaneously + ## that might use up the maximum number of threads allowed, because Tensorflow will create more threads than needed at the beginning of running the program + ## Obseleted after adding --tensorflow_threads defaulted at 4 + parser.add_argument('--delay', type=int, default=5, + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Run(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/GetTruth.py b/benchmarks/nn-variant/Clair3/preprocess/GetTruth.py new file mode 100644 index 0000000..e78f3a1 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/GetTruth.py @@ -0,0 +1,117 @@ +import sys +import shlex +from subprocess import PIPE +from argparse import ArgumentParser +from shared.utils import subprocess_popen, vcf_candidates_from + +class TruthStdout(object): + def __init__(self, handle): + self.stdin = handle + + def __del__(self): + self.stdin.close() + +def OutputVariant(args): + var_fn = args.var_fn + vcf_fn = args.vcf_fn + truth_vcf_fn = args.truth_vcf_fn + ctg_name = args.ctgName + ctg_start = args.ctgStart + ctg_end = args.ctgEnd + + truth_vcf_set = set() + variant_set = set() + if args.truth_vcf_fn is not None: + truth_vcf_set = set(vcf_candidates_from(vcf_fn=truth_vcf_fn, contig_name=ctg_name)) + if args.var_fn != "PIPE": + var_fpo = open(var_fn, "wb") + var_fp = subprocess_popen(shlex.split("gzip -c"), stdin=PIPE, stdout=var_fpo) + else: + var_fp = TruthStdout(sys.stdout) + + is_ctg_region_provided = ctg_start is not None and ctg_end is not None + + vcf_fp = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + + for row in vcf_fp.stdout: + columns = row.strip().split() + if columns[0][0] == "#": + continue + + # position in vcf is 1-based + chromosome, position = columns[0], columns[1] + if chromosome != ctg_name: + continue + if is_ctg_region_provided and not (ctg_start <= int(position) <= ctg_end): + continue + reference, alternate, last_column = columns[3], columns[4], columns[-1] + # normal GetTruth + genotype = last_column.split(":")[0].replace("/", "|").replace(".", "0").split("|") + genotype_1, genotype_2 = genotype + + # 1000 Genome GetTruth (format problem) (no genotype is given) + if int(genotype_1) > int(genotype_2): + genotype_1, genotype_2 = genotype_2, genotype_1 + + #remove * to guarentee vcf match + if '*' in alternate: + alternate = alternate.split(',') + if int(genotype_1) + int(genotype_2) != 3 or len(alternate) != 2: + print ('error with variant represatation') + continue + alternate = ''.join([alt_base for alt_base in alternate if alt_base != '*']) + # * always have a genotype 1/2 + + genotype_1, genotype_2 = '0', '1' + + variant_set.add(int(position)) + var_fp.stdin.write(" ".join((chromosome, position, reference, alternate, genotype_1, genotype_2))) + var_fp.stdin.write("\n") + + for position in truth_vcf_set: + if position not in variant_set: + # miss variant set used in Tensor2Bin + var_fp.stdin.write(" ".join((chromosome, str(position), "None", "None", "-1", "-1"))) + var_fp.stdin.write("\n") + + vcf_fp.stdout.close() + vcf_fp.wait() + + if args.var_fn != "PIPE": + var_fp.stdin.close() + var_fp.wait() + var_fpo.close() + + +def main(): + parser = ArgumentParser(description="Extract variant type and allele from a truth dataset") + + parser.add_argument('--vcf_fn', type=str, default="input.vcf", required=True, + help="Truth VCF file input, required") + + parser.add_argument('--var_fn', type=str, default="PIPE", + help="Truth variants output, use PIPE for standard output, default: %(default)s") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed") + + parser.add_argument('--truth_vcf_fn', type=str, default=None, + help="Truth VCF file input, only used when vcf_fn is unified vcf. Marked truth variants not in unified as missing") + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + OutputVariant(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/MergeBin.py b/benchmarks/nn-variant/Clair3/preprocess/MergeBin.py new file mode 100644 index 0000000..42e41c4 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/MergeBin.py @@ -0,0 +1,96 @@ +import sys +import logging +import numpy as np +from argparse import ArgumentParser, SUPPRESS +import tables + +import clair3.utils as utils + +logging.basicConfig(format='%(message)s', level=logging.INFO) + +def Run(args): + in_fn_list = args.in_fn + out_fn = args.out_fn + platform = args.platform + pileup = args.pileup + + global param + float_type = 'int32' + if pileup: + import shared.param_p as param + else: + import shared.param_f as param + float_type = 'int8' + + tensor_shape = param.ont_input_shape if platform == 'ont' else param.input_shape + + # select all match prefix if file path not exists + tables.set_blosc_max_threads(64) + int_atom = tables.Atom.from_dtype(np.dtype(float_type)) + string_atom = tables.StringAtom(itemsize=param.no_of_positions + 50) + long_string_atom = tables.StringAtom(itemsize=5000) # max alt_info length + table_file = tables.open_file(out_fn, mode='w', filters=utils.FILTERS) + table_file.create_earray(where='/', name='position_matrix', atom=int_atom, shape=[0] + tensor_shape, + filters=utils.FILTERS) + table_file.create_earray(where='/', name='position', atom=string_atom, shape=(0, 1), filters=utils.FILTERS) + table_file.create_earray(where='/', name='label', atom=int_atom, shape=(0, param.label_size), filters=utils.FILTERS) + table_file.create_earray(where='/', name='alt_info', atom=long_string_atom, shape=(0, 1), filters=utils.FILTERS) + + table_dict = utils.update_table_dict() + total_compressed = 0 + + for f in in_fn_list: + print("[INFO] Merging file {}".format(f)) + fi = tables.open_file(f, model='r') + assert (len(fi.root.label) == len(fi.root.position) == len(fi.root.position_matrix) == len(fi.root.alt_info)) + for index in range(len(fi.root.label)): + table_dict['label'].append(fi.root.label[index]) + table_dict['position'].append(fi.root.position[index]) + table_dict['position_matrix'].append(fi.root.position_matrix[index]) + table_dict['alt_info'].append(fi.root.alt_info[index]) + + total_compressed += 1 + + if total_compressed % 500 == 0 and total_compressed > 0: + table_dict = utils.write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type) + + if total_compressed % 50000 == 0: + print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr) + fi.close() + + if total_compressed % 500 != 0 and total_compressed > 0: + table_dict = utils.write_table_file(table_file, table_dict, tensor_shape, param.label_size, float_type) + print("[INFO] Compressed %d tensor" % (total_compressed), file=sys.stderr) + + table_file.close() + + +def main(): + parser = ArgumentParser(description="Combine tensor binaries into a single file") + + parser.add_argument('in_fn', type=str, nargs='+', + help="Tensor input files, required") + + parser.add_argument('--out_fn', type=str, default=None, required=True, + help="Output a binary tensor file, required") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + ## In pileup mode or not (full alignment mode), default: False + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Run(args) + + +if __name__ == "__main__": + main() + + diff --git a/benchmarks/nn-variant/Clair3/preprocess/MergeVcf.py b/benchmarks/nn-variant/Clair3/preprocess/MergeVcf.py new file mode 100644 index 0000000..aca403c --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/MergeVcf.py @@ -0,0 +1,354 @@ +import sys +import os +import shlex +import logging +import heapq + +logging.basicConfig(format='%(message)s', level=logging.INFO) + +from argparse import ArgumentParser, SUPPRESS +from shared.utils import subprocess_popen, str2bool, log_error, log_warning +import shared.param_f as param +from shared.interval_tree import bed_tree_from, is_region_in +from preprocess.utils import gvcfGenerator + + +def update_haploid_precise_genotype(columns): + INFO = columns[9].split(':') + genotype_string = INFO[0].replace('|', '/') + + if genotype_string == '1/1': + genotype = ['1'] + elif genotype_string == '0/0': + genotype = ['0'] + else: + return "" + # update genotype + columns[9] = ':'.join(genotype + INFO[1:]) + row = '\t'.join(columns) + '\n' + return row + +def update_haploid_sensitive_genotype(columns): + INFO = columns[9].split(':') + genotype_string = INFO[0].replace('|', '/') + ref_base, alt_base = columns[3], columns[4] + is_multi = ',' in alt_base + + if is_multi: + return "" + + if genotype_string in ('0/1','1/0','1/1'): + genotype = ['1'] + else: + genotype = ['0'] + # update genotype + columns[9] = ':'.join(genotype + INFO[1:]) + row = '\t'.join(columns) + '\n' + return row + +def MarkLowQual(row, quality_score_for_pass, qual): + if row == '': + return row + + if quality_score_for_pass and qual <= quality_score_for_pass: + row = row.split("\t") + row[6] = "LowQual" + return '\t'.join(row) + return row + +def MergeVcf_illumina(args): + # region vcf merge for illumina, as read realignment will make candidate varaints shift and missing. + bed_fn_prefix = args.bed_fn_prefix + output_fn = args.output_fn + full_alignment_vcf_fn = args.full_alignment_vcf_fn + pileup_vcf_fn = args.pileup_vcf_fn # true vcf var + contig_name = args.ctgName + QUAL = args.qual + bed_fn = None + if not os.path.exists(bed_fn_prefix): + exit(log_error("[ERROR] Input directory: {} not exists!").format(bed_fn_prefix)) + + all_files = os.listdir(bed_fn_prefix) + all_files = [item for item in all_files if item.startswith(contig_name + '.')] + if len(all_files) != 0: + bed_fn = os.path.join(bed_fn_prefix, "full_aln_regions_{}".format(contig_name)) + with open(bed_fn, 'w') as output_file: + for file in all_files: + with open(os.path.join(bed_fn_prefix, file)) as f: + output_file.write(f.read()) + + is_haploid_precise_mode_enabled = args.haploid_precise + is_haploid_sensitive_mode_enabled = args.haploid_sensitive + print_ref = args.print_ref_calls + + tree = bed_tree_from(bed_file_path=bed_fn, padding=param.no_of_positions, contig_name=contig_name) + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (pileup_vcf_fn))) + output_dict = {} + header = [] + pileup_count = 0 + for row in unzip_process.stdout: + if row[0] == '#': + header.append(row) + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name != None and ctg_name != contig_name: + continue + pos = int(columns[1]) + qual = float(columns[5]) + pass_bed = is_region_in(tree, ctg_name, pos) + ref_base, alt_base = columns[3], columns[4] + is_reference = (alt_base == "." or ref_base == alt_base) + if is_haploid_precise_mode_enabled: + row = update_haploid_precise_genotype(columns) + if is_haploid_sensitive_mode_enabled: + row = update_haploid_sensitive_genotype(columns) + + if not pass_bed: + if not is_reference: + row = MarkLowQual(row, QUAL, qual) + output_dict[pos] = row + pileup_count += 1 + elif print_ref: + output_dict[pos] = row + pileup_count += 1 + + unzip_process.stdout.close() + unzip_process.wait() + + realigned_vcf_unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (full_alignment_vcf_fn))) + realiged_read_num = 0 + for row in realigned_vcf_unzip_process.stdout: + if row[0] == '#': + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name != None and ctg_name != contig_name: + continue + + pos = int(columns[1]) + qual = float(columns[5]) + ref_base, alt_base = columns[3], columns[4] + is_reference = (alt_base == "." or ref_base == alt_base) + + if is_haploid_precise_mode_enabled: + row = update_haploid_precise_genotype(columns) + if is_haploid_sensitive_mode_enabled: + row = update_haploid_sensitive_genotype(columns) + + if is_region_in(tree, ctg_name, pos): + if not is_reference: + row = MarkLowQual(row, QUAL, qual) + output_dict[pos] = row + realiged_read_num += 1 + elif print_ref: + output_dict[pos] = row + realiged_read_num += 1 + + logging.info('[INFO] Pileup positions variants proceeded in {}: {}'.format(contig_name, pileup_count)) + logging.info('[INFO] Realigned positions variants proceeded in {}: {}'.format(contig_name, realiged_read_num)) + realigned_vcf_unzip_process.stdout.close() + realigned_vcf_unzip_process.wait() + + with open(output_fn, 'w') as output_file: + output_list = header + [output_dict[pos] for pos in sorted(output_dict.keys())] + output_file.write(''.join(output_list)) + + +def MergeVcf(args): + """ + Merge pileup and full alignment vcf output. We merge the low quality score pileup candidates + recalled by full-alignment model with high quality score pileup output. + """ + + output_fn = args.output_fn + full_alignment_vcf_fn = args.full_alignment_vcf_fn + pileup_vcf_fn = args.pileup_vcf_fn # true vcf var + contig_name = args.ctgName + QUAL = args.qual + is_haploid_precise_mode_enabled = args.haploid_precise + is_haploid_sensitive_mode_enabled = args.haploid_sensitive + print_ref = args.print_ref_calls + full_alignment_vcf_unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (full_alignment_vcf_fn))) + + full_alignment_output = [] + full_alignment_output_set = set() + header = [] + + for row in full_alignment_vcf_unzip_process.stdout: + if row[0] == '#': + header.append(row) + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name != None and ctg_name != contig_name: + continue + pos = int(columns[1]) + qual = float(columns[5]) + ref_base, alt_base = columns[3], columns[4] + is_reference = (alt_base == "." or ref_base == alt_base) + + full_alignment_output_set.add((ctg_name, pos)) + + if is_haploid_precise_mode_enabled: + row = update_haploid_precise_genotype(columns) + if is_haploid_sensitive_mode_enabled: + row = update_haploid_sensitive_genotype(columns) + + if not is_reference: + row = MarkLowQual(row, QUAL, qual) + full_alignment_output.append((pos, row)) + + elif print_ref: + full_alignment_output.append((pos, row)) + + full_alignment_vcf_unzip_process.stdout.close() + full_alignment_vcf_unzip_process.wait() + + pileup_vcf_unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (pileup_vcf_fn))) + + output_file = open(output_fn, 'w') + output_file.write(''.join(header)) + + def pileup_vcf_generator_from(pileup_vcf_unzip_process): + pileup_row_count = 0 + for row in pileup_vcf_unzip_process.stdout: + if row[0] == '#': + continue + + columns = row.rstrip().split('\t') + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + qual = float(columns[5]) + ref_base, alt_base = columns[3], columns[4] + is_reference = (alt_base == "." or ref_base == alt_base) + + if (ctg_name, pos) in full_alignment_output_set: + continue + + if is_haploid_precise_mode_enabled: + row = update_haploid_precise_genotype(columns) + if is_haploid_sensitive_mode_enabled: + row = update_haploid_sensitive_genotype(columns) + + if not is_reference: + row = MarkLowQual(row, QUAL, qual) + pileup_row_count += 1 + yield (pos, row) + elif print_ref: + pileup_row_count += 1 + yield (pos, row) + + logging.info('[INFO] Pileup variants processed in {}: {}'.format(contig_name, pileup_row_count)) + + pileup_vcf_generator = pileup_vcf_generator_from(pileup_vcf_unzip_process=pileup_vcf_unzip_process) + full_alignment_vcf_generator = iter(full_alignment_output) + for vcf_infos in heapq.merge(full_alignment_vcf_generator, pileup_vcf_generator): + if len(vcf_infos) != 2: + continue + pos, row = vcf_infos + output_file.write(row) + + logging.info('[INFO] Full-alignment variants processed in {}: {}'.format(contig_name, len(full_alignment_output))) + + pileup_vcf_unzip_process.stdout.close() + pileup_vcf_unzip_process.wait() + output_file.close() + +def mergeNonVariant(args): + ''' + merge the variant calls and non-variants + + ''' + gvcf_generator = gvcfGenerator(ref_path=args.ref_fn, samtools=args.samtools) + raw_gvcf_path = args.non_var_gvcf_fn + raw_vcf_path = args.output_fn + + if (args.gvcf_fn == None): + save_path = args.call_fn.split('.')[0] + '.g.vcf' + else: + save_path = args.gvcf_fn + logging.info("[INFO] Merge variants and non-variants to GVCF") + gvcf_generator.mergeCalls(raw_vcf_path, raw_gvcf_path, save_path, args.sampleName, args.ctgName, args.ctgStart, + args.ctgEnd) + pass + + +def main(): + parser = ArgumentParser(description="Generate 1-based variant candidates using alignments") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--ref_fn', type=str, default=None, + help="Reference fasta file input") + + parser.add_argument('--pileup_vcf_fn', type=str, default=None, + help="Path to the pileup vcf file") + + parser.add_argument('--full_alignment_vcf_fn', type=str, default=None, + help="Path to the full alignment vcf file") + + parser.add_argument('--gvcf', type=str2bool, default=False, + help="Enable GVCF output, default: disabled") + + parser.add_argument('--non_var_gvcf_fn', type=str, default=None, + help='Path to the non-variant GVCF') + + parser.add_argument('--gvcf_fn', type=str, default=None, + help='Filename of the GVCF output') + + parser.add_argument('--output_fn', type=str, default=None, + help="Filename of the merged output") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed") + + parser.add_argument('--bed_fn_prefix', type=str, default=None, + help="Process variant only in the provided regions prefix") + + parser.add_argument('--qual', type=int, default=2, + help="If set, variants with >$qual will be marked 'PASS', or 'LowQual' otherwise, optional") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file") + + parser.add_argument('--samtools', type=str, default='samtools', + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + parser.add_argument('--print_ref_calls', type=str2bool, default=False, + help="Show reference calls (0/0) in vcf file output") + + # options for advanced users + parser.add_argument('--haploid_precise', type=str2bool, default=False, + help="EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant") + + parser.add_argument('--haploid_sensitive', type=str2bool, default=False, + help="EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant") + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + # realignment region merge + if args.platform == 'ilmn': + MergeVcf_illumina(args=args) + else: + MergeVcf(args=args) + + if (args.gvcf): + mergeNonVariant(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/RealignReads.py b/benchmarks/nn-variant/Clair3/preprocess/RealignReads.py new file mode 100644 index 0000000..22e76ba --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/RealignReads.py @@ -0,0 +1,705 @@ +import sys +import os +import shlex +import ctypes +import re +from subprocess import PIPE +from os.path import isfile +from argparse import ArgumentParser, SUPPRESS +from collections import defaultdict + +import shared.param_f as param +from shared.utils import file_path_from, subprocess_popen, reference_sequence_from, \ + IUPAC_base_to_ACGT_base_dict as BASE2ACGT, IUPAC_base_to_num_dict as BASE2NUM + +from shared.interval_tree import bed_tree_from +from shared.intervaltree.intervaltree import IntervalTree + +min_dbg_mapping_quality = min_dbg_base_quality = 20 +region_expansion_in_bp = expand_align_ref_region = 20 +min_windows_distance = expand_align_ref_region * 4 +max_window_size = max_region_reads_num = 1000 + +# using 5 charaters for store long read name +CHAR_STR = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!#$%&()*+./:;<=>?[]^`{|}~" +L_CHAR_STR = len(CHAR_STR) +EXP = 5 +T_READ_NAME = L_CHAR_STR ** EXP +L_CHAR_STR_EXP = [L_CHAR_STR ** i for i in range(EXP - 1, 0, -1)] + +realigner_mod = os.path.join(*(os.path.split(__file__)[:-1] + ('realign/realigner',))) +dbg_mod = os.path.join(*(os.path.split(__file__)[:-1] + ('realign/debruijn_graph',))) + +realigner = ctypes.cdll.LoadLibrary(realigner_mod) +dbg = ctypes.cdll.LoadLibrary(dbg_mod) + + +class StructPointer(ctypes.Structure): + _fields_ = [("position", ctypes.c_int * max_region_reads_num), + ("cigar_string", ctypes.c_char_p * max_region_reads_num), + ] + + +class DBGPointer(ctypes.Structure): + _fields_ = [("consensus_size", ctypes.c_int), + ("consensus", ctypes.c_char_p * 200), + ] + + +#Read class for storing read information +cigar_indel_re = r"(\d+)(D)" +cigarRe = r"(\d+)([MIDNSHP=X])" +graph_min_mapping_quality = 14 +def get_len(seq, cigar): + if 'D' not in cigar: + return len(seq) + indel_length = 0 + for m in re.finditer(cigar_indel_re, cigar): + indel_length += int(m.group(1)) + return len(seq) + indel_length + + +class Read(object): + def __init__(self, read_start, seq, cigar, mapping_quality, base_quality, strand, raw_base_quality=None, + unalign=False, read_name=None, read_id=None, flag=None, RNEXT=0, PNEXT=0, TLEN=0, phasing=None): + self.read_start = read_start + self.cigar = cigar + self.mapping_quality = mapping_quality + self.seq = seq + self.base_quality = base_quality + self.read_id = read_id + self.read_end = self.read_start + get_len(seq, cigar) + self.strand = strand + self.graph_mq = True if self.mapping_quality >= graph_min_mapping_quality else False + self.raw_base_quality = raw_base_quality + self.read_name = read_name + self.region = {} + self.region_cigar = None + self.region_start = None + self.flag = str(flag) + self.RNEXT = RNEXT + self.PNEXT = PNEXT + self.TLEN = PNEXT + self.test_pos = None + self.best_cigar = cigar + self.best_pos = read_start + self.best_align_score = None + self.phasing = phasing + + def set_realign_flag(self): + self.unalign = True + + def count_align_score(self, cigar): + score = 0 + for m in re.finditer(cigarRe, cigar): + l, op, = int(m.group(1)), m.group(2) + if op in 'MX=S': + continue + elif op in 'ID': + score += l + return score + + def set_realignment_info(self, region_start, realignment_cigar, realignment_start): + realignment_cigar = realignment_cigar.replace('X', 'M') + if realignment_cigar == self.cigar and realignment_start == self.read_start: + return + + if self.best_align_score and realignment_cigar == self.best_cigar and realignment_start == self.best_pos: + return + realignment_align_score = self.count_align_score(realignment_cigar) + if not self.best_align_score or realignment_align_score >= self.best_align_score: + self.best_cigar = realignment_cigar + self.best_pos = realignment_start + self.best_align_score = realignment_align_score + + def decode_region(self, region_str): + if region_str == '-' or '-' not in region_str: + return + region_str = region_str.rstrip().split('_') + for region in region_str: + region, cigar, pos = region.split('-') + region, pos = int(region), int(pos) + self.region[region] = [cigar, pos] + + + +def byte(x): + return bytes(x, encoding="utf8") + + +def find_max_overlap_index(query_region, search_regions): + def overlap_length(region1, region2): + return max(0, (min(region1[1], region2[1]) - max(region1[0], region2[0]))) + + overlap_lengths = [overlap_length(query_region, search_region) for search_region in search_regions] + argmax = max(range(len(search_regions)), key=lambda idx: overlap_lengths[idx]) + return None if overlap_lengths[argmax] == 0 else argmax + + +def get_reference_seq(sequence, start, end, reference_start_0_based): + if end < start: + end, start = start, end + return sequence[start - reference_start_0_based: end - reference_start_0_based] + + +def phredscore2raw_score(qual): + return ord(qual) - 33 + + +def evc_base_from(base): + return base if base == "N" else BASE2ACGT[base] + + +def region_from(ctg_name, ctg_start=None, ctg_end=None): + """ + 1-based region string [start, end] + """ + if ctg_name is None: + return "" + if (ctg_start is None) != (ctg_end is None): + return "" + + if ctg_start is None and ctg_end is None: + return "{}".format(ctg_name) + return "{}:{}-{}".format(ctg_name, ctg_start, ctg_end) + + +class TensorStdout(object): + def __init__(self, handle): + self.stdin = handle + + def __del__(self): + self.stdin.close() + + +def get_halpotype_tag(samtools_view_columns): + found_hp_tag = False + tag = [c for c in samtools_view_columns if 'HP:i:' in c] + if not len(tag) or len(tag[0]) < 6 or not tag[0][5].isdigit(): + return None + return tag[0][5] + + +def is_too_many_soft_clipped_bases_for_a_read_from(CIGAR): + soft_clipped_bases = 0 + total_alignment_positions = 0 + + advance = 0 + for c in str(CIGAR): + if c.isdigit(): + advance = advance * 10 + int(c) + continue + if c == "S": + soft_clipped_bases += advance + total_alignment_positions += advance + advance = 0 + + # skip a read less than 55% aligned + return 1.0 - float(soft_clipped_bases) / (total_alignment_positions + 1) < 0.55 + + +def samtools_view_generator_from(samtools_view_process, aligned_reads, pileup, ctg_name, reference_sequence, + reference_start_0_based, header): + CHUNK_SIZE = param.realign_chunk_size + chunk_start, chunk_end = None, None + rs_idx = -1 + for row_id, row in enumerate(samtools_view_process.stdout): + if row[0] == '@': + header.append(row) + continue + columns = row.strip().split() + RNAME = columns[2] + if RNAME != ctg_name: + continue + + read_name = columns[0] + FLAG = int(columns[1]) + POS = int(columns[3]) - 1 # switch from 1-base to 0-base to match sequence index + MAPQ = int(columns[4]) + CIGAR = columns[5] + SEQ = columns[9].upper() # uppercase for SEQ (regexp is \*|[A-Za-z=.]+) + RNEXT = columns[6] + PNEXT = columns[7] + TLEN = columns[8] + reference_position = POS + query_position = 0 + raw_base_quality = columns[10] + QUAL = [phredscore2raw_score(item) for item in raw_base_quality] + STRAND = (16 == (FLAG & 16)) + HP_TAG = get_halpotype_tag(columns[11:]) + read_name += "_" + str(int(STRAND)) # distinguish two strand + read_name, rs_idx = simplfy_read_name(rs_idx) + if chunk_start is None: + chunk_start = POS + chunk_end = chunk_start + CHUNK_SIZE + if POS >= chunk_end + region_expansion_in_bp: + yield chunk_start, chunk_end + chunk_start += CHUNK_SIZE + chunk_end += CHUNK_SIZE + + read = Read(read_start=POS, + seq=SEQ, + cigar=CIGAR, + mapping_quality=MAPQ, + base_quality=QUAL, + strand=STRAND, + raw_base_quality=raw_base_quality, + read_name=read_name, + flag=FLAG, + PNEXT=PNEXT, + RNEXT=RNEXT, + TLEN=TLEN, + phasing=HP_TAG) + + if CIGAR == "*" or is_too_many_soft_clipped_bases_for_a_read_from(CIGAR): + continue + + aligned_reads[read_name] = read + if MAPQ < min_dbg_mapping_quality: + continue + advance = 0 + for c in str(CIGAR): + if c.isdigit(): + advance = advance * 10 + int(c) + continue + if c == '=': + reference_position += advance + query_position += advance + elif c == "M" or c == 'X': + for _ in range(advance): + if QUAL[query_position] >= min_dbg_base_quality: + reference_base = reference_sequence[reference_position - reference_start_0_based] # 0 base + query_base = SEQ[query_position] + if reference_base in 'ACGT' and query_base != reference_base: + pileup[reference_position]['X'] += 1 + reference_position += 1 + query_position += 1 + + elif c == "I" or c == 'S': + pre_base = reference_sequence[reference_position - reference_start_0_based - 1] + ins_base_quality = QUAL[query_position: query_position + advance] + out_of_region = reference_position < chunk_start - region_expansion_in_bp or reference_position > chunk_end + region_expansion_in_bp + if not out_of_region and pre_base in 'ACGT' and ( + sum([True for bq in ins_base_quality if bq < min_dbg_base_quality]) == 0): + # skip the bad seq + start = reference_position - advance + end = reference_position + advance + for ins_idx in range(start, end): + pileup[ins_idx]["X"] += 1 + + # insertion consumes query + query_position += advance + + elif c == "D": + out_of_region = reference_position < chunk_start - region_expansion_in_bp or reference_position > chunk_end + region_expansion_in_bp + pre_base = reference_sequence[reference_position - reference_start_0_based - 1] # 0-base + if not out_of_region and pre_base in 'ACGT': + start = reference_position + end = reference_position + advance + for ins_idx in range(start, end): + pileup[ins_idx]["X"] += 1 + # deletion consumes reference + reference_position += advance + # reset advance + advance = 0 + + yield chunk_start, chunk_end + yield None, None + + +def simplfy_read_name(rs_idx): + rs_idx = (rs_idx + 1) % T_READ_NAME + save_read_name = "" + div_num = rs_idx + for div_exp in L_CHAR_STR_EXP: + save_read_name += CHAR_STR[div_num // div_exp] + div_num = div_num % div_exp + if EXP != 1: + save_read_name += CHAR_STR[div_num % L_CHAR_STR] + return save_read_name, rs_idx + + +def reads_realignment(args): + bed_file_path = args.full_aln_regions + extend_bed = args.extend_bed + fasta_file_path = args.ref_fn + ctg_name = args.ctgName + ctg_start = args.ctgStart + ctg_end = args.ctgEnd + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + chunk_num = args.chunk_num + samtools_execute_command = args.samtools + bam_file_path = args.bam_fn + minMQ = args.minMQ + min_coverage = args.minCoverage + is_bed_file_given = bed_file_path is not None + is_ctg_name_given = ctg_name is not None + read_fn = args.read_fn + + global test_pos + test_pos = None + if is_bed_file_given: + candidate_file_path_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_file_path))) + candidate_file_path_output = candidate_file_path_process.stdout + + ctg_start, ctg_end = float('inf'), 0 + for row in candidate_file_path_output: + row = row.rstrip().split('\t') + if row[0] != ctg_name: continue + position = int(row[1]) + 1 + end = int(row[2]) + 1 + ctg_start = min(position, ctg_start) + ctg_end = max(end, ctg_end) + + candidate_file_path_output.close() + candidate_file_path_process.wait() + + if chunk_id is not None: + fai_fn = file_path_from(fasta_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + contig_length = 0 + with open(fai_fn, 'r') as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + + contig_name = columns[0] + if contig_name != ctg_name: + continue + contig_length = int(columns[1]) + chunk_size = contig_length // chunk_num + 1 if contig_length % chunk_num else contig_length // chunk_num + ctg_start = chunk_size * chunk_id # 0-base to 1-base + ctg_end = ctg_start + chunk_size + + is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None + + # 1-based regions [start, end] (start and end inclusive) + ref_regions = [] + reads_regions = [] + reference_start, reference_end = None, None + + if is_ctg_range_given: + extend_start = ctg_start - max_window_size + extend_end = ctg_end + max_window_size + reads_regions.append(region_from(ctg_name=ctg_name, ctg_start=extend_start, ctg_end=extend_end)) + reference_start, reference_end = ctg_start - param.expandReferenceRegion, ctg_end + param.expandReferenceRegion + reference_start = 1 if reference_start < 1 else reference_start + ref_regions.append(region_from(ctg_name=ctg_name, ctg_start=reference_start, ctg_end=reference_end)) + elif is_ctg_name_given: + reads_regions.append(region_from(ctg_name=ctg_name)) + ref_regions.append(region_from(ctg_name=ctg_name)) + reference_start = 1 + + reference_sequence = reference_sequence_from( + samtools_execute_command=samtools_execute_command, + fasta_file_path=fasta_file_path, + regions=ref_regions + ) + if reference_sequence is None or len(reference_sequence) == 0: + sys.exit("[ERROR] Failed to load reference sequence from file ({}).".format(fasta_file_path)) + + tree = bed_tree_from(bed_file_path=bed_file_path) + if is_bed_file_given and ctg_name not in tree: + sys.exit("[ERROR] ctg_name({}) not exists in bed file({}).".format(ctg_name, bed_file_path)) + + bed_option = ' -L {}'.format(extend_bed) if extend_bed else "" + bed_option = ' -L {}'.format(bed_file_path) if is_bed_file_given else bed_option + mq_option = ' -q {}'.format(minMQ) if minMQ > 0 else "" + samtools_view_command = "{} view -h {} {}".format(samtools_execute_command, bam_file_path, + " ".join(reads_regions)) + mq_option + bed_option + samtools_view_process = subprocess_popen( + shlex.split(samtools_view_command) + ) + + if read_fn and read_fn == 'PIPE': + save_file_fp = TensorStdout(sys.stdout) + elif read_fn: + save_file_fp = subprocess_popen(shlex.split("{} view -bh - -o {}".format(samtools_execute_command, read_fn + ( + '.{}_{}'.format(ctg_start, ctg_end) if is_ctg_range_given and not test_pos else ""))), stdin=PIPE, + stdout=PIPE) + + reference_start_0_based = 0 if reference_start is None else (reference_start - 1) + + header = [] + add_header = False + aligned_reads = defaultdict() + pileup = defaultdict(lambda: {"X": 0}) + samtools_view_generator = samtools_view_generator_from(samtools_view_process=samtools_view_process, + aligned_reads=aligned_reads, + pileup=pileup, + ctg_name=ctg_name, + reference_sequence=reference_sequence, + reference_start_0_based=reference_start_0_based, + header=header) + pre_aligned_reads = defaultdict() + + while True: + chunk_start, chunk_end = next(samtools_view_generator) + if chunk_start is None: + break + if not add_header: + save_file_fp.stdin.write(''.join(header)) + add_header = True + + variant_allele_list = [[position, pileup[position]["X"]] for position in list(pileup.keys())] + candidate_position_list = [(position, support_allele_count) for position, support_allele_count in + variant_allele_list if + support_allele_count >= min_coverage and position >= chunk_start - region_expansion_in_bp - 1 and position <= chunk_end + region_expansion_in_bp - 1] + candidate_position_list.sort(key=(lambda x: x[0])) + + if not len(aligned_reads) or not len(candidate_position_list): + continue + if len(pre_aligned_reads): # update the read in previous chunk + for read_name, read in pre_aligned_reads.items(): + aligned_reads[read_name] = read + + region_dict = {} + split_region_size = max_window_size + region_tree = IntervalTree() + for split_idx in range((chunk_end - chunk_start) // split_region_size): + split_start = chunk_start + split_idx * split_region_size - region_expansion_in_bp - 1 + split_end = split_start + split_region_size + region_expansion_in_bp * 2 + 1 + region_dict[(split_start, split_end)] = [] + region_tree.addi(split_start, split_end) + for candidate_position in candidate_position_list: + for region in region_tree.at(candidate_position[0]): + region_dict[(region.begin, region.end)].append(candidate_position[0]) + + for key, split_candidate_position_list in region_dict.items(): + start_pos, end_pos = None, None + windows = [] + read_windows_dict = {} + for pos in split_candidate_position_list: + if start_pos is None: + start_pos = pos + end_pos = pos + + elif pos > end_pos + 2 * min_windows_distance: + temp_window = (start_pos - min_windows_distance, end_pos + min_windows_distance) + windows.append(temp_window) + read_windows_dict[temp_window] = [] + + start_pos = pos + end_pos = pos + else: + end_pos = pos + + if start_pos is not None: + temp_window = (start_pos - min_windows_distance, end_pos + min_windows_distance) + windows.append(temp_window) + read_windows_dict[temp_window] = [] + if not len(windows): continue + windows = sorted(windows, key=lambda x: x[0]) + max_window_end = max([item[1] for item in windows]) + # #find read windows overlap_pair + for read_name, read in aligned_reads.items(): + if read.read_start > max_window_end: continue + argmax_window_idx = find_max_overlap_index((read.read_start, read.read_end), windows) + if argmax_window_idx is not None: + read_windows_dict[windows[argmax_window_idx]].append(read_name) + + # realignment + for window in windows: + start_pos, end_pos = window + if end_pos - start_pos > max_window_size: # or (window not in need_align_windows_set): + continue + + ref_start = start_pos - reference_start_0_based + ref_end = end_pos - reference_start_0_based + ref = reference_sequence[ref_start:ref_end] + reads = [] + low_base_quality_pos_list = [] + # pypy binding with ctypes for DBG building + for read_name in read_windows_dict[window]: + read = aligned_reads[read_name] + if (not read.graph_mq) or read.read_start > end_pos or read.read_end < start_pos: + continue + reads.append(read.seq) + low_base_quality_pos_list.append( + ' '.join([str(bq_idx) for bq_idx, item in enumerate(read.base_quality) if int(item) < 15])) + totoal_read_num = len(reads) + c_ref = byte(ref) + read_list1 = ctypes.c_char_p(byte(','.join(reads))) + low_base_quality_pos_array = ctypes.c_char_p(byte(','.join(low_base_quality_pos_list))) + + dbg.get_consensus.restype = ctypes.POINTER(DBGPointer) + dbg.get_consensus.argtypes = [ctypes.c_char_p, ctypes.c_char_p, ctypes.c_char_p, ctypes.c_int] + + dbg_p = dbg.get_consensus(ctypes.c_char_p(c_ref), read_list1, low_base_quality_pos_array, + totoal_read_num) + + c_consensus, consensus_size = dbg_p.contents.consensus, dbg_p.contents.consensus_size + consensus = [item.decode() for item in c_consensus[:consensus_size]] + + if len(consensus) == 0 or len(consensus) == 1 and consensus[0] == ref or len( + read_windows_dict[window]) == 0: + continue + min_read_start = min([aligned_reads[item].read_start for item in read_windows_dict[window]]) + max_read_end = max([aligned_reads[item].read_end for item in read_windows_dict[window]]) + tmp_ref_start = max(0, min(min_read_start, start_pos) - expand_align_ref_region) + tmp_ref_end = max(max_read_end, end_pos) + expand_align_ref_region + + ref_prefix = get_reference_seq(reference_sequence, tmp_ref_start, start_pos, reference_start_0_based) + ref_center = get_reference_seq(reference_sequence, start_pos, end_pos, reference_start_0_based) + if tmp_ref_end < end_pos: + continue + ref_suffix = get_reference_seq(reference_sequence, end_pos, tmp_ref_end, reference_start_0_based) + ref_seq = ref_prefix + ref_center + ref_suffix + + # pypy binding with ctypes for realignment + read_name_list = [] + totoal_read_num = min(max_region_reads_num, len(read_windows_dict[window])) + seq_list = (ctypes.c_char_p * totoal_read_num)() + position_list = (ctypes.c_int * totoal_read_num)() + cigars_list = (ctypes.c_char_p * totoal_read_num)() + + for read_idx, read_name in enumerate(read_windows_dict[window]): + read = aligned_reads[read_name] + if read_idx >= totoal_read_num: break + seq_list[read_idx] = byte(read.seq.upper()) + position_list[read_idx] = read.read_start + cigars_list[read_idx] = byte(read.cigar) + read_name_list.append(read_name) + haplotypes_list = [ref_prefix + cons + ref_suffix for cons in consensus] + haplotypes = ' '.join(haplotypes_list) + + realigner.realign_reads.restype = ctypes.POINTER(StructPointer) + realigner.realign_reads.argtypes = [ctypes.c_char_p * totoal_read_num, ctypes.c_int * totoal_read_num, + ctypes.c_char_p * totoal_read_num, ctypes.c_char_p, ctypes.c_char_p, + ctypes.c_int, + ctypes.c_int, ctypes.c_int, ctypes.c_int] + + realigner_p = realigner.realign_reads(seq_list, position_list, cigars_list, + ctypes.c_char_p(byte(ref_seq)), + ctypes.c_char_p(byte(haplotypes)), tmp_ref_start, + len(ref_prefix), len(ref_suffix), totoal_read_num) + + realign_positions, realign_cigars = realigner_p.contents.position, realigner_p.contents.cigar_string + read_position_list = realign_positions[:totoal_read_num] + read_cigar_list = [item.decode() for item in realign_cigars[:totoal_read_num]] + + if len(read_name_list): + for read_id, read_name in enumerate(read_name_list): + if read_cigar_list[read_id] == "" or ( + aligned_reads[read_name].cigar == read_cigar_list[read_id] and aligned_reads[ + read_name].read_start == read_position_list[read_id]): + continue + # update cigar and read start position + aligned_reads[read_name].test_pos = test_pos + realignment_start = read_position_list[read_id] + realignment_cigar = read_cigar_list[read_id].replace('X', 'M') + if realignment_cigar == aligned_reads[read_name].cigar and realignment_start == aligned_reads[ + read_name].read_start: + continue + aligned_reads[read_name].set_realignment_info(split_start, read_cigar_list[read_id], + read_position_list[read_id]) + + realigner.free_memory.restype = ctypes.POINTER(ctypes.c_void_p) + realigner.free_memory.argtypes = [ctypes.POINTER(StructPointer), ctypes.c_int] + realigner.free_memory(realigner_p, totoal_read_num) + # # realignment end + + if read_fn: + sorted_key = sorted([(key, item.best_pos) for key, item in aligned_reads.items()], key=lambda x: x[1]) + for read_name, read_start in sorted_key: + read = aligned_reads[read_name] + if read_start < chunk_start - region_expansion_in_bp - max_window_size: # safe distance for save reads + phasing_info = 'HP:i:{}'.format(read.phasing) if read.phasing else "" + pass + read_str = '\t'.join([read_name, read.flag, ctg_name, str(read_start + 1), + str(read.mapping_quality), read.best_cigar, read.RNEXT, read.PNEXT, read.TLEN, + read.seq, + read.raw_base_quality, + phasing_info]) + save_file_fp.stdin.write(read_str + '\n') + del aligned_reads[read_name] + for pile_pos in list(pileup.keys()): + if pile_pos < chunk_start - region_expansion_in_bp - max_window_size: + del pileup[pile_pos] + + if read_fn and aligned_reads: + sorted_key = sorted([(key, item.best_pos) for key, item in aligned_reads.items()], key=lambda x: x[1]) + for read_name, read_start in sorted_key: + read = aligned_reads[read_name] + phasing_info = 'HP:i:{}'.format(read.phasing) if read.phasing else "" + read_str = '\t'.join([read_name, read.flag, ctg_name, str(read_start + 1), + str(read.mapping_quality), read.best_cigar, read.RNEXT, read.PNEXT, read.TLEN, + read.seq, + read.raw_base_quality, + phasing_info]) + save_file_fp.stdin.write(read_str + '\n') + del aligned_reads[read_name] + if read_fn != 'PIPE': + save_file_fp.stdin.close() + save_file_fp.wait() + samtools_view_process.stdout.close() + samtools_view_process.wait() + + if test_pos: + save_file_fp = subprocess_popen(shlex.split("samtools index {}".format( + read_fn + ('.{}_{}'.format(ctg_start, ctg_end) if is_ctg_range_given and not test_pos else ""))), + stdin=PIPE, stdout=PIPE) + save_file_fp.stdin.close() + save_file_fp.wait() + + +def main(): + parser = ArgumentParser(description="Reads realignment") + + parser.add_argument('--bam_fn', type=str, default=None, required=True, + help="Sorted BAM file input, required") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", required=True, + help="Reference fasta file input, required") + + parser.add_argument('--read_fn', type=str, default="PIPE", + help="Output realigned BAM. Default directly pass reads to CreateTensor_phasing using PIPE. Default: %(default)s") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed") + + parser.add_argument('--full_aln_regions', type=str, default=None, + help="Realign reads only in the provided bed regions") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + # options for advanced users + parser.add_argument('--minCoverage', type=float, default=2, + help="EXPERIMENTAL: Minimum coverage required to call a variant, default: %(default)f") + + parser.add_argument('--minMQ', type=int, default=5, + help="EXPERIMENTAL: Minimum Mapping Quality. Mapping quality lower than the setting will be filtered, default: %(default)d") + + # options for debug purpose + parser.add_argument('--extend_bed', type=str, default=None, + help="DEBUG: Extend the regions in the --bed_fn by a few bp for tensor creation, default extend 16bp") + + # options for internal process control + ## Test in specific candidate position. Only for testing + parser.add_argument('--test_pos', type=int, default=0, + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + args = parser.parse_args() + + reads_realignment(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/SelectCandidates.py b/benchmarks/nn-variant/Clair3/preprocess/SelectCandidates.py new file mode 100644 index 0000000..60dab58 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/SelectCandidates.py @@ -0,0 +1,384 @@ +import shlex +import math +import sys +import logging +import os + +from argparse import ArgumentParser, SUPPRESS +from collections import defaultdict + +from shared.intervaltree.intervaltree import IntervalTree +import shared.param_f as param +from shared.utils import subprocess_popen, IUPAC_base_to_num_dict as BASE2NUM, region_from, reference_sequence_from, str2bool, log_warning + +logging.basicConfig(format='%(message)s', level=logging.INFO) + + +def gaussian_distribution(x, mu, sig=16): + return math.exp(-math.pow(x - mu, 2.) / (2 * math.pow(sig, 2.))) + + +def discrete_gaussian_pro(entropy_windnow): + gaussian_pro = [gaussian_distribution(index, entropy_windnow / 2, 1) for index in range(entropy_windnow)] + return gaussian_pro + + +def calculate_sequence_entropy(sequence, entropy_window=None, kmer=5): + """ + We use a kmer-based sequence entropy calculation to measure the complexity of a region. + sequence: a chunked sequence around a candidate position, default no_of_positions = flankingBaseNum + 1 + flankingBaseNum + entropy_window: a maximum entropy window for scanning, if the sequence is larger than the entropy window, a slide + window would be adopted for measurement. + kmer: default kmer size for sequence entropy calculation. + """ + + count_repeat_kmer_counts = [0] * (entropy_window + 2) + count_repeat_kmer_counts[0] = entropy_window + + entropy = [0.0] * (entropy_window + 2) + for i in range(1, entropy_window + 2): + e = 1.0 / entropy_window * i + entropy[i] = e * math.log(e) + entropy_mul = -1 / math.log(entropy_window) + entropy_kmer_space = 1 << (2 * kmer) + + kmer_hash_counts = [0] * entropy_kmer_space # value should smaller than len(seq) + mask = -1 if kmer > 15 else ~((-1) << (2 * kmer)) + kmer_suffix, kmer_prefix = 0, 0 + + i = 0 + i2 = -entropy_window + entropy_sum = 0.0 + all_entropy_sum = [0.0] * len(sequence) + while (i2 < len(sequence)): + + if (i < len(sequence)): + n = BASE2NUM[sequence[i]] + kmer_suffix = ((kmer_suffix << 2) | n) & mask + + count_repeat_kmer_counts[kmer_hash_counts[kmer_suffix]] -= 1 + entropy_sum -= entropy[kmer_hash_counts[kmer_suffix]] + kmer_hash_counts[kmer_suffix] += 1 + count_repeat_kmer_counts[kmer_hash_counts[kmer_suffix]] += 1 + entropy_sum += entropy[kmer_hash_counts[kmer_suffix]] + + if i2 >= 0 and i < len(sequence): + n2 = BASE2NUM[sequence[i2]] + kmer_prefix = ((kmer_prefix << 2) | n2) & mask # add base info + count_repeat_kmer_counts[kmer_hash_counts[kmer_prefix]] -= 1 + entropy_sum -= entropy[kmer_hash_counts[kmer_prefix]] + kmer_hash_counts[kmer_prefix] -= 1 + count_repeat_kmer_counts[kmer_hash_counts[kmer_prefix]] += 1 + entropy_sum += entropy[kmer_hash_counts[kmer_prefix]] + all_entropy_sum[i] = entropy_sum + i += 1 + i2 += 1 + return entropy_sum * entropy_mul + + +def sqeuence_entropy_from(samtools_execute_command, fasta_file_path, contig_name, candidate_positions): + """ + Calculate sequence entropy in a specific candidate windows, variants in low sequence entropy regions (low + mappability regions, such as homopolymer, tandem repeat, segmental duplications regions) would more likely have + more complex variants representation, which is beyond pileup calling. Hence, those candidate variants are re-called by + full alignment calling. + We use a kmer-based sequence entropy calculation to measure the complexity of a region, we would directly query the + chunked reference sequence for sequence entropy calculation for each candidate variant. + """ + + ref_regions = [] + reference_start, reference_end = min(list(candidate_positions)) - param.no_of_positions, max( + list(candidate_positions)) + param.no_of_positions + 1 + reference_start = 1 if reference_start < 1 else reference_start + ref_regions.append(region_from(ctg_name=contig_name, ctg_start=reference_start, ctg_end=reference_end)) + reference_sequence = reference_sequence_from( + samtools_execute_command=samtools_execute_command, + fasta_file_path=fasta_file_path, + regions=ref_regions + ) + if reference_sequence is None or len(reference_sequence) == 0: + sys.exit("[ERROR] Failed to load reference seqeunce from file ({}).".format(fasta_file_path)) + + entropy_window = param.no_of_positions + candidate_positions_entropy_list = [] + for pos in candidate_positions: + ref_seq = reference_sequence[ + pos - param.flankingBaseNum - reference_start: pos + param.flankingBaseNum + 1 - reference_start] + sequence_entropy = calculate_sequence_entropy(sequence=ref_seq, entropy_window=entropy_window) + candidate_positions_entropy_list.append((pos, sequence_entropy)) + + return candidate_positions_entropy_list + + +def SelectCandidates(args): + """ + Select low quality and low sequence entropy candidate variants for full aligement. False positive pileup variants + and true variants missed by pileup calling would mostly have low quality score (reference quality score for missing + variants), so only use a proportion of low quality variants for full alignment while maintain high quality pileup + output, as full alignment calling is substantially slower than pileup calling. + """ + + phased_vcf_fn = args.phased_vcf_fn + pileup_vcf_fn = args.pileup_vcf_fn + var_pct_full = args.var_pct_full + ref_pct_full = args.ref_pct_full + seq_entropy_pro = args.seq_entropy_pro + contig_name = args.ctgName + phasing_window_size = param.phasing_window_size + platform = args.platform + split_bed_size = args.split_bed_size + split_folder = args.split_folder + extend_bp = param.extend_bp + call_low_seq_entropy = args.call_low_seq_entropy + phasing_info_in_bam = args.phasing_info_in_bam + need_phasing_list = [] + need_phasing_set = set() + ref_call_pos_list = [] + variant_dict = defaultdict(str) + flankingBaseNum = param.flankingBaseNum + qual_fn = args.qual_fn if args.qual_fn is not None else 'qual' + fasta_file_path = args.ref_fn + samtools_execute_command = args.samtools + + found_qual_cut_off = False + low_sequence_entropy_list = [] + # try to find the global quality cut off: + f_qual = os.path.join(split_folder, qual_fn) + if os.path.exists(f_qual): + with open(f_qual, 'r') as f: + line = f.read().rstrip().split(' ') + var_qual, ref_qual = float(line[0]), float(line[1]) + found_qual_cut_off = True + + all_full_aln_regions = [] + if phased_vcf_fn and os.path.exists(phased_vcf_fn): + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (phased_vcf_fn))) + for row in unzip_process.stdout: + row = row.rstrip() + if row[0] == '#': + continue + columns = row.strip().split('\t') + + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + genotype_info = columns[9].split(':') + genotype, phase_set = genotype_info[0], genotype_info[-1] + if '|' not in genotype: # unphasable + continue + variant_dict[pos] = '-'.join([ref_base, alt_base, ('1' if genotype == '0|1' else '2'), phase_set]) + + if pileup_vcf_fn and os.path.exists(pileup_vcf_fn): + # vcf format + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (pileup_vcf_fn))) + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns = row.rstrip().split('\t') + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + qual = float(columns[5]) + + # reference calling + if alt_base == "." or ref_base == alt_base: + ref_call_pos_list.append((pos, qual)) + else: + need_phasing_list.append((pos, qual)) + need_phasing_set.add(pos) + + if found_qual_cut_off: + low_qual_ref_list = [[k, v] for k, v in ref_call_pos_list if v < ref_qual] + low_qual_variant_list = [[k, v] for k, v in need_phasing_list if v < var_qual] + else: + low_qual_ref_list = sorted(ref_call_pos_list, key=lambda x: x[1])[:int(ref_pct_full * len(ref_call_pos_list))] + low_qual_variant_list = sorted(need_phasing_list, key=lambda x: x[1])[ + :int(var_pct_full * len(need_phasing_list))] + + if call_low_seq_entropy: + candidate_positions = sorted(ref_call_pos_list, key=lambda x: x[1])[ + :int((var_pct_full + seq_entropy_pro) * len(ref_call_pos_list))] + sorted(need_phasing_list, + key=lambda x: x[ + 1])[:int( + (var_pct_full + seq_entropy_pro) * len(need_phasing_list))] + candidate_positions = set([item[0] for item in candidate_positions]) + + candidate_positions_entropy_list = sqeuence_entropy_from(samtools_execute_command=samtools_execute_command, + fasta_file_path=fasta_file_path, + contig_name=contig_name, + candidate_positions=candidate_positions) + + low_sequence_entropy_list = sorted(candidate_positions_entropy_list, key=lambda x: x[1])[ + :int(seq_entropy_pro * len(candidate_positions_entropy_list))] + + # calling with phasing_info_in_bam: select low qual ref and low qual vairant for phasing calling + if phasing_info_in_bam: + logging.info( + '[INFO] Low quality reference calls to be processed in {}: {}'.format(contig_name, len(low_qual_ref_list))) + logging.info( + '[INFO] Low quality variants to be processed in {}: {}'.format(contig_name, len(low_qual_variant_list))) + if call_low_seq_entropy: + logging.info('[INFO] Total low sequence entropy variants to be processed in {}: {}'.format(contig_name, len( + low_sequence_entropy_list))) + + need_phasing_row_list = set( + [item[0] for item in low_qual_ref_list] + [item[0] for item in low_qual_variant_list] + [item[0] for + item in + low_sequence_entropy_list]) + need_phasing_row_list = sorted(list(need_phasing_row_list)) + + if len(need_phasing_row_list) == 0: + print(log_warning( + "[WARNING] Cannot find any low-quality 0/0, 0/1 or 1/1 variant in pileup output in contig {}".format(contig_name))) + + region_num = len(need_phasing_row_list) // split_bed_size + 1 if len( + need_phasing_row_list) % split_bed_size else len(need_phasing_row_list) // split_bed_size + + for idx in range(region_num): + # a windows region for create tensor # samtools mpileup not include last position + split_output = need_phasing_row_list[idx * split_bed_size: (idx + 1) * split_bed_size] + + if platform == 'ilmn': + region_size = param.split_region_size + split_output = [(item // region_size * region_size - param.no_of_positions, + item // region_size * region_size + region_size + param.no_of_positions) for item + in split_output] + else: + split_output = [(item - flankingBaseNum, item + flankingBaseNum + 2) for item in + split_output] + + split_output = sorted(split_output, key=lambda x: x[0]) + + # currently deprecate using ctgName.start_end as file name, which will run similar regions for several times when start and end has slight difference + # output_path = os.path.join(split_folder, '{}.{}_{}'.format(contig_name, split_output[0][0], split_output[-1][1])) + output_path = os.path.join(split_folder, '{}.{}_{}'.format(contig_name, idx, region_num)) + all_full_aln_regions.append(output_path) + with open(output_path, 'w') as output_file: + output_file.write('\n'.join( + ['\t'.join([contig_name, str(x[0] - 1), str(x[1] - 1), ]) for x in + split_output]) + '\n') # bed format + + all_full_aln_regions_path = os.path.join(split_folder, 'FULL_ALN_FILE_{}'.format(contig_name)) + with open(all_full_aln_regions_path, 'w') as output_file: + output_file.write('\n'.join(all_full_aln_regions) + '\n') + return + + for pos, qual in low_qual_ref_list: + need_phasing_set.add(pos) + + # Call variant in all candidate position + elif args.all_alt_fn is not None: + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (args.all_alt_fn))) + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns = row.rstrip().split('\t') + ctg_name, pos = columns[0].split() + pos = int(pos) + if contig_name and contig_name != ctg_name: + continue + need_phasing_set.add(pos) + + need_phasing_row_list = sorted(list(set(need_phasing_set))) + snp_tree = IntervalTree() + hete_snp_row_list = sorted(list(set(variant_dict.keys()).intersection(set(need_phasing_row_list)))) + print('[INFO] Total hete snp with reads support in {}: '.format(contig_name), len(hete_snp_row_list)) + print('[INFO] Total candidates need to be processed in {}: '.format(contig_name), len(need_phasing_row_list)) + + for item in hete_snp_row_list: + snp_tree.addi(item, item + 1) + + region_num = len(need_phasing_row_list) // split_bed_size + 1 if len( + need_phasing_row_list) % split_bed_size else len(need_phasing_row_list) // split_bed_size + for idx in range(region_num): + split_output = need_phasing_row_list[idx * split_bed_size: (idx + 1) * split_bed_size] + + start = split_output[0] + end = split_output[-1] + extend_start, extend_end = start - phasing_window_size, end + phasing_window_size + overlaps = snp_tree.overlap(extend_start, extend_end) + snp_split_out = [] + for overlap in overlaps: + snp_split_out.append((contig_name, overlap[0] - extend_bp - 1 - 1, overlap[0] + 1 + extend_bp - 1, + variant_dict[overlap[0]])) # bed format + split_output = [(contig_name, item - flankingBaseNum - 1, item + flankingBaseNum + 1 - 1) for item in + split_output] # a windows region for create tensor # bed format + + split_output += snp_split_out + split_output = sorted(split_output, key=lambda x: x[1]) + + with open(os.path.join(split_folder, '{}.{}_{}'.format(contig_name, start, end)), 'w') as output_file: + output_file.write('\n'.join(['\t'.join(map(str, x)) for x in split_output]) + '\n') # bed format + + +def main(): + parser = ArgumentParser(description="Select pileup candidates for full alignment") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--split_folder', type=str, default=None, required=True, + help="Path to directory that stores candidate region, required") + + parser.add_argument('--pileup_vcf_fn', type=str, default=None, required=True, + help="Input pileup pileup vcf, required") + + parser.add_argument('--ref_fn', type=str, default=None, + help="Reference fasta file input, required") + + parser.add_argument('--var_pct_full', type=float, default=0.3, + help="Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: %(default)f") + + parser.add_argument('--ref_pct_full', type=float, default=0.3, + help="Specify an expected percentage of low quality 0/0 variants called in the pileup mode for full-alignment mode calling, default: %(default)f") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed") + + parser.add_argument('--samtools', type=str, default="samtools", + help="Path to the 'samtools', samtools version >= 1.10 is required, default: %(default)s") + + # options for advanced users + parser.add_argument('--call_low_seq_entropy', type=str2bool, default=False, + help="EXPERIMENTAL: Enable full alignment calling on candidate variants with low sequence entropy") + + parser.add_argument('--seq_entropy_pro', type=float, default=0.05, + help="EXPERIMENTAL: Define the percentage of the candidate variants with the lowest sequence entropy for full alignment calling, default: %(default)f") + + parser.add_argument('--split_bed_size', type=int, default=10000, + help="EXPERIMENTAL: Define the candidate bed size for each split bed file. default: %(default)s") + + # options for debug purpose + parser.add_argument('--phasing_info_in_bam', action='store_false', + help="DEBUG: Skip phasing and use the phasing info provided in the input BAM (HP tag), default: True") + + # options for internal process control + ## Default chr prefix for contig name + parser.add_argument('--chr_prefix', type=str, default='chr', + help=SUPPRESS) + + ## Input phased pileup vcf + parser.add_argument('--phased_vcf_fn', type=str, default=None, + help=SUPPRESS) + + ## Output all alternative candidates path + parser.add_argument('--all_alt_fn', type=str, default=None, + help=SUPPRESS) + + ## Input the file that contains the quality cut-off for selecting low-quality pileup calls for phasing and full-alignment calling + parser.add_argument('--qual_fn', type=str, default=None, + help=SUPPRESS) + + args = parser.parse_args() + + SelectCandidates(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/SelectHetSnp.py b/benchmarks/nn-variant/Clair3/preprocess/SelectHetSnp.py new file mode 100644 index 0000000..d42a066 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/SelectHetSnp.py @@ -0,0 +1,400 @@ +import shlex +import os +import sys +from argparse import ArgumentParser, SUPPRESS +from collections import defaultdict +from shared.intervaltree.intervaltree import IntervalTree + + +import shared.param_f as param +from shared.utils import subprocess_popen + +def FiterHeteSnpPhasing(args): + + """ + Filter heterozygous snp variant for phasing, currently, we only filter snp variant with low quality socore as low + quality variant contains more false positive variant that would lead to a larger minimum error correction loss. + """ + qual_fn = args.qual_fn if args.qual_fn is not None else 'phase_qual' + vcf_fn = args.vcf_fn + var_pct_full = args.var_pct_full + contig_name = args.ctgName + split_folder = args.split_folder + variant_dict = defaultdict(str) + qual_set = defaultdict(int) + found_qual_cut_off = False + header = [] + + #try to find the global quality cut off: + f_qual = os.path.join(split_folder, qual_fn) + if os.path.exists(f_qual): + phase_qual_cut_off = float(open(f_qual, 'r').read().rstrip()) + found_qual_cut_off = True + + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + for row in unzip_process.stdout: + row = row.rstrip() + if row[0] == '#': + header.append(row + '\n') + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + genotype = columns[9].split(':')[0].replace('|', '/') + + if len(ref_base) == 1 and len(alt_base) == 1: + if genotype == '0/1' or genotype=='1/0': + variant_dict[pos] = row + qual = float(columns[5]) + qual_set[pos] = qual + + if found_qual_cut_off: + remove_low_qual_list = [[k,v] for k,v in qual_set.items() if v < phase_qual_cut_off ] + else: + remove_low_qual_list = sorted(qual_set.items(), key=lambda x: x[1])[:int(var_pct_full * len(qual_set))] + for pos, qual in remove_low_qual_list: + del variant_dict[pos] + + print ('[INFO] Total heterozygous SNP positions selected: {}: {}'.format(contig_name, len(variant_dict))) + + f = open(os.path.join(split_folder, '{}.vcf'.format(contig_name)), 'w') + f.write(''.join(header)) + for key,row in sorted(variant_dict.items(), key=lambda x: x[0]): + f.write(row +'\n') + f.close() + + +def FiterHeteSnp_FP(args): + + """ + Filter heterozygous snp variant for calling, this is a testing function to validate various proportion of phasing + effect on full alignment calling, currently for testing only. + """ + + vcf_fn = args.vcf_fn + proportion = args.proportion + chr_prefix = args.chr_prefix + contig_name = args.ctgName + phasing_window_size = param.phasing_window_size + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + split_bed_size = args.split_bed_size + split_folder = args.split_folder + output = [] + snp = [] + need_phasing_list = [] + chr_prefix_length = len(chr_prefix) + variant_dict = defaultdict(str) + for row in unzip_process.stdout: + + if row[0] == '#': + output.append(row.rstrip()) + continue + columns = row.strip().split() + + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + genotype = columns[9].split(':')[0].replace('|', '/') + qual = int(columns[5]) + if len(ref_base) == 1 and len(alt_base) == 1: + if genotype == '0/1': + snp.append((qual, pos)) + variant_dict[pos] = ref_base + '-' + alt_base + else: + need_phasing_list.append((qual, pos)) + + qual_list = sorted(snp, key=lambda x: x[0]) + print('[INFO] Total hete snp variants:', len(qual_list)) + cut_off_index = int(len(qual_list) * proportion) + hete_snp_row_list = [item[1] for item in qual_list[cut_off_index:]] + print ('[INFO] Total hete snp filter matches:', len(hete_snp_row_list)) + + qual_list = sorted(need_phasing_list, key=lambda x: -x[0]) + cut_off_index = int(len(qual_list) * proportion) + need_phasing_row_list = sorted([item[1] for item in qual_list[cut_off_index:]]) + print('[INFO] Total variants need to be phased:', len(need_phasing_row_list)) + phasing_tree = IntervalTree() + for item_idx, item in enumerate(need_phasing_list): + pos = item[1] + start = pos - phasing_window_size + end = pos + phasing_window_size + phasing_tree.addi(start, end) + + snp_tree = IntervalTree() + for item in hete_snp_row_list: + if len(phasing_tree.at(item)): snp_tree.addi(item, item + 1) + + region_num = len(need_phasing_row_list) // split_bed_size + 1 if len(need_phasing_row_list) % split_bed_size else len(need_phasing_row_list) // split_bed_size + for idx in range(region_num): + split_output = need_phasing_row_list[idx * split_bed_size : (idx+1) * split_bed_size] + start, end = split_output[0] - phasing_window_size, split_output[-1] + phasing_window_size + overlaps = snp_tree.overlap(start, end) + snp_split_out = [] + for overlap in overlaps: + snp_split_out.append((overlap[0], overlap[0] + 1, 1)) + split_output = [(item - param.flankingBaseNum, item+1 + param.flankingBaseNum, 0) for item in split_output] # a windows region for create tensor + print (len(split_output), len(snp_split_out)) + split_output += snp_split_out + split_output = sorted(split_output, key=lambda x: x[0]) + + with open(os.path.join(split_folder, 'split_{}.{}'.format(contig_name[chr_prefix_length:], idx)), 'w') as output_file: + output_file.write('\n'.join(['\t'.join([contig_name, str(x[0]-1), str(x[1]-1), str(x[2]), variant_dict[x[0]]]) for x in split_output]) + '\n') # bed format + +def FiterHeteSnp(args): + + """ + Filter heterozygous snp variant for training, if there are too many candidates for full alignment training, we + would select more in low quality variants, which is more challenging for pileup model to predict and using more + information will benefit calling those variants. + """ + + vcf_fn = args.vcf_fn # true vcf var + alt_fn = args.alt_fn + var_pct_full = args.var_pct_full + ref_pct_full = args.ref_pct_full if args.ref_pct_full is not None else var_pct_full + chr_prefix = args.chr_prefix + contig_name = args.ctgName + phasing_window_size = param.phasing_window_size + chunk_id = args.chunk_id - 1 if args.chunk_id else None # 1-base to 0-base + DEPTH = args.depth + chunk_num = args.chunk_num + sample_name = args.sampleName + split_bed_size = args.split_bed_size + split_folder = args.split_folder + extend_bp = param.extend_bp + phasing_info_in_bam = args.phasing_info_in_bam + need_phasing_list = [] + need_phasing_set = set() + ref_call_pos_list = [] + chr_prefix_length = len(chr_prefix) + variant_dict = defaultdict(str) + realign_window_size = args.realign_window_size if args.realign_window_size is not None else param.flankingBaseNum + candidate_positions = set() + + if vcf_fn and os.path.exists(vcf_fn): + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + for row in unzip_process.stdout: + row =row.rstrip() + if row[0] == '#': + continue + columns = row.strip().split('\t') + + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + genotype_info = columns[9].split(':') + genotype, phase_set = genotype_info[0], genotype_info[-1] + if '|' not in genotype: #unphasable + continue + variant_dict[pos] = '-'.join([ref_base, alt_base, ('1' if genotype == '0|1' else '2'), phase_set]) + + if alt_fn is not None: + # vcf format + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (alt_fn))) + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns =row.rstrip().split('\t') + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + ref_base = columns[3] + alt_base = columns[4] + qual = float(columns[5]) + + + candidate_positions.add(pos) + #ref_call was marked as '.' after v0.1-r5 + if ref_base == alt_base or alt_base == ".": + ref_call_pos_list.append((pos,qual)) + else: + need_phasing_list.append((pos,qual)) + need_phasing_set.add(pos) + + low_qual_ref_list = sorted(ref_call_pos_list, key=lambda x: x[1])[:int(ref_pct_full * len(ref_call_pos_list))] + low_qual_variant_list = sorted(need_phasing_list, key=lambda x: x[1])[:int(var_pct_full * len(need_phasing_list))] + + #calling with phasing_info_in_bam: select low qual ref and low qual vairant for phasing calling + if phasing_info_in_bam: + print('[INFO] {} {} total low qual ref calling to process: {}'.format(sample_name, contig_name, len(low_qual_ref_list))) + print('[INFO] {} {} total low qual variant calling to process: {}'.format(sample_name, contig_name, len(low_qual_variant_list))) + + need_phasing_row_list = set([item[0] for item in low_qual_ref_list] + [item[0] for item in low_qual_variant_list]) + need_phasing_row_list = sorted(list(need_phasing_row_list)) + + if chunk_num: + all_candidate_size = len(need_phasing_row_list) + chunk_size = all_candidate_size // chunk_num + 1 if all_candidate_size % chunk_num else all_candidate_size // chunk_num + + for chunk_idx in range(chunk_num): + start_pos = chunk_idx * chunk_size + end_pos = min(start_pos + chunk_size, all_candidate_size) + split_output = need_phasing_row_list[start_pos:end_pos] + split_output = [(item - realign_window_size, item + realign_window_size + 2) for item in + split_output] # a windows region for create tensor # samtools mpileup not include last position + + split_output = sorted(split_output, key=lambda x: x[0]) + with open(os.path.join(split_folder, + '{}_{}_{}_{}'.format(sample_name, DEPTH, contig_name[chr_prefix_length:], chunk_idx+1)), # zero-base to one-base + 'w') as output_file: + output_file.write('\n'.join( + ['\t'.join([contig_name, str(x[0] - 1), str(x[1] - 1), ]) for x in + split_output]) + '\n') # bed format + return + + region_num = len(need_phasing_row_list) // split_bed_size + 1 if len( + need_phasing_row_list) % split_bed_size else len(need_phasing_row_list) // split_bed_size + for idx in range(region_num): + split_output = need_phasing_row_list[idx * split_bed_size: (idx + 1) * split_bed_size] + split_output = [(item - realign_window_size, item + realign_window_size + 2) for item in + split_output] # a windows region for create tensor # samtools mpileup not include last position + + split_output = sorted(split_output, key=lambda x: x[0]) + + with open(os.path.join(split_folder, '{}.{}_{}'.format(contig_name[chr_prefix_length:], split_output[0][0], split_output[-1][1])), + 'w') as output_file: + output_file.write('\n'.join( + ['\t'.join([contig_name, str(x[0] - 1), str(x[1] - 1),]) for x in + split_output]) + '\n') # bed format + return + + for pos, qual in low_qual_ref_list: + need_phasing_set.add(pos) + + # train or call in all_pos + elif args.all_alt_fn is not None: + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (args.all_alt_fn))) + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns = row.rstrip().split('\t') + ctg_name, pos = columns[0].split() + pos = int(pos) + if contig_name and contig_name != ctg_name: + continue + need_phasing_set.add(pos) + + need_phasing_row_list = sorted(list(set(need_phasing_set))) + snp_tree = IntervalTree() + hete_snp_row_list = sorted(list(set(variant_dict.keys()).intersection(set(need_phasing_row_list)))) + print('[INFO] Total hete snp with reads support in {}: '.format(contig_name), len(hete_snp_row_list)) + print('[INFO] Total candidates need to be processed in {}: '.format(contig_name), len(need_phasing_row_list)) + + for item in hete_snp_row_list: + snp_tree.addi(item, item + 1) + + region_num = len(need_phasing_row_list) // split_bed_size + 1 if len(need_phasing_row_list) % split_bed_size else len(need_phasing_row_list) // split_bed_size + for idx in range(region_num): + split_output = need_phasing_row_list[idx * split_bed_size : (idx+1) * split_bed_size] + + start = split_output[0] + end = split_output[-1] + extend_start, extend_end = start - phasing_window_size, end + phasing_window_size + overlaps = snp_tree.overlap(extend_start, extend_end) + snp_split_out = [] + for overlap in overlaps: + snp_split_out.append((contig_name, overlap[0] - extend_bp - 1 - 1, overlap[0] + 1 + extend_bp - 1, variant_dict[overlap[0]]))# bed format + split_output = [(contig_name, item - realign_window_size-1, item+realign_window_size+1-1) for item in split_output] # a windows region for create tensor # bed format + + split_output += snp_split_out + split_output = sorted(split_output, key=lambda x: x[1]) + + with open(os.path.join(split_folder, '{}.{}_{}'.format(contig_name[chr_prefix_length:], start, end)), 'w') as output_file: + output_file.write('\n'.join(['\t'.join(map(str, x)) for x in split_output]) + '\n') # bed format + +def main(): + parser = ArgumentParser(description="Select heterozygous snp candidates for WhatsHap phasing") + + parser.add_argument('--split_folder', type=str, default=None, + help="Path to directory that stores small bed region for raw alignment. (default: %(default)s)") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Path of the input vcf file. (default: %(default)s)") + + parser.add_argument('--var_pct_full', type=float, default=0.3, + help="Default variant call proportion for raw alignment or remove low quality proportion for whatshap phasing. (default: %(default)f)") + + parser.add_argument('--ref_pct_full', type=float, default=None, + help="Default reference call proportion for raw alignment or remove low quality proportion for whatshap phasing. (default: %(default)f)") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, default: %(default)s") + + parser.add_argument('--phase', action='store_false', + help="Only select hete candidates for phasing, default: True") + + parser.add_argument('--sampleName', type=str, default="", + help="Define the sample name to be shown in the VCF file, optional") + + # options for debug purpose + parser.add_argument('--phasing_info_in_bam', action='store_true', + help="DEBUG: Input bam or sam have phasing info in HP tag, default: False") + + parser.add_argument('--split_bed_size', type=int, default=1000, + help="DEBUG: Default split bed size for parallel excution, default: %(default)s") + + parser.add_argument('--calling', type=int, default=0, + help="DEBUG: Path of the output folder, default: %(default)s") + + parser.add_argument('--realign_window_size', type=int, default=None, + help="DEBUG: The window size of read realignment, work with need_realignment") + + parser.add_argument('--split_region_size', type=int, default=40000000, + help="DEBUG: Vcf phasing split_region_size default: %(default)s") + + # options for internal process control + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Output all alternative candidates path + parser.add_argument('--all_alt_fn', type=str, default=None, + help=SUPPRESS) + + ## Default chr prefix for contig name + parser.add_argument('--chr_prefix', type=str, default='chr', + help=SUPPRESS) + + ## Default subsample depth for subsample bam file, 1000 means no subsampling + parser.add_argument('--depth', type=int, default=1000, + help=SUPPRESS) + + ## Path of provided alternative file + parser.add_argument('--alt_fn', type=str, default=None, + help=SUPPRESS) + + ## Input the file that contains the quality cut-off for selecting low-quality pileup calls for phasing and full-alignment calling + parser.add_argument('--qual_fn', type=str, default=None, + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + # + if args.phase: + FiterHeteSnpPhasing(args) + elif args.calling == 1: + FiterHeteSnp_FP(args) + else: + FiterHeteSnp(args) + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/SelectQual.py b/benchmarks/nn-variant/Clair3/preprocess/SelectQual.py new file mode 100644 index 0000000..1be2c22 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/SelectQual.py @@ -0,0 +1,151 @@ +from sys import stdin +from argparse import ArgumentParser, SUPPRESS +import os + +from shared.utils import file_path_from, log_warning + +major_contigs_order = ["chr" + str(a) for a in list(range(1, 23)) + ["X", "Y"]] + [str(a) for a in + list(range(1, 23)) + ["X", "Y"]] + +def select_phase_qual_from_stdin(args): + + """ + Select a global quality cut-off for phasing and reads haplotag. + """ + qual_fn = args.qual_fn if args.qual_fn is not None else "phase_qual" + var_pct_full = args.var_pct_full + var_pct_phasing = args.var_pct_phasing + low_qual_hete_var_pct = 1 - var_pct_phasing if var_pct_phasing is not None else var_pct_full + phase_qual_list = [] + for row in stdin: + if row[0] == '#': + continue + row = row.rstrip().split() + ref_base, alt_base = row[3], row[4] + # select heterozygous snp only + if len(ref_base) != 1 or len(alt_base) != 1: + continue + qual, gt_info = row[5], row[9] + genotype = gt_info.split(':')[0] + if genotype == '0/1': + phase_qual_list.append(float(qual)) + + # in phase mode, var_pct_full is the proportion of low-quality heterozygous variants to be discarded for whatshap phasing + phase_qual_list = sorted(phase_qual_list) + low_phase_qual_list = phase_qual_list[:int(low_qual_hete_var_pct * len(phase_qual_list))] + if len(low_phase_qual_list) == 0: + print(log_warning( + "[WARNING] Cannot find any 0/1 variant in pileup output using variant quality cut-off proportion: {}, total heterozygous variants: {}".format( + low_qual_hete_var_pct, len(low_phase_qual_list)))) + print(log_warning("[WARNING] Set low variant quality score cut-off to 0.0")) + qual_cut_off = 0.0 + else: + qual_cut_off = low_phase_qual_list[-1] + print ('[INFO] Select heterozygous pileup variants exceeding phasing quality cutoff {}'.format(round(qual_cut_off), 0)) + + if args.output_fn: + with open(os.path.join(args.output_fn, qual_fn), 'w') as output: + output.write(str(qual_cut_off)) + + + +def select_qual_from_stdin(args): + + """ + Select a global quality cut-off for full alignment calling from pileup vcf file. False positive pileup variants + and true variants missed by pileup calling would mostly have low quality score (reference quality score for missing + variants), so only use a proportion of low quality variants for full alignment while maintain high quality pileup + output, as full alignment calling is substantially slower than pileup calling. + """ + var_pct_full = args.var_pct_full + qual_fn = args.qual_fn if args.qual_fn is not None else "qual" + vcf_fn = file_path_from(args.vcf_fn) + ref_pct_full = args.ref_pct_full if args.ref_pct_full else var_pct_full + # for efficiency, we use a maximum 30% reference candidates proportion for full-alignment calling, which is almost cover all false negative candidates + # for ont platform, we set a default 10% reference candidates proportion for full-alignment calling unless a known vcf file is provided (genotyping mode) + # directly set default value in run_clair3.sh from v0.1-r5 + # ref_pct_full = 0.1 if args.platform == 'ont' else ref_pct_full + # ref_pct_full = min(ref_pct_full, 0.3) + + variant_qual_list = [] + ref_qual_list = [] + for row in stdin: + if row[0] == '#': + continue + row = row.rstrip().split() + + qual, gt_info = row[5], row[9] + genotype = gt_info.split(':')[0] + if genotype == '0/0': + ref_qual_list.append(float(qual)) + else: + variant_qual_list.append(float(qual)) + + ref_qual_list = sorted(ref_qual_list) + variant_qual_list = sorted(variant_qual_list) + low_variant_qual_list = variant_qual_list[:int(var_pct_full * len(variant_qual_list))] + if len(low_variant_qual_list) == 0: + print(log_warning( + "[WARNING] Cannot find any low-quality 0/1 or 1/1 variant in pileup output using variant quality cut-off proportion: {}, total variants: {}".format( + var_pct_full, len(variant_qual_list)))) + print(log_warning("[WARNING] Set low variant quality score cut-off to 0.0")) + var_qual_cut_off = 0.0 + else: + var_qual_cut_off = low_variant_qual_list[-1] + + # If a known vcf file is provided, use user-defined proportion + low_ref_qual_list = ref_qual_list[:int(ref_pct_full * len(ref_qual_list))] if vcf_fn is None else ref_qual_list[:int(args.ref_pct_full * len(ref_qual_list))] + if len(low_ref_qual_list) == 0: + print(log_warning( + "[WARNING] Cannot find any low-quality 0/0 reference calls in pileup output using reference quality cut-off proportion: {}, total reference calls: {}".format( + ref_pct_full, len(ref_qual_list)))) + print(log_warning("[WARNING] Set low reference quality score cut-off to 0.0")) + ref_qual_cut_off = 0.0 + else: + ref_qual_cut_off = low_ref_qual_list[-1] + print ('[INFO] Set variants quality cutoff {}'.format(round(var_qual_cut_off, 0))) + print ('[INFO] Set reference calls quality cutoff {}'.format(round(ref_qual_cut_off, 0))) + + if args.output_fn: + with open(os.path.join(args.output_fn, qual_fn), 'w') as output: + output.write(str(var_qual_cut_off) + ' ' + str(ref_qual_cut_off)) + + +def main(): + parser = ArgumentParser(description="Select quality cut-off for phasing and full alignment") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--output_fn', type=str, default=None, required=True, + help="Define the output folder, required") + + parser.add_argument('--var_pct_full', type=float, default=0.3, + help="Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3") + + parser.add_argument('--ref_pct_full', type=float, default=0.3, + help="Specify an expected percentage of low quality 0/0 variants called in the pileup mode for full-alignment mode calling, default: 0.3 for ilmn and hifi, 0.1 for ont") + + parser.add_argument('--var_pct_phasing', type=float, default=0.7, + help="Specify an expected percentage of high quality 0/1 variants used in WhatsHap phasing, default: 0.8 for ont guppy5 and 0.7 for other platforms") + + parser.add_argument('--phase', action='store_true', + help="Select only heterozygous candidates for phasing or not, default: False") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") + + # options for internal process control + ## Input the file that contains the quality cut-off for selecting low-quality pileup calls for phasing and full-alignment calling + parser.add_argument('--qual_fn', type=str, default=None, + help=SUPPRESS) + + args = parser.parse_args() + if args.phase: + select_phase_qual_from_stdin(args) + else: + select_qual_from_stdin(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/SortVcf.py b/benchmarks/nn-variant/Clair3/preprocess/SortVcf.py new file mode 100644 index 0000000..e08597c --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/SortVcf.py @@ -0,0 +1,271 @@ +import os +import subprocess +import shlex +from sys import stdin, exit +from argparse import ArgumentParser +from collections import defaultdict + + +from shared.utils import log_error, log_warning, file_path_from, subprocess_popen +major_contigs_order = ["chr" + str(a) for a in list(range(1, 23)) + ["X", "Y"]] + [str(a) for a in + list(range(1, 23)) + ["X", "Y"]] + + +def compress_index_vcf(input_vcf): + # use bgzip to compress vcf -> vcf.gz + # use tabix to index vcf.gz + proc = subprocess.run('bgzip -f {}'.format(input_vcf), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + proc = subprocess.run('tabix -f -p vcf {}.gz'.format(input_vcf), shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + +def output_header(output_fn, reference_file_path, sample_name='SAMPLE'): + output_file = open(output_fn, "w") + from textwrap import dedent + output_file.write(dedent("""\ + ##fileformat=VCFv4.2 + ##FILTER= + ##FILTER= + ##FILTER= + ##INFO= + ##INFO= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT=""" + ) + '\n') + + if reference_file_path is not None: + reference_index_file_path = file_path_from(reference_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + with open(reference_index_file_path, "r") as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + contig_name, contig_size = columns[0], columns[1] + output_file.write(("##contig=" % (contig_name, contig_size) + '\n')) + + output_file.write('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (sample_name)) + output_file.close() + +def print_calling_step(output_fn=""): + + merge_output = os.path.join(os.path.dirname(output_fn), 'merge_output.vcf.gz') + pileup_output = os.path.join(os.path.dirname(output_fn), 'pileup.vcf.gz') + + print (log_warning("[WARNING] Copying pileup.vcf.gz to {}".format(merge_output))) + subprocess.run('cp {} {}'.format(pileup_output, merge_output), shell=True, stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + +def check_header_in_gvcf(header, contigs_list): + # Only output the contigs processed to be consistent with GATK + # Contig format: ##contig= + + update_header = [] + for row_id, row in enumerate(header): + if row.startswith("##contig="): + contig = row.split(',')[0].split('=')[2] + if contig not in contigs_list: + continue + update_header.append(row) + + return update_header + +def sort_vcf_from_stdin(args): + """ + Sort vcf file according to variants start position and contig name. + """ + + row_count = 0 + header = [] + contig_dict = defaultdict(defaultdict) + no_vcf_output = True + for row in stdin: + row_count += 1 + if row[0] == '#': + if row not in header: + header.append(row) + continue + # use the first vcf header + columns = row.strip().split(maxsplit=3) + ctg_name, pos = columns[0], columns[1] + contig_dict[ctg_name][int(pos)] = row + no_vcf_output = False + if row_count == 0: + print(log_warning("[WARNING] No vcf file found, please check the setting")) + if no_vcf_output: + print(log_warning("[WARNING] No variant found, please check the setting")) + + contigs_order = major_contigs_order + list(contig_dict.keys()) + contigs_order_list = sorted(contig_dict.keys(), key=lambda x: contigs_order.index(x)) + with open(args.output_fn, 'w') as output: + output.write(''.join(header)) + for contig in contigs_order_list: + all_pos = sorted(contig_dict[contig].keys()) + for pos in all_pos: + output.write(contig_dict[contig][pos]) + + +def sort_vcf_from(args): + """ + Sort vcf file from providing vcf filename prefix. + """ + output_fn = args.output_fn + input_dir = args.input_dir + vcf_fn_prefix = args.vcf_fn_prefix + vcf_fn_suffix = args.vcf_fn_suffix + sample_name = args.sampleName + ref_fn = args.ref_fn + contigs_fn = args.contigs_fn + + if not os.path.exists(input_dir): + exit(log_error("[ERROR] Input directory: {} not exists!").format(input_dir)) + all_files = os.listdir(input_dir) + + if vcf_fn_prefix is not None: + all_files = [item for item in all_files if item.startswith(vcf_fn_prefix)] + if len(all_files) == 0: + output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name) + print (log_warning( + "[WARNING] No vcf file found with prefix:{}/{}, output empty vcf file".format(input_dir,vcf_fn_prefix))) + compress_index_vcf(output_fn) + print_calling_step(output_fn=output_fn) + return + + if vcf_fn_suffix is not None: + all_files = [item for item in all_files if item.endswith(vcf_fn_suffix)] + if len(all_files) == 0: + output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name) + print (log_warning( + "[WARNING] No vcf file found with suffix:{}/{}, output empty vcf file".format(input_dir,vcf_fn_prefix))) + compress_index_vcf(output_fn) + print_calling_step(output_fn=output_fn) + return + + all_contigs_list = [] + if contigs_fn and os.path.exists(contigs_fn): + with open(contigs_fn) as f: + all_contigs_list = [item.rstrip() for item in f.readlines()] + else: + exit(log_error("[ERROR] Cannot find contig file {}. Exit!").format(contigs_fn)) + + contigs_order = major_contigs_order + all_contigs_list + contigs_order_list = sorted(all_contigs_list, key=lambda x: contigs_order.index(x)) + + row_count = 0 + header = [] + no_vcf_output = True + need_write_header = True + + # only compress intermediate gvcf using lz4 output and keep final gvcf in bgzip format + output_bgzip_gvcf = vcf_fn_suffix == '.gvcf' + compress_gvcf = 'gvcf' in vcf_fn_suffix + if compress_gvcf: + lz4_path = subprocess.run("which lz4", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip() + compress_gvcf = True if lz4_path != "" else False + is_lz4_format = compress_gvcf + compress_gvcf_output = compress_gvcf and not output_bgzip_gvcf + if compress_gvcf_output: + write_fpo = open(output_fn, 'w') + write_proc = subprocess_popen(shlex.split("lz4 -c"), stdin=subprocess.PIPE, stdout=write_fpo, stderr=subprocess.DEVNULL) + output = write_proc.stdin + else: + output = open(output_fn, 'w') + + for contig in contigs_order_list: + contig_dict = defaultdict(str) + contig_vcf_fns = [fn for fn in all_files if contig in fn] + for vcf_fn in contig_vcf_fns: + file = os.path.join(input_dir, vcf_fn) + if is_lz4_format: + read_proc = subprocess_popen(shlex.split("{} {}".format("lz4 -fdc", file)), stderr=subprocess.DEVNULL) + fn = read_proc.stdout + else: + fn = open(file, 'r') + for row in fn: + row_count += 1 + if row[0] == '#': + # skip phasing command line only occur with --enable_phasing, otherwise would lead to hap.py evaluation failure + if row.startswith('##commandline='): + continue + if row not in header: + header.append(row) + continue + # use the first vcf header + columns = row.strip().split(maxsplit=3) + ctg_name, pos = columns[0], columns[1] + # skip vcf file sharing same contig prefix, ie, chr1 and chr11 + if ctg_name != contig: + break + contig_dict[int(pos)] = row + no_vcf_output = False + fn.close() + if is_lz4_format: + read_proc.wait() + if need_write_header and len(header): + if output_bgzip_gvcf: + header = check_header_in_gvcf(header=header, contigs_list=all_contigs_list) + output.write(''.join(header)) + need_write_header = False + all_pos = sorted(contig_dict.keys()) + for pos in all_pos: + output.write(contig_dict[pos]) + + if compress_gvcf_output: + write_proc.stdin.close() + write_proc.wait() + write_fpo.close() + return + else: + output.close() + + if row_count == 0: + print (log_warning("[WARNING] No vcf file found, output empty vcf file")) + output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name) + compress_index_vcf(output_fn) + print_calling_step(output_fn=output_fn) + return + if no_vcf_output: + output_header(output_fn=output_fn, reference_file_path=ref_fn, sample_name=sample_name) + print (log_warning("[WARNING] No variant found, output empty vcf file")) + compress_index_vcf(output_fn) + print_calling_step(output_fn=output_fn) + return + + if vcf_fn_suffix == ".tmp.gvcf": + return + if vcf_fn_suffix == ".gvcf": + print("[INFO] Need some time to compress and index GVCF file...") + compress_index_vcf(output_fn) + + +def main(): + parser = ArgumentParser(description="Sort a VCF file according to contig name and starting position") + + parser.add_argument('--output_fn', type=str, default=None, required=True, + help="Output VCF filename, required") + + parser.add_argument('--input_dir', type=str, default=None, + help="Input directory") + + parser.add_argument('--vcf_fn_prefix', type=str, default=None, + help="Input vcf filename prefix") + + parser.add_argument('--vcf_fn_suffix', type=str, default='.vcf', + help="Input vcf filename suffix") + + parser.add_argument('--ref_fn', type=str, default=None, + help="Reference fasta file input") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file, optional") + + parser.add_argument('--contigs_fn', type=str, default=None, + help="Contigs file with all processing contigs") + + args = parser.parse_args() + if args.input_dir is None: + sort_vcf_from_stdin(args) + else: + sort_vcf_from(args) + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/SplitExtendBed.py b/benchmarks/nn-variant/Clair3/preprocess/SplitExtendBed.py new file mode 100644 index 0000000..5d444c4 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/SplitExtendBed.py @@ -0,0 +1,89 @@ +import sys +import shlex +from argparse import ArgumentParser +import shared.param_p as param +from shared.utils import subprocess_popen + +def split_extend_bed(args): + + """ + Split bed file regions according to the contig name and extend bed region with no_of_positions = + flankingBaseNum + 1 + flankingBaseNum, which allow samtools mpileup submodule to scan the flanking windows. + """ + + bed_fn = args.bed_fn + output_fn = args.output_fn + contig_name = args.ctgName + region_start = args.ctgStart + region_end = args.ctgEnd + expand_region_size = args.expand_region_size + if bed_fn is None: + return + output = [] + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_fn))) + pre_end, pre_start = -1, -1 + + for row in unzip_process.stdout: + + if row[0] == '#': + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name != None and ctg_name != contig_name: + continue + ctg_start, ctg_end = int(columns[1]), int(columns[2]) + if region_start and ctg_end < region_start: + continue + if region_end and ctg_start > region_end: + break + if pre_start == -1: + pre_start = ctg_start - expand_region_size + pre_end = ctg_end + expand_region_size + continue + if pre_end >= ctg_start - expand_region_size: + pre_end = ctg_end + expand_region_size + continue + else: + output.append(' '.join([contig_name, str(pre_start), str(pre_end)])) + pre_start = ctg_start - expand_region_size + pre_end = ctg_end + expand_region_size + + with open(output_fn, 'w') as output_file: + output_file.write('\n'.join(output)) + + unzip_process.stdout.close() + unzip_process.wait() + + +def main(): + parser = ArgumentParser(description="Extend bed region for pileup calling") + + parser.add_argument('--output_fn', type=str, default=None, + help="Path to directory that stores small bins, default: %(default)s)" + ) + parser.add_argument('--bed_fn', type=str, default=None, + help="Path of the output folder, default: %(default)s") + + parser.add_argument('--expand_region_size', type=int, default=param.no_of_positions, + help="Expand region size for each bed region, default: %(default)s") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of sequence to be processed, default: %(default)s") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed") + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + split_extend_bed(args) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmarks/nn-variant/Clair3/preprocess/Tensor2Bin.py b/benchmarks/nn-variant/Clair3/preprocess/Tensor2Bin.py new file mode 100644 index 0000000..bf86ca0 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/Tensor2Bin.py @@ -0,0 +1,85 @@ +import sys +import logging +from argparse import ArgumentParser, SUPPRESS + +import clair3.utils as utils + +logging.basicConfig(format='%(message)s', level=logging.INFO) + +def Run(args): + + + utils.setup_environment() + logging.info("Loading the dataset ...") + + utils.get_training_array( + tensor_fn=args.tensor_fn, + var_fn=args.var_fn, + bed_fn=args.bed_fn, + bin_fn=args.bin_fn, + shuffle=args.shuffle, + is_allow_duplicate_chr_pos=args.allow_duplicate_chr_pos, + chunk_id=args.chunk_id-1 if args.chunk_id else None, # 1-base to 0-base + chunk_num=args.chunk_num, + pileup=args.pileup, + platform=args.platform, + maximum_non_variant_ratio=args.maximum_non_variant_ratio, + candidate_details_fn_prefix=args.candidate_details_fn_prefix) + logging.info("Finish!") + + +def main(): + parser = ArgumentParser(description="Combine the variant and non-variant tensors and convert them to a binary") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--tensor_fn', type=str, default="PIPE", + help="Tensor input") + + parser.add_argument('--candidate_details_fn_prefix', type=str, default=None, + help="Candidate details input (unused, retained for compatibility)") + + parser.add_argument('--var_fn', type=str, default=None, required=True, + help="Truth variants list input, required") + + parser.add_argument('--bin_fn', type=str, default=None, required=True, + help="Output a binary tensor file, required") + + parser.add_argument('--bed_fn', type=str, default=None, + help="High confident genome regions input in the BED format") + + parser.add_argument('--shuffle', action='store_true', + help="Shuffle the inputs") + + parser.add_argument('--allow_duplicate_chr_pos', action='store_true', + help="Allow duplicated chromosome:position in the tensor input") + + # options for internal process control + ## In pileup mode or not (full alignment mode), default: False + parser.add_argument('--pileup', action='store_true', + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + ## Maximum non-variant ratio against variant in the training data + parser.add_argument('--maximum_non_variant_ratio', type=float, default=None, + help=SUPPRESS) + + args = parser.parse_args() + + if len(sys.argv[1:]) == 0: + parser.print_help() + sys.exit(1) + + Run(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/Clair3/preprocess/UnifyRepresentation.py b/benchmarks/nn-variant/Clair3/preprocess/UnifyRepresentation.py new file mode 100644 index 0000000..92a2ca6 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/UnifyRepresentation.py @@ -0,0 +1,1456 @@ +import collections +import heapq +import itertools +import json +import shlex +import signal +import sys +import os + +from collections import Counter +from argparse import ArgumentParser, SUPPRESS +from collections import defaultdict +from subprocess import PIPE, Popen +from shared.command_options import ( + CommandOption, + CommandOptionWithNoValue, + ExecuteCommand, + command_string_from, + command_option_from +) +from shared.utils import file_path_from, executable_command_string_from, subprocess_popen, str2bool, log_warning + +from shared.interval_tree import bed_tree_from, is_region_in +from shared.utils import subprocess_popen, region_from, reference_sequence_from +import shared.param_p as param + +class InstancesClass(object): + def __init__(self): + self.create_tensor = None + + def poll(self): + self.create_tensor.poll() + + +c = InstancesClass() + +def check_return_code(signum, frame): + c.poll() + if c.create_tensor.returncode != None and c.create_tensor.returncode != 0: + c.compress_tensor.kill() + sys.exit("CreateTensor.py exited with exceptions. Exiting...") + + if c.create_tensor.returncode == None: + signal.alarm(5) + + +extended_window_size = 200 +region_size = 50000 +reference_region_size = region_size * 2 +extend_bp = 100 +reference_allele_gap = 0.9 + +def subprocess_popen(args, stdin=None, stdout=PIPE, stderr=sys.stderr, bufsize=8388608): + return Popen(args, stdin=stdin, stdout=stdout, stderr=sys.stderr, bufsize=bufsize, universal_newlines=True) + +class Reference(object): + """ + Reference region query with given reference start and reference end, we cocnat the reference base with altertive base + with reference base to generate read-level haplotype. + """ + def __init__(self, seq, start, reference_sequence, reference_start): + self.start = start + self.end = start + len(seq) + self.seq = seq + self.reference_sequence = reference_sequence + self.reference_start = reference_start + + def query(self, start, end): + return self.seq[start - self.start:end - self.start] + + +def all_genotypes_combination(variant, alt_dict, variant_dict): + """ + Enumerate true variant site and candidate site genotype combination and find read. For a phased confident site, we + only enumerate one genotype with confident flag. + """ + + if variant.type == 'candidate': + num_ref_and_alts = len(variant.variant.alternate_bases) + + if variant.variant.start in alt_dict and alt_dict[variant.variant.start].phased_genotype: + return 1 + elif variant.variant.start in alt_dict and alt_dict[variant.variant.start].confident_variant: + return 2 + return (num_ref_and_alts + 1) * num_ref_and_alts / 2 + else: + if variant.variant.start in alt_dict and alt_dict[ + variant.variant.start].phased_genotype and variant.variant.start in variant_dict: + return 1 + return len(variant.variant.alternate_bases) + +def unique_genotypes_selection(genotype_options): + """ + Extend true two haplotypes according to the chosen genotype and only save the haplotype set with distinct haplotype, + if two haplotypes set match in either hap1(1) == hap2(1) or hap1(1) = hap2(1), remove the duplication to reduce + massive comparison. + """ + + genotypes_list = [] + existed_genotypes_path = set() + for genotype_pair in itertools.product(*genotype_options): + g1, g2 = "", "" + for h1, h2 in genotype_pair: + g1 += str(h1) + g2 += str(h2) + genotype_tupe = (g1, g2) + genotype_tupe_reverse = (g2, g1) + if genotype_tupe in existed_genotypes_path or genotype_tupe_reverse in existed_genotypes_path: + continue + existed_genotypes_path.add(genotype_tupe) + existed_genotypes_path.add(genotype_tupe_reverse) + genotypes_list.append(genotype_pair) + return genotypes_list + +def find_read_support(variants, ref, variant_type, max_calculate_count, variant_dict=None, read_name_info_dict=None, truths=None, alt_dict=None, no_match_found=False): + """ + Find read-level support for each matched haplotype, we only extended the reference sequence with the alternative base, + and discard low allele frequency systematic error. + """ + read_seqs_counter = None + if variant_type == 'candidate': + all_read_set = set() + all_read_seq_dict = defaultdict(str) + for v in variants: + pos = v.start + read_name_set = alt_dict[pos].read_name_set + all_read_set = all_read_set.union(read_name_set) + for read_name in all_read_set: + ref_start = ref.start + ref_end = ref.end + read_seq = read_name_info_dict[read_name].seq + if not len(read_seq): + continue + ref_offset, alt_offset, pre_end = 0, 0, 0 + for start, end, seq in read_seq: + if end < ref_start or start > ref_end: + continue + start_offset = start - ref_start + end_offset = end - ref_start + if start_offset >= pre_end: + all_read_seq_dict[read_name] += ref.seq[pre_end:start_offset] + seq + pre_end = end_offset + if pre_end < len(ref.seq): + all_read_seq_dict[read_name] += ref.seq[pre_end:] + read_seqs_counter = Counter(all_read_seq_dict.values()) + + def extend_genotype(variants_and_genotypes, next_pos_list): + + """ + We give two iterator of two haplotype start position and update them separately, which allow haplotype extension + more flexible when one of the haplotype has insertion or deletion, if the start position reach our provided region, + then the extension will stop. + """ + if next_pos_list is None or None in next_pos_list: + pass + if not variants_and_genotypes: + hap1_last_pos, hap2_last_pos = next_pos_list + + rest_seq_1 = ref.query(hap1_last_pos, + ref.end) if hap1_last_pos != ref.end and hap1_last_pos else "" + rest_seq_2 = ref.query(hap2_last_pos, + ref.end) if hap2_last_pos != ref.end and hap2_last_pos else "" + yield [rest_seq_1, rest_seq_2] # add last padding ref base + else: + current_variant, remaining_variants = [variants_and_genotypes[0]], variants_and_genotypes[1:] + prefix_seqs_list, next_pos_list = find_seqs( + current_variant, next_pos_list, ref) + prefix_seqs_list = list(prefix_seqs_list) + + if not prefix_seqs_list or next_pos_list is None: + pass + + for seqs in extend_genotype(remaining_variants, next_pos_list): + yield [prefix_seqs_list[0] + seqs[0], prefix_seqs_list[1] + seqs[1]] + + def extend_genotypes(variants_and_genotypes, next_pos): + + try: + for r in extend_genotype(variants_and_genotypes, next_pos): + yield r + except: + pass + + genotypes_combinations = genotypes_combination(variants, variant_type, variant_dict, max_calculate_count, truths, alt_dict=alt_dict, no_match=no_match_found) + genotypes_seqs_dict = defaultdict(list) + genotypes_list = unique_genotypes_selection(genotypes_combinations) + GT = collections.namedtuple('GT', ['variant', 'genotypes']) + for genotypes in genotypes_list: + variants_and_genotypes = [GT(v, g) for v, g in zip(variants, genotypes)] + for seqs in extend_genotypes(variants_and_genotypes, [ref.start, ref.start]): + genotypes_seqs_dict[frozenset(seqs)].append(genotypes) + return genotypes_seqs_dict, read_seqs_counter + +def remove_common_suffix(ref_base, alt_base): + """ + For each haploid match, we simplify the reference base and alternative base and remove their common suffix characters. + """ + + min_length = min(len(ref_base) - 1, min([len(item) - 1 for item in alt_base])) # keep at least one base + prefix = ref_base[::-1] + for string in alt_base: + string = string[::-1] + while string[:len(prefix)] != prefix and prefix: + prefix = prefix[:len(prefix) - 1] + if not prefix: + break + res_length = len(prefix) + if res_length > min_length: + return ref_base, alt_base + return ref_base[:len(ref_base) - res_length], [item[:len(item) - res_length] for item in alt_base] + + return ref_base[-min_length], [item[-min_length] for item in alt_base] + + +def has_multi_in_truths(truths, start=None): + for t in truths: + if len(t.alternate_bases) > 1: + return True + return False + +def count_combination(genotypes_combinations): + """ + Calculate the Cartesian product required for all genotype combinations + """ + product = 1 + for gc in genotypes_combinations: + product *= len(gc) + return product + +def genotypes_combination(variants, variant_type, variant_dict, max_calculate_count, truths=None, alt_dict=None, no_match=False, simplfy_combination=False): + + """ + Calculate genotype combination for haplotype set generation. For a locked confident variant or candidate site, we directly + extend with its phased genotype, while for other candidates, we need to assign with a missing flag to skip the candidate, or a + pass flag keep the genotype for further extension. + """ + + if no_match: + [{(0, 0)}] * len(variants) + if variant_type == 'truth': + output = [] + for v in variants: + if v: + gt = tuple(v.genotype) + + is_confident_pos = (variant_dict[v.start].confident_variant and variant_dict[v.start].phased_genotype) if v.start in variant_dict else False + output.append({tuple(variant_dict[v.start].phased_genotype)} if is_confident_pos else {(0, 0), tuple(gt), tuple(list(gt)[::-1])}) + else: + output.append({(-1, -1)}) + return output + elif variant_type == 'candidate': + output = [] + has_multi_in_truth = has_multi_in_truths(truths) + + for v in variants: + is_confident_pos = v.start in alt_dict and alt_dict[v.start].phased_genotype + pos_in_truths = is_confident_pos and v.start in variant_dict and variant_dict[v.start].phased_genotype + if pos_in_truths: + output.append({tuple(alt_dict[v.start].phased_genotype)}) + elif is_confident_pos: + if simplfy_combination: + output.append({tuple(alt_dict[v.start].phased_genotype)}) + else: + output.append({(0, 0), tuple(alt_dict[v.start].phased_genotype)}) + else: + gt_set = set() + for idx_1 in range(len(v.alternate_bases) + 1): + for idx_2 in range(len(v.alternate_bases) + 1): + if simplfy_combination and not has_multi_in_truth: + if idx_1 != 0 and idx_2 != 0 and idx_1 != idx_2: + continue + gt_set.add((idx_1, idx_2)) + output.append(gt_set) + # in extra mass combination, need to simplfy low possible cases: + if count_combination(output) > max_calculate_count: + if simplfy_combination: + # skip + return [{(0, 0)}] * len(variants) + else: + return genotypes_combination(variants, variant_type, variant_dict, max_calculate_count, truths, alt_dict, no_match, simplfy_combination=True) + return output + + +def find_seqs(variants_and_genotypes, last_pos_list, ref): + seqs = ["", ""] + genotypes = [vg.genotypes for vg in variants_and_genotypes] + variants = [vg.variant for vg in variants_and_genotypes] + all_genotypes = [tuple([item[i] for item in genotypes]) for i in [0,1]] + next_last_pos_list = [0,0] + for idx, phased_genotype in enumerate(all_genotypes): + current_seq, hap_end = build_seq(variants, phased_genotype, ref, last_pos_list[idx], None) # if overlap, merge multiple variants together + next_last_pos_list[idx] = hap_end + if current_seq: + seqs[idx] = current_seq + return seqs, next_last_pos_list + + +def build_seq(variants, phased_genotype, ref, pre_start, ref_end=None): + + """ + Build or extend the haplotype according to provided genotype. We marked the start position iterator of each haplotype and + update with variant alternative base. + """ + + seqs = "" + position = pre_start + for variant, phased in zip(variants, phased_genotype): + if variant.start < pre_start: + if variant.start == pre_start - 1 and phased != 0: # this only happen when pre pos is deletion and current pos is insertion + ref_base = variant.reference_bases + alt_base = variant.alternate_bases[phased - 1] + if len(alt_base) > len(ref_base): # is an insertion + # print ('has insertion and deletion overlap'.format(variant.start)) + return alt_base[1:], position + if phased != 0: # impossible # sometimes happen in true vcf + return None, None + else: + return "", pre_start # do not do anything if 0 allele + else: + seqs += ref.query(pre_start, variant.start) + + allele = variant.reference_bases if phased == 0 else variant.alternate_bases[phased - 1] + if phased == 0: + allele = allele[0] + position = variant.start + 1 + seqs += allele # only add one ref base + else: + ref_base = variant.reference_bases + alt_base = variant.alternate_bases[phased-1] + ref_base, alt_base = remove_common_suffix(ref_base, [alt_base]) + end = variant.start + len(ref_base) + position = end + seqs += alt_base[0] + + return seqs, position + +class ReadMatch(object): + def __init__(self, sample_ctg_info, candidates, candidate_genotypes, truths,match_seq,truth_genotypes,match_reads_count=0): + + self.sample_ctg_info = sample_ctg_info + self.candidates = candidates + self.truths = truths + self.candidate_genotypes = candidate_genotypes + self.truth_genotypes = truth_genotypes + self.match_reads_count = match_reads_count + self.match_seq = sorted(match_seq) + self.truths_pos_list = [t.start for t in truths] + self.candidates_pos_list = [c.start for c in candidates] + self.raw_genotypes = [tuple(v.genotype) for v in self.truths] + self.non_variants = [int(sum(cg) == 0) for cg in self.candidate_genotypes] + self.miss_variants_count = sum([1 if sum(gt) < sum(raw_gt) else 0 for raw_gt, gt in zip(self.raw_genotypes,self.truth_genotypes)]) + self.match_variants_count = sum([1 if item < 1 else 0 for item in self.non_variants]) + self.non_variants_count = sum(self.non_variants) + self.match_order = (self.match_reads_count, self.miss_variants_count, self.non_variants_count, self.match_variants_count) + + def match_info(self): + can_info, truth_info = "", "" + for can, gt in zip(self.candidates, self.candidate_genotypes): + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + can_info += str(can.start) + '-' + can.reference_bases + '->' + '-'.join(can.alternate_bases) + gt_str + + for truth, gt in zip(self.truths, self.truth_genotypes): + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + truth_info += str(truth.start) + '-' + truth.reference_bases + '->' + '-'.join(truth.alternate_bases) + gt_str + + extro_info = "" + if self.match_reads_count >=-6 and self.miss_variants_count > 0 and self.match_variants_count > 0: + extro_info = '\nthis match has few read support' + return ('ctg_info={}, read_support,miss_variants,non_variants,match_variants={}, candidate={}, truth={} {}').format(self.sample_ctg_info, + self.match_order, + can_info, truth_info, extro_info) + +class Position(object): + def __init__(self, pos, genotype1, genotype2, ref_base=None, alt_base=None, candidate=False, cigar_count=None, + confident_variant=False, depth=None, alt_list=None, af_list=None, alt_type_mapping_dict=None): + self.pos = pos + self.reference_bases = ref_base + self.candidate = candidate + + if candidate == True: + self.alternate_bases = alt_base + else: + self.alternate_bases = [alt_base] if ',' not in alt_base else alt_base.split(',') + + self.start = pos + self.end = self.pos + len(ref_base) + self.genotype = [genotype1, genotype2] + self.cigar_count = cigar_count + self.confident_variant = confident_variant + self.read_name_set = set() + self.depth = depth + self.variant_hap_dict = defaultdict(defaultdict) + self.phased_genotype = None + self.hap_count_dict = defaultdict(int) + self.alt_list = alt_list + def update_info(self, ref_base, alt_base, genotype): + self.reference_bases = ref_base + self.alternate_bases = alt_base + self.genotype = genotype + + +class Read(object): + def __init__(self, hap=0): + self.hap = hap + self.pos_alt_dict = defaultdict(str) + self.start = None + self.end = None + self.seq = [] + + +def decode_alt_info(cigar_count, ref_base, depth, minimum_allele_gap): + """ + Decode the input read-level alternative information + cigar_count: each alternative base including snp, insertion and deletion of each position + pileup_bases: pileup bases list of each read in specific candidate position from samtools mpileup 1.10 + reference_sequence: the whole reference sequence index by contig:start-end. 0-based. + ref_base: upper reference base for cigar calculation. + depth: depth of candidate position for calculation. + minimum_allele_gap: default minimum allele frequency for candidate to consider as a potential true variant for unification. + """ + alt_type_list = [] # SNP I D + seqs = cigar_count.split(' ') + seq_alt_bases_dict = dict(zip(seqs[::2], [int(item) for item in seqs[1::2]])) if len(seqs) else {} + if '*' in seq_alt_bases_dict: + del seq_alt_bases_dict['*'] + max_del_cigar = "" + del_list = [] + ref_represatation = ref_base + alt_list = sorted(list(seq_alt_bases_dict.items()), key=lambda x: x[1], reverse=True) + + seq_insertion_bases_list = alt_list[:2] + af_list = [] + for alt_type, count in seq_insertion_bases_list: + count = int(count) + if '*' in alt_type or '#' in alt_type: + continue + if count / float(depth) < minimum_allele_gap: + continue + af_list.append(count/ float(depth)) + if alt_type[0] == 'X': + alt_type_list.append(alt_type[1]) + elif alt_type[0] == 'I': + alt_type_list.append(alt_type[1:]) + elif alt_type[0] == 'D': + if len(alt_type[1:]) > len(max_del_cigar): + max_del_cigar = alt_type[1:] + del_list.append(ref_base + alt_type[1:]) + new_del_list = [] + if len(max_del_cigar): + ref_represatation = ref_base + max_del_cigar + alt_type_list = [item + max_del_cigar for item in alt_type_list] + for idx, item in enumerate(del_list): + start_pos = len(item[1:]) + append_del_bases = max_del_cigar[start_pos:] + new_del_list.append( + ref_base + append_del_bases) # ACG-> A, ACGTT -> A, max_del_cigar is CGTT, represent ACG-> A to ACGTT->ATT + alt_base_list = alt_type_list + new_del_list + return ref_represatation, alt_base_list,af_list, alt_list + +def has_variant_suport(ref_base, alt_base, pos, alt_dict): + """ + ref_base: reference base of the true varaint. + alt_base: alternative base of the true variant + pos: pos: candidate position for unification. + alt_dict: dictionary (pos: pos info) which keep position level candidate reference base and alternative base information. + return the alternative index of each candidate site if found match + """ + + alt_index = -1 + if pos not in alt_dict or not len(alt_dict[pos]): + return alt_index + cigar_count = alt_dict[pos] + if len(ref_base) == 1 and len(alt_base) == 1: # Snp + if alt_base in cigar_count: + alt_index = cigar_count[0][alt_base] + elif len(ref_base) > len(alt_base): # D + del_base = ref_base[1:len(ref_base) - len(alt_base) + 1] + if del_base in cigar_count: + alt_index = cigar_count[1][alt_base] + elif len(ref_base) < len(alt_base): # I + ins_base = alt_base[1:len(alt_base) - len(ref_base) + 1] + if ins_base in cigar_count: + alt_index = cigar_count[2][alt_base] + return alt_index + + +def get_ref(ref_fn, ctg_name): + refernce_sequences = [] + samtools_faidx_process = subprocess_popen( + shlex.split("samtools faidx {} {}".format(ref_fn, ctg_name)) + ) + while True: + row = samtools_faidx_process.stdout.readline() + is_finish_reading_output = row == '' and samtools_faidx_process.poll() is not None + if is_finish_reading_output: + break + if row: + refernce_sequences.append(row.rstrip()) + + reference_sequence = "".join(refernce_sequences[1:]) + + reference_sequence = reference_sequence.upper() + samtools_faidx_process.stdout.close() + samtools_faidx_process.wait() + reference_start = 1 + return reference_sequence, reference_start + + +def get_genotype(genotype): + g1, g2 = genotype + min_gt = min(int(g1), int(g2)) + max_gt = max(int(g1), int(g2)) + return str(min_gt) + '/' + str(max_gt) + +def lock_variant(variant, truth): + """ + Find potential locked true variant and candidate site if we consider it as a confident site, there are only exactly one + candidate match with the true variant in a specific site, and we can further take the candidate into consideration with + a match flag. + """ + if truth is None: + return None, False + variant_ref_base = variant.reference_bases + variant_alt_base = variant.alternate_bases + truth_ref_base = truth.reference_bases + truth_alt_base = truth.alternate_bases + + if len(variant_alt_base) != len(truth_alt_base): + return None, False + tmp_alt_list = [] + for ab in variant_alt_base: + ref_base1, alt_base1 = remove_common_suffix(variant_ref_base, [ab]) + tmp_alt_list.append((alt_base1[0], ref_base1)) + match_index = [-1] * len(truth_alt_base) + + for t_idx, ab in enumerate(truth_alt_base): + ref_base1, alt_base1 = remove_common_suffix(truth_ref_base, [ab]) + for a_idx, (alt_base, ref_base) in enumerate(tmp_alt_list): + if alt_base1[0] == alt_base and ref_base1 == ref_base: + match_index[t_idx] = a_idx + match = sum([1 if item >= 0 else 0 for item in match_index]) == len(truth_alt_base) # can find all alt_base + return match_index, match + +def decode_variant(variant, reference_base): + if variant == 'R': + return 'R', 'R' + if variant[0] == 'X': + return reference_base, variant[1] + elif variant[0] == 'I': + return reference_base, variant[1:] + elif variant[0] == 'D': + return reference_base + variant[1:], reference_base + +def update_variant_hap_dict(alt_dict, pos, reference_sequence, reference_start, is_variant_confident, variant_dict, allele_gap, platform): + """ + For a phased alignment, the candidates are easier to lock as confident if the signal exists strongly in one side and have confident + match with true variant. + """ + phased_genotype = [-1,-1] + reference_base = reference_sequence[pos- reference_start] + variant_hap_dict = alt_dict[pos].variant_hap_dict + if not len(variant_hap_dict): + return None + hap_count_dict = alt_dict[pos].hap_count_dict + variant_ref_base = alt_dict[pos].reference_bases + variant_alt_base = alt_dict[pos].alternate_bases + for variant, hap_dict in variant_hap_dict.items(): + if variant in '*#': + continue + hap_0 = hap_dict[0] if 0 in hap_dict else 0 + # for illumina unification, phased information contributes less, we safely denote with reference allele gap + if platform == 'ilmn': + hap_total_af = (hap_0) / float(sum(list(hap_count_dict.values()))) + if variant not in 'R*' and hap_total_af > reference_allele_gap and is_variant_confident and variant_dict[pos].genotype == [1, 1]: + phased_genotype = [1, 1] + return phased_genotype + if -1 in phased_genotype: + return None + return phased_genotype + + hap_1 = hap_dict[1] if 1 in hap_dict else 0 + hap_2 = hap_dict[2] if 2 in hap_dict else 0 + hap_0 = hap_0 if hap_0 > 3 else 0 + hap_1 = hap_1 if hap_1 > 3 else 0 + hap_2 = hap_2 if hap_2 > 3 else 0 + + hap_total_af = (hap_0 + hap_1 + hap_2) / float(sum(list(hap_count_dict.values()))) + if variant not in 'R*#' and hap_total_af > 1 - allele_gap / 2 and is_variant_confident and variant_dict[pos].genotype== [1, 1]: + phased_genotype = [1, 1] + return phased_genotype + hap_1_af = hap_1 / float(hap_count_dict[1]) if 1 in hap_count_dict else 0.0 + hap_2_af = hap_2 / float(hap_count_dict[2]) if 2 in hap_count_dict else 0.0 + + if variant == 'R': + if hap_1_af >= 1 - allele_gap: + phased_genotype[0] = 0 + if hap_2_af >= 1 - allele_gap: + phased_genotype[1] = 0 + continue + ref_base, alt_base = decode_variant(variant, reference_base) + + alt_index = -1 + for ab_idx, ab in enumerate(variant_alt_base): + ref_base1, alt_base1 = remove_common_suffix(variant_ref_base, [ab]) + if alt_base1[0] == alt_base and ref_base1 == ref_base: + alt_index = ab_idx + break + if alt_index == -1: + continue + if hap_1_af >= 1 - allele_gap * 2: + phased_genotype[0] = alt_index + 1 + if hap_2_af >= 1 - allele_gap * 2: + phased_genotype[1] = alt_index + 1 + + if -1 in phased_genotype: + return None + return phased_genotype + + +def match_alt_base(alt_list, ref_base, alt_base): + if not len(alt_list) or (len(alt_list) == 1 and 'R' in alt_list): + return False + alt_set = set([item[0] for item in alt_list]) + + for ab in alt_base: + rb, ab = remove_common_suffix(ref_base, [ab]) + if len(rb) == len(ab[0]): #snp + ab = 'X' + ab[0] + if ab in alt_set: + return True + elif len(rb) < len(ab[0]): # insertion + ab = 'I' + ab[0] + if ab in alt_set: + return True + elif len(rb) > len(ab[0]): + ab = 'D' + rb[1:] + if ab in alt_set: + return True + return False + +def check_confident_match(candidates, truths): + + """ + Double check whether the candidate site match the representation in the truth variant site in reference base, + alternative base, genotype and start position. + """ + + if len(candidates) != len(truths): + return False + all_candidate_positions = set([c.start for c in candidates]) + for truth in truths: + if truth.start not in all_candidate_positions: + return False + for candidate in candidates: + if candidate.start == truth.start: + if candidate.reference_bases != truth.reference_bases or candidate.alternate_bases != truth.alternate_bases: + return False + return True + +def split_variants_truths(candidates, + truths, + partition_size, + max_candidates_distance, + max_calculate_count, + variant_dict=None, + alt_dict=None): + """ + Split all candidate sites and true variant according to the start position, for true variant site, we extend the + at least one candidate site in both two sides to aviod missing match. + """ + INFO = collections.namedtuple('INFO', ['start', 'type', 'variant']) + def match_max_candidate_distance(partition, variants, new_count): + if not partition: + return True + n_of_type = sum(1 for g in partition if g.type == variants.type) + if new_count >= max_calculate_count or n_of_type >= partition_size: + if new_count >= max_calculate_count: + print('{} exceed max calculation count'.format(new_count)) + return False + else: + for g in partition: + if variants.variant.start - g.variant.end + 1 > max_candidates_distance: + return False + + last_par = partition[-1].variant.end + if variants.variant.start - last_par + 1 > extend_bp: + return False + return True + + truths_pos_set = set([v.start for v in truths]) + sorted_variants = list(heapq.merge( + [INFO(v.start, 'candidate', v) for v in candidates], + [INFO(t.start, 'truth', t) for t in truths])) + + all_partitions = [] + partition = [] + product_count = 1 + for sv_idx in range(len(sorted_variants)): + variants = sorted_variants[sv_idx] + new_count = product_count * all_genotypes_combination( + variants, variant_dict, alt_dict) + if match_max_candidate_distance(partition, variants, + new_count): + partition.append(variants) + product_count = new_count + else: + if variants.start == partition[-1].start and variants.type != partition[ + -1].type: # + # add same truths or variants together and add at least one nearby candidate + partition.append(variants) + if sv_idx < len(sorted_variants) - 1 and sorted_variants[sv_idx + 1].start not in truths_pos_set and \ + sorted_variants[sv_idx + 1].start - variants.start <= extend_bp: + partition.append(sorted_variants[sv_idx + 1]) + all_partitions.append(partition) + partition = [] + product_count = 1 + else: + all_partitions.append(partition) + partition = [variants] + product_count = all_genotypes_combination(variants, variant_dict, alt_dict) + if partition: + all_partitions.append(partition) + + split_partitions = [] + for partitions in all_partitions: + candidate_partitions = [] + truth_partitions = [] + for p in partitions: + if p.type == 'candidate': + candidate_partitions.append(p.variant) + elif p.type == 'truth': + truth_partitions.append(p.variant) + split_partitions.append([candidate_partitions, truth_partitions]) + return split_partitions + +class RepresentationUnification(object): + + def __init__(self, + sample_name, + contig_name, + reference_sequence, + reference_start, + partition_size, + max_candidates_distance, + max_calculate_count, + subsample_ratio): + + self.sample_name = sample_name + self.contig_name = contig_name + self.subsample_ratio = subsample_ratio + self.sample_ctg_info = '_'.join([sample_name, str(subsample_ratio), contig_name]) + self.partition_size = partition_size + self.max_candidates_distance = max_candidates_distance + self.max_calculate_count = max_calculate_count + self.reference_sequence = reference_sequence + self.reference_start = reference_start + + + def get_reference_seq(self, candidates, true_variants, bufsize=50): + all_variants = candidates + true_variants + start = min(x.start for x in all_variants) + end = max(x.end for x in all_variants) + + ref_bases = self.reference_sequence[start - self.reference_start - 1:end + bufsize - self.reference_start] + return Reference(seq=ref_bases, + start=start - 1, + reference_sequence=self.reference_sequence, + reference_start=self.reference_start) + + def find_match_pairs(self, candidates, truths, ref, variant_dict, read_name_info_dict=None, alt_dict=None): + no_match_found = len(candidates) == 0 or len(truths) == 0 + + if no_match_found: + can_info, truth_info = "", "" + for can in candidates: + gt = can.genotype + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + can_info += str(can.start) + '-' + can.reference_bases + '->' + '-'.join(can.alternate_bases) + gt_str + + for truth in truths: + gt = truth.genotype + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + truth_info += str(truth.start) + '-' + truth.reference_bases + '->' + '-'.join( + truth.alternate_bases) + gt_str + + print ('[INFO] Missing match: ctg_info={}, read_support,miss_variants,non_variants,match_variants=(0, {}, {}, 0), candidate={}, truth={}'.format(self.sample_ctg_info,len(truths), len(candidates),can_info, truth_info)) + return None + + confident_match = check_confident_match(candidates, truths) + if confident_match: + can_info, truth_info = "", "" + for can in candidates: + gt = can.genotype + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + can_info += str(can.start) + '-' + can.reference_bases + '->' + '-'.join(can.alternate_bases) + gt_str + + for truth in truths: + gt = truth.genotype + gt_str = '_' + '/'.join(map(str, gt)) + ' ' + truth_info += str(truth.start) + '-' + truth.reference_bases + '->' + '-'.join( + truth.alternate_bases) + gt_str + + print ( + '[INFO] Found confident match: ctg_info={}, read_support,miss_variants,non_variants,match_variants=(None, 0, 0, {}), candidate={}, truth={}'.format( + self.sample_ctg_info, len(truths), can_info, truth_info)) + + match_genotype = [tuple(v.genotype) for v in truths] + + return ReadMatch( + sample_ctg_info=self.sample_ctg_info, + candidates=truths, + candidate_genotypes=match_genotype, + truths=truths, + truth_genotypes=match_genotype, + match_seq=ref.seq, + match_reads_count=100) + + truths_candidate_gentoypes = genotypes_combination(truths, 'truth', variant_dict, max_calculate_count, + truths, alt_dict=alt_dict, no_match=no_match_found) + candidates_candidate_gentoypes = genotypes_combination(candidates, 'candidate', variant_dict, + max_calculate_count, + truths, alt_dict=alt_dict, + no_match=no_match_found) + + truths_genotypes_list = unique_genotypes_selection(truths_candidate_gentoypes) + candidates_genotypes_list = unique_genotypes_selection(candidates_candidate_gentoypes) + + print (len(truths_genotypes_list) * len(candidates_genotypes_list)) + if len(truths_genotypes_list) * len(candidates_genotypes_list) > self.max_calculate_count: + return None + + truth_seqs, _ = find_read_support( + variants=truths, + ref=ref, + variant_type='truth', + max_calculate_count=self.max_calculate_count, + variant_dict=variant_dict, + truths=truths, + read_name_info_dict=read_name_info_dict, + alt_dict=alt_dict, + no_match_found=no_match_found) + + variant_seqs, read_seqs_counter = find_read_support( + variants=candidates, + ref=ref, + variant_type='candidate', + max_calculate_count=self.max_calculate_count, + variant_dict=variant_dict, + truths=truths, + read_name_info_dict=read_name_info_dict, + alt_dict=alt_dict, + no_match_found=no_match_found) + + matches = [] + for variant_seq, variant_genotypes in variant_seqs.items(): + if variant_seq not in truth_seqs: + continue + truth_seq = truth_seqs[variant_seq] + for variant_genotype in variant_genotypes: + match_reads_count = -sum([read_seqs_counter[seq] if read_seqs_counter and seq in read_seqs_counter else 0 for seq in variant_seq ]) # more match reads and negative count is better if smaller + matches.append(ReadMatch( + sample_ctg_info=self.sample_ctg_info, + candidates=candidates, + candidate_genotypes=variant_genotype, + truths=truths, + truth_genotypes=truth_seq[0], + match_seq=variant_seq, + match_reads_count=match_reads_count)) + if not matches: + return None + else: + best_matches = sorted(matches, key=lambda x: x.match_order)[0] + print ('[INFO] Found match case:', best_matches.match_info()) + return best_matches + + + def unify_label(self, variants, truths, region, ctg_start, ctg_end, all_pos, variant_dict, + rescue_dict=None, output_vcf_fn=None, test_pos=None, read_name_info_dict=None, alt_dict=None): + split_start, split_end = region + + all_partitions = split_variants_truths( + candidates=list(variants), + truths=list(truths), + partition_size=self.partition_size, + max_candidates_distance=self.max_candidates_distance, + max_calculate_count=self.max_calculate_count, + variant_dict=variant_dict, + alt_dict=alt_dict) + + for all_candidates, all_truths in all_partitions: + ref = self.get_reference_seq(all_candidates, all_truths) + match_pairs = self.find_match_pairs(candidates=all_candidates, + truths=all_truths, + ref=ref, + variant_dict=variant_dict, + read_name_info_dict=read_name_info_dict, + alt_dict=alt_dict) + + if match_pairs is None: + if not len(truths): + continue + # double check to rescue true variants + for truth in all_truths: + pos = truth.start + # add missing low-confident tp position + if not (pos >= split_start and pos < split_end) or (ctg_start is not None and ctg_end is not None + and not (pos >= ctg_start and pos < ctg_end)): + continue + if pos in alt_dict and pos in variant_dict: + ref_base = variant_dict[pos].reference_bases + alt_base = variant_dict[pos].alternate_bases + alt_list = alt_dict[pos].alt_list + if not match_alt_base(alt_list, ref_base, alt_base): + print('[INFO] {} {} miss and has no cigar support'.format(self.sample_ctg_info, pos)) + continue + print('[INFO] {} {} miss by match, append to vcf'.format(self.sample_ctg_info, pos)) + if pos in all_pos or pos in rescue_dict: + continue + ref_base = variant_dict[pos].reference_bases + variant = ','.join(variant_dict[pos].alternate_bases) + genotype_string = '/'.join(map(str, variant_dict[pos].genotype)) + # For efficiency, we currently only compute reference base, altnertive base and genotype from GetTruth.py + rescue_dict[pos] = "%s\t%d\t.\t%s\t%s\t%d\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( + self.contig_name, + pos, + ref_base, + variant, + 10, + 'PASS', + '.', + genotype_string, + 10, + 10, + 0.5) + else: + print('[INFO] {} {} miss and no variant support'.format(self.sample_ctg_info, pos)) + continue + for idx, (candidate, candidate_genotypes) in enumerate( + zip(match_pairs.candidates, match_pairs.candidate_genotypes)): + pos = candidate.start + + have_miss_variants = True if sum([1 for gt in match_pairs.truth_genotypes if sum(gt) == 0]) else False + + # append a position into rescue queue if it was missed by the unification + if sum(candidate_genotypes) == 0 and pos not in rescue_dict and have_miss_variants and pos not in variant_dict and pos in alt_dict and alt_dict[pos].phased_genotype: + genotype_string = '/'.join(map(str, alt_dict[pos].phased_genotype)) + variant = ','.join(candidate.alternate_bases) + ref_base = candidate.reference_bases + rescue_dict[pos] = "%s\t%d\t.\t%s\t%s\t%d\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( + self.contig_name, pos, ref_base, variant, 10, 'PASS', '.', genotype_string, 10, 10, 0.5) + continue + if sum(candidate_genotypes) == 0: + continue + if not len(candidate.alternate_bases): + continue + g1, g2 = candidate_genotypes + variant = set() + ref_base = candidate.reference_bases + if g1 != 0: + variant.add(candidate.alternate_bases[g1 - 1]) + if g2 != 0: + variant.add(candidate.alternate_bases[g2 - 1]) + if g1 == 0 or g2 == 0: + genotype_string = '0/1' + elif g1 == g2: + genotype_string = '1/1' + elif g1 != g2: + genotype_string = '1/2' + ref_base, variant = remove_common_suffix(ref_base, list(variant)) + variant = ','.join(variant) + if candidate.start in all_pos: + continue + all_pos.add(pos) + + if output_vcf_fn is not None: + # For efficiency, we only compute reference base, altnertive base and genotype for GetTruth.py currently + print("%s\t%d\t.\t%s\t%s\t%d\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( + self.contig_name, candidate.start, ref_base, variant, 10, 'PASS', '.', genotype_string, 10, 10, 0.5), file=output_vcf_fn) + if pos in rescue_dict: + del rescue_dict[pos] + for idx, (pos, raw_genotype, truth_genotype) in enumerate( + zip(match_pairs.truths_pos_list, match_pairs.raw_genotypes, + match_pairs.truth_genotypes)): + if not (pos >= split_start and pos < split_end) or (ctg_start is not None and ctg_end is not None + and not (pos >= ctg_start and pos < ctg_end)): + continue + + if truth_genotype == (0, 0) and sum(raw_genotype) > 0:# miss genoytpe + if pos in alt_dict and pos in variant_dict: + ref_base = variant_dict[pos].reference_bases + alt_base = variant_dict[pos].alternate_bases + alt_list = alt_dict[pos].alt_list + if not match_alt_base(alt_list, ref_base, alt_base): + print('{} {} miss and has no cigar support'.format(self.sample_ctg_info, pos)) + continue + print('{} {} miss by match, append to vcf'.format(self.sample_ctg_info, pos)) + if pos in all_pos: + continue + all_pos.add(pos) + + ref_base = variant_dict[pos].reference_bases + variant = ','.join(variant_dict[pos].alternate_bases) + genotype_string = '/'.join(map(str, variant_dict[pos].genotype)) + + if output_vcf_fn is not None: + rescue_dict[pos] = "%s\t%d\t.\t%s\t%s\t%d\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( + self.contig_name, pos, ref_base, variant, 10, 'PASS', '.', genotype_string, 10, 10, 0.5) + + +def parse_candidates_file(candidate_details_fn, contig_name=None): + candidate_details_list = [] + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % candidate_details_fn)) + for row in unzip_process.stdout: + candidate_details_list.append(row) + return candidate_details_list + + +def parse_line(row): + row = row.strip().split('\t') # ['chr_pos', 'depth', 'cigar_count'] + chr_pos, depth, var_read_json = row[:3] + ctg_name, pos = chr_pos.split() + pos, depth = int(pos), int(depth) + return (ctg_name, pos, depth, var_read_json) + + +def UnifyRepresentation(args): + + """ + Representation Unification algorithm main function, this algorithm aims to unify variant representation + between training material and true variant set. + All candidate sites with sufficient read support and over a certain allele frequency were selected, of which the + same variant information with true set was locked as confident candidate sites. Secondly, for each remaining + candidate site and true variant site, a matching or missing flag was assigned. We build haplotype pairs based on + all possible flag combinations. For each fully matching pair, we use the match pair with the most support reads as + the final best match pair. For the remaining unmatched sites, we decrease allele frequency to further seek remaining + candidate sites. We rescue the candidate sites matching true variant type. In the end, the unified VCF consists of + locked variant sites, unified variant sites, and rescued variant sites. + """ + sample_name = args.sampleName + var_fn = args.var_fn # true vcf var + candidate_details_fn = args.candidate_details_fn + contig_name = args.ctgName + ctg_start = args.ctgStart + ctg_end = args.ctgEnd + bed_fn = args.bed_fn + is_confident_bed_file_given = bed_fn is not None + partition_size = args.partition_size + minimum_allele_gap = args.minimum_allele_gap + max_candidates_distance = args.max_candidates_distance + global max_calculate_count + max_calculate_count = args.max_calculate_count + subsample_ratio = args.subsample_ratio + platform = args.platform + chunk_id = args.chunk_id + chunk_num = args.chunk_num + + global test_pos + test_pos = None + + alt_dict = defaultdict() + read_name_info_dict = defaultdict(Read) + + if candidate_details_fn is None: + basedir = os.path.dirname(__file__) + CTFA_Bin = basedir + "/../clair3.py CreateTensorFullAlignment" + pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True) + bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) + ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) + vcf_fn = file_path_from(args.vcf_fn) + extend_bed = file_path_from(args.extend_bed) + min_af = args.min_af + ctgStart, ctgEnd = None, None + if ctg_start is not None and ctg_end is not None and int(ctg_start) <= int(ctg_end): + ctgStart = CommandOption('ctgStart', ctg_start) + ctgEnd = CommandOption('ctgEnd', ctg_end) + chunkId, chunkNum = None, None + if chunk_id is not None and chunk_num is not None and int(chunk_id) <= int(chunk_num): + chunkId = CommandOption('chunk_id', chunk_id) + chunkNum = CommandOption('chunk_num', chunk_num) + + create_tensor_command_options = [ + pypyBin, + CTFA_Bin, + CommandOption('bam_fn', bam_fn), + CommandOption('ref_fn', ref_fn), + CommandOption('vcf_fn', vcf_fn), + CommandOption('ctgName', contig_name), + CommandOption('platform', platform), + CommandOption('bed_fn', bed_fn), + CommandOption('extend_bed', extend_bed), + ctgStart, + ctgEnd, + chunkId, + chunkNum, + CommandOptionWithNoValue('unify_repre'), + CommandOptionWithNoValue('phasing_info_in_bam'), + CommandOption('unify_repre_fn', 'PIPE') + ] + if min_af is not None: + create_tensor_command_options.append(CommandOption('min_af', min_af)) + else: + candidate_details_list = [] + if os.path.exists(candidate_details_fn): + candidate_details_list = parse_candidates_file(candidate_details_fn, contig_name) + else: + directory, prefix = os.path.split(candidate_details_fn) + for f in os.listdir(directory): + if not f.startswith(prefix): + continue + candidate_details_list += parse_candidates_file(os.path.join(directory, f), contig_name) + + candidate_details_list = sorted(candidate_details_list, key=lambda x: x[0]) + + if not len(candidate_details_list): + return + + if ctg_start is None or ctg_end is None: + alt_start = parse_line(candidate_details_list[0])[1] + alt_end = parse_line(candidate_details_list[-1])[1] + chunk_id = args.chunk_id - 1 # 1-base to 0-base + chunk_num = args.chunk_num + chunk_size = (alt_end - alt_start) // chunk_num + 1 + ctg_start = alt_start + chunk_size * chunk_id + ctg_end = ctg_start + chunk_size + + is_ctg_name_given = contig_name is not None + is_ctg_range_given = is_ctg_name_given and ctg_start is not None and ctg_end is not None + ref_regions = [] + reference_start = 1 + if is_ctg_range_given: + reference_start, reference_end = ctg_start - reference_region_size, ctg_end + reference_region_size + reference_start = 1 if reference_start < 1 else reference_start + ref_regions.append(region_from(ctg_name=contig_name, ctg_start=reference_start, ctg_end=reference_end)) + elif is_ctg_name_given: + ref_regions.append(region_from(ctg_name=contig_name)) + reference_start = 1 + + reference_sequence = reference_sequence_from( + samtools_execute_command='samtools', + fasta_file_path=args.ref_fn, + regions=ref_regions + ) + + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn))) + variant_dict = defaultdict(list) + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns = row.strip().split() + ctg_name = columns[0] + if contig_name and contig_name != ctg_name: + continue + pos = int(columns[1]) + if ctg_start is not None and ctg_end is not None and \ + (pos < ctg_start - extended_window_size or pos > ctg_end + extended_window_size): + continue + ref_base = columns[2] + alt_base = columns[3] + genotype1 = int(columns[4]) + genotype2 = int(columns[5]) + variant_dict[pos] = Position(pos=pos, + ref_base=ref_base, + alt_base=alt_base, + genotype1=genotype1, + genotype2=genotype2) + + if candidate_details_fn is None: + try: + c.create_tensor = subprocess_popen( + shlex.split(command_string_from(create_tensor_command_options)) + ) + candidate_source = c.create_tensor.stdout + + signal.signal(signal.SIGALRM, check_return_code) + signal.alarm(2) + except Exception as e: + print(e, file=sys.stderr) + sys.exit("Failed to start required processes. Exiting...") + else: + candidate_source = candidate_details_list + + for row in candidate_source: + ctg_name, pos, depth, var_read_json = parse_line(row) + if contig_name != ctg_name: + continue + if ctg_start is not None and ctg_end is not None and \ + (pos < ctg_start - extended_window_size or pos > ctg_end + extended_window_size): + continue + + var_read_dict = json.loads(var_read_json) + if not len(var_read_dict): + continue + + cigar_count = ' '.join([' '.join([item, str(len(var_read_dict[item].split(' ')))]) for item in var_read_dict.keys()]) + ref_base = reference_sequence[pos - reference_start] + pos_in_truths = pos in variant_dict + ref_base, alt_base, af_list,alt_list = decode_alt_info(cigar_count=cigar_count, + ref_base=ref_base, + depth=depth, + minimum_allele_gap=minimum_allele_gap) + + alt_dict[pos] = Position(pos=pos, + ref_base=ref_base, + alt_base=alt_base, + genotype1=-1, + genotype2=-1, + candidate=True, + depth=depth, + alt_list=alt_list) + + for variant, read_str in var_read_dict.items(): + read_list = read_str.split(' ') + for read_name in read_list: + read_name, hap = read_name[:-2], read_name[-1] + if read_name not in read_name_info_dict or read_name_info_dict[read_name].hap == 0 and hap != 0: + read_name_info_dict[read_name].hap = int(hap) + + read_hap = read_name_info_dict[read_name].hap if read_name in read_name_info_dict else 0 + if read_hap in alt_dict[pos].variant_hap_dict[variant]: + alt_dict[pos].variant_hap_dict[variant][read_hap] += 1 + else: + alt_dict[pos].variant_hap_dict[variant][read_hap] = 1 + alt_dict[pos].hap_count_dict[read_hap] += 1 + alt_dict[pos].read_name_set.add(read_name) + read_name_info_dict[read_name].pos_alt_dict[pos] = variant + + match_index, is_variant_confident = lock_variant(alt_dict[pos], variant_dict[pos] if pos_in_truths else None) + if is_variant_confident: + variant_dict[pos].confident_variant = match_index + alt_dict[pos].phased_genotype = update_variant_hap_dict(alt_dict=alt_dict, + pos=pos, + reference_sequence=reference_sequence, + reference_start=reference_start, + is_variant_confident=is_variant_confident, + variant_dict=variant_dict, + allele_gap=minimum_allele_gap, + platform=platform) + # lock the candidate if it has meet the phased_genotype requirement and have a exactly one match true variant site + if alt_dict[pos].phased_genotype and pos_in_truths and is_variant_confident: + if alt_dict[pos].phased_genotype.count(0) != variant_dict[pos].genotype.count(0) or (sum(variant_dict[pos].genotype) == 3 and sum(alt_dict[pos].phased_genotype) != 3): + # skip wrong genotype + alt_dict[pos].phased_genotype = None + variant_dict[pos].reference_bases = alt_dict[pos].reference_bases + variant_dict[pos].alternate_bases = alt_dict[pos].alternate_bases + variant_dict[pos].phased_genotype = alt_dict[pos].phased_genotype + + + if is_confident_bed_file_given: + tree = bed_tree_from(bed_fn, contig_name=contig_name) + for read_name, read in read_name_info_dict.items(): + if not len(read_name_info_dict[read_name].pos_alt_dict): + continue + for pos, alt_base in read_name_info_dict[read_name].pos_alt_dict.items(): + read.start = min(read.start, pos) if read.start is not None else pos + if alt_base[0] == 'X': + read.seq.append((pos, pos+1, alt_base[1])) + elif alt_base[0] == 'I': + read.seq.append((pos, pos+1, alt_base[1:])) + elif alt_base[0] == 'D': + del_length = len(alt_base[1:]) + read.seq.append((pos, pos+del_length+1, reference_sequence[pos - reference_start])) + else: #"R" + continue + read.end = max([item[1] for item in read.seq]) if len(read.seq) else None + + if candidate_details_fn is None: + c.create_tensor.stdout.close() + c.create_tensor.wait() + signal.alarm(0) + + if not len(alt_dict) or not len(variant_dict): + return + + region_start = ctg_start if ctg_start else min(alt_dict.keys()) + region_end = ctg_end if ctg_end else max(alt_dict.keys()) + rescue_dict = defaultdict() + all_pos = set() + + output_vcf_fn = None + if args.output_vcf_fn: + output_vcf_fn = open(args.output_vcf_fn, "w") + + def output(string_value): + print(string_value, file=output_vcf_fn) + + from textwrap import dedent + + output(dedent("""\ + ##fileformat=VCFv4.2 + ##FILTER= + ##FILTER= + ##FILTER= + ##ALT= + ##ALT= + ##INFO= + ##INFO= + ##INFO= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT=""" + )) + output('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (sample_name)) + + # For pacbio hifi platform, we select larger candidate distance for better performance + if platform == 'hifi': + max_candidates_distance = 200 + partition_size = 20 + + for split_idx in range((region_end - region_start) // region_size + 1): + split_start = region_start + split_idx * region_size + split_end = split_start + region_size + extend_split_start = split_start - extend_bp + extend_split_end = split_end + extend_bp + + #bed region include last pos + variants = sorted([(item, alt_dict[item]) for item in alt_dict.keys() if + item >= extend_split_start and item < extend_split_end and len(alt_dict[item].alternate_bases) and is_region_in( + tree=tree, + contig_name=contig_name, + region_start=item-2, + region_end=alt_dict[item].end + 2)], key=lambda x: x[0]) + variants = [item[1] for item in variants] + + truths = sorted([(item, variant_dict[item]) for item in variant_dict.keys() if + item >= extend_split_start and item < extend_split_end and is_region_in( + tree=tree, + contig_name=contig_name, + region_start=item-2, + region_end=variant_dict[item].end + 2)], key=lambda x: x[0]) + truths = [item[1] for item in truths] + + if not len(variants) and not len(truths): + continue + RU = RepresentationUnification( + sample_name=sample_name, + contig_name=contig_name, + reference_sequence=reference_sequence, + reference_start=reference_start, + partition_size=partition_size, + max_candidates_distance=max_candidates_distance, + max_calculate_count=max_calculate_count, + subsample_ratio=subsample_ratio) + + RU.unify_label(variants=variants, + truths=truths, + region=(split_start, split_end), + ctg_start=ctg_start, + ctg_end=ctg_end, + all_pos=all_pos, + variant_dict=variant_dict, + rescue_dict=rescue_dict, + output_vcf_fn=output_vcf_fn, + test_pos=test_pos, + read_name_info_dict=read_name_info_dict, + alt_dict=alt_dict) + + if not len(rescue_dict): + return + if output_vcf_fn is not None: + for pos, vcf_info in rescue_dict.items(): + print(vcf_info, file=output_vcf_fn) + output_vcf_fn.close() + + if os.path.exists(args.output_vcf_fn): + for row in open(args.output_vcf_fn, 'r'): + if row[0] != '#': + return + os.remove(args.output_vcf_fn) + print("[INFO] No vcf output for file {}, remove empty file".format(args.output_vcf_fn)) + +def main(): + parser = ArgumentParser(description="Representation unification for candidate site and true variant") + + parser.add_argument('--platform', type=str, default="ont", + help="Sequencing platform of the input. Options: 'ont,hifi,ilmn', default: %(default)s") + + parser.add_argument('--var_fn', type=str, default=None, + help="Truth variants list input from GetTruth.py") + + parser.add_argument('--ref_fn', type=str, default="ref.fa", + help="Reference fasta file input, default: %(default)s") + + parser.add_argument('--candidate_details_fn', type=str, default=None, + help="Read-level candidate details file, default: %(default)s") + + parser.add_argument('--output_vcf_fn', type=str, default=None, + help="VCF output filename or stdout if not set,default: %(default)s") + + parser.add_argument('--sampleName', type=str, default="SAMPLE", + help="Define the sample name to be shown in the VCF file, optional") + + parser.add_argument('--ctgName', type=str, default=None, + help="The name of the sequence to be processed") + + parser.add_argument('--ctgStart', type=int, default=None, + help="The 1-based starting position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + parser.add_argument('--ctgEnd', type=int, default=None, + help="The 1-based inclusive ending position of the sequence to be processed, optional, will process the whole --ctgName if not set") + + # options for advanced users + parser.add_argument('--max_candidates_distance', type=int, default=100, + help="EXPERIMENTAL: Maximum distance between subsequent variants within a group") + + parser.add_argument('--max_calculate_count', type=int, default=10000, + help="EXPERIMENTAL: Maximum calculation times for chunk window ") + + parser.add_argument('--partition_size', type=int, default=15, + help="EXPERIMENTAL: Maximum variants in per group size") + + parser.add_argument('--minimum_allele_gap', type=int, default=0.15, + help="EXPERIMENTAL: Minimum allele gap filtering candidate path generation") + + parser.add_argument('--bed_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, will choose candidate +/- 1 or +/- 2. Use together with gen4Training. default: %(default)s") + + parser.add_argument('--vcf_fn', type=str, default=None, + help="Candidate sites VCF file input, if provided, will choose candidate +/- 1 or +/- 2. Use together with gen4Training. default: %(default)s") + + parser.add_argument('--extend_bed', type=str, default=None, + help=SUPPRESS) + + # options for internal process control + ## Subsample ratio tag for sub-sampled BAM file + parser.add_argument('--subsample_ratio', type=int, default=1000, + help=SUPPRESS) + + ## Test in specific candidate site. Only use for analysis + parser.add_argument('--test_pos', type=int, default=0, + help=SUPPRESS) + + ## The number of chucks to be divided into for parallel processing + parser.add_argument('--chunk_num', type=int, default=None, + help=SUPPRESS) + + ## The chuck ID to work on + parser.add_argument('--chunk_id', type=int, default=None, + help=SUPPRESS) + + parser.add_argument('--min_af', type=float, default=None, + help=SUPPRESS) + + parser.add_argument('--bam_fn', type=str, default=None, + help=SUPPRESS) + + parser.add_argument('--pypy', type=str, default="pypy3", + help=SUPPRESS) + + args = parser.parse_args() + UnifyRepresentation(args) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/nn-variant/shared/__init__.py b/benchmarks/nn-variant/Clair3/preprocess/__init__.py similarity index 100% rename from benchmarks/nn-variant/shared/__init__.py rename to benchmarks/nn-variant/Clair3/preprocess/__init__.py diff --git a/benchmarks/nn-variant/Clair3/preprocess/__pycache__/CheckEnvs.cpython-39.pyc b/benchmarks/nn-variant/Clair3/preprocess/__pycache__/CheckEnvs.cpython-39.pyc new file mode 100644 index 0000000..bf7fa92 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/preprocess/__pycache__/CheckEnvs.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/preprocess/__pycache__/__init__.cpython-39.pyc b/benchmarks/nn-variant/Clair3/preprocess/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..1df8a3f Binary files /dev/null and b/benchmarks/nn-variant/Clair3/preprocess/__pycache__/__init__.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/__init__.py b/benchmarks/nn-variant/Clair3/preprocess/realign/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.cpp b/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.cpp new file mode 100644 index 0000000..8708bff --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.cpp @@ -0,0 +1,438 @@ +/* +Copyright 2020 Google LLC. +Copyright 2021 The University of Hong Kong, Department of Computer Science + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include "debruijn_graph.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +using Vertex = DeBruijnGraph::Vertex; +using VertexIndexMap = DeBruijnGraph::VertexIndexMap; +using Edge = DeBruijnGraph::Edge; +using Path = DeBruijnGraph::Path; + + +class CycleDetector : public boost::dfs_visitor<> { +public: + explicit CycleDetector(bool* has_cycle) : has_cycle(has_cycle) {} + +template +void back_edge(Edge, const Graph&) { + *has_cycle = true; +} + +private: + bool* has_cycle; +}; + +template +class EdgeLabelWriter { +public: + explicit EdgeLabelWriter(const BoostGraph& g) : g_(g) {} + +void operator()(ostream& out, const Edge e) const { + EdgeInfo ei = g_[e]; + out << "[label=" << ei.weight << (ei.is_ref ? " color=red" : "") << "]"; +} + +private: + const BoostGraph& g_; +}; + +class ReachableVertexVisitor : public boost::dfs_visitor<> { +public: + explicit ReachableVertexVisitor(set* reachable_vertices) + : reachable_vertices(reachable_vertices) {} + +template +void tree_edge(Edge e, const Graph& g) { + Vertex from = boost::source(e, g); + if (reachable_vertices->find(from) != reachable_vertices->end()) { + Vertex to = boost::target(e, g); + reachable_vertices->insert(to); + } +} + +private: + set* reachable_vertices; +}; + +template +set VerticesReachableFrom( + Vertex v, const BoostGraphT& g, const VertexIndexMapT& vertex_index_map) { + set reachable_vertices{v}; + ReachableVertexVisitor vis(&reachable_vertices); + boost::depth_first_search( + g, boost::visitor(vis).root_vertex(v).vertex_index_map(vertex_index_map)); + return reachable_vertices; +} + + +Vertex DeBruijnGraph::EnsureVertex(string kmer) { + Vertex v; + auto vertex_find = kmer_to_vertex_.find(kmer); + if (vertex_find != kmer_to_vertex_.end()) { + v = (*vertex_find).second; + } else { + string kmer_copy(kmer); + v = boost::add_vertex(VertexInfo{kmer_copy}, g_); + kmer_to_vertex_[string(g_[v].kmer)] = v; + } + + return v; +} + +Vertex DeBruijnGraph::VertexForKmer(string kmer) const { + return kmer_to_vertex_.at(kmer); +} + +void DeBruijnGraph::RebuildIndexMap() { + map table; + VertexIterator vi, vend; + tie(vi, vend) = boost::vertices(g_); + int index = 0; + for (; vi != vend; ++vi) { + table[*vi] = index; + ++index; + } + vertex_index_map_ = table; +} + +VertexIndexMap DeBruijnGraph::IndexMap() const { + boost::const_associative_property_map vmap( + vertex_index_map_); + return vmap; +} + +bool DeBruijnGraph::HasCycle() const { + bool has_cycle = false; + CycleDetector cycle_detector(&has_cycle); + boost::depth_first_search( + g_, boost::visitor(cycle_detector).vertex_index_map(IndexMap())); + return has_cycle; +} + +DeBruijnGraph::DeBruijnGraph( + const string& ref, + const vector& reads, + vector >& base_quality, + int k) + : k_(k) { + + AddEdgesForReference(ref); + source_ = VertexForKmer(ref.substr(0, k_)); + sink_ = VertexForKmer(ref.substr(ref.size() - k_, k_)); + for (int i = 0; i < reads.size(); i++) { + AddEdgesForRead(reads[i], base_quality[i]); + // } + } + RebuildIndexMap(); +} + + +constexpr int kBoundsNoWorkingK = -1; +struct KBounds { + int min_k; + int max_k; +}; + +KBounds KMinMaxFromReference(const string ref) { +KBounds bounds; +bounds.min_k = kBoundsNoWorkingK; +bounds.max_k = min(101, static_cast(ref.size()) - 1); + +for (int k = 10; k <= bounds.max_k; k++) { + bool has_cycle = false; + set kmers; + + for (int i = 0; i < ref.size() - k + 1; i++) { + string kmer = ref.substr(i, k); + if (kmers.insert(kmer).second == false) { + has_cycle = true; + break; + } + } + + if (!has_cycle) { + bounds.min_k = k; + break; + } + } + + return bounds; +} + + +vector DeBruijnGraph::Build( + const string& ref, + const vector& reads, + vector >& base_quality) { +vector haplotypes; +KBounds bounds = KMinMaxFromReference(ref); + if (bounds.min_k == kBoundsNoWorkingK) return haplotypes; + +for (int k = bounds.min_k; k <= bounds.max_k; k++) { + shared_ptr graph = shared_ptr( + new DeBruijnGraph(ref, reads, base_quality, k)); + if (graph->HasCycle()) { + continue; + } else { + graph->Prune(); + + for (const Path& path : graph->CandidatePaths()) { + haplotypes.push_back(graph->HaplotypeForPath(path)); + } + sort(haplotypes.begin(), haplotypes.end()); + return haplotypes; + } +} + return haplotypes; +} + +Edge DeBruijnGraph::AddEdge(Vertex from_vertex, Vertex to_vertex, bool is_ref) { + bool was_present; + Edge edge; + tie(edge, was_present) = boost::edge(from_vertex, to_vertex, g_); + if (!was_present) { + tie(edge, ignore) = boost::add_edge(from_vertex, to_vertex, + EdgeInfo{0, false}, g_); + } + EdgeInfo& ei = g_[edge]; + ei.weight++; + ei.is_ref |= is_ref; + return edge; +} + +void DeBruijnGraph::AddKmersAndEdges(string bases, int start, int end, + bool is_ref) { + if (end > 0) { + Vertex vertex_prev = EnsureVertex(bases.substr(start, k_)); + for (int i = start + 1; i <= end; ++i) { + Vertex vertex_cur = EnsureVertex(bases.substr(i, k_)); + AddEdge(vertex_prev, vertex_cur, is_ref); + vertex_prev = vertex_cur; + } + } +} + +void DeBruijnGraph::AddEdgesForReference(string ref) { + AddKmersAndEdges(ref, 0, ref.size() - k_, true); +} + + +void DeBruijnGraph::AddEdgesForRead(const string& read, set& base_quality_set) { + const string bases = read; + + auto NextBadPosition = [&read, &bases, &base_quality_set, this](int start) -> int { + string ACGT = "ACGT"; + + for (int i = start; i < bases.size(); ++i) { + if (ACGT.find(bases[i]) == string::npos || base_quality_set.find(i) != base_quality_set.end()) { + return i; + } + } + return bases.size(); +}; + +const string bases_view(bases); + const int stop = bases.size() - k_; + int i = 0; + while (i < stop) { + int next_bad_position = NextBadPosition(i); + AddKmersAndEdges(bases_view, i, next_bad_position - k_, false /* is_ref */); + i = next_bad_position + 1; + } +} + +vector DeBruijnGraph::CandidatePaths() const { +vector terminated_paths; +queue extendable_paths; + +extendable_paths.push({source_}); + +while (!extendable_paths.empty()) { + int n_total_paths = terminated_paths.size() + extendable_paths.size(); + if (n_total_paths > 256) { + return {}; + } + + Path path = extendable_paths.front(); + extendable_paths.pop(); + Vertex last_v = path.back(); + AdjacencyIterator vi, vend; + tie(vi, vend) = boost::adjacent_vertices(last_v, g_); + for (; vi != vend; ++vi) { + Path extended_path(path); + extended_path.push_back(*vi); + if (*vi == sink_ || boost::out_degree(*vi, g_) == 0) { + terminated_paths.push_back(extended_path); + } else { + extendable_paths.push(extended_path); + } + } + } + return terminated_paths; +} + +string DeBruijnGraph::HaplotypeForPath(const Path& path) const { + stringstream haplotype; + for (Vertex v : path) { + haplotype << g_[v].kmer[0]; + } + if (!path.empty()) { + haplotype << g_[path.back()].kmer.substr(1, k_ - 1); + } + return haplotype.str(); +} + +vector DeBruijnGraph::CandidateHaplotypes() const { + vector haplotypes; + for (const Path& path : CandidatePaths()) { + haplotypes.push_back(HaplotypeForPath(path)); + } + sort(haplotypes.begin(), haplotypes.end()); + return haplotypes; +} + +string DeBruijnGraph::GraphViz() const { + stringstream graphviz; + auto vertex_label_writer = boost::make_label_writer( + boost::get(&VertexInfo::kmer, g_)); + boost::write_graphviz( + graphviz, + g_, + vertex_label_writer, + EdgeLabelWriter(g_), + boost::default_writer(), + IndexMap()); + return graphviz.str(); +} + +void DeBruijnGraph::Prune() { + boost::remove_edge_if( + [this](const Edge& e) { + return !g_[e].is_ref && g_[e].weight < 2; + }, + g_); + + // Remove vertices not reachable forward from src or backward from sink. + VertexIterator vbegin, vend; + tie(vbegin, vend) = boost::vertices(g_); + set all_vertices(vbegin, vend); + + set fwd_reachable_vertices, rev_reachable_vertices; + fwd_reachable_vertices = VerticesReachableFrom( + source_, g_, IndexMap()); + rev_reachable_vertices = VerticesReachableFrom( + sink_, boost::make_reverse_graph(g_), IndexMap()); + + set reachable_vertices; + set_intersection( + fwd_reachable_vertices.begin(), fwd_reachable_vertices.end(), + rev_reachable_vertices.begin(), rev_reachable_vertices.end(), + inserter(reachable_vertices, reachable_vertices.end())); + for (Vertex v : all_vertices) { + if (reachable_vertices.find(v) == reachable_vertices.end()) { + kmer_to_vertex_.erase(g_[v].kmer); + boost::clear_vertex(v, g_); + boost::remove_vertex(v, g_); + } + } + RebuildIndexMap(); +} + + +extern "C" { + struct_str_arr* get_consensus(char* reference, char* c_reads, char* c_base_quality, int read_size) { + const string ref = reference; + vector reads; + string r = c_reads; + string r2 = c_base_quality; + string item; + boost::split(reads, c_reads, boost::is_any_of(",")); + vector > base_quality_set; + vector base_quality_array; //( c_base_quality, c_base_quality + read_size); + + + string delimiter = " "; + boost::split(base_quality_array, c_base_quality, boost::is_any_of(",")); + + for (int i = 0; i < base_quality_array.size(); i++) { + string s = base_quality_array[i]; + + size_t pos = 0; + set bq_pos; + string token; + stringstream ss(s); + int temp; + while (ss >> temp) { + bq_pos.insert(temp); + } + + base_quality_set.push_back(bq_pos); + + } + vector output = DeBruijnGraph::Build(ref, reads, base_quality_set); + struct_str_arr* str_arr_ptr = new struct_str_arr(); + + str_arr_ptr->consensus_size = output.size(); + for (int i=0; i < output.size(); i++) { + str_arr_ptr->consensus[i] = new char[output[i].size() + 1]; + strcpy(str_arr_ptr->consensus[i], output[i].c_str()); + } + return str_arr_ptr; + +} +} + +extern "C" { + void free_memory(struct_str_arr* pointer, int size) { + for (int i=0; i< size; i++) { + delete [] pointer->consensus[i]; + } + delete pointer; + pointer = NULL; + } +} diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.h b/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.h new file mode 100644 index 0000000..40675f6 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/debruijn_graph.h @@ -0,0 +1,138 @@ +/* +Copyright 2020 Google LLC. +Copyright 2021 The University of Hong Kong, Department of Computer Science + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#include +#include +#include +#include +#include +#include +#include + +using namespace std; + +struct struct_str_arr +{ + int consensus_size; + char* consensus[500]; +}; + + +struct VertexInfo { + string kmer; +}; + + +struct EdgeInfo { + int weight; + bool is_ref; +}; + + +class DeBruijnGraph { + public: + using BoostGraph = boost::adjacency_list< + boost::setS, + boost::listS, + boost::bidirectionalS, + VertexInfo, + EdgeInfo>; + + using VertexIterator = boost::graph_traits::vertex_iterator; + using EdgeIterator = boost::graph_traits::edge_iterator; + using AdjacencyIterator = boost::graph_traits::adjacency_iterator; + + public: + using Vertex = boost::graph_traits::vertex_descriptor; + using Edge = boost::graph_traits::edge_descriptor; + using Path = vector; + + using RawVertexIndexMap = map; + using VertexIndexMap = + boost::const_associative_property_map; + + + public: + + void RebuildIndexMap(); + + VertexIndexMap IndexMap() const; + + Vertex EnsureVertex(string kmer); + + Vertex VertexForKmer(string kmer) const; + + bool HasCycle() const; + + DeBruijnGraph( + const string& ref, + const vector& reads, + vector >& base_quality, + int k); + + Edge AddEdge(Vertex from_vertex, Vertex to_vertex, bool is_ref); + + + void AddKmersAndEdges(string bases, int start, int end, + bool is_ref); + + void AddEdgesForReference(string ref); + + void AddEdgesForRead(const string& read, set& base_quality_set); + + vector CandidatePaths() const; + + string HaplotypeForPath(const Path& path) const; + + void Prune(); + + public: + + static vector Build( + const string& ref, + + const vector& reads, + vector >& base_quality); + + vector CandidateHaplotypes() const; + + string GraphViz() const; + + int KmerSize() const { return k_; } + + public: + BoostGraph g_; + int k_; + Vertex source_; + Vertex sink_; + + unordered_map kmer_to_vertex_; + RawVertexIndexMap vertex_index_map_; +}; + diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.cpp b/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.cpp new file mode 100644 index 0000000..7bf5893 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.cpp @@ -0,0 +1,871 @@ +/* +Copyright 2020 Google LLC. +Copyright 2021 The University of Hong Kong, Department of Computer Science + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +#include "realigner.h" +#include +#include +#include +#include +#include +#include +#include "ssw_cpp.h" +#include + +#include +#include +#include +#include +#include +using namespace std; + +void ReAligner::set_reference(const string& reference) { + this->reference_ = reference; +} + +void ReAligner::set_reads(const vector& reads) { + this->reads_ = reads; +} + +void ReAligner::set_ref_start(uint64_t position) { + this->region_position_in_chr_ = position; +} + +void ReAligner::set_haplotypes(const vector& haplotypes) { + this->haplotypes_ = haplotypes; +} + +void ReAligner::set_options() { + + this->kmer_size_ = 32; + this->read_size_ = 250; + this->max_num_of_mismatches_ = 2; + this->similarity_threshold_ = 0.16934; + this->match_score_ = 4; + this->mismatch_penalty_ = 6; + this->gap_opening_penalty_ = 8; + this->gap_extending_penalty_ = 2; +} +// +void ReAligner::CalculateSswAlignmentScoreThreshold() { + ssw_alignment_score_threshold_ = match_score_ + * read_size_ + * similarity_threshold_ + - mismatch_penalty_ + * read_size_ + * (1 - similarity_threshold_); + if (ssw_alignment_score_threshold_ < 0) { + ssw_alignment_score_threshold_ = 1; + } +} + +vector ReAligner::AlignReads( + const vector& reads_param) { + + for (const auto& read : reads_param) { + reads_.push_back(read.seq); + } + CalculateSswAlignmentScoreThreshold(); + + BuildIndex(); + + FastAlignReadsToHaplotypes(); + + InitSswLib(); + + AlignHaplotypesToReference(); + + CalculatePositionMaps(); + + SswAlignReadsToHaplotypes(ssw_alignment_score_threshold_); + + sort(read_to_haplotype_alignments_.begin(), + read_to_haplotype_alignments_.end()); + + vector realigned_reads; + +// vector* realigned_reads(new vector()); + RealignReadsToReference(reads_param, realigned_reads); + + return realigned_reads; +} + +void ReAligner::InitSswLib() { + StripedSmithWaterman::Aligner ssw_aligner_(match_score_, + mismatch_penalty_, + gap_opening_penalty_, + gap_extending_penalty_); + StripedSmithWaterman::Filter filter; +} + +void ReAligner::SswSetReference(const string& reference) { + + ssw_aligner_.SetReferenceSequence(reference.c_str(), reference.length()); +} + +StripedSmithWaterman::Alignment ReAligner::SswAlign(const string& target) const { + + StripedSmithWaterman::Filter filter; + StripedSmithWaterman::Alignment alignment; + + int32_t maskLen = strlen(target.c_str())/2; + maskLen = maskLen < 15 ? 15 : maskLen; + if (ssw_aligner_.Align(target.c_str(), filter, &alignment)) { + return alignment; + } else { + return StripedSmithWaterman::Alignment(); + } +} +// + +void ReAligner::FastAlignReadsToHaplotypes() { + vector read_alignment_scores(reads_.size()); + for (int i = 0; i < haplotypes_.size(); i++) { + const auto& haplotype = haplotypes_[i]; + int haplotype_score = 0; + for (auto& readAlignment : read_alignment_scores) { + readAlignment.reset(); + } + FastAlignReadsToHaplotype(haplotype, + &haplotype_score, + &read_alignment_scores); + + if (haplotype_score == 0) { + for (auto& readAlignment : read_alignment_scores) { + readAlignment.reset(); + } + } + + read_to_haplotype_alignments_.push_back( + HaplotypeReadsAlignment(i, haplotype_score, read_alignment_scores)); + } +} + +void ReAligner::FastAlignReadsToHaplotype( + const string& haplotype, int* haplotype_score, + vector* haplotype_read_alignment_scores) { + + string bases_view = haplotype; + + bool is_ref = (haplotype == reference_); + vector coverage(haplotype.size(), 0); + const auto& lastPos = haplotype.length() - kmer_size_; + for (int i = 0; i <= lastPos; i++) { + auto index_it = kmer_index_.find(bases_view.substr(i, kmer_size_)); + if (index_it == kmer_index_.end()) { + continue; + } + for (const auto& it : index_it->second) { + uint64_t read_id_index = static_cast(it.read_id); + int target_start_pos = max( + static_cast(0), + static_cast(i) - static_cast(it.read_pos.pos)); + int cur_read_size = reads_[read_id_index].size(); + int span = cur_read_size; + if (target_start_pos + cur_read_size > haplotype.length()) { + continue; + } + auto& read_alignment = + (*haplotype_read_alignment_scores)[read_id_index]; + + if (read_alignment.position != ReadAlignment::kNotAligned && + read_alignment.position == target_start_pos) { + continue; + } + int num_of_mismatches = 0; + int new_read_alignment_score = FastAlignStrings( + bases_view.substr(target_start_pos, span), + reads_[read_id_index], + max_num_of_mismatches_ + 1, &num_of_mismatches); + + if (num_of_mismatches <= max_num_of_mismatches_) { + int oldScore = read_alignment.score; + for (auto pos = target_start_pos; pos < target_start_pos + span; + pos++) { + coverage[pos]++; + } + + if (oldScore < new_read_alignment_score) { + read_alignment.score = new_read_alignment_score; + *haplotype_score -= oldScore; + *haplotype_score += read_alignment.score; + read_alignment.position = target_start_pos; + read_alignment.cigar = to_string(cur_read_size) + "="; + } + } + } + + if (coverage[i] == 0 && i >= ref_prefix_len_ && + i < haplotype.size() - ref_suffix_len_ && !is_ref) { + *haplotype_score = 0; + return; + } + } +} + +int ReAligner::FastAlignStrings(string s1, + string s2, + int max_mismatches, + int* num_of_mismatches) const { + int num_of_matches = 0; + *num_of_mismatches = 0; + for (int i = 0; i < s1.size(); i++) { + const auto& c1 = s1[i]; + const auto& c2 = s2[i]; + if (c1 != c2 && (c1 != 'N' && c2 != 'N')) { + if (c1 != c2) { + (*num_of_mismatches)++; + } + if (*num_of_mismatches == max_mismatches) { + return 0; + } + } else { + num_of_matches++; + } + } + return num_of_matches * match_score_ - *num_of_mismatches * mismatch_penalty_; +} + +Operation CigarOperationFromChar(char op) { + switch (op) { + case '=': + case 'X': + return ALIGNMENT_MATCH; + case 'S': + return CLIP_SOFT; + case 'D': + return DELETE; + case 'I': + return INSERT; + default: + return OPERATION_UNSPECIFIED; + } +} + + +list CigarStringToVector(const string& cigar) { + list cigarOps; + string str = cigar; + smatch result; + string regex_str("(\\d+)([XIDS=])"); + regex pattern1(regex_str,regex::icase); + + string::const_iterator iter = str.begin(); + string::const_iterator iterEnd= str.end(); + string temp; + while (regex_search(iter,iterEnd,result,pattern1)) { + temp=result[0]; + int op_len = atoi(temp.substr(0, temp.length()-1).c_str()); + char op_char = temp[temp.length()-1]; + Operation op = CigarOperationFromChar(op_char); + cigarOps.push_back(CigarOp(op, op_len)); + iter = result[0].second; + } + return cigarOps; +} + + +string CigarVectorToString(const list& cigar) { + string cigar_string = ""; + for (auto& op : cigar) { + int len = op.length; + int op_index = static_cast(op.operation); + string op_string = ""; + switch (op_index) { + case 1: + op_string = "X"; + break; + case 2: + op_string = "I"; + break; + case 3: + op_string = "D"; + break; + case 5: + op_string = "S"; + break; + } + cigar_string += to_string(len)+ op_string; + } + return cigar_string; +} + + + +inline bool AlignmentIsRef(const string& cigar, int target_len) { + return cigar == to_string(target_len) + "="; +} + +void ReAligner::AlignHaplotypesToReference() { + SswSetReference(reference_); + + if (read_to_haplotype_alignments_.empty()) { + for (int i = 0; i < haplotypes_.size(); i++) { + read_to_haplotype_alignments_.push_back(HaplotypeReadsAlignment( + i, -1, vector(reads_.size()))); + } + } + + for (auto& haplotype_alignment : read_to_haplotype_alignments_) { + StripedSmithWaterman::Filter filter; + StripedSmithWaterman::Alignment alignment = + SswAlign(haplotypes_[haplotype_alignment.haplotype_index]); + auto hap_len = haplotypes_[haplotype_alignment.haplotype_index].size(); + if (alignment.sw_score > 0) { + haplotype_alignment.is_reference = + AlignmentIsRef(alignment.cigar_string, hap_len); + haplotype_alignment.cigar = alignment.cigar_string; + haplotype_alignment.cigar_ops = + CigarStringToVector(haplotype_alignment.cigar); + haplotype_alignment.ref_pos = alignment.ref_begin; + } + } +} + +void ReAligner::SswAlignReadsToHaplotypes(int score_threshold) { + + for (int i = 0; i < reads_.size(); i++) { + bool has_at_least_one_alignment = false; + for (const auto& hap_alignment : read_to_haplotype_alignments_) { + if (hap_alignment.read_alignment_scores[i].score > 0) { + has_at_least_one_alignment = true; + break; + } + } + if (!has_at_least_one_alignment) { + for (auto& hap_alignment : read_to_haplotype_alignments_) { + + if (hap_alignment.haplotype_score == 0) { + continue; + } + SswSetReference(haplotypes_[hap_alignment.haplotype_index]); + StripedSmithWaterman::Alignment alignment = SswAlign(reads_[i]); + if (alignment.sw_score > 0) { + if (alignment.sw_score >= score_threshold) { + if (hap_alignment.read_alignment_scores[i].score < + alignment.sw_score) { + hap_alignment.read_alignment_scores[i].score = alignment.sw_score; + hap_alignment.read_alignment_scores[i].cigar = + alignment.cigar_string; + hap_alignment.read_alignment_scores[i].position = + alignment.ref_begin; + } + } + } + } + } + } // for all reads +} + +void ReAligner::RealignReadsToReference( + const vector& reads, + vector& realigned_reads) { + + for (int read_index = 0; read_index < reads.size(); read_index++) { + const struct Read& read = reads[read_index]; + struct Read realigned_read = read; + + int best_hap_index = -1; + + if (GetBestReadAlignment(read_index, &best_hap_index)) { + const HaplotypeReadsAlignment& bestHaplotypeAlignments = + read_to_haplotype_alignments_[best_hap_index]; + + int new_position; + auto read_to_hap_pos = bestHaplotypeAlignments + .read_alignment_scores[read_index] + .position; + + new_position = region_position_in_chr_ + + bestHaplotypeAlignments.ref_pos + + read_to_hap_pos + + bestHaplotypeAlignments + .hap_to_ref_positions_map[read_to_hap_pos]; + list readToRefCigarOps; + + CalculateReadToRefAlignment( + read_index, bestHaplotypeAlignments.read_alignment_scores[read_index], + bestHaplotypeAlignments.cigar_ops, &readToRefCigarOps); + + if (readToRefCigarOps.size() > 0) { + realigned_read.cigar.clear(); + realigned_read.cigar_string = CigarVectorToString(readToRefCigarOps); + realigned_read.position = new_position; + } + realigned_reads.push_back(realigned_read); + + } else { + realigned_reads.push_back(realigned_read); + } + } // for +} + +void ReAligner::AddKmerToIndex(string kmer, + ReadId read_id, KmerOffset pos) { + kmer_index_[kmer].push_back(KmerOccurrence(read_id, pos)); +} + +void ReAligner::AddReadToIndex(const string& read, ReadId read_id) { + + if (read.length() <= kmer_size_) { + return; + } + auto last_pos = read.length() - kmer_size_; + string bases_view = read; + for (int i = 0; i <= last_pos; i++) { + AddKmerToIndex(bases_view.substr(i, kmer_size_), read_id, KmerOffset(i)); + } +} + +void ReAligner::BuildIndex() { + int read_id = 0; + for (const auto& read : reads_) { + AddReadToIndex(read, ReadId(read_id++)); + } +} + +void SetPositionsMap(int haplotype_size, + HaplotypeReadsAlignment* hyplotype_alignment) { + vector& positions_map = + hyplotype_alignment->hap_to_ref_positions_map; + positions_map.resize(haplotype_size); + string str = hyplotype_alignment->cigar; + int cur_shift = 0; + int haplotype_pos = 0; + int last_pos = 0; + int operation_len; + string operation_type; + + smatch result; + string regex_str("(\\d+)([XIDS=])"); + regex pattern1(regex_str,regex::icase); + + string::const_iterator iter = str.begin(); + string::const_iterator iterEnd = str.end(); + string temp; + while (regex_search(iter,iterEnd,result,pattern1)) { + temp=result[0]; + int operation_len = atoi(temp.substr(0, temp.length()-1).c_str()); + char op = temp[temp.length()-1]; + switch (op) { + case '=': + case 'X': + last_pos = haplotype_pos + operation_len; + while (haplotype_pos != last_pos) { + positions_map[haplotype_pos] = cur_shift; + haplotype_pos++; + } + break; + case 'S': + last_pos = haplotype_pos + operation_len; + cur_shift -= operation_len; + while (haplotype_pos != last_pos) { + positions_map[haplotype_pos] = cur_shift; + haplotype_pos++; + } + break; + case 'D': + cur_shift += operation_len; + break; + case 'I': + last_pos = haplotype_pos + operation_len; + while (haplotype_pos != last_pos) { + positions_map[haplotype_pos] = cur_shift; + cur_shift--; + haplotype_pos++; + } + break; + } + iter = result[0].second; + } +} +// +void ReAligner::CalculatePositionMaps() { + for (auto& hyplotype_alignment : read_to_haplotype_alignments_) { + SetPositionsMap(haplotypes_[hyplotype_alignment.haplotype_index].size(), + &hyplotype_alignment); + } +} +// +bool ReAligner::GetBestReadAlignment( + int readId, + int* best_hap_index) const { + int best_score = 0; + bool best_haplotype_found = false; + for (int hap_index = 0; hap_index < haplotypes_.size(); hap_index++) { + if (read_to_haplotype_alignments_[hap_index] + .read_alignment_scores[readId] + .score > best_score + || (best_score > 0 && + read_to_haplotype_alignments_[hap_index] + .read_alignment_scores[readId] + .score == best_score && + !read_to_haplotype_alignments_[hap_index].is_reference)) { + best_score = read_to_haplotype_alignments_[hap_index] + .read_alignment_scores[readId] + .score; + *best_hap_index = hap_index; + best_haplotype_found = true; + } + } + return best_haplotype_found; +} + +int AlignedLength(const list& cigar) { + int len = 0; + for (auto& op : cigar) { + if (op.operation != DELETE) { + len += op.length; + } + } + return len; +} + + +void MergeCigarOp(const CigarOp& op, int read_len, list* cigar) { + const auto& last_cigar_op = + cigar->empty() ? OPERATION_UNSPECIFIED + : cigar->back().operation; + int aligned_length_before_merge = AlignedLength(*cigar); + int new_op_length = 0; + if (op.operation != DELETE) { + new_op_length = min(op.length, read_len - aligned_length_before_merge); + } else { + new_op_length = op.length; + } + + + if (new_op_length <= 0 || aligned_length_before_merge == read_len) { + return; + } + + if (op.operation == last_cigar_op) { + cigar->back().length += new_op_length; + + } else { + cigar->push_back(CigarOp(op.operation, new_op_length)); + } +} + + + +list LeftTrimHaplotypeToRefAlignment( + const list& haplotype_to_ref_cigar_ops_input, + int read_to_haplotype_pos) { + int cur_pos = 0; + list haplotype_to_ref_cigar_ops( + haplotype_to_ref_cigar_ops_input); + while (cur_pos != read_to_haplotype_pos) { + CigarOp cur_hap_op = haplotype_to_ref_cigar_ops.front(); + haplotype_to_ref_cigar_ops.pop_front(); + if (cur_hap_op.operation == + ALIGNMENT_MATCH || + cur_hap_op.operation == CLIP_HARD || + cur_hap_op.operation == CLIP_SOFT || + cur_hap_op.operation == INSERT) { + if (cur_hap_op.length + cur_pos > read_to_haplotype_pos) { + haplotype_to_ref_cigar_ops.push_front( + CigarOp(cur_hap_op.operation, + cur_hap_op.length - (read_to_haplotype_pos - cur_pos))); + } + cur_pos = min(cur_hap_op.length + cur_pos, read_to_haplotype_pos); + } + } + + if (haplotype_to_ref_cigar_ops.front().operation == + DELETE) { + haplotype_to_ref_cigar_ops.pop_front(); + } + + return haplotype_to_ref_cigar_ops; +} + +inline bool BothOpsAreMatch(const CigarOp& op1, const CigarOp& op2) { + return (op1.operation == ALIGNMENT_MATCH || + op1.operation == CLIP_SOFT) && + (op2.operation == ALIGNMENT_MATCH || + op2.operation == CLIP_SOFT); +} + +inline bool OneOfOpsIsSoftClip(const CigarOp& op1, const CigarOp& op2) { + return op1.operation == CLIP_SOFT || + op2.operation == CLIP_SOFT; +} + +inline bool DelAndMatch(const CigarOp& op1, const CigarOp& op2) { + return op1.operation == DELETE && + (op2.operation == ALIGNMENT_MATCH || + op2.operation == CLIP_SOFT); +} + +inline bool BothOpsAreDel(const CigarOp& op1, const CigarOp& op2) { + return op1.operation == DELETE && + op2.operation == DELETE; +} + +inline bool InsAndMatch(const CigarOp& op1, const CigarOp& op2) { + return op1.operation == INSERT && + (op2.operation == ALIGNMENT_MATCH || + op2.operation == CLIP_SOFT); +} + +inline bool BothOpsAreIns(const CigarOp& op1, const CigarOp& op2) { + return (op1.operation == INSERT && + op2.operation == INSERT); +} + +inline void PushFrontIfNotEmpty(const CigarOp& op, list* cigar) { + if (cigar == nullptr) { + return; + } + if (op.length > 0) { + cigar->push_front(op); + } +} + + +void ReAligner::CalculateReadToRefAlignment( + int read_index, + const ReadAlignment& read_to_haplotype_alignment, + const list& haplotype_to_ref_cigar_ops_input, + list* read_to_ref_cigar_ops) const { + int read_len = reads_[read_index].length(); + int read_to_haplotype_pos = read_to_haplotype_alignment.position; + list read_to_haplotype_cigar_ops = + CigarStringToVector(read_to_haplotype_alignment.cigar); + + + list haplotype_to_ref_cigar_ops = + LeftTrimHaplotypeToRefAlignment(haplotype_to_ref_cigar_ops_input, + read_to_haplotype_pos); + + + if (!read_to_haplotype_cigar_ops.empty() && + read_to_haplotype_cigar_ops.front().operation == + CLIP_SOFT) { + MergeCigarOp(CigarOp(CLIP_SOFT, + read_to_haplotype_cigar_ops.front().length), + read_len, read_to_ref_cigar_ops); + read_to_haplotype_cigar_ops.pop_front(); + } + + while ((!read_to_haplotype_cigar_ops.empty() || + !haplotype_to_ref_cigar_ops.empty()) && + AlignedLength(*read_to_ref_cigar_ops) < read_len) { + + if (!read_to_haplotype_cigar_ops.empty() && + haplotype_to_ref_cigar_ops.empty()) { + MergeCigarOp(read_to_haplotype_cigar_ops.front(), read_len, + read_to_ref_cigar_ops); + read_to_haplotype_cigar_ops.pop_front(); + continue; + } + + if (read_to_haplotype_cigar_ops.empty() && + !haplotype_to_ref_cigar_ops.empty()) { + break; + } + + CigarOp cur_read_to_hap_op = read_to_haplotype_cigar_ops.front(); + read_to_haplotype_cigar_ops.pop_front(); + CigarOp cur_hap_to_ref_op = haplotype_to_ref_cigar_ops.front(); + haplotype_to_ref_cigar_ops.pop_front(); + + + // cur_read_to_hap_op, cur_hap_to_ref_op = M|S, M|S + if (BothOpsAreMatch(cur_read_to_hap_op, cur_hap_to_ref_op)) { + int new_op_len = + min(cur_read_to_hap_op.length, cur_hap_to_ref_op.length); + if (OneOfOpsIsSoftClip(cur_read_to_hap_op, cur_hap_to_ref_op)) { + MergeCigarOp( + CigarOp(CLIP_SOFT, new_op_len), + read_len, read_to_ref_cigar_ops); + } else { + MergeCigarOp(CigarOp(ALIGNMENT_MATCH, + new_op_len), + read_len, read_to_ref_cigar_ops); + } + cur_read_to_hap_op.length -= new_op_len; + PushFrontIfNotEmpty(cur_read_to_hap_op, &read_to_haplotype_cigar_ops); + cur_hap_to_ref_op.length -= new_op_len; + PushFrontIfNotEmpty(cur_hap_to_ref_op, &haplotype_to_ref_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = D, M + } else if (DelAndMatch(cur_read_to_hap_op, cur_hap_to_ref_op)) { + MergeCigarOp(CigarOp(DELETE, + cur_read_to_hap_op.length), + read_len, read_to_ref_cigar_ops); + cur_hap_to_ref_op.length -= cur_read_to_hap_op.length; + PushFrontIfNotEmpty(cur_hap_to_ref_op, &haplotype_to_ref_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = M, D + } else if (DelAndMatch(cur_hap_to_ref_op, cur_read_to_hap_op)) { + MergeCigarOp(CigarOp(DELETE, + cur_hap_to_ref_op.length), + read_len, read_to_ref_cigar_ops); + PushFrontIfNotEmpty(cur_read_to_hap_op, &read_to_haplotype_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = D, D + } else if (BothOpsAreDel(cur_read_to_hap_op, cur_hap_to_ref_op)) { + MergeCigarOp( + CigarOp(DELETE, + cur_hap_to_ref_op.length + cur_read_to_hap_op.length), + read_len, read_to_ref_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = I, M + } else if (InsAndMatch(cur_read_to_hap_op, cur_hap_to_ref_op)) { + cur_read_to_hap_op.length = + min(read_len - AlignedLength(*read_to_ref_cigar_ops), + cur_read_to_hap_op.length); + MergeCigarOp(CigarOp(INSERT, + cur_read_to_hap_op.length), + read_len, read_to_ref_cigar_ops); + PushFrontIfNotEmpty(cur_hap_to_ref_op, &haplotype_to_ref_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = M, I + } else if (InsAndMatch(cur_hap_to_ref_op, cur_read_to_hap_op)) { + cur_hap_to_ref_op.length = + min(read_len - AlignedLength(*read_to_ref_cigar_ops), + cur_hap_to_ref_op.length); + MergeCigarOp(CigarOp(INSERT, + cur_hap_to_ref_op.length), + read_len, read_to_ref_cigar_ops); + // We need to decrease the length of cur_read_to_hap_op by INS length + cur_read_to_hap_op.length = + max(0, cur_read_to_hap_op.length - cur_hap_to_ref_op.length); + PushFrontIfNotEmpty(cur_read_to_hap_op, &read_to_haplotype_cigar_ops); + + // cur_read_to_hap_op, cur_hap_to_ref_op = I, I + } else if (BothOpsAreIns(cur_hap_to_ref_op, cur_read_to_hap_op)) { + cur_hap_to_ref_op.length = + cur_hap_to_ref_op.length + cur_read_to_hap_op.length; + MergeCigarOp(CigarOp(INSERT, + cur_hap_to_ref_op.length), + read_len, read_to_ref_cigar_ops); + } else { + + read_to_ref_cigar_ops->clear(); + return; + } + } +} + + + + +struct_str_arr* ReAligner::realign_reads(char* seqs[], int* positions, char* cigars[], char* reference, char* haplotypes, int ref_start, int ref_prefix, int ref_suffix, int read_size) { + + string reference_string = reference; + string str = haplotypes; + + istringstream in_hal(str); + vector haplotypes_string; + string t; + while (in_hal >> t) { + haplotypes_string.push_back(t); + } + + vector seq_list; + for (int i=0; i < read_size; i++) { + t = seqs[i]; + seq_list.push_back(t); + } + + vector position_list; + for (int i=0; i < read_size; i++) { + int pos = positions[i]; + position_list.push_back(pos); + } + + + vector cigar_list; + for (int i=0; i < read_size; i++) { + t = cigars[i]; + cigar_list.push_back(t); + } + + ReAligner aligner; + aligner.set_options(); + aligner.set_reference(reference_string); + aligner.set_haplotypes(haplotypes_string); + aligner.set_ref_start(static_cast(ref_start)); + aligner.set_ref_prefix_len(ref_prefix); + aligner.set_ref_suffix_len(ref_suffix); + vector Reads; + for (int i = 0; i < read_size; i++) { + string seq = seq_list[i]; + int position = position_list[i]; + string cigar_string = cigar_list[i]; + struct Read tmp_read(position, cigar_string, seq); + + list cigar_vector = CigarStringToVector(cigar_string); + for (auto& op : cigar_vector) { + tmp_read.cigar.push_back(op); + } + Reads.push_back(tmp_read); + } + vector result = aligner.AlignReads(Reads); + + vector result_cigar; + vector result_position; + for (auto& read : result) { + result_cigar.push_back(read.cigar_string); + result_position.push_back(read.position); + } + struct_str_arr* str_arr_ptr = new struct_str_arr(); + + for (int i=0; i < read_size; i++) { + str_arr_ptr->cigar_string[i] = new char[result_cigar[i].size() + 1]; + strcpy(str_arr_ptr->cigar_string[i], result_cigar[i].c_str()); + str_arr_ptr->position[i] = result_position[i]; + } + + return str_arr_ptr; + } + + + +extern "C" { + ReAligner obj; + struct_str_arr* realign_reads(char * seqs[], int* positions, char* cigars[], char* reference, char* haplotypes, int ref_start, int ref_prefix, int ref_suffix, int read_size) { + return obj.realign_reads(seqs, positions, cigars, reference, haplotypes, ref_start, ref_prefix, ref_suffix, read_size); + } +} + +extern "C" { + void free_memory(struct_str_arr* pointer, int size) { + for (int i=0; i< size; i++) { + delete [] pointer->cigar_string[i]; + } + delete pointer; + pointer = NULL; + } +} + + diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.h b/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.h new file mode 100644 index 0000000..4321197 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/realigner.h @@ -0,0 +1,326 @@ +/* +Copyright 2020 Google LLC. +Copyright 2021 The University of Hong Kong, Department of Computer Science + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#include +#include +#include +#include +#include +#include +#include +#include "ssw_cpp.h" + + +using namespace std; + +struct struct_str_arr +{ + int position[1000]; + char* cigar_string[1000]; +}; + +enum Operation { + OPERATION_UNSPECIFIED = 0, + ALIGNMENT_MATCH = 1, + INSERT = 2, + DELETE = 3, + SKIP = 4, + CLIP_SOFT = 5, + CLIP_HARD = 6, +}; + +struct Kmer { + string sequence; +}; + +struct ReadId { + ReadId() : is_set(false), id(0) {} + explicit ReadId(int id) : is_set(true), id(id) {} + + explicit operator int64_t() const { return id; } + explicit operator uint64_t() const { return id; } + + bool operator<(const ReadId& that) const { return id < that.id; } + bool operator==(const ReadId& that) const { return id == that.id; } + + bool is_set; + int id; +}; + + +struct KmerOffset { + KmerOffset() : is_set(false), pos(0) {} + explicit KmerOffset(int pos) : is_set(true), pos(pos) {} + + bool operator==(const KmerOffset& that) const { + return pos == that.pos && is_set == that.is_set; + } + + bool is_set; + int pos; +}; + +struct KmerOccurrence { + KmerOccurrence() {} + KmerOccurrence(ReadId read_id, KmerOffset pos) + : read_id(read_id), read_pos(pos) {} + + bool operator==(const KmerOccurrence& that) const { + return read_id == that.read_id && read_pos == that.read_pos; + } + + ReadId read_id; + KmerOffset read_pos; +}; + + +struct ReadAlignment { + static const int kNotAligned = -1; + ReadAlignment() : position(kNotAligned), cigar(""), score(0) {} + + ReadAlignment(int position_param, const string& cigar_param, + int score_param) + : position(position_param), cigar(cigar_param), score(score_param) {} + + bool operator==(const ReadAlignment& that) const { + return score == that.score && position == that.position && + cigar == that.cigar; + } + + void reset() { + score = 0; + position = kNotAligned; + cigar = ""; + } + + int position; + string cigar; + int score; +}; + +struct CigarOp { + CigarOp() + : operation(OPERATION_UNSPECIFIED), + length(0) {} + CigarOp(Operation op, int len) : operation(op), length(len) {} + + bool operator==(const CigarOp& that) const { + return operation == that.operation && length == that.length; + } + + Operation operation; + int length; +}; + + +struct Read { + int position; + list cigar; + string cigar_string; + string seq; + + Read(int position, string cigar_string, string seq) + { + this->position = position; + this->cigar_string = cigar_string; + this->seq = seq; + }; + +}; + + +struct HaplotypeReadsAlignment { + HaplotypeReadsAlignment() : haplotype_index(0), haplotype_score(0) {} + HaplotypeReadsAlignment( + int haplotype_index, + int score, + const vector& read_alignment_scores) + : haplotype_index(haplotype_index), haplotype_score(score) { + this->read_alignment_scores.assign(read_alignment_scores.begin(), + read_alignment_scores.end()); + } + + bool operator==(const HaplotypeReadsAlignment& that) const { + return haplotype_index == that.haplotype_index && + haplotype_score == that.haplotype_score && + read_alignment_scores == that.read_alignment_scores && + cigar == that.cigar && cigar_ops == that.cigar_ops && + is_reference == that.is_reference && + hap_to_ref_positions_map == that.hap_to_ref_positions_map; + } + + bool operator<(const HaplotypeReadsAlignment& that) const { + return haplotype_score < that.haplotype_score; + } + + int haplotype_index; + + int haplotype_score; + + vector read_alignment_scores; + + string cigar; + + + list cigar_ops; + + int ref_pos; + + + vector hap_to_ref_positions_map; + + + bool is_reference; +}; + + +void SetPositionsMap(int haplotype_size, + HaplotypeReadsAlignment* hyplotype_alignment); + +void MergeCigarOp(const CigarOp& op, int read_len, list* cigar); + +using KmerIndexType = + unordered_map >; + +class ReAligner { + public: + + struct_str_arr* realign_reads(char* seqs[], int* positions, char* cigars[], char* reference, char* haplotypes, int ref_start, int ref_prefix, int ref_suffix, int read_size); + + void set_reference(const string& reference); + void set_reads(const vector& reads); + vector get_reads() const { return reads_; } + void set_ref_start(uint64_t position); + void set_haplotypes(const vector& haplotypes); + uint8_t get_match_score() const { return match_score_; } + uint8_t get_mismatch_penalty() const { return mismatch_penalty_; } + void set_options(); + void set_ref_prefix_len(int ref_prefix_len) { + ref_prefix_len_ = ref_prefix_len; + } + void set_ref_suffix_len(int ref_suffix_len) { + ref_suffix_len_ = ref_suffix_len; + } + int get_ssw_alignment_score_threshold() const { + return ssw_alignment_score_threshold_; + } + +vector AlignReads( + const vector& reads_param); + + void BuildIndex(); +// + KmerIndexType GetKmerIndex() const { return kmer_index_; } + + void FastAlignReadsToHaplotype( + const string& haplotype, int* haplotype_score, + vector* haplotype_read_alignment_scores); + void SswAlignReadsToHaplotypes(int score_threshold); + + void InitSswLib(); + + void SswSetReference(const string& reference); +// + StripedSmithWaterman::Alignment SswAlign(const string& target) const; +// + void AlignHaplotypesToReference(); +// + const vector& GetReadToHaplotypeAlignments() + const { + return read_to_haplotype_alignments_; + } + + bool GetBestReadAlignment(int readId, int* best_hap_index) const; + void CalculateReadToRefAlignment( + int read_index, + const ReadAlignment& read_to_haplotype_alignment, + const list& haplotype_to_ref_cigar_ops_input, + list* read_to_ref_cigar_ops) const; + + void RealignReadsToReference( + const vector& reads, + vector& realigned_reads); + + void CalculateSswAlignmentScoreThreshold(); + +// private: + + string reference_; + + string region_chromosome_ ; + + int region_position_in_chr_; + + vector haplotypes_; + + vector read_to_haplotype_alignments_; + + KmerIndexType kmer_index_; + + vector reads_; + + int kmer_size_ = 32; + + int read_size_ = 250; + + int max_num_of_mismatches_ = 2; + + int ssw_alignment_score_threshold_ = 1; + + int match_score_ = 4; + int mismatch_penalty_ = 6; + int gap_opening_penalty_ = 8; + int gap_extending_penalty_ = 1; + + double similarity_threshold_ = 0.85; + +// unique_ptr ssw_aligner_; + StripedSmithWaterman::Aligner ssw_aligner_; +// StripedSmithWaterman::Filter filter; +// StripedSmithWaterman::Alignment alignment; + + int ref_prefix_len_; + int ref_suffix_len_; + + void FastAlignReadsToHaplotypes(); + + void AddReadToIndex(const string& read, ReadId read_id); + + void AddKmerToIndex(string kmer, ReadId read_id, + KmerOffset pos); + + int FastAlignStrings(string s1, string s2, + int max_mismatches, int* num_of_mismatches) const; + + void UpdateBestHaplotypes( + int haplotype_index, int haplotype_score, + const vector& current_read_scores); + + void CalculatePositionMaps(); +}; diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.c b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.c new file mode 100644 index 0000000..825c39b --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.c @@ -0,0 +1,867 @@ +/* The MIT License + + Copyright (c) 2012-1015 Boston College. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS + BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN + ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN + CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. +*/ + +/* Contact: Mengyao Zhao */ + +/* + * ssw.c + * + * Created by Mengyao Zhao on 6/22/10. + * Copyright 2010 Boston College. All rights reserved. + * Version 0.1.4 + * Last revision by Mengyao Zhao on 02/11/16. + * + */ + +#include +#include +#include +#include +#include +#include +#include "ssw.h" + +#ifdef __GNUC__ +#define LIKELY(x) __builtin_expect((x),1) +#define UNLIKELY(x) __builtin_expect((x),0) +#else +#define LIKELY(x) (x) +#define UNLIKELY(x) (x) +#endif + +/* Convert the coordinate in the scoring matrix into the coordinate in one line of the band. */ +#define set_u(u, w, i, j) { int x=(i)-(w); x=x>0?x:0; (u)=(j)-x+1; } + +/* Convert the coordinate in the direction matrix into the coordinate in one line of the band. */ +#define set_d(u, w, i, j, p) { int x=(i)-(w); x=x>0?x:0; x=(j)-x; (u)=x*3+p; } + +/*! @function + @abstract Round an integer to the next closest power-2 integer. + @param x integer to be rounded (in place) + @discussion x will be modified. + */ +#define kroundup32(x) (--(x), (x)|=(x)>>1, (x)|=(x)>>2, (x)|=(x)>>4, (x)|=(x)>>8, (x)|=(x)>>16, ++(x)) + +typedef struct { + uint16_t score; + int32_t ref; //0-based position + int32_t read; //alignment ending position on read, 0-based +} alignment_end; + +typedef struct { + uint32_t* seq; + int32_t length; +} cigar; + +struct _profile{ + __m128i* profile_byte; // 0: none + __m128i* profile_word; // 0: none + const int8_t* read; + const int8_t* mat; + int32_t readLen; + int32_t n; + uint8_t bias; +}; + +/* Generate query profile rearrange query sequence & calculate the weight of match/mismatch. */ +static __m128i* qP_byte (const int8_t* read_num, + const int8_t* mat, + const int32_t readLen, + const int32_t n, /* the edge length of the squre matrix mat */ + uint8_t bias) { + + int32_t segLen = (readLen + 15) / 16; /* Split the 128 bit register into 16 pieces. + Each piece is 8 bit. Split the read into 16 segments. + Calculat 16 segments in parallel. + */ + __m128i* vProfile = (__m128i*)malloc(n * segLen * sizeof(__m128i)); + int8_t* t = (int8_t*)vProfile; + int32_t nt, i, j, segNum; + + /* Generate query profile rearrange query sequence & calculate the weight of match/mismatch */ + for (nt = 0; LIKELY(nt < n); nt ++) { + for (i = 0; i < segLen; i ++) { + j = i; + for (segNum = 0; LIKELY(segNum < 16) ; segNum ++) { + *t++ = j>= readLen ? bias : mat[nt * n + read_num[j]] + bias; + j += segLen; + } + } + } + return vProfile; +} + +/* Striped Smith-Waterman + Record the highest score of each reference position. + Return the alignment score and ending position of the best alignment, 2nd best alignment, etc. + Gap begin and gap extension are different. + wight_match > 0, all other weights < 0. + The returned positions are 0-based. + */ +static alignment_end* sw_sse2_byte (const int8_t* ref, + int8_t ref_dir, // 0: forward ref; 1: reverse ref + int32_t refLen, + int32_t readLen, + const uint8_t weight_gapO, /* will be used as - */ + const uint8_t weight_gapE, /* will be used as - */ + const __m128i* vProfile, + uint8_t terminate, /* the best alignment score: used to terminate + the matrix calculation when locating the + alignment beginning point. If this score + is set to 0, it will not be used */ + uint8_t bias, /* Shift 0 point to a positive value. */ + int32_t maskLen) { + +#define max16(m, vm) (vm) = _mm_max_epu8((vm), _mm_srli_si128((vm), 8)); \ + (vm) = _mm_max_epu8((vm), _mm_srli_si128((vm), 4)); \ + (vm) = _mm_max_epu8((vm), _mm_srli_si128((vm), 2)); \ + (vm) = _mm_max_epu8((vm), _mm_srli_si128((vm), 1)); \ + (m) = _mm_extract_epi16((vm), 0) + + uint8_t max = 0; /* the max alignment score */ + int32_t end_read = readLen - 1; + int32_t end_ref = -1; /* 0_based best alignment ending point; Initialized as isn't aligned -1. */ + int32_t segLen = (readLen + 15) / 16; /* number of segment */ + + /* array to record the largest score of each reference position */ + uint8_t* maxColumn = (uint8_t*) calloc(refLen, 1); + + /* array to record the alignment read ending position of the largest score of each reference position */ + int32_t* end_read_column = (int32_t*) calloc(refLen, sizeof(int32_t)); + + /* Define 16 byte 0 vector. */ + __m128i vZero = _mm_set1_epi32(0); + + __m128i* pvHStore = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvHLoad = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvE = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvHmax = (__m128i*) calloc(segLen, sizeof(__m128i)); + + int32_t i, j; + /* 16 byte insertion begin vector */ + __m128i vGapO = _mm_set1_epi8(weight_gapO); + + /* 16 byte insertion extension vector */ + __m128i vGapE = _mm_set1_epi8(weight_gapE); + + /* 16 byte bias vector */ + __m128i vBias = _mm_set1_epi8(bias); + + __m128i vMaxScore = vZero; /* Trace the highest score of the whole SW matrix. */ + __m128i vMaxMark = vZero; /* Trace the highest score till the previous column. */ + __m128i vTemp; + int32_t edge, begin = 0, end = refLen, step = 1; + + /* outer loop to process the reference sequence */ + if (ref_dir == 1) { + begin = refLen - 1; + end = -1; + step = -1; + } + for (i = begin; LIKELY(i != end); i += step) { + int32_t cmp; + __m128i e, vF = vZero, vMaxColumn = vZero; /* Initialize F value to 0. + Any errors to vH values will be corrected in the Lazy_F loop. + */ + + __m128i vH = pvHStore[segLen - 1]; + vH = _mm_slli_si128 (vH, 1); /* Shift the 128-bit value in vH left by 1 byte. */ + const __m128i* vP = vProfile + ref[i] * segLen; /* Right part of the vProfile */ + + /* Swap the 2 H buffers. */ + __m128i* pv = pvHLoad; + pvHLoad = pvHStore; + pvHStore = pv; + + /* inner loop to process the query sequence */ + for (j = 0; LIKELY(j < segLen); ++j) { + vH = _mm_adds_epu8(vH, _mm_load_si128(vP + j)); + vH = _mm_subs_epu8(vH, vBias); /* vH will be always > 0 */ + + /* Get max from vH, vE and vF. */ + e = _mm_load_si128(pvE + j); + vH = _mm_max_epu8(vH, e); + vH = _mm_max_epu8(vH, vF); + vMaxColumn = _mm_max_epu8(vMaxColumn, vH); + + /* Save vH values. */ + _mm_store_si128(pvHStore + j, vH); + + /* Update vE value. */ + vH = _mm_subs_epu8(vH, vGapO); /* saturation arithmetic, result >= 0 */ + e = _mm_subs_epu8(e, vGapE); + e = _mm_max_epu8(e, vH); + _mm_store_si128(pvE + j, e); + + /* Update vF value. */ + vF = _mm_subs_epu8(vF, vGapE); + vF = _mm_max_epu8(vF, vH); + + /* Load the next vH. */ + vH = _mm_load_si128(pvHLoad + j); + } + + /* Lazy_F loop: has been revised to disallow adjecent insertion and then deletion, so don't update E(i, j), learn from SWPS3 */ + /* reset pointers to the start of the saved data */ + j = 0; + vH = _mm_load_si128 (pvHStore + j); + + /* the computed vF value is for the given column. since */ + /* we are at the end, we need to shift the vF value over */ + /* to the next column. */ + vF = _mm_slli_si128 (vF, 1); + vTemp = _mm_subs_epu8 (vH, vGapO); + vTemp = _mm_subs_epu8 (vF, vTemp); + vTemp = _mm_cmpeq_epi8 (vTemp, vZero); + cmp = _mm_movemask_epi8 (vTemp); + + while (cmp != 0xffff) + { + vH = _mm_max_epu8 (vH, vF); + vMaxColumn = _mm_max_epu8(vMaxColumn, vH); + _mm_store_si128 (pvHStore + j, vH); + vF = _mm_subs_epu8 (vF, vGapE); + j++; + if (j >= segLen) + { + j = 0; + vF = _mm_slli_si128 (vF, 1); + } + vH = _mm_load_si128 (pvHStore + j); + + vTemp = _mm_subs_epu8 (vH, vGapO); + vTemp = _mm_subs_epu8 (vF, vTemp); + vTemp = _mm_cmpeq_epi8 (vTemp, vZero); + cmp = _mm_movemask_epi8 (vTemp); + } + + vMaxScore = _mm_max_epu8(vMaxScore, vMaxColumn); + vTemp = _mm_cmpeq_epi8(vMaxMark, vMaxScore); + cmp = _mm_movemask_epi8(vTemp); + if (cmp != 0xffff) { + uint8_t temp; + vMaxMark = vMaxScore; + max16(temp, vMaxScore); + vMaxScore = vMaxMark; + + if (LIKELY(temp > max)) { + max = temp; + if (max + bias >= 255) break; //overflow + end_ref = i; + + /* Store the column with the highest alignment score in order to trace the alignment ending position on read. */ + for (j = 0; LIKELY(j < segLen); ++j) pvHmax[j] = pvHStore[j]; + } + } + + /* Record the max score of current column. */ + max16(maxColumn[i], vMaxColumn); + if (maxColumn[i] == terminate) break; + } + + /* Trace the alignment ending position on read. */ + uint8_t *t = (uint8_t*)pvHmax; + int32_t column_len = segLen * 16; + for (i = 0; LIKELY(i < column_len); ++i, ++t) { + int32_t temp; + if (*t == max) { + temp = i / 16 + i % 16 * segLen; + if (temp < end_read) end_read = temp; + } + } + + free(pvHmax); + free(pvE); + free(pvHLoad); + free(pvHStore); + + /* Find the most possible 2nd best alignment. */ + alignment_end* bests = (alignment_end*) calloc(2, sizeof(alignment_end)); + bests[0].score = max + bias >= 255 ? 255 : max; + bests[0].ref = end_ref; + bests[0].read = end_read; + + bests[1].score = 0; + bests[1].ref = 0; + bests[1].read = 0; + + edge = (end_ref - maskLen) > 0 ? (end_ref - maskLen) : 0; + for (i = 0; i < edge; i ++) { + if (maxColumn[i] > bests[1].score) { + bests[1].score = maxColumn[i]; + bests[1].ref = i; + } + } + edge = (end_ref + maskLen) > refLen ? refLen : (end_ref + maskLen); + for (i = edge + 1; i < refLen; i ++) { + if (maxColumn[i] > bests[1].score) { + bests[1].score = maxColumn[i]; + bests[1].ref = i; + } + } + + free(maxColumn); + free(end_read_column); + return bests; +} + +static __m128i* qP_word (const int8_t* read_num, + const int8_t* mat, + const int32_t readLen, + const int32_t n) { + + int32_t segLen = (readLen + 7) / 8; + __m128i* vProfile = (__m128i*)malloc(n * segLen * sizeof(__m128i)); + int16_t* t = (int16_t*)vProfile; + int32_t nt, i, j; + int32_t segNum; + + /* Generate query profile rearrange query sequence & calculate the weight of match/mismatch */ + for (nt = 0; LIKELY(nt < n); nt ++) { + for (i = 0; i < segLen; i ++) { + j = i; + for (segNum = 0; LIKELY(segNum < 8) ; segNum ++) { + *t++ = j>= readLen ? 0 : mat[nt * n + read_num[j]]; + j += segLen; + } + } + } + return vProfile; +} + +static alignment_end* sw_sse2_word (const int8_t* ref, + int8_t ref_dir, // 0: forward ref; 1: reverse ref + int32_t refLen, + int32_t readLen, + const uint8_t weight_gapO, /* will be used as - */ + const uint8_t weight_gapE, /* will be used as - */ + const __m128i* vProfile, + uint16_t terminate, + int32_t maskLen) { + +#define max8(m, vm) (vm) = _mm_max_epi16((vm), _mm_srli_si128((vm), 8)); \ + (vm) = _mm_max_epi16((vm), _mm_srli_si128((vm), 4)); \ + (vm) = _mm_max_epi16((vm), _mm_srli_si128((vm), 2)); \ + (m) = _mm_extract_epi16((vm), 0) + + uint16_t max = 0; /* the max alignment score */ + int32_t end_read = readLen - 1; + int32_t end_ref = 0; /* 1_based best alignment ending point; Initialized as isn't aligned - 0. */ + int32_t segLen = (readLen + 7) / 8; /* number of segment */ + + /* array to record the largest score of each reference position */ + uint16_t* maxColumn = (uint16_t*) calloc(refLen, 2); + + /* array to record the alignment read ending position of the largest score of each reference position */ + int32_t* end_read_column = (int32_t*) calloc(refLen, sizeof(int32_t)); + + /* Define 16 byte 0 vector. */ + __m128i vZero = _mm_set1_epi32(0); + + __m128i* pvHStore = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvHLoad = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvE = (__m128i*) calloc(segLen, sizeof(__m128i)); + __m128i* pvHmax = (__m128i*) calloc(segLen, sizeof(__m128i)); + + int32_t i, j, k; + /* 16 byte insertion begin vector */ + __m128i vGapO = _mm_set1_epi16(weight_gapO); + + /* 16 byte insertion extension vector */ + __m128i vGapE = _mm_set1_epi16(weight_gapE); + + __m128i vMaxScore = vZero; /* Trace the highest score of the whole SW matrix. */ + __m128i vMaxMark = vZero; /* Trace the highest score till the previous column. */ + __m128i vTemp; + int32_t edge, begin = 0, end = refLen, step = 1; + + /* outer loop to process the reference sequence */ + if (ref_dir == 1) { + begin = refLen - 1; + end = -1; + step = -1; + } + for (i = begin; LIKELY(i != end); i += step) { + int32_t cmp; + __m128i e, vF = vZero; /* Initialize F value to 0. + Any errors to vH values will be corrected in the Lazy_F loop. + */ + __m128i vH = pvHStore[segLen - 1]; + vH = _mm_slli_si128 (vH, 2); /* Shift the 128-bit value in vH left by 2 byte. */ + + /* Swap the 2 H buffers. */ + __m128i* pv = pvHLoad; + + __m128i vMaxColumn = vZero; /* vMaxColumn is used to record the max values of column i. */ + + const __m128i* vP = vProfile + ref[i] * segLen; /* Right part of the vProfile */ + pvHLoad = pvHStore; + pvHStore = pv; + + /* inner loop to process the query sequence */ + for (j = 0; LIKELY(j < segLen); j ++) { + vH = _mm_adds_epi16(vH, _mm_load_si128(vP + j)); + + /* Get max from vH, vE and vF. */ + e = _mm_load_si128(pvE + j); + vH = _mm_max_epi16(vH, e); + vH = _mm_max_epi16(vH, vF); + vMaxColumn = _mm_max_epi16(vMaxColumn, vH); + + /* Save vH values. */ + _mm_store_si128(pvHStore + j, vH); + + /* Update vE value. */ + vH = _mm_subs_epu16(vH, vGapO); /* saturation arithmetic, result >= 0 */ + e = _mm_subs_epu16(e, vGapE); + e = _mm_max_epi16(e, vH); + _mm_store_si128(pvE + j, e); + + /* Update vF value. */ + vF = _mm_subs_epu16(vF, vGapE); + vF = _mm_max_epi16(vF, vH); + + /* Load the next vH. */ + vH = _mm_load_si128(pvHLoad + j); + } + + /* Lazy_F loop: has been revised to disallow adjecent insertion and then deletion, so don't update E(i, j), learn from SWPS3 */ + for (k = 0; LIKELY(k < 8); ++k) { + vF = _mm_slli_si128 (vF, 2); + for (j = 0; LIKELY(j < segLen); ++j) { + vH = _mm_load_si128(pvHStore + j); + vH = _mm_max_epi16(vH, vF); + vMaxColumn = _mm_max_epi16(vMaxColumn, vH); //newly added line + _mm_store_si128(pvHStore + j, vH); + vH = _mm_subs_epu16(vH, vGapO); + vF = _mm_subs_epu16(vF, vGapE); + if (UNLIKELY(! _mm_movemask_epi8(_mm_cmpgt_epi16(vF, vH)))) goto end; + } + } + +end: + vMaxScore = _mm_max_epi16(vMaxScore, vMaxColumn); + vTemp = _mm_cmpeq_epi16(vMaxMark, vMaxScore); + cmp = _mm_movemask_epi8(vTemp); + if (cmp != 0xffff) { + uint16_t temp; + vMaxMark = vMaxScore; + max8(temp, vMaxScore); + vMaxScore = vMaxMark; + + if (LIKELY(temp > max)) { + max = temp; + end_ref = i; + for (j = 0; LIKELY(j < segLen); ++j) pvHmax[j] = pvHStore[j]; + } + } + + /* Record the max score of current column. */ + max8(maxColumn[i], vMaxColumn); + if (maxColumn[i] == terminate) break; + } + + /* Trace the alignment ending position on read. */ + uint16_t *t = (uint16_t*)pvHmax; + int32_t column_len = segLen * 8; + for (i = 0; LIKELY(i < column_len); ++i, ++t) { + int32_t temp; + if (*t == max) { + temp = i / 8 + i % 8 * segLen; + if (temp < end_read) end_read = temp; + } + } + + free(pvHmax); + free(pvE); + free(pvHLoad); + free(pvHStore); + + /* Find the most possible 2nd best alignment. */ + alignment_end* bests = (alignment_end*) calloc(2, sizeof(alignment_end)); + bests[0].score = max; + bests[0].ref = end_ref; + bests[0].read = end_read; + + bests[1].score = 0; + bests[1].ref = 0; + bests[1].read = 0; + + edge = (end_ref - maskLen) > 0 ? (end_ref - maskLen) : 0; + for (i = 0; i < edge; i ++) { + if (maxColumn[i] > bests[1].score) { + bests[1].score = maxColumn[i]; + bests[1].ref = i; + } + } + edge = (end_ref + maskLen) > refLen ? refLen : (end_ref + maskLen); + for (i = edge; i < refLen; i ++) { + if (maxColumn[i] > bests[1].score) { + bests[1].score = maxColumn[i]; + bests[1].ref = i; + } + } + + free(maxColumn); + free(end_read_column); + return bests; +} + +static cigar* banded_sw (const int8_t* ref, + const int8_t* read, + int32_t refLen, + int32_t readLen, + int32_t score, + const uint32_t weight_gapO, /* will be used as - */ + const uint32_t weight_gapE, /* will be used as - */ + int32_t band_width, + const int8_t* mat, /* pointer to the weight matrix */ + int32_t n) { + + uint32_t *c = (uint32_t*)malloc(16 * sizeof(uint32_t)), *c1; + int32_t i, j, e, f, temp1, temp2, s = 16, s1 = 8, l, max = 0; + int64_t s2 = 1024; + char op, prev_op; + int32_t width, width_d, *h_b, *e_b, *h_c; + int8_t *direction, *direction_line; + cigar* result = (cigar*)malloc(sizeof(cigar)); + h_b = (int32_t*)malloc(s1 * sizeof(int32_t)); + e_b = (int32_t*)malloc(s1 * sizeof(int32_t)); + h_c = (int32_t*)malloc(s1 * sizeof(int32_t)); + direction = (int8_t*)malloc(s2 * sizeof(int8_t)); + + do { + width = band_width * 2 + 3, width_d = band_width * 2 + 1; + while (width >= s1) { + ++s1; + kroundup32(s1); + h_b = (int32_t*)realloc(h_b, s1 * sizeof(int32_t)); + e_b = (int32_t*)realloc(e_b, s1 * sizeof(int32_t)); + h_c = (int32_t*)realloc(h_c, s1 * sizeof(int32_t)); + } + while (width_d * readLen * 3 >= s2) { + ++s2; + kroundup32(s2); + if (s2 < 0) { + fprintf(stderr, "Alignment score and position are not consensus.\n"); + exit(1); + } + direction = (int8_t*)realloc(direction, s2 * sizeof(int8_t)); + } + direction_line = direction; + for (j = 1; LIKELY(j < width - 1); j ++) h_b[j] = 0; + for (i = 0; LIKELY(i < readLen); i ++) { + int32_t beg = 0, end = refLen - 1, u = 0, edge; + j = i - band_width; beg = beg > j ? beg : j; // band start + j = i + band_width; end = end < j ? end : j; // band end + edge = end + 1 < width - 1 ? end + 1 : width - 1; + f = h_b[0] = e_b[0] = h_b[edge] = e_b[edge] = h_c[0] = 0; + direction_line = direction + width_d * i * 3; + + for (j = beg; LIKELY(j <= end); j ++) { + int32_t b, e1, f1, d, de, df, dh; + set_u(u, band_width, i, j); set_u(e, band_width, i - 1, j); + set_u(b, band_width, i, j - 1); set_u(d, band_width, i - 1, j - 1); + set_d(de, band_width, i, j, 0); + set_d(df, band_width, i, j, 1); + set_d(dh, band_width, i, j, 2); + + temp1 = i == 0 ? -weight_gapO : h_b[e] - weight_gapO; + temp2 = i == 0 ? -weight_gapE : e_b[e] - weight_gapE; + e_b[u] = temp1 > temp2 ? temp1 : temp2; + direction_line[de] = temp1 > temp2 ? 3 : 2; + + temp1 = h_c[b] - weight_gapO; + temp2 = f - weight_gapE; + f = temp1 > temp2 ? temp1 : temp2; + direction_line[df] = temp1 > temp2 ? 5 : 4; + + e1 = e_b[u] > 0 ? e_b[u] : 0; + f1 = f > 0 ? f : 0; + temp1 = e1 > f1 ? e1 : f1; + temp2 = h_b[d] + mat[ref[j] * n + read[i]]; + h_c[u] = temp1 > temp2 ? temp1 : temp2; + + if (h_c[u] > max) max = h_c[u]; + + if (temp1 <= temp2) direction_line[dh] = 1; + else direction_line[dh] = e1 > f1 ? direction_line[de] : direction_line[df]; + } + for (j = 1; j <= u; j ++) h_b[j] = h_c[j]; + } + band_width *= 2; + } while (LIKELY(max < score)); + band_width /= 2; + + // trace back + i = readLen - 1; + j = refLen - 1; + e = 0; // Count the number of M, D or I. + l = 0; // record length of current cigar + op = prev_op = 'M'; + temp2 = 2; // h + while (LIKELY(i > 0)) { + set_d(temp1, band_width, i, j, temp2); + switch (direction_line[temp1]) { + case 1: + --i; + --j; + temp2 = 2; + direction_line -= width_d * 3; + op = 'M'; + break; + case 2: + --i; + temp2 = 0; // e + direction_line -= width_d * 3; + op = 'I'; + break; + case 3: + --i; + temp2 = 2; + direction_line -= width_d * 3; + op = 'I'; + break; + case 4: + --j; + temp2 = 1; + op = 'D'; + break; + case 5: + --j; + temp2 = 2; + op = 'D'; + break; + default: + fprintf(stderr, "Trace back error: %d.\n", direction_line[temp1 - 1]); + free(direction); + free(h_c); + free(e_b); + free(h_b); + free(c); + free(result); + return 0; + } + if (op == prev_op) ++e; + else { + ++l; + while (l >= s) { + ++s; + kroundup32(s); + c = (uint32_t*)realloc(c, s * sizeof(uint32_t)); + } + c[l - 1] = to_cigar_int(e, prev_op); + prev_op = op; + e = 1; + } + } + if (op == 'M') { + ++l; + while (l >= s) { + ++s; + kroundup32(s); + c = (uint32_t*)realloc(c, s * sizeof(uint32_t)); + } + c[l - 1] = to_cigar_int(e + 1, op); + }else { + l += 2; + while (l >= s) { + ++s; + kroundup32(s); + c = (uint32_t*)realloc(c, s * sizeof(uint32_t)); + } + c[l - 2] = to_cigar_int(e, op); + c[l - 1] = to_cigar_int(1, 'M'); + } + + // reverse cigar + c1 = (uint32_t*)malloc(l * sizeof(uint32_t)); + s = 0; + e = l - 1; + while (LIKELY(s <= e)) { + c1[s] = c[e]; + c1[e] = c[s]; + ++ s; + -- e; + } + result->seq = c1; + result->length = l; + + free(direction); + free(h_c); + free(e_b); + free(h_b); + free(c); + return result; +} + +static int8_t* seq_reverse(const int8_t* seq, int32_t end) /* end is 0-based alignment ending position */ +{ + int8_t* reverse = (int8_t*)calloc(end + 1, sizeof(int8_t)); + int32_t start = 0; + while (LIKELY(start <= end)) { + reverse[start] = seq[end]; + reverse[end] = seq[start]; + ++ start; + -- end; + } + return reverse; +} + +s_profile* ssw_init (const int8_t* read, const int32_t readLen, const int8_t* mat, const int32_t n, const int8_t score_size) { + s_profile* p = (s_profile*)calloc(1, sizeof(struct _profile)); + p->profile_byte = 0; + p->profile_word = 0; + p->bias = 0; + + if (score_size == 0 || score_size == 2) { + /* Find the bias to use in the substitution matrix */ + int32_t bias = 0, i; + for (i = 0; i < n*n; i++) if (mat[i] < bias) bias = mat[i]; + bias = abs(bias); + + p->bias = bias; + p->profile_byte = qP_byte (read, mat, readLen, n, bias); + } + if (score_size == 1 || score_size == 2) p->profile_word = qP_word (read, mat, readLen, n); + p->read = read; + p->mat = mat; + p->readLen = readLen; + p->n = n; + return p; +} + +void init_destroy (s_profile* p) { + free(p->profile_byte); + free(p->profile_word); + free(p); +} + +s_align* ssw_align (const s_profile* prof, + const int8_t* ref, + int32_t refLen, + const uint8_t weight_gapO, + const uint8_t weight_gapE, + const uint8_t flag, // (from high to low) bit 5: return the best alignment beginning position; 6: if (ref_end1 - ref_begin1 <= filterd) && (read_end1 - read_begin1 <= filterd), return cigar; 7: if max score >= filters, return cigar; 8: always return cigar; if 6 & 7 are both setted, only return cigar when both filter fulfilled + const uint16_t filters, + const int32_t filterd, + const int32_t maskLen) { + + alignment_end* bests = 0, *bests_reverse = 0; + __m128i* vP = 0; + int32_t word = 0, band_width = 0, readLen = prof->readLen; + int8_t* read_reverse = 0; + cigar* path; + s_align* r = (s_align*)calloc(1, sizeof(s_align)); + r->ref_begin1 = -1; + r->read_begin1 = -1; + r->cigar = 0; + r->cigarLen = 0; + if (maskLen < 15) { + fprintf(stderr, "When maskLen < 15, the function ssw_align doesn't return 2nd best alignment information.\n"); + } + + // Find the alignment scores and ending positions + if (prof->profile_byte) { + bests = sw_sse2_byte(ref, 0, refLen, readLen, weight_gapO, weight_gapE, prof->profile_byte, -1, prof->bias, maskLen); + if (prof->profile_word && bests[0].score == 255) { + free(bests); + bests = sw_sse2_word(ref, 0, refLen, readLen, weight_gapO, weight_gapE, prof->profile_word, -1, maskLen); + word = 1; + } else if (bests[0].score == 255) { + fprintf(stderr, "Please set 2 to the score_size parameter of the function ssw_init, otherwise the alignment results will be incorrect.\n"); + free(r); + return NULL; + } + }else if (prof->profile_word) { + bests = sw_sse2_word(ref, 0, refLen, readLen, weight_gapO, weight_gapE, prof->profile_word, -1, maskLen); + word = 1; + }else { + fprintf(stderr, "Please call the function ssw_init before ssw_align.\n"); + free(r); + return NULL; + } + r->score1 = bests[0].score; + r->ref_end1 = bests[0].ref; + r->read_end1 = bests[0].read; + if (maskLen >= 15) { + r->score2 = bests[1].score; + r->ref_end2 = bests[1].ref; + } else { + r->score2 = 0; + r->ref_end2 = -1; + } + free(bests); + if (flag == 0 || (flag == 2 && r->score1 < filters)) goto end; + + // Find the beginning position of the best alignment. + read_reverse = seq_reverse(prof->read, r->read_end1); + if (word == 0) { + vP = qP_byte(read_reverse, prof->mat, r->read_end1 + 1, prof->n, prof->bias); + bests_reverse = sw_sse2_byte(ref, 1, r->ref_end1 + 1, r->read_end1 + 1, weight_gapO, weight_gapE, vP, r->score1, prof->bias, maskLen); + } else { + vP = qP_word(read_reverse, prof->mat, r->read_end1 + 1, prof->n); + bests_reverse = sw_sse2_word(ref, 1, r->ref_end1 + 1, r->read_end1 + 1, weight_gapO, weight_gapE, vP, r->score1, maskLen); + } + free(vP); + free(read_reverse); + r->ref_begin1 = bests_reverse[0].ref; + r->read_begin1 = r->read_end1 - bests_reverse[0].read; + free(bests_reverse); + if ((7&flag) == 0 || ((2&flag) != 0 && r->score1 < filters) || ((4&flag) != 0 && (r->ref_end1 - r->ref_begin1 > filterd || r->read_end1 - r->read_begin1 > filterd))) goto end; + + // Generate cigar. + refLen = r->ref_end1 - r->ref_begin1 + 1; + readLen = r->read_end1 - r->read_begin1 + 1; + band_width = abs(refLen - readLen) + 1; + path = banded_sw(ref + r->ref_begin1, prof->read + r->read_begin1, refLen, readLen, r->score1, weight_gapO, weight_gapE, band_width, prof->mat, prof->n); + if (path == 0) { + free(r); + r = NULL; + } + else { + r->cigar = path->seq; + r->cigarLen = path->length; + free(path); + } + +end: + return r; +} + +void align_destroy (s_align* a) { + free(a->cigar); + free(a); +} +/* +inline char cigar_int_to_op(uint32_t cigar_int) { + return UNLIKELY((cigar_int & 0xfU) > 8) ? 'M': MAPSTR[cigar_int & 0xfU]; +} + + +inline uint32_t cigar_int_to_len (uint32_t cigar_int) +{ + return cigar_int >> BAM_CIGAR_SHIFT; +}*/ diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.h b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.h new file mode 100644 index 0000000..685ecf3 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw.h @@ -0,0 +1,188 @@ +/* + * ssw.h + * + * Created by Mengyao Zhao on 6/22/10. + * Copyright 2010 Boston College. All rights reserved. + * Version 0.1.4 + * Last revision by Mengyao Zhao on 02/11/16. + * + */ + +#ifndef SSW_H +#define SSW_H + +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +#define MAPSTR "MIDNSHP=X" +#ifndef BAM_CIGAR_SHIFT +#define BAM_CIGAR_SHIFT 4 +#endif + + +/*! @typedef structure of the query profile */ +struct _profile; +typedef struct _profile s_profile; + +/*! @typedef structure of the alignment result + @field score1 the best alignment score + @field score2 sub-optimal alignment score + @field ref_begin1 0-based best alignment beginning position on reference; ref_begin1 = -1 when the best alignment beginning + position is not available + @field ref_end1 0-based best alignment ending position on reference + @field read_begin1 0-based best alignment beginning position on read; read_begin1 = -1 when the best alignment beginning + position is not available + @field read_end1 0-based best alignment ending position on read + @field read_end2 0-based sub-optimal alignment ending position on read + @field cigar best alignment cigar; stored the same as that in BAM format, high 28 bits: length, low 4 bits: M/I/D (0/1/2); + cigar = 0 when the best alignment path is not available + @field cigarLen length of the cigar string; cigarLen = 0 when the best alignment path is not available +*/ +typedef struct { + uint16_t score1; + uint16_t score2; + int32_t ref_begin1; + int32_t ref_end1; + int32_t read_begin1; + int32_t read_end1; + int32_t ref_end2; + uint32_t* cigar; + int32_t cigarLen; +} s_align; + +/*! @function Create the query profile using the query sequence. + @param read pointer to the query sequence; the query sequence needs to be numbers + @param readLen length of the query sequence + @param mat pointer to the substitution matrix; mat needs to be corresponding to the read sequence + @param n the square root of the number of elements in mat (mat has n*n elements) + @param score_size estimated Smith-Waterman score; if your estimated best alignment score is surely < 255 please set 0; if + your estimated best alignment score >= 255, please set 1; if you don't know, please set 2 + @return pointer to the query profile structure + @note example for parameter read and mat: + If the query sequence is: ACGTATC, the sequence that read points to can be: 1234142 + Then if the penalty for match is 2 and for mismatch is -2, the substitution matrix of parameter mat will be: + //A C G T + 2 -2 -2 -2 //A + -2 2 -2 -2 //C + -2 -2 2 -2 //G + -2 -2 -2 2 //T + mat is the pointer to the array {2, -2, -2, -2, -2, 2, -2, -2, -2, -2, 2, -2, -2, -2, -2, 2} +*/ +s_profile* ssw_init (const int8_t* read, const int32_t readLen, const int8_t* mat, const int32_t n, const int8_t score_size); + +/*! @function Release the memory allocated by function ssw_init. + @param p pointer to the query profile structure +*/ +void init_destroy (s_profile* p); + +// @function ssw alignment. +/*! @function Do Striped Smith-Waterman alignment. + @param prof pointer to the query profile structure + @param ref pointer to the target sequence; the target sequence needs to be numbers and corresponding to the mat parameter of + function ssw_init + @param refLen length of the target sequence + @param weight_gapO the absolute value of gap open penalty + @param weight_gapE the absolute value of gap extension penalty + @param flag bitwise FLAG; (from high to low) bit 5: when setted as 1, function ssw_align will return the best alignment + beginning position; bit 6: when setted as 1, if (ref_end1 - ref_begin1 < filterd && read_end1 - read_begin1 + < filterd), (whatever bit 5 is setted) the function will return the best alignment beginning position and + cigar; bit 7: when setted as 1, if the best alignment score >= filters, (whatever bit 5 is setted) the function + will return the best alignment beginning position and cigar; bit 8: when setted as 1, (whatever bit 5, 6 or 7 is + setted) the function will always return the best alignment beginning position and cigar. When flag == 0, only + the optimal and sub-optimal scores and the optimal alignment ending position will be returned. + @param filters score filter: when bit 7 of flag is setted as 1 and bit 8 is setted as 0, filters will be used (Please check the + decription of the flag parameter for detailed usage.) + @param filterd distance filter: when bit 6 of flag is setted as 1 and bit 8 is setted as 0, filterd will be used (Please check + the decription of the flag parameter for detailed usage.) + @param maskLen The distance between the optimal and suboptimal alignment ending position >= maskLen. We suggest to use + readLen/2, if you don't have special concerns. Note: maskLen has to be >= 15, otherwise this function will NOT + return the suboptimal alignment information. Detailed description of maskLen: After locating the optimal + alignment ending position, the suboptimal alignment score can be heuristically found by checking the second + largest score in the array that contains the maximal score of each column of the SW matrix. In order to avoid + picking the scores that belong to the alignments sharing the partial best alignment, SSW C library masks the + reference loci nearby (mask length = maskLen) the best alignment ending position and locates the second largest + score from the unmasked elements. + @return pointer to the alignment result structure + @note Whatever the parameter flag is setted, this function will at least return the optimal and sub-optimal alignment score, + and the optimal alignment ending positions on target and query sequences. If both bit 6 and 7 of the flag are setted + while bit 8 is not, the function will return cigar only when both criteria are fulfilled. All returned positions are + 0-based coordinate. +*/ +s_align* ssw_align (const s_profile* prof, + const int8_t* ref, + int32_t refLen, + const uint8_t weight_gapO, + const uint8_t weight_gapE, + const uint8_t flag, + const uint16_t filters, + const int32_t filterd, + const int32_t maskLen); + +/*! @function Release the memory allocated by function ssw_align. + @param a pointer to the alignment result structure +*/ +void align_destroy (s_align* a); + +/*! @function Produce CIGAR 32-bit unsigned integer from CIGAR operation and CIGAR length + @param length length of CIGAR + @param op_letter CIGAR operation character ('M', 'I', etc) + @return 32-bit unsigned integer, representing encoded CIGAR operation and length +*/ +static inline uint32_t to_cigar_int (uint32_t length, char op_letter) +{ + switch (op_letter) { + case 'M': /* alignment match (can be a sequence match or mismatch */ + default: + return length << BAM_CIGAR_SHIFT; + case 'S': /* soft clipping (clipped sequences present in SEQ) */ + return (length << BAM_CIGAR_SHIFT) | (4u); + case 'D': /* deletion from the reference */ + return (length << BAM_CIGAR_SHIFT) | (2u); + case 'I': /* insertion to the reference */ + return (length << BAM_CIGAR_SHIFT) | (1u); + case 'H': /* hard clipping (clipped sequences NOT present in SEQ) */ + return (length << BAM_CIGAR_SHIFT) | (5u); + case 'N': /* skipped region from the reference */ + return (length << BAM_CIGAR_SHIFT) | (3u); + case 'P': /* padding (silent deletion from padded reference) */ + return (length << BAM_CIGAR_SHIFT) | (6u); + case '=': /* sequence match */ + return (length << BAM_CIGAR_SHIFT) | (7u); + case 'X': /* sequence mismatch */ + return (length << BAM_CIGAR_SHIFT) | (8u); + } + return (uint32_t)-1; // This never happens +} + + +/*! @function Extract CIGAR operation character from CIGAR 32-bit unsigned integer + @param cigar_int 32-bit unsigned integer, representing encoded CIGAR operation and length + @return CIGAR operation character ('M', 'I', etc) +*/ +//char cigar_int_to_op (uint32_t cigar_int); +static inline char cigar_int_to_op(uint32_t cigar_int) +{ + return (cigar_int & 0xfU) > 8 ? 'M': MAPSTR[cigar_int & 0xfU]; +} + + +/*! @function Extract length of a CIGAR operation from CIGAR 32-bit unsigned integer + @param cigar_int 32-bit unsigned integer, representing encoded CIGAR operation and length + @return length of CIGAR operation +*/ +//uint32_t cigar_int_to_len (uint32_t cigar_int); +static inline uint32_t cigar_int_to_len (uint32_t cigar_int) +{ + return cigar_int >> BAM_CIGAR_SHIFT; +} +#ifdef __cplusplus +} +#endif // __cplusplus + +#endif // SSW_H diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.cpp b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.cpp new file mode 100644 index 0000000..75c7def --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.cpp @@ -0,0 +1,477 @@ +#include "ssw_cpp.h" +#include "ssw.h" + +#include + +namespace { + +static const int8_t kBaseTranslation[128] = { + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + // A C G + 4, 0, 4, 1, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, + // T + 4, 4, 4, 4, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + // a c g + 4, 0, 4, 1, 4, 4, 4, 2, 4, 4, 4, 4, 4, 4, 4, 4, + // t + 4, 4, 4, 4, 3, 0, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4 +}; + +void BuildSwScoreMatrix(const uint8_t& match_score, + const uint8_t& mismatch_penalty, + int8_t* matrix) { + + // The score matrix looks like + // // A, C, G, T, N + // score_matrix_ = { 2, -2, -2, -2, -2, // A + // -2, 2, -2, -2, -2, // C + // -2, -2, 2, -2, -2, // G + // -2, -2, -2, 2, -2, // T + // -2, -2, -2, -2, -2};// N + + int id = 0; + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 4; ++j) { + matrix[id] = ((i == j) ? match_score : static_cast(-mismatch_penalty)); + ++id; + } + matrix[id] = static_cast(-mismatch_penalty); // For N + ++id; + } + + for (int i = 0; i < 5; ++i) + matrix[id++] = static_cast(-mismatch_penalty); // For N + +} + +void ConvertAlignment(const s_align& s_al, + const int& query_len, + StripedSmithWaterman::Alignment* al) { + al->sw_score = s_al.score1; + al->sw_score_next_best = s_al.score2; + al->ref_begin = s_al.ref_begin1; + al->ref_end = s_al.ref_end1; + al->query_begin = s_al.read_begin1; + al->query_end = s_al.read_end1; + al->ref_end_next_best = s_al.ref_end2; + + al->cigar.clear(); + al->cigar_string.clear(); + + if (s_al.cigarLen > 0) { + std::ostringstream cigar_string; + if (al->query_begin > 0) { + uint32_t cigar = to_cigar_int(al->query_begin, 'S'); + al->cigar.push_back(cigar); + cigar_string << al->query_begin << 'S'; + } + + for (int i = 0; i < s_al.cigarLen; ++i) { + al->cigar.push_back(s_al.cigar[i]); + cigar_string << cigar_int_to_len(s_al.cigar[i]) << cigar_int_to_op(s_al.cigar[i]); + } + + int end = query_len - al->query_end - 1; + if (end > 0) { + uint32_t cigar = to_cigar_int(end, 'S'); + al->cigar.push_back(cigar); + cigar_string << end << 'S'; + } + + al->cigar_string = cigar_string.str(); + } // end if +} + +// @Function: +// Calculate the length of the previous cigar operator +// and store it in new_cigar and new_cigar_string. +// Clean up in_M (false), in_X (false), length_M (0), and length_X(0). +void CleanPreviousMOperator( + bool* in_M, + bool* in_X, + uint32_t* length_M, + uint32_t* length_X, + std::vector* new_cigar, + std::ostringstream* new_cigar_string) { + if (*in_M) { + uint32_t match = to_cigar_int(*length_M, '='); + new_cigar->push_back(match); + (*new_cigar_string) << *length_M << '='; + } else if (*in_X){ //in_X + uint32_t match = to_cigar_int(*length_X, 'X'); + new_cigar->push_back(match); + (*new_cigar_string) << *length_X << 'X'; + } + + // Clean up + *in_M = false; + *in_X = false; + *length_M = 0; + *length_X = 0; +} + +// @Function: +// 1. Calculate the number of mismatches. +// 2. Modify the cigar string: +// differentiate matches (M) and mismatches(X). +// Note that SSW does not differentiate matches and mismatches. +// @Return: +// The number of mismatches. +int CalculateNumberMismatch( + StripedSmithWaterman::Alignment* al, + int8_t const *ref, + int8_t const *query, + const int& query_len) { + + ref += al->ref_begin; + query += al->query_begin; + int mismatch_length = 0; + + std::vector new_cigar; + std::ostringstream new_cigar_string; + + if (al->query_begin > 0) { + uint32_t cigar = to_cigar_int(al->query_begin, 'S'); + new_cigar.push_back(cigar); + new_cigar_string << al->query_begin << 'S'; + } + + bool in_M = false; // the previous is match + bool in_X = false; // the previous is mismatch + uint32_t length_M = 0; + uint32_t length_X = 0; + + for (unsigned int i = 0; i < al->cigar.size(); ++i) { + char op = cigar_int_to_op(al->cigar[i]); + uint32_t length = cigar_int_to_len(al->cigar[i]); + if (op == 'M') { + for (uint32_t j = 0; j < length; ++j) { + if (*ref != *query) { + ++mismatch_length; + if (in_M) { // the previous is match; however the current one is mismatche + uint32_t match = to_cigar_int(length_M, '='); + new_cigar.push_back(match); + new_cigar_string << length_M << '='; + } + length_M = 0; + ++length_X; + in_M = false; + in_X = true; + } else { // *ref == *query + if (in_X) { // the previous is mismatch; however the current one is matche + uint32_t match = to_cigar_int(length_X, 'X'); + new_cigar.push_back(match); + new_cigar_string << length_X << 'X'; + } + ++length_M; + length_X = 0; + in_M = true; + in_X = false; + } // end of if (*ref != *query) + ++ref; + ++query; + } + } else if (op == 'I') { + query += length; + mismatch_length += length; + CleanPreviousMOperator(&in_M, &in_X, &length_M, &length_X, &new_cigar, &new_cigar_string); + new_cigar.push_back(al->cigar[i]); + new_cigar_string << length << 'I'; + } else if (op == 'D') { + ref += length; + mismatch_length += length; + CleanPreviousMOperator(&in_M, &in_X, &length_M, &length_X, &new_cigar, &new_cigar_string); + new_cigar.push_back(al->cigar[i]); + new_cigar_string << length << 'D'; + } + } + + CleanPreviousMOperator(&in_M, &in_X, &length_M, &length_X, &new_cigar, &new_cigar_string); + + int end = query_len - al->query_end - 1; + if (end > 0) { + uint32_t cigar = to_cigar_int(end, 'S'); + new_cigar.push_back(cigar); + new_cigar_string << end << 'S'; + } + + al->cigar_string.clear(); + al->cigar.clear(); + al->cigar_string = new_cigar_string.str(); + al->cigar = new_cigar; + + return mismatch_length; +} + +void SetFlag(const StripedSmithWaterman::Filter& filter, uint8_t* flag) { + if (filter.report_begin_position) *flag |= 0x08; + if (filter.report_cigar) *flag |= 0x0f; +} + +// http://www.cplusplus.com/faq/sequences/arrays/sizeof-array/#cpp +template +inline size_t SizeOfArray( const T(&)[ N ] ) +{ + return N; +} + +} // namespace + + + +namespace StripedSmithWaterman { + +Aligner::Aligner(void) + : score_matrix_(NULL) + , score_matrix_size_(5) + , translation_matrix_(NULL) + , match_score_(4) + , mismatch_penalty_(6) + , gap_opening_penalty_(8) + , gap_extending_penalty_(2) + , translated_reference_(NULL) + , reference_length_(0) +{ + BuildDefaultMatrix(); +} + +Aligner::Aligner( + const uint8_t& match_score, + const uint8_t& mismatch_penalty, + const uint8_t& gap_opening_penalty, + const uint8_t& gap_extending_penalty) + + : score_matrix_(NULL) + , score_matrix_size_(5) + , translation_matrix_(NULL) + , match_score_(match_score) + , mismatch_penalty_(mismatch_penalty) + , gap_opening_penalty_(gap_opening_penalty) + , gap_extending_penalty_(gap_extending_penalty) + , translated_reference_(NULL) + , reference_length_(0) +{ + BuildDefaultMatrix(); +} + +Aligner::Aligner(const int8_t* score_matrix, + const int& score_matrix_size, + const int8_t* translation_matrix, + const int& translation_matrix_size) + + : score_matrix_(NULL) + , score_matrix_size_(score_matrix_size) + , translation_matrix_(NULL) + , match_score_(4) + , mismatch_penalty_(6) + , gap_opening_penalty_(8) + , gap_extending_penalty_(2) + , translated_reference_(NULL) + , reference_length_(0) +{ + score_matrix_ = new int8_t[score_matrix_size_ * score_matrix_size_]; + memcpy(score_matrix_, score_matrix, sizeof(int8_t) * score_matrix_size_ * score_matrix_size_); + translation_matrix_ = new int8_t[translation_matrix_size]; + memcpy(translation_matrix_, translation_matrix, sizeof(int8_t) * translation_matrix_size); +} + + +Aligner::~Aligner(void){ + Clear(); +} + +int Aligner::SetReferenceSequence(const char* seq, const int& length) { + + int len = 0; + if (translation_matrix_) { + // calculate the valid length + //int calculated_ref_length = static_cast(strlen(seq)); + //int valid_length = (calculated_ref_length > length) + // ? length : calculated_ref_length; + int valid_length = length; + // delete the current buffer + CleanReferenceSequence(); + // allocate a new buffer + translated_reference_ = new int8_t[valid_length]; + + len = TranslateBase(seq, valid_length, translated_reference_); + } else { + // nothing + } + + reference_length_ = len; + return len; + + +} + +int Aligner::TranslateBase(const char* bases, const int& length, + int8_t* translated) const { + + const char* ptr = bases; + int len = 0; + for (int i = 0; i < length; ++i) { + translated[i] = translation_matrix_[(int) *ptr]; + ++ptr; + ++len; + } + + return len; +} + + +bool Aligner::Align(const char* query, const Filter& filter, + Alignment* alignment) const +{ + if (!translation_matrix_) return false; + if (reference_length_ == 0) return false; + + int query_len = strlen(query); + if (query_len == 0) return false; + int8_t* translated_query = new int8_t[query_len]; + TranslateBase(query, query_len, translated_query); + + const int8_t score_size = 2; + s_profile* profile = ssw_init(translated_query, query_len, score_matrix_, + score_matrix_size_, score_size); + + uint8_t flag = 0; + SetFlag(filter, &flag); + s_align* s_al = ssw_align(profile, translated_reference_, reference_length_, + static_cast(gap_opening_penalty_), + static_cast(gap_extending_penalty_), + flag, filter.score_filter, filter.distance_filter, query_len); + + alignment->Clear(); + ConvertAlignment(*s_al, query_len, alignment); + alignment->mismatches = CalculateNumberMismatch(&*alignment, translated_reference_, translated_query, query_len); + + + // Free memory + delete [] translated_query; + align_destroy(s_al); + init_destroy(profile); + + return true; +} + + +bool Aligner::Align(const char* query, const char* ref, const int& ref_len, + const Filter& filter, Alignment* alignment) const +{ + if (!translation_matrix_) return false; + + int query_len = strlen(query); + if (query_len == 0) return false; + int8_t* translated_query = new int8_t[query_len]; + TranslateBase(query, query_len, translated_query); + + // calculate the valid length + //int calculated_ref_length = static_cast(strlen(ref)); + //int valid_ref_len = (calculated_ref_length > ref_len) + // ? ref_len : calculated_ref_length; + int valid_ref_len = ref_len; + int8_t* translated_ref = new int8_t[valid_ref_len]; + TranslateBase(ref, valid_ref_len, translated_ref); + + + const int8_t score_size = 2; + s_profile* profile = ssw_init(translated_query, query_len, score_matrix_, + score_matrix_size_, score_size); + + uint8_t flag = 0; + SetFlag(filter, &flag); + s_align* s_al = ssw_align(profile, translated_ref, valid_ref_len, + static_cast(gap_opening_penalty_), + static_cast(gap_extending_penalty_), + flag, filter.score_filter, filter.distance_filter, query_len); + + alignment->Clear(); + ConvertAlignment(*s_al, query_len, alignment); + alignment->mismatches = CalculateNumberMismatch(&*alignment, translated_ref, translated_query, query_len); + + // Free memory + delete [] translated_query; + delete [] translated_ref; + align_destroy(s_al); + init_destroy(profile); + + return true; +} + +void Aligner::Clear(void) { + ClearMatrices(); + CleanReferenceSequence(); +} + +void Aligner::SetAllDefault(void) { + score_matrix_size_ = 5; + match_score_ = 4; + mismatch_penalty_ = 6; + gap_opening_penalty_ = 8; + gap_extending_penalty_ = 2; + reference_length_ = 0; +} + +bool Aligner::ReBuild(void) { + if (translation_matrix_) return false; + + SetAllDefault(); + BuildDefaultMatrix(); + + return true; +} + +bool Aligner::ReBuild( + const uint8_t& match_score, + const uint8_t& mismatch_penalty, + const uint8_t& gap_opening_penalty, + const uint8_t& gap_extending_penalty) { + if (translation_matrix_) return false; + + SetAllDefault(); + + match_score_ = match_score; + mismatch_penalty_ = mismatch_penalty; + gap_opening_penalty_ = gap_opening_penalty; + gap_extending_penalty_ = gap_extending_penalty; + + BuildDefaultMatrix(); + + return true; +} + +bool Aligner::ReBuild( + const int8_t* score_matrix, + const int& score_matrix_size, + const int8_t* translation_matrix, + const int& translation_matrix_size) { + + ClearMatrices(); + score_matrix_ = new int8_t[score_matrix_size_ * score_matrix_size_]; + memcpy(score_matrix_, score_matrix, sizeof(int8_t) * score_matrix_size_ * score_matrix_size_); + translation_matrix_ = new int8_t[translation_matrix_size]; + memcpy(translation_matrix_, translation_matrix, sizeof(int8_t) * translation_matrix_size); + + return true; +} + +void Aligner::BuildDefaultMatrix(void) { + ClearMatrices(); + score_matrix_ = new int8_t[score_matrix_size_ * score_matrix_size_]; + BuildSwScoreMatrix(match_score_, mismatch_penalty_, score_matrix_); + translation_matrix_ = new int8_t[SizeOfArray(kBaseTranslation)]; + memcpy(translation_matrix_, kBaseTranslation, sizeof(int8_t) * SizeOfArray(kBaseTranslation)); +} + +void Aligner::ClearMatrices(void) { + delete [] score_matrix_; + score_matrix_ = NULL; + + delete [] translation_matrix_; + translation_matrix_ = NULL; +} +} // namespace StripedSmithWaterman diff --git a/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.h b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.h new file mode 100644 index 0000000..cdcf717 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/realign/ssw_cpp.h @@ -0,0 +1,219 @@ +#ifndef COMPLETE_STRIPED_SMITH_WATERMAN_CPP_H_ +#define COMPLETE_STRIPED_SMITH_WATERMAN_CPP_H_ + +#include +#include +#include + +namespace StripedSmithWaterman { + +struct Alignment { + uint16_t sw_score; // The best alignment score + uint16_t sw_score_next_best; // The next best alignment score + int32_t ref_begin; // Reference begin position of the best alignment + int32_t ref_end; // Reference end position of the best alignment + int32_t query_begin; // Query begin position of the best alignment + int32_t query_end; // Query end position of the best alignment + int32_t ref_end_next_best; // Reference end position of the next best alignment + int32_t mismatches; // Number of mismatches of the alignment + std::string cigar_string; // Cigar string of the best alignment + std::vector cigar; // Cigar stored in the BAM format + // high 28 bits: length + // low 4 bits: M/I/D/S/X (0/1/2/4/8); + void Clear() { + sw_score = 0; + sw_score_next_best = 0; + ref_begin = 0; + ref_end = 0; + query_begin = 0; + query_end = 0; + ref_end_next_best = 0; + mismatches = 0; + cigar_string.clear(); + cigar.clear(); + }; +}; + +struct Filter { + // NOTE: No matter the filter, those five fields of Alignment will be given anyway. + // sw_score; sw_score_next_best; ref_end; query_end; ref_end_next_best. + // NOTE: Only need score of alignments, please set 'report_begin_position' + // and 'report_cigar' false. + + bool report_begin_position; // Give ref_begin and query_begin. + // If it is not set, ref_begin and query_begin are -1. + bool report_cigar; // Give cigar_string and cigar. + // report_begin_position is automatically TRUE. + + // When *report_cigar* is true and alignment passes these two filters, + // cigar_string and cigar will be given. + uint16_t score_filter; // score >= score_filter + uint16_t distance_filter; // ((ref_end - ref_begin) < distance_filter) && + // ((query_end - read_begin) < distance_filter) + + Filter() + : report_begin_position(true) + , report_cigar(true) + , score_filter(0) + , distance_filter(32767) + {}; + + Filter(const bool& pos, const bool& cigar, const uint16_t& score, const uint16_t& dis) + : report_begin_position(pos) + , report_cigar(cigar) + , score_filter(score) + , distance_filter(dis) + {}; +}; + +class Aligner { + public: + // ========= + // @function Construct an Aligner on default values. + // The function will build the {A.C,G,T,N} aligner. + // If you target for other character aligners, then please + // use the other constructor and pass the corresponding matrix in. + // ========= + Aligner(void); + + // ========= + // @function Construct an Aligner by assigning scores. + // The function will build the {A.C,G,T,N} aligner. + // If you target for other character aligners, then please + // use the other constructor and pass the corresponding matrix in. + // ========= + Aligner(const uint8_t& match_score, + const uint8_t& mismatch_penalty, + const uint8_t& gap_opening_penalty, + const uint8_t& gap_extending_penalty); + + // ========= + // @function Construct an Aligner by the specific matrixs. + // ========= + Aligner(const int8_t* score_matrix, + const int& score_matrix_size, + const int8_t* translation_matrix, + const int& translation_matrix_size); + + ~Aligner(void); + + // ========= + // @function Build the reference sequence and thus make + // Align(const char* query, s_align* alignment) function; + // otherwise the reference should be given when aligning. + // [NOTICE] If there exists a sequence, that one will be deleted + // and replaced. + // @param seq The reference bases; + // [NOTICE] It is not necessary null terminated. + // @param length The length of bases will be be built. + // @return The length of the built bases. + // ========= + int SetReferenceSequence(const char* seq, const int& length); + + void CleanReferenceSequence(void); + + // ========= + // @function Set penalties for opening and extending gaps + // [NOTICE] The defaults are 3 and 1 respectively. + // ========= + void SetGapPenalty(const uint8_t& opening, const uint8_t& extending) { + gap_opening_penalty_ = opening; + gap_extending_penalty_ = extending; + }; + + // ========= + // @function Align the query againt the reference that is set by + // SetReferenceSequence. + // @param query The query sequence. + // @param filter The filter for the alignment. + // @param alignment The container contains the result. + // @return True: succeed; false: fail. + // ========= + bool Align(const char* query, const Filter& filter, Alignment* alignment) const; + + // ========= + // @function Align the query againt the reference. + // [NOTICE] The reference won't replace the reference + // set by SetReferenceSequence. + // @param query The query sequence. + // @param ref The reference sequence. + // [NOTICE] It is not necessary null terminated. + // @param ref_len The length of the reference sequence. + // @param filter The filter for the alignment. + // @param alignment The container contains the result. + // @return True: succeed; false: fail. + // ========= + bool Align(const char* query, const char* ref, const int& ref_len, + const Filter& filter, Alignment* alignment) const; + + // @function Clear up all containers and thus the aligner is disabled. + // To rebuild the aligner please use Build functions. + void Clear(void); + + // ========= + // @function Rebuild the aligner's ability on default values. + // [NOTICE] If the aligner is not cleaned, rebuilding will fail. + // @return True: succeed; false: fail. + // ========= + bool ReBuild(void); + + // ========= + // @function Rebuild the aligner's ability by the specific matrixs. + // [NOTICE] If the aligner is not cleaned, rebuilding will fail. + // @return True: succeed; false: fail. + // ========= + bool ReBuild( + const uint8_t& match_score, + const uint8_t& mismatch_penalty, + const uint8_t& gap_opening_penalty, + const uint8_t& gap_extending_penalty); + + // ========= + // @function Construct an Aligner by the specific matrixs. + // [NOTICE] If the aligner is not cleaned, rebuilding will fail. + // @return True: succeed; false: fail. + // ========= + bool ReBuild( + const int8_t* score_matrix, + const int& score_matrix_size, + const int8_t* translation_matrix, + const int& translation_matrix_size); + + private: + int8_t* score_matrix_; + int score_matrix_size_; + int8_t* translation_matrix_; + + uint8_t match_score_; // default: 2 + uint8_t mismatch_penalty_; // default: 2 + uint8_t gap_opening_penalty_; // default: 3 + uint8_t gap_extending_penalty_; // default: 1 + + int8_t* translated_reference_; + int32_t reference_length_; + + int TranslateBase(const char* bases, const int& length, int8_t* translated) const; + void SetAllDefault(void); + void BuildDefaultMatrix(void); + void ClearMatrices(void); + + Aligner& operator= (const Aligner&); + Aligner (const Aligner&); +}; // class Aligner + + +// ================ +// inline functions +// ================ +inline void Aligner::CleanReferenceSequence(void) { + if (reference_length_ == 0) return; + + // delete the current buffer + if (reference_length_ > 1) delete [] translated_reference_; + else delete translated_reference_; + + reference_length_ = 0; +} +} // namespace StripedSmithWaterman + +#endif // COMPLETE_STRIPED_SMITH_WATERMAN_CPP_H_ diff --git a/benchmarks/nn-variant/Clair3/preprocess/utils.py b/benchmarks/nn-variant/Clair3/preprocess/utils.py new file mode 100644 index 0000000..a980776 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/preprocess/utils.py @@ -0,0 +1,759 @@ +from cffi import FFI +import logging +import os +import sys +import re +import subprocess +import shlex +logging.getLogger().setLevel(logging.INFO) +from shared.utils import file_path_from, subprocess_popen + +use_mpmath = True +try: + import mpmath as math +except ImportError: + import math + use_mpmath = False + +LOG_10 = 2.3025 +LOG_2 = 0.3010 + +# compress intermediate gvcf row using lz4, issue:https://github.com/HKU-BAL/Clair3/issues/48 +lz4_path = subprocess.run("which lz4", stdout=subprocess.PIPE, shell=True).stdout.decode().rstrip() +COMPRESS_GVCF = True if lz4_path != "" else False +LZ4_COMPRESS = "lz4 -c" +LZ4_DECOMPRESS = "lz4 -fdc" +GVCF_SUFFIX = ".tmp.gvcf" + +class compressReaderWriter(object): + def __init__(self, input_path=None, output_path=None, compress=False): + self.input_path = input_path + self.output_path = output_path + self.compress = compress + self.read_proc = None + self.reader = None + + self.writer = None + self.write_proc = None + self.write_fpo = None + + def read_input(self): + if self.compress: + self.read_proc = subprocess_popen(shlex.split("{} {}".format(LZ4_DECOMPRESS, self.input_path)), stderr=subprocess.DEVNULL) + a = subprocess_popen(shlex.split("{} {}".format(LZ4_DECOMPRESS, self.input_path)), stderr=subprocess.DEVNULL) + streamdata = a.communicate()[0] + rc = a.returncode + self.reader = self.read_proc.stdout + else: + self.reader = open(self.input_path, 'r') + return self.reader + + def close_reader(self): + if self.compress: + self.read_proc.stdout.close() + self.read_proc.wait() + else: + self.reader.close() + + def write_output(self): + if self.compress: + self.write_fpo = open(self.output_path, 'w') + self.write_proc = subprocess_popen(shlex.split(LZ4_COMPRESS), stdin=subprocess.PIPE, stdout=self.write_fpo, stderr=subprocess.DEVNULL) + self.writer = self.write_proc.stdin + + else: + self.writer = open(self.output_path, 'w') + return self.writer + + def close_writer(self): + if self.compress: + self.write_proc.stdin.close() + self.write_proc.wait() + self.write_fpo.close() + else: + self.writer.close() + + +class gvcfGenerator(object): + + def __init__(self, ref_path, samtools='samtools'): + + self.reference_file_path = ref_path + self.samtools = samtools + pass + + def readCalls(self, callPath, callType='variant', ctgName=None, ctgStart=None, ctgEnd=None, add_header=False, writer=None): + + CR = compressReaderWriter(input_path=callPath, compress=COMPRESS_GVCF) + reader = CR.read_input() + need_write_header = True + header = [] + for line in reader: + if (line.startswith('#')): + if add_header and line not in header: + header.append(line) + + continue + if add_header and len(header) and need_write_header: + print(''.join(header).rstrip(), file=writer) + need_write_header = False + if (callType == 'non-variant'): + cur_non_variant_start = int(line.strip('\n').split('\t')[1]) + cur_non_variant_end = int(re.search(r'.*END=(.*)\tGT.*', line).group(1)) + cur_non_variant_chr = line.strip('\n').split('\t')[0] + if ((ctgName and cur_non_variant_chr == ctgName) or (not ctgName)): + if ((ctgStart and cur_non_variant_start >= ctgStart) or (not ctgStart)): + if ((ctgEnd and cur_non_variant_end <= ctgEnd) or (not ctgEnd)): + yield line.strip('\n'), cur_non_variant_start, cur_non_variant_end, 'original' + else: + # for variant calls, return "pos" + # DEL and INS should be considered here + tmp = line.strip('\n').split('\t') + ref = tmp[3] + alt = tmp[4] + n_alt = len(alt.split(',')) + cur_variant_start = int(line.strip('\n').split('\t')[1]) + cur_variant_end = cur_variant_start - 1 + len(ref) + is_reference_call = (alt == '.') or (ref == alt) + if not is_reference_call: + # assuming AD is at the columns [-3], add 0 to AD for gVCF + ori_info = tmp[-1].split(':') + ori_info[-3] += ',0' + tmp[-1] = ':'.join(ori_info) + + # assumeing PL is at the last column + # add to variant calls + tmp[4] = tmp[4] + ',' + if (n_alt == 1): + + tmp[-1] = tmp[-1] + ',990,990,990' + + elif (n_alt == 2): + tmp[-1] = tmp[-1] + ',990,990,990,990' + else: + # skip reference calls + continue + new_line = '\t'.join(tmp) + + cur_variant_chr = tmp[0] + + + if ((ctgName and cur_variant_chr == ctgName) or (not ctgName)): + if ((ctgStart and cur_variant_start >= ctgStart) or (not ctgStart)): + if ((ctgEnd and cur_variant_end <= ctgEnd) or (not ctgEnd)): + yield new_line, cur_variant_start, cur_variant_end + + CR.close_reader() + + def readReferenceBaseAtPos(self, pos): + + cmd = self.samtools + ' faidx ' + self.reference_file_path + ' ' + pos + + reader = os.popen(cmd) + for line in reader: + if (line.startswith('>')): + continue + else: + ref_base = line.strip('\n').upper() + reader.close() + return ref_base + + def _writeRightBlock(self, block_new_start, curNonVarEnd, curNonVarCall, save_writer): + + pos_cmd = str(curNonVarCall.split('\t')[0]) + ':' + str(block_new_start) + '-' + str(block_new_start) + new_ref = self.readReferenceBaseAtPos(pos_cmd) + tmp = curNonVarCall.split('\t') + tmp[1] = str(block_new_start) + tmp[3] = str(new_ref) + print('\t'.join(tmp), file=save_writer) + + def _writeLeftBlock(self, end_pos, curNonVarCall, save_writer): + + new_left_block = re.sub("END=[0-9]*\t", "END=" + str(end_pos) + '\t', curNonVarCall) + print(new_left_block, file=save_writer) + pass + + def writeNonVarBlock(self, start, end, pos_flag, curNonVarCall, save_writer): + + if (pos_flag == 'left'): + self._writeLeftBlock(end, curNonVarCall, save_writer) + elif (pos_flag == 'right'): + self._writeRightBlock(start, end, curNonVarCall, save_writer) + else: + print(curNonVarCall, file=save_writer) + def mergeCalls(self, variantCallPath, nonVarCallPath, savePath, sampleName, ctgName=None, ctgStart=None, + ctgEnd=None): + + ''' + merge calls between variant and non-variant + ''' + + varCallStop = False + nonVarCallStop = False + + #output writer + CW = compressReaderWriter(output_path=savePath, compress=COMPRESS_GVCF) + save_writer = CW.write_output() + + varCallGenerator = self.readCalls(variantCallPath, 'variant', ctgName, ctgStart, ctgEnd) + nonVarCallGenerator = self.readCalls(nonVarCallPath, 'non-variant', ctgName, ctgStart, ctgEnd, add_header=True, writer=save_writer) + hasVar = True + # in case of empty file + try: + curVarCall, curVarStart, curVarEnd = next(varCallGenerator) + except StopIteration: + varCallStop = True + hasVar = False + try: + curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos = next(nonVarCallGenerator) + except StopIteration: + nonVarCallStop = True + + while True and (not varCallStop) and (not nonVarCallStop): + if (curNonVarEnd < curVarStart): + + ''' + |____| {____} + nonVar Var + nonVar region is on the left, no overlapped region + ''' + # print(curNonVarCall,file=save_writer) + self.writeNonVarBlock(curNonVarStart, curNonVarEnd, curNonVarPos, curNonVarCall, save_writer) + # move non variant calls to the next + try: + curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos = next(nonVarCallGenerator) + except StopIteration: + nonVarCallStop = True + break + elif (curVarEnd < curNonVarStart): + ''' + {____} |____| + Var nonVar + var region is on the left, no overlapped region + ''' + # print("{_____} |_____|") + print(curVarCall, file=save_writer) + try: + curVarCall, curVarStart, curVarEnd = next(varCallGenerator) + except StopIteration: + varCallStop = True + break + + elif (curVarStart <= curNonVarStart and curVarEnd >= curNonVarStart): + ''' + {____|____}___| + or + {____|_______|____} + the left point of nonvar block can be included + var region is on the left, has overlapped region + ''' + # write the current variant Call + print(curVarCall, file=save_writer) + block_new_start = curVarEnd + 1 + try: + curVarCall, curVarStart, curVarEnd = next(varCallGenerator) + except StopIteration: + varCallStop = True + break + + while (block_new_start > curNonVarEnd): + # skip the non-variant block within the current variant block + + try: + + curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos = next(nonVarCallGenerator) + except StopIteration: + nonVarCallStop = True + break + + if (nonVarCallStop): + break + + # check if the start of the current non-variant block + if ((block_new_start - 1) >= curNonVarStart): + # there is overlap between variants and non-variant block + # just write the right part of the non-variant block + curNonVarStart = block_new_start + curNonVarPos = 'right' + + elif (curVarStart > curNonVarStart): + ''' + |_{__________________}__| + or + |__{______________|____} + ''' + # var call is within the non-var block + # split the non-var block + non_var_block_left_end = curVarStart - 1 + if (non_var_block_left_end >= curNonVarStart): + self._writeLeftBlock(non_var_block_left_end, curNonVarCall, save_writer) + # print out variant call + print(curVarCall, file=save_writer) + # take care here, whether write the left right non-variant block, + # it dependes on the position of the next variant calls + non_var_block_right_start = curVarEnd + 1 + + try: + curVarCall, curVarStart, curVarEnd = next(varCallGenerator) + except StopIteration: + varCallStop = True + break + + # still has the right left block + if (non_var_block_right_start <= curNonVarEnd): + curNonVarStart = non_var_block_right_start + curNonVarPos = 'right' + else: + # get the next non-variant block,skip the non-var block that is within the variant + while True: + try: + curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos = next(nonVarCallGenerator) + except StopIteration: + nonVarCallStop = True + break + if (non_var_block_right_start <= curNonVarEnd): + break + + if (nonVarCallStop): + break + + curNonVarStart = non_var_block_right_start + curNonVarPos = 'right' + + else: + print("[ERROR] CurVarStart", curVarStart, 'curVarEnd', curVarEnd, 'curNonVarStart', curNonVarStart, + 'curNonVarEnd', curNonVarEnd) + + # printout the remain content + if (not varCallStop): + # print out the left + + print(curVarCall, file=save_writer) + for curVarCall, curVarStart, curVarEnd in varCallGenerator: + print(curVarCall, file=save_writer) + if (not nonVarCallStop): + if (hasVar and curNonVarEnd > curVarEnd): + self.writeNonVarBlock(curVarEnd + 1, curNonVarEnd, curNonVarPos, curNonVarCall, save_writer) + for curNonVarCall, curNonVarStart, curNonVarEnd, curNonVarPos in nonVarCallGenerator: + print(curNonVarCall, file=save_writer) + + CW.close_writer() + + +class variantInfoCalculator(object): + + def __init__(self, gvcfWritePath, ref_path, p_err, gq_bin_size, ctgName, bp_resolution=False, sample_name='None', mode='L'): + + # default p_error is 0.001, while it could be set by the users' option + self.p_error = p_err + self.LOG_10 = LOG_10 + self.logp = math.log(self.p_error) / self.LOG_10 + self.log1p = math.log1p(-self.p_error) / self.LOG_10 + self.LOG_2 = LOG_2 + # need to check with the clair3 settings + #self.max_gq = 255 + self.max_gq = 50 + self.variantMath = mathcalculator() + self.constant_log10_probs = self.variantMath.normalize_log10_prob([-1.0, -1.0, -1.0]) + self.gq_bin_size = gq_bin_size + self.CW = None + # set by the users + if (gvcfWritePath != "PIPE"): + if (not os.path.exists(gvcfWritePath)): + os.mkdir(gvcfWritePath) + + self.CW = compressReaderWriter(output_path=os.path.join(gvcfWritePath, sample_name + GVCF_SUFFIX), compress=COMPRESS_GVCF) + self.vcf_writer = self.CW.write_output() + else: + self.vcf_writer = sys.stdout + self.writePath = gvcfWritePath + self.sampleName = sample_name.split('.')[0] + self.bp_resolution = bp_resolution + self.reference_file_path = ref_path + + if (mode == 'L'): + # dictionary to store constant log values for speeding up + self.normalized_prob_pool = {} + + self.current_block = [] + self._print_vcf_header() + self.cur_gq_bin_index = None + self.cur_gt = None + self.cur_min_DP = None + self.cur_max_DP = None + self.cur_chr = None + self.cur_raw_gq = None + pass + def write_empty_pileup(self,ctgName,ctgStart,ctgEnd): + + non_variant_info = {"validPL": False, "gq": 1, "binned_gq": 1, "pl": [0,0,0], + "chr": ctgName, 'pos': max(1,ctgStart), 'ref': 'N', + "gt": './.', 'min_dp': 0, 'END': ctgEnd} + self.write_to_gvcf(non_variant_info) + def make_gvcf_online(self, variant_summary, push_current=False): + + ''' + + make gvcf while reading from pileup + ''' + + if (push_current): + if (len(self.current_block) > 0): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block = [] + self.cur_gq_bin_index = None + self.cur_gt = None + self.cur_min_DP = None + self.cur_max_DP = None + self.cur_chr = None + self.cur_raw_gq = None + return + + cur_item = self.reference_likelihood(variant_summary) + _gq_bin = cur_item['binned_gq'] + _gt = cur_item["gt"] + _DP = cur_item["min_dp"] + _chr = cur_item['chr'] + _raw_gq = cur_item['gq'] + _cur_ref = variant_summary['ref'] + + if (self.cur_gq_bin_index == None): + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + self.cur_ref = _cur_ref + + elif (_gq_bin != self.cur_gq_bin_index): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + self.cur_ref = _cur_ref + + elif (_gt != self.cur_gt): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + self.cur_ref = _cur_ref + + elif (_chr != self.cur_chr): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + self.cur_ref = _cur_ref + elif( (_cur_ref != self.cur_ref) and ((_cur_ref=='N') or (self.cur_ref=='N'))): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + self.cur_ref = _cur_ref + else: + ''' + # do not consider DP + if(_DP < self.cur_min_DP): + self.cur_min_DP = _DP + if(_raw_gq < self.cur_raw_gq): + self.cur_raw_gq = _raw_gq + self.current_block.append(cur_item) + ''' + + if (_DP < self.cur_min_DP): + tmp_cur_min_DP = _DP + #if (self.cur_max_DP > math.ceil((tmp_cur_min_DP + min(3, tmp_cur_min_DP * 0.3)))): + if (self.cur_max_DP > math.ceil(tmp_cur_min_DP + tmp_cur_min_DP * 0.3)): + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + else: + self.cur_min_DP = tmp_cur_min_DP + if (_raw_gq < self.cur_raw_gq): + self.cur_raw_gq = _raw_gq + self.current_block.append(cur_item) + elif (_DP > self.cur_max_DP): + #if (_DP <= math.ceil(self.cur_min_DP + min(3, self.cur_min_DP * 0.3))): + if (_DP <= math.ceil(self.cur_min_DP + self.cur_min_DP * 0.3)): + self.cur_max_DP = _DP + if (_raw_gq < self.cur_raw_gq): + self.cur_raw_gq = _raw_gq + self.current_block.append(cur_item) + else: + self.write_to_gvcf_batch(self.current_block, self.cur_min_DP, self.cur_raw_gq) + self.current_block, self.cur_gq_bin_index, self.cur_gt, self.cur_min_DP, self.cur_max_DP, self.cur_chr, self.cur_raw_gq = ( + [cur_item], _gq_bin, _gt, _DP, _DP, _chr, _raw_gq) + else: + if (_raw_gq < self.cur_raw_gq): + self.cur_raw_gq = _raw_gq + self.current_block.append(cur_item) + + def reference_likelihood(self, variant_summary): + + ''' + for non-variant sites, this function is calculate the GQ,QUAL,PL,etc for the genotype 0/0 or ./. + ''' + + n_ref = variant_summary["n_ref"] + n_total = variant_summary['n_total'] + + + validPL, gq, binned_gq, log10_probs = self._cal_reference_likelihood(n_ref, n_total) + if (validPL): + gt = '0/0' + else: + gt = './.' + _tmp_phred_probs = [-10 * x for x in log10_probs] + min_phred_probs = min(_tmp_phred_probs) + + phred_probs = [int(x - min_phred_probs) for x in _tmp_phred_probs] + + if(variant_summary['ref'] not in ['A','T','C','G']): + tmp_ref = 'N' + gq = 1 + binned_gq = 1 + phred_probs = [0,0,0] + else: + tmp_ref = variant_summary['ref'] + non_variant_info = {"validPL": validPL, "gq": gq, "binned_gq": binned_gq, "pl": phred_probs, + "chr": variant_summary['chr'], 'pos': variant_summary['pos'], 'ref': tmp_ref, + "gt": gt, 'min_dp': variant_summary['n_total'], 'END': variant_summary['pos']} + + return non_variant_info + pass + + def _cal_reference_likelihood(self, n_ref, n_total): + + ''' + calculate the phred genotype likelihood for a single non-variant site. + n_ref: number of referece bases + n_total: number of all bases by ignoring Ns + P(hom_ref) = (1-prr)^n_ref*prr^(n_total-n_ref) + P(Het_alt) = (1/2)^n_total + P(hom_alt) = prr^n_ref*(1-prr)^(n_total-n_ref) + return flag of validPL, raw GQ, binned GQ, PLs + ''' + + validPL = True + if (n_total == 0): + # when the coverage is 0 + log10_probs = self.constant_log10_probs + + pass + else: + + n_alts = n_total - n_ref + + log10_p_ref = n_ref * self.log1p + n_alts * self.logp + + log10_p_het = -n_total * self.LOG_2 + log10_p_hom_alt = n_ref * self.logp + n_alts * self.log1p + + # normalization + + log10_probs = self.variantMath.normalize_log10_prob([log10_p_ref, log10_p_het, log10_p_hom_alt]) + + + + gq = self.variantMath.log10p_to_phred(log10_probs[0]) + + gq = int(min(int(gq), self.max_gq)) + if (gq >= 1): + binned_index = (gq - 1) // self.gq_bin_size + binned_gq = binned_index * self.gq_bin_size + 1 + else: + binned_gq = 0 + + + validPL = log10_probs[0] == max(log10_probs) + return validPL, gq, binned_gq, log10_probs + + def _print_vcf_header(self): + + from textwrap import dedent + print(dedent("""\ + ##fileformat=VCFv4.2 + ##FILTER= + ##FILTER= + ##FILTER= + ##INFO= + ##INFO= + ##ALT= + ##INFO= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT= + ##FORMAT="""), file + =self.vcf_writer) + if self.reference_file_path is not None: + reference_index_file_path = file_path_from(self.reference_file_path, suffix=".fai", exit_on_not_found=True, sep='.') + with open(reference_index_file_path, "r") as fai_fp: + for row in fai_fp: + columns = row.strip().split("\t") + contig_name, contig_size = columns[0], columns[1] + print("##contig=" % (contig_name, contig_size), file=self.vcf_writer) + + print('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (self.sampleName), file=self.vcf_writer) + + pass + + + def write_to_gvcf_batch(self, block, block_min_dp, block_min_raw_gq): + + if ((self.bp_resolution or block[0]['gt'] == "./.") and block[0]['ref']!='N'): + # write it to VCF + for item in block: + self.write_to_gvcf(item) + + else: + start_pos = block[0]['pos'] + end_pos = block[-1]['pos'] + first_PL = block[0]['pl'] + first_gq = block[0]['gq'] + first_binned_gq = block[0]['binned_gq'] + first_gt = block[0]['gt'] + first_ref = block[0]['ref'] + first_chr = block[0]['chr'] + if(first_ref=='N'): + # special case for reference N + non_variant_info = {"gq": 1, "binned_gq": 1, "pl": [0,0,0], "chr": first_chr, + 'pos': start_pos, 'ref': first_ref, "gt": "./.", 'min_dp': block_min_dp, + 'END': end_pos} + + # write binned_gq + # non_variant_info = { "gq":first_gq, "binned_gq":first_binned_gq, "pl":first_PL,"chr":first_chr,'pos':start_pos,'ref':first_ref,"gt":first_gt,'min_dp':block_min_dp,'END':end_pos} + # write min raw gq + else: + non_variant_info = {"gq": first_gq, "binned_gq": block_min_raw_gq, "pl": first_PL, "chr": first_chr, + 'pos': start_pos, 'ref': first_ref, "gt": first_gt, 'min_dp': block_min_dp, + 'END': end_pos} + self.write_to_gvcf(non_variant_info) + + def write_to_gvcf(self, variant_info): + + ''' + write a temporary file gvcf. This file is needed to be merged with model variant calls. + ''' + + _tmpLine = str(variant_info["chr"]) + '\t' + str(variant_info["pos"]) + "\t.\t" + variant_info[ + 'ref'] + '\t\t0\t.\tEND=' + str(variant_info['END']) + '\tGT:GQ:MIN_DP:PL\t' + variant_info[ + 'gt'] + ':' + str(variant_info['binned_gq']) + ':' + str(variant_info['min_dp']) + ':' + str( + variant_info['pl'][0]) + ',' + str(variant_info['pl'][1]) + ',' + str(variant_info['pl'][2]) + print(_tmpLine, file=self.vcf_writer) + + + def close_vcf_writer(self): + self.CW.close_writer() + +class mathcalculator(object): + + + def __init__(self,speedUp=True): + + self.LOG_10 = LOG_10 + self.maxPhredScore = 255 + self.speedUp = speedUp + if(speedUp): + try: + self._creatCFFIFunc() + except: + self.speedUp = False + pass + + + + def _creatCFFIFunc(self): + self.ffi = FFI() + self.ffi.cdef(""" + double log10p_to_phred(double log10p); + double log10sumexp(double log10_array[],int n_array); + double getMyMaxItem(double list[],int n_list); + """) + self.lib = self.ffi.verify(""" + #include + #include + double LOG_10 = 2.3025; + double log10p_to_phred(double log10p){ + double ptrue; + ptrue = pow(10,log10p); + if(ptrue==1){ + + return 50; + } + return -10*(log(1.0-ptrue)/LOG_10); + } + double f_log10(double myInput){ + double res; + res=log(myInput)/LOG_10; + return res; + } + double getMyMaxItem(double list[],int n_list){ + double curMax; + int i; + curMax = list[0]; + for(i=1;i<=n_list;i++){ + if(list[i]>curMax){ + curMax= list[i]; + } + } + return curMax; + } + double log10sumexp(double log10_array[],int n_array){ + double m, mySum,tmp; + int i; + + m = getMyMaxItem(log10_array,n_array); + mySum = 0.0; + for(i=0;i$qual will be marked PASS, or LowQual otherwise.' + echo $' --samtools=STR Path of samtools, samtools version >= 1.10 is required.' + echo $' --python=STR Path of python, python3 >= 3.6 is required.' + echo $' --pypy=STR Path of pypy3, pypy3 >= 3.6 is required.' + echo $' --parallel=STR Path of parallel, parallel >= 20191122 is required.' + echo $' --whatshap=STR Path of whatshap, whatshap >= 1.0 is required.' + echo $' --chunk_size=INT The size of each chuck for parallel processing, default: 5000000.' + echo $' --pileup_only Use the pileup model only when calling, default: disable.' + echo $' --print_ref_calls Show reference calls (0/0) in VCF file, default: disable.' + echo $' --include_all_ctgs Call variants on all contigs, otherwise call in chr{1..22,X,Y} and {1..22,X,Y}, default: disable.' + echo $' --gvcf Enable GVCF output, default: disable.' + echo $' --enable_phasing Output phased variants using whatshap, default: disable.' + echo $' --remove_intermediate_dir Remove intermediate directory, including intermediate phased BAM, pileup and full-alignment results. default: disable.' + echo $' --snp_min_af=FLOAT Minimum SNP AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.08,hifi:0.08,ilmn:0.08.' + echo $' --indel_min_af=FLOAT Minimum Indel AF required for a candidate variant. Lowering the value might increase a bit of sensitivity in trade of speed and accuracy, default: ont:0.15,hifi:0.08,ilmn:0.08.' + echo $' --var_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/1 and 1/1 variants called in the pileup mode for full-alignment mode calling, default: 0.3.' + echo $' --ref_pct_full=FLOAT EXPERIMENTAL: Specify an expected percentage of low quality 0/0 variants called in the pileup mode for full-alignment mode calling, default: 0.3 for ilmn and hifi, 0.1 for ont.' + echo $' --var_pct_phasing=FLOAT EXPERIMENTAL: Specify an expected percentage of high quality 0/1 variants used in WhatsHap phasing, default: 0.8 for ont guppy5 and 0.7 for other platforms.' + echo $' --pileup_model_prefix=STR EXPERIMENTAL: Model prefix in pileup calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index. default: pileup.' + echo $' --fa_model_prefix=STR EXPERIMENTAL: Model prefix in full-alignment calling, including $prefix.data-00000-of-00002, $prefix.data-00001-of-00002 $prefix.index, default: full_alignment.' + echo $' --fast_mode EXPERIMENTAL: Skip variant candidates with AF <= 0.15, default: disable.' + echo $' --haploid_precise EXPERIMENTAL: Enable haploid calling mode. Only 1/1 is considered as a variant, default: disable.' + echo $' --haploid_sensitive EXPERIMENTAL: Enable haploid calling mode. 0/1 and 1/1 are considered as a variant, default: disable.' + echo $' --no_phasing_for_fa EXPERIMENTAL: Call variants without whatshap phasing in full alignment calling, default: disable.' + echo $' --call_snp_only EXPERIMENTAL: Call candidates pass SNP minimum AF only, ignore Indel candidates, default: disable.' + echo $' --enable_long_indel EXPERIMENTAL: Call long Indel variants(>50 bp), default: disable.' + echo $'' +} + +print_version() +{ + echo "Clair3 ${VERSION}" + exit 0 +} + +ERROR="\\033[31m[ERROR]" +WARNING="\\033[33m[WARNING]" +NC="\\033[0m" + +ARGS=`getopt -o b:f:t:m:p:o:hv \ +-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ +bed_fn::,vcf_fn::,ctg_name::,sample_name::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,ref_pct_full::,var_pct_phasing::,\ +snp_min_af::,indel_min_af::,pileup_model_prefix::,fa_model_prefix::,fast_mode,gvcf,pileup_only,print_ref_calls,haploid_precise,haploid_sensitive,include_all_ctgs,\ +remove_intermediate_dir,no_phasing_for_fa,call_snp_only,enable_phasing,enable_long_indel,help,version -n 'run_clair3.sh' -- "$@"` + +if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi +eval set -- "${ARGS}" + +# default options +SAMPLE="SAMPLE" +BED_FILE_PATH="EMPTY" +VCF_FILE_PATH='EMPTY' +CONTIGS="EMPTY" +SAMTOOLS="samtools" +PYPY="pypy3" +PYTHON='python3' +PARALLEL='parallel' +WHATSHAP='whatshap' +CHUNK_NUM=0 +CHUNK_SIZE=5000000 +QUAL=2 +PHASING_PCT="0" +PRO="0" +REF_PRO="0" +GVCF=False +PILEUP_ONLY=False +FAST_MODE=False +SHOW_REF=False +SNP_AF="0" +INDEL_AF="0" +HAP_PRE=False +HAP_SEN=False +SNP_ONLY=False +INCLUDE_ALL_CTGS=False +NO_PHASING=False +RM_TMP_DIR=False +ENABLE_PHASING=False +ENABLE_LONG_INDEL=False +PILEUP_PREFIX="pileup" +FA_PREFIX="full_alignment" + +while true; do + case "$1" in + -b|--bam_fn ) BAM_FILE_PATH="$2"; shift 2 ;; + -f|--ref_fn ) REFERENCE_FILE_PATH="$2"; shift 2 ;; + -t|--threads ) THREADS="$2"; shift 2 ;; + -m|--model_path ) MODEL_PATH="$2"; shift 2 ;; + -p|--platform ) PLATFORM="$2"; shift 2 ;; + -o|--output ) OUTPUT_FOLDER="$2"; shift 2 ;; + --bed_fn ) BED_FILE_PATH="$2"; shift 2 ;; + --vcf_fn ) VCF_FILE_PATH="$2"; shift 2 ;; + --ctg_name ) CONTIGS="$2"; shift 2 ;; + --sample_name ) SAMPLE="$2"; shift 2 ;; + --chunk_num ) CHUNK_NUM="$2"; shift 2 ;; + --chunk_size ) CHUNK_SIZE="$2"; shift 2 ;; + --qual ) QUAL="$2"; shift 2 ;; + --samtools ) SAMTOOLS="$2"; shift 2 ;; + --python ) PYTHON="$2"; shift 2 ;; + --pypy ) PYPY="$2"; shift 2 ;; + --parallel ) PARALLEL="$2"; shift 2 ;; + --whatshap ) WHATSHAP="$2"; shift 2 ;; + --var_pct_full ) PRO="$2"; shift 2 ;; + --ref_pct_full ) REF_PRO="$2"; shift 2 ;; + --var_pct_phasing ) PHASING_PCT="$2"; shift 2 ;; + --snp_min_af ) SNP_AF="$2"; shift 2 ;; + --indel_min_af ) INDEL_AF="$2"; shift 2 ;; + --pileup_model_prefix ) PILEUP_PREFIX="$2"; shift 2 ;; + --fa_model_prefix ) FA_PREFIX="$2"; shift 2 ;; + --gvcf ) GVCF=True; shift 1 ;; + --pileup_only ) PILEUP_ONLY=True; shift 1 ;; + --fast_mode ) FAST_MODE=True; shift 1 ;; + --call_snp_only ) SNP_ONLY=True; shift 1 ;; + --print_ref_calls ) SHOW_REF=True; shift 1 ;; + --haploid_precise ) HAP_PRE=True; shift 1 ;; + --haploid_sensitive ) HAP_SEN=True; shift 1 ;; + --include_all_ctgs ) INCLUDE_ALL_CTGS=True; shift 1 ;; + --no_phasing_for_fa ) NO_PHASING=True; shift 1 ;; + --remove_intermediate_dir ) RM_TMP_DIR=True; shift 1 ;; + --enable_phasing ) ENABLE_PHASING=True; shift 1 ;; + --enable_long_indel ) ENABLE_LONG_INDEL=True; shift 1 ;; + + -- ) shift; break; ;; + -h|--help ) print_help_messages; exit 0 ;; + -v|--version ) print_version; exit 0 ;; + * ) print_help_messages; break ;; + esac +done + +if [ -z ${BAM_FILE_PATH} ] || [ -z ${REFERENCE_FILE_PATH} ] || [ -z ${THREADS} ] || [ -z ${OUTPUT_FOLDER} ] || [ -z ${PLATFORM} ] || [ -z ${MODEL_PATH} ]; then + if [ -z ${BAM_FILE_PATH} ] && [ -z ${REFERENCE_FILE_PATH} ] && [ -z ${THREADS} ] && [ -z ${OUTPUT_FOLDER} ] && [ -z ${PLATFORM} ] && [ -z ${MODEL_PATH} ]; then print_help_messages; exit 0; fi + if [ -z ${BAM_FILE_PATH} ]; then echo -e "${ERROR} Require to define index BAM input by --bam_fn=BAM${NC}"; fi + if [ -z ${REFERENCE_FILE_PATH} ]; then echo -e "${ERROR} Require to define FASTA reference file input by --ref_fn=REF${NC}"; fi + if [ -z ${THREADS} ]; then echo -e "${ERROR} Require to define max threads to be used by --threads=THREADS${NC}"; fi + if [ -z ${OUTPUT_FOLDER} ]; then echo -e "${ERROR} Require to define output folder by --output=OUTPUT_DIR${NC}"; fi + if [ -z ${PLATFORM} ]; then echo -e "${ERROR} Require to define platform by --platform={ont,hifi,ilmn}${NC}"; fi + if [ -z ${MODEL_PATH} ]; then echo -e "${ERROR} Require to define model path by --model_path=MODEL_PREFIX${NC}"; fi + exit 1; +fi + +# force to use absolute path when in docker or singularity environment +if [ `pwd` = "/opt/bin" ]; then + if [[ ! "${BAM_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --bam_fn=FILE${NC}"; exit 1; fi + if [[ ! "${REFERENCE_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --ref_fn=FILE${NC}"; exit 1; fi + if [[ ! "${MODEL_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --model_path=PATH${NC}"; exit 1; fi + if [[ ! "${OUTPUT_FOLDER}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --output=PATH${NC}"; exit 1; fi + if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [[ ! "${BED_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --bef_fn=FILE${NC}"; exit 1; fi + if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [[ ! "${VCF_FILE_PATH}" = /* ]]; then echo -e "${ERROR} Require to use absolute file path --vcf_fn=FILE${NC}"; exit 1; fi +fi + +# relative path support +if [[ ! "${BAM_FILE_PATH}" = /* ]] && [ -f ${BAM_FILE_PATH} ]; then BAM_FILE_PATH=`pwd`/${BAM_FILE_PATH}; fi +if [[ ! "${REFERENCE_FILE_PATH}" = /* ]] && [ -f ${REFERENCE_FILE_PATH} ]; then REFERENCE_FILE_PATH=`pwd`/${REFERENCE_FILE_PATH}; fi +if [[ ! "${MODEL_PATH}" = /* ]] && [ -d ${MODEL_PATH} ]; then MODEL_PATH=`pwd`/${MODEL_PATH}; fi +if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [[ ! "${BED_FILE_PATH}" = /* ]] && [ -f ${BED_FILE_PATH} ]; then BED_FILE_PATH=`pwd`/${BED_FILE_PATH}; fi +if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [[ ! "${VCF_FILE_PATH}" = /* ]] && [ -f ${VCF_FILE_PATH} ]; then VCF_FILE_PATH=`pwd`/${VCF_FILE_PATH}; fi +if [[ ! "${OUTPUT_FOLDER}" = /* ]]; then echo -e "${WARNING} No absolute output path provided, using current directory as prefix${NC}"; OUTPUT_FOLDER=`pwd`/${OUTPUT_FOLDER}; fi + +mkdir -p ${OUTPUT_FOLDER} +if [ ! -d ${OUTPUT_FOLDER} ]; then echo -e "${ERROR} Cannot create output folder ${OUTPUT_FOLDER}${NC}"; exit 1; fi + +# show default reference proportion 0.3 for ilmn and hifi, 0.1 for ont +if [ "${PLATFORM}" = "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.1; fi +if [ "${PLATFORM}" != "ont" ] && [ "${REF_PRO}" = "0" ]; then REF_PRO=0.3; fi + +# show default variant proportion 0.3 for ilmn and hifi, 0.7 for ont +if [ "${PLATFORM}" = "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.7; fi +if [ "${PLATFORM}" != "ont" ] && [ "${PRO}" = "0" ]; then PRO=0.3; fi + +# show default high quality hete variant proportion for whatshap phasing, 0.8 for ont guppy5 and 0.7 for others +if [ "${PHASING_PCT}" = "0" ]; then PHASING_PCT=0.7; fi +BASE_MODEL=$(basename ${MODEL_PATH}) +if [ "${BASE_MODEL}" = "r941_prom_sup_g5014" ] || [ "${BASE_MODEL}" = "r941_prom_hac_g5014" ] || [ "${BASE_MODEL}" = "ont_guppy5" ]; then PHASING_PCT=0.8; fi + +# remove the last '/' character in directory input +OUTPUT_FOLDER=$(echo ${OUTPUT_FOLDER%*/}) +MODEL_PATH=$(echo ${MODEL_PATH%*/}) + +# optional parameters should use "=" +(time ( +echo "[INFO] CLAIR3 VERSION: ${VERSION}" +echo "[INFO] BAM FILE PATH: ${BAM_FILE_PATH}" +echo "[INFO] REFERENCE FILE PATH: ${REFERENCE_FILE_PATH}" +echo "[INFO] MODEL PATH: ${MODEL_PATH}" +echo "[INFO] OUTPUT FOLDER: ${OUTPUT_FOLDER}" +echo "[INFO] PLATFORM: ${PLATFORM}" +echo "[INFO] THREADS: ${THREADS}" +echo "[INFO] BED FILE PATH: ${BED_FILE_PATH}" +echo "[INFO] VCF FILE PATH: ${VCF_FILE_PATH}" +echo "[INFO] CONTIGS: ${CONTIGS}" +echo "[INFO] CONDA PREFIX: ${CONDA_PREFIX}" +echo "[INFO] SAMTOOLS PATH: ${SAMTOOLS}" +echo "[INFO] PYTHON PATH: ${PYTHON}" +echo "[INFO] PYPY PATH: ${PYPY}" +echo "[INFO] PARALLEL PATH: ${PARALLEL}" +echo "[INFO] WHATSHAP PATH: ${WHATSHAP}" +echo "[INFO] CHUNK SIZE: ${CHUNK_SIZE}" +if [ ${CHUNK_NUM} -gt 0 ]; then echo "[INFO] CHUNK NUM: ${CHUNK_NUM}"; fi +echo "[INFO] FULL ALIGN PROPORTION: ${PRO}" +echo "[INFO] FULL ALIGN REFERENCE PROPORTION: ${REF_PRO}" +echo "[INFO] PHASING PROPORTION: ${PHASING_PCT}" +if [ "${SNP_AF}" != "0" ]; then echo "[INFO] USER DEFINED SNP THRESHOLD: ${SNP_AF}"; fi +if [ "${INDEL_AF}" != "0" ]; then echo "[INFO] USER DEFINED INDEL THRESHOLD: ${INDEL_AF}"; fi +echo "[INFO] ENABLE FILEUP ONLY CALLING: ${PILEUP_ONLY}" +echo "[INFO] ENABLE FAST MODE CALLING: ${FAST_MODE}" +echo "[INFO] ENABLE CALLING SNP CANDIDATES ONLY: ${SNP_ONLY}" +echo "[INFO] ENABLE PRINTING REFERENCE CALLS: ${SHOW_REF}" +echo "[INFO] ENABLE OUTPUT GVCF: ${GVCF}" +echo "[INFO] ENABLE HAPLOID PRECISE MODE: ${HAP_PRE}" +echo "[INFO] ENABLE HAPLOID SENSITIVE MODE: ${HAP_SEN}" +echo "[INFO] ENABLE INCLUDE ALL CTGS CALLING: ${INCLUDE_ALL_CTGS}" +echo "[INFO] ENABLE NO PHASING FOR FULL ALIGNMENT: ${NO_PHASING}" +echo "[INFO] ENABLE REMOVING INTERMEDIATE FILES: ${RM_TMP_DIR}" +echo "[INFO] ENABLE PHASING VCF OUTPUT: ${ENABLE_PHASING}" +echo "[INFO] ENABLE LONG INDEL CALLING: ${ENABLE_LONG_INDEL}" +echo $'' + +# file check +if [ ! -f ${BAM_FILE_PATH} ]; then echo -e "${ERROR} BAM file ${BAM_FILE_PATH} not found${NC}"; exit 1; fi +if [ ! -f ${BAM_FILE_PATH}.bai ] && [ ! -f ${BAM_FILE_PATH%.*}.bai ]; then echo -e "${ERROR} BAM index bai file not found, please use 'samtools index \$BAM' first${NC}"; exit 1; fi +if [ ! -f ${REFERENCE_FILE_PATH} ]; then echo -e "${ERROR} Reference file ${REFERENCE_FILE_PATH} not found${NC}"; exit 1; fi +if [ ! -f ${REFERENCE_FILE_PATH}.fai ] && [ ! -f ${REFERENCE_FILE_PATH%.*}.fai ]; then echo -e "${ERROR} Reference index fai file not found, please use 'samtools faidx \$REF' first${NC}"; exit 1; fi + +if [ "${BED_FILE_PATH}" != "EMPTY" ] && [ ! -z ${BED_FILE_PATH} ] && [ ! -f ${BED_FILE_PATH} ]; then echo -e "${ERROR} BED file ${BED_FILE_PATH} provides but not found${NC}"; exit 1; fi +if [ "${VCF_FILE_PATH}" != "EMPTY" ] && [ ! -z ${VCF_FILE_PATH} ] && [ ! -f ${VCF_FILE_PATH} ]; then echo -e "${ERROR} VCF file ${VCF_FILE_PATH} provides but not found${NC}"; exit 1; fi +if [ ! -d ${MODEL_PATH} ] && [ -z ${CONDA_PREFIX} ]; then echo -e "${ERROR} Conda prefix not found, please activate clair3 conda environment first, model path: ${MODEL_PATH}${NC}"; exit 1; fi +if [ ! -d ${MODEL_PATH} ]; then echo -e "${ERROR} Model path not found${NC}"; exit 1; fi + +# max threads detection +MAX_THREADS=$(nproc) +if [[ ! ${THREADS} =~ ^[\-0-9]+$ ]] || (( ${THREADS} <= 0)); then echo -e "${ERROR} Invalid threads input --threads=INT ${NC}"; exit 1; fi +if [[ ${THREADS} -gt ${MAX_THREADS} ]]; then echo -e "${WARNING} Threads setting exceeds maximum available threads ${MAX_THREADS}, set threads=${MAX_THREADS}${NC}"; THREADS=${MAX_THREADS}; fi + +# max user ulimit threads detection +MAX_ULIMIT_THREADS=`ulimit -u` +if [ ! -z ${MAX_ULIMIT_THREADS} ]; then PER_ULIMIT_THREADS=$((${MAX_ULIMIT_THREADS}/30)); else MAX_ULIMIT_THREADS="unlimited"; PER_ULIMIT_THREADS=${THREADS}; fi +if [[ ${PER_ULIMIT_THREADS} < 1 ]]; then PER_ULIMIT_THREADS=1; fi +if [ "${MAX_ULIMIT_THREADS}" != "unlimited" ] && [[ ${THREADS} -gt ${PER_ULIMIT_THREADS} ]]; then echo -e "${WARNING} Threads setting exceeds maximum ulimit threads ${THREADS} * 30 > ${MAX_ULIMIT_THREADS} (ulimit -u), set threads=${PER_ULIMIT_THREADS}${NC}"; THREADS=${PER_ULIMIT_THREADS}; fi + +# platform check +if [ ! ${PLATFORM} = "ont" ] && [ ! ${PLATFORM} = "hifi" ] && [ ! ${PLATFORM} = "ilmn" ]; then echo -e "${ERROR} Invalid platform input, optional: {ont, hifi, ilmn}${NC}"; exit 1; fi + +# optional parameter detection +if [ -z ${BED_FILE_PATH} ]; then echo -e "${ERROR} Use '--bed_fn=FILE' instead of '--bed_fn FILE' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${VCF_FILE_PATH} ]; then echo -e "${ERROR} Use '--vcf_fn=FILE' instead of '--vcf_fn =FILE' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${CONTIGS} ]; then echo -e "${ERROR} Use '--ctg_name=STR' instead of '--ctg_name STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SAMPLE} ]; then echo -e "${ERROR} Use '--sample_name=STR' instead of '--sample_name STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${QUAL} ]; then echo -e "${ERROR} Use '--qual=INT' instead of '--qual INT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SAMTOOLS} ]; then echo -e "${ERROR} Use '--samtools=STR' instead of '--samtools STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PYTHON} ]; then echo -e "${ERROR} Use '--python=STR' instead of '--python STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PYPY} ]; then echo -e "${ERROR} Use '--pypy=STR' instead of '--pypy STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PARALLEL} ]; then echo -e "${ERROR} Use '--parallel=STR' instead of '--parallel STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${WHATSHAP} ]; then echo -e "${ERROR} Use '--whatshap=STR' instead of '--whatshap STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${CHUNK_SIZE} ]; then echo -e "${ERROR} Use '--chunk_size=INT' instead of '--chunk_size INT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${SNP_AF} ]; then echo -e "${ERROR} Use '--snp_min_af=FLOAT' instead of '--snp_min_af FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${INDEL_AF} ]; then echo -e "${ERROR} Use '--indel_min_af=FLOAT' instead of '--indel_min_af FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PRO} ]; then echo -e "${ERROR} Use '--var_pct_full=FLOAT' instead of '--var_pct_full FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${REF_PRO} ]; then echo -e "${ERROR} Use '--ref_pct_full=FLOAT' instead of '--ref_pct_full FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PHASING_PCT} ]; then echo -e "${ERROR} Use '--var_pct_phasing=FLOAT' instead of '--var_pct_phasing FLOAT' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${PILEUP_PREFIX} ]; then echo -e "${ERROR} Use '--pileup_model_prefix=STR' instead of '--pileup_model_prefix STR' for optional parameters${NC}"; exit 1 ; fi +if [ -z ${FA_PREFIX} ]; then echo -e "${ERROR} Use '--fa_model_prefix=STR' instead of '--fa_model_prefix STR' for optional parameters${NC}"; exit 1 ; fi + +# model prefix detection +if [ ! -f ${MODEL_PATH}/${PILEUP_PREFIX}.index ]; then echo -e "${ERROR} No pileup model found in provided model path and model prefix ${MODEL_PATH}/${PILEUP_PREFIX} ${NC}"; exit 1; fi +if [ ! -f ${MODEL_PATH}/${FA_PREFIX}.index ]; then echo -e "${ERROR} No full-alignment model found in provided model path and model prefix ${MODEL_PATH}/${FA_PREFIX} ${NC}"; exit 1; fi + + +set -x +${SCRIPT_PATH}/scripts/clair3.sh \ + --bam_fn ${BAM_FILE_PATH} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --threads ${THREADS} \ + --model_path ${MODEL_PATH} \ + --platform ${PLATFORM} \ + --output ${OUTPUT_FOLDER} \ + --bed_fn=${BED_FILE_PATH} \ + --vcf_fn=${VCF_FILE_PATH} \ + --ctg_name=${CONTIGS} \ + --sample_name=${SAMPLE} \ + --chunk_num=${CHUNK_NUM} \ + --chunk_size=${CHUNK_SIZE} \ + --samtools=${SAMTOOLS} \ + --python=${PYTHON} \ + --pypy=${PYPY} \ + --parallel=${PARALLEL} \ + --whatshap=${WHATSHAP} \ + --qual=${QUAL} \ + --var_pct_full=${PRO} \ + --ref_pct_full=${REF_PRO} \ + --var_pct_phasing=${PHASING_PCT} \ + --snp_min_af=${SNP_AF} \ + --indel_min_af=${INDEL_AF} \ + --pileup_only=${PILEUP_ONLY} \ + --gvcf=${GVCF} \ + --fast_mode=${FAST_MODE} \ + --call_snp_only=${SNP_ONLY} \ + --print_ref_calls=${SHOW_REF} \ + --haploid_precise=${HAP_PRE} \ + --haploid_sensitive=${HAP_SEN} \ + --include_all_ctgs=${INCLUDE_ALL_CTGS} \ + --no_phasing_for_fa=${NO_PHASING} \ + --pileup_model_prefix=${PILEUP_PREFIX} \ + --fa_model_prefix=${FA_PREFIX} \ + --remove_intermediate_dir=${RM_TMP_DIR} \ + --enable_phasing=${ENABLE_PHASING} \ + --enable_long_indel=${ENABLE_LONG_INDEL} + + +)) |& tee ${OUTPUT_FOLDER}/run_clair3.log diff --git a/benchmarks/nn-variant/Clair3/scripts/__init__.py b/benchmarks/nn-variant/Clair3/scripts/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/nn-variant/Clair3/scripts/clair3.sh b/benchmarks/nn-variant/Clair3/scripts/clair3.sh new file mode 100755 index 0000000..5efc6d0 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/scripts/clair3.sh @@ -0,0 +1,329 @@ +#!/bin/bash +SCRIPT_NAME=$(basename "$0") +Usage="Usage: ./${SCRIPT_NAME} --bam_fn=BAM --ref_fn=REF --output=OUTPUT_DIR --threads=THREADS --platform=PLATFORM --model_path=MODEL_PREFIX [--bed_fn=BED] [options]" +# INFO: whole calling workflow of clair3 + +set -e +ARGS=`getopt -o b:f:t:m:p:o:r::c::s::h::g \ +-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ +bed_fn::,vcf_fn::,ctg_name::,sample_name::,help::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,var_pct_phasing::,\ +snp_min_af::,indel_min_af::,ref_pct_full::,pileup_only::,fast_mode::,gvcf::,print_ref_calls::,haploid_precise::,haploid_sensitive::,include_all_ctgs::,\ +no_phasing_for_fa::,pileup_model_prefix::,fa_model_prefix::,call_snp_only::,remove_intermediate_dir::,enable_phasing::,enable_long_indel:: -n 'run_clair3.sh' -- "$@"` + +if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi +eval set -- "${ARGS}" + +while true; do + case "$1" in + -b|--bam_fn ) BAM_FILE_PATH="$2"; shift 2 ;; + -f|--ref_fn ) REFERENCE_FILE_PATH="$2"; shift 2 ;; + -t|--threads ) THREADS="$2"; shift 2 ;; + -m|--model_path ) MODEL_PATH="$2"; shift 2 ;; + -p|--platform ) PLATFORM="$2"; shift 2 ;; + -o|--output ) OUTPUT_FOLDER="$2"; shift 2 ;; + --bed_fn ) BED_FILE_PATH="$2"; shift 2 ;; + --vcf_fn ) VCF_FILE_PATH="$2"; shift 2 ;; + --ctg_name ) CONTIGS="$2"; shift 2 ;; + --sample_name ) SAMPLE="$2"; shift 2 ;; + --chunk_num ) CHUNK_NUM="$2"; shift 2 ;; + --chunk_size ) CHUNK_SIZE="$2"; shift 2 ;; + --qual ) QUAL="$2"; shift 2 ;; + --samtools ) SAMTOOLS="$2"; shift 2 ;; + --python ) PYTHON="$2"; shift 2 ;; + --pypy ) PYPY="$2"; shift 2 ;; + --parallel ) PARALLEL="$2"; shift 2 ;; + --whatshap ) WHATSHAP="$2"; shift 2 ;; + --var_pct_full ) PRO="$2"; shift 2 ;; + --ref_pct_full ) REF_PRO="$2"; shift 2 ;; + --var_pct_phasing ) PHASING_PCT="$2"; shift 2 ;; + --pileup_only ) PILEUP_ONLY="$2"; shift 2 ;; + --fast_mode ) FAST_MODE="$2"; shift 2 ;; + --call_snp_only ) SNP_ONLY="$2"; shift 2 ;; + --print_ref_calls ) SHOW_REF="$2"; shift 2 ;; + --gvcf ) GVCF="$2"; shift 2 ;; + --snp_min_af ) SNP_AF="$2"; shift 2 ;; + --indel_min_af ) INDEL_AF="$2"; shift 2 ;; + --pileup_model_prefix ) PILEUP_PREFIX="$2"; shift 2 ;; + --fa_model_prefix ) FA_PREFIX="$2"; shift 2 ;; + --haploid_precise ) HAP_PRE="$2"; shift 2 ;; + --haploid_sensitive ) HAP_SEN="$2"; shift 2 ;; + --include_all_ctgs ) INCLUDE_ALL_CTGS="$2"; shift 2 ;; + --no_phasing_for_fa ) NO_PHASING="$2"; shift 2 ;; + --remove_intermediate_dir ) RM_TMP_DIR="$2"; shift 2 ;; + --enable_phasing ) ENABLE_PHASING="$2"; shift 2 ;; + --enable_long_indel ) ENABLE_LONG_INDEL="$2"; shift 2 ;; + + -- ) shift; break; ;; + -h|--help ) print_help_messages; break ;; + * ) print_help_messages; exit 0 ;; + esac +done + + +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +CLAIR3="${SHELL_FOLDER}/../clair3.py" + +if [ ${BED_FILE_PATH} = "EMPTY" ] ; then BED_FILE_PATH= ; fi +RETRIES=4 + +PILEUP_CHECKPOINT_PATH="${MODEL_PATH}/${PILEUP_PREFIX}" +FULL_ALIGNMENT_CHECKPOINT_PATH="${MODEL_PATH}/${FA_PREFIX}" +LOG_PATH="${OUTPUT_FOLDER}/log" +TMP_FILE_PATH="${OUTPUT_FOLDER}/tmp" +SPLIT_BED_PATH="${TMP_FILE_PATH}/split_beds" +PILEUP_VCF_PATH="${TMP_FILE_PATH}/pileup_output" +GVCF_TMP_PATH="${TMP_FILE_PATH}/gvcf_tmp_output" +PHASE_OUTPUT_PATH="${TMP_FILE_PATH}/phase_output" +FULL_ALIGNMENT_OUTPUT_PATH="${TMP_FILE_PATH}/full_alignment_output" +PHASE_VCF_PATH="${PHASE_OUTPUT_PATH}/phase_vcf" +PHASE_BAM_PATH="${PHASE_OUTPUT_PATH}/phase_bam" +CANDIDATE_BED_PATH="${FULL_ALIGNMENT_OUTPUT_PATH}/candidate_bed" +export OPENBLAS_NUM_THREADS=1 +export GOTO_NUM_THREADS=1 +export OMP_NUM_THREADS=1 + +echo "[INFO] Check environment variables" +${PYTHON} ${CLAIR3} CheckEnvs \ + --bam_fn ${BAM_FILE_PATH} \ + --bed_fn ${BED_FILE_PATH} \ + --output_fn_prefix ${OUTPUT_FOLDER} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --vcf_fn ${VCF_FILE_PATH} \ + --ctg_name ${CONTIGS} \ + --chunk_num ${CHUNK_NUM} \ + --chunk_size ${CHUNK_SIZE} \ + --include_all_ctgs ${INCLUDE_ALL_CTGS} \ + --threads ${THREADS} \ + --python ${PYTHON} \ + --pypy ${PYPY} \ + --samtools ${SAMTOOLS} \ + --whatshap ${WHATSHAP} \ + --parallel ${PARALLEL} \ + --qual ${QUAL} \ + --sampleName ${SAMPLE} \ + --var_pct_full ${PRO} \ + --ref_pct_full ${REF_PRO} \ + --snp_min_af ${SNP_AF} \ + --indel_min_af ${INDEL_AF} +readarray -t CHR < "${OUTPUT_FOLDER}/tmp/CONTIGS" +if [ ${#CHR[@]} -eq 0 ]; then echo "[INFO] Exit in environment checking"; exit 0; fi +THREADS_LOW=$((${THREADS}*3/4)) +if [[ ${THREADS_LOW} < 1 ]]; then THREADS_LOW=1; fi + +cd ${OUTPUT_FOLDER} +# Pileup calling +#----------------------------------------------------------------------------------------------------------------------- +export CUDA_VISIBLE_DEVICES="" +echo "[INFO] 1/7 Call variants using pileup model" +\time -v ${PARALLEL} --retries ${RETRIES} -C ' ' --joblog ${LOG_PATH}/parallel_1_call_var_bam_pileup.log -j ${THREADS_LOW} \ +"${PYTHON} ${CLAIR3} CallVarBam \ + --chkpnt_fn ${PILEUP_CHECKPOINT_PATH} \ + --bam_fn ${BAM_FILE_PATH} \ + --call_fn ${PILEUP_VCF_PATH}/pileup_{1}_{2}.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --extend_bed ${SPLIT_BED_PATH}/{1} \ + --bed_fn ${BED_FILE_PATH} \ + --vcf_fn ${VCF_FILE_PATH} \ + --ctgName {1} \ + --chunk_id {2} \ + --chunk_num {3} \ + --platform ${PLATFORM} \ + --fast_mode ${FAST_MODE} \ + --snp_min_af ${SNP_AF} \ + --indel_min_af ${INDEL_AF} \ + --call_snp_only ${SNP_ONLY} \ + --gvcf ${GVCF} \ + --enable_long_indel ${ENABLE_LONG_INDEL} \ + --python ${PYTHON} \ + --pypy ${PYPY} \ + --samtools ${SAMTOOLS} \ + --temp_file_dir ${GVCF_TMP_PATH} \ + --pileup" :::: ${OUTPUT_FOLDER}/tmp/CHUNK_LIST |& tee ${LOG_PATH}/1_call_var_bam_pileup.log + + +\time -v ${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${PILEUP_VCF_PATH} \ + --vcf_fn_prefix "pileup" \ + --output_fn ${OUTPUT_FOLDER}/pileup.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS + +if [ "$( gzip -fdc ${OUTPUT_FOLDER}/pileup.vcf.gz | grep -v '#' | wc -l )" -eq 0 ]; then echo "[INFO] Exit in pileup variant calling"; exit 0; fi +if [ ${PILEUP_ONLY} == True ]; then + if [ ${RM_TMP_DIR} == True ]; then echo "[INFO] Removing intermediate files in ${OUTPUT_FOLDER}/tmp"; rm -rf ${OUTPUT_FOLDER}/tmp; fi + echo "[INFO] Only call pileup output with --pileup_only, output file: ${OUTPUT_FOLDER}/pileup.vcf.gz" + echo "[INFO] Finish calling!" + exit 0; +fi + +# Whatshap phasing and haplotaging +#----------------------------------------------------------------------------------------------------------------------- +if [ ${NO_PHASING} == True ] +then + echo "[INFO] 2/7 No phasing for full alignment calling" + ${PARALLEL} -j${THREADS} ln -sf ${BAM_FILE_PATH} ${PHASE_BAM_PATH}/{1}.bam ::: ${CHR[@]} + if [ -f ${BAM_FILE_PATH}.bai ]; then ${PARALLEL} --retries ${RETRIES} -j${THREADS} ln -sf ${BAM_FILE_PATH}.bai ${PHASE_BAM_PATH}/{1}.bam.bai ::: ${CHR[@]}; fi + if [ -f ${BAM_FILE_PATH%.*}.bai ]; then ${PARALLEL} --retries ${RETRIES} -j${THREADS} ln -sf ${BAM_FILE_PATH%.*}.bai ${PHASE_BAM_PATH}/{1}.bam.bai ::: ${CHR[@]}; fi +else + echo $'' + echo "[INFO] 2/7 Select heterozygous SNP variants for Whatshap phasing and haplotagging" + gzip -fdc ${OUTPUT_FOLDER}/pileup.vcf.gz | ${PYPY} ${CLAIR3} SelectQual --phase --output_fn ${PHASE_VCF_PATH} --var_pct_phasing ${PHASING_PCT} + \time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_2_select_hetero_snp.log -j${THREADS} \ + "${PYPY} ${CLAIR3} SelectHetSnp \ + --vcf_fn ${OUTPUT_FOLDER}/pileup.vcf.gz \ + --split_folder ${PHASE_VCF_PATH} \ + --ctgName {1}" ::: ${CHR[@]} ::: ${ALL_SAMPLE[@]} |& tee ${LOG_PATH}/2_select_hetero_snp.log + + echo $'' + echo "[INFO] 3/7 Phase VCF file using Whatshap" + \time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_3_phase.log -j${THREADS} \ + "${WHATSHAP} phase \ + --output ${PHASE_VCF_PATH}/phased_{1}.vcf.gz \ + --reference ${REFERENCE_FILE_PATH} \ + --chromosome {1} \ + --distrust-genotypes \ + --ignore-read-groups \ + ${PHASE_VCF_PATH}/{1}.vcf \ + ${BAM_FILE_PATH}" ::: ${CHR[@]} |& tee ${LOG_PATH}/3_phase.log + ${PARALLEL} -j${THREADS} tabix -f -p vcf ${PHASE_VCF_PATH}/phased_{}.vcf.gz ::: ${CHR[@]} + + echo $'' + echo "[INFO] 4/7 Haplotag input BAM file using Whatshap" + \time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_4_haplotag.log -j${THREADS} \ + "${WHATSHAP} haplotag \ + --output ${PHASE_BAM_PATH}/{1}.bam \ + --reference ${REFERENCE_FILE_PATH} \ + --ignore-read-groups \ + --regions {1} \ + ${PHASE_VCF_PATH}/phased_{1}.vcf.gz \ + ${BAM_FILE_PATH}" ::: ${CHR[@]} |& tee ${LOG_PATH}/4_haplotag.log + ${PARALLEL} -j${THREADS} ${SAMTOOLS} index -@12 ${PHASE_BAM_PATH}/{1}.bam ::: ${CHR[@]} +fi + +# Full alignment calling +#----------------------------------------------------------------------------------------------------------------------- +echo $'' +echo "[INFO] 5/7 Select candidates for full-alignment calling" +gzip -fdc ${OUTPUT_FOLDER}/pileup.vcf.gz | ${PYPY} ${CLAIR3} SelectQual --output_fn ${CANDIDATE_BED_PATH} \ +--var_pct_full ${PRO} --ref_pct_full ${REF_PRO} --platform ${PLATFORM} --vcf_fn ${VCF_FILE_PATH} +\time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_5_select_candidate.log -j${THREADS} \ +"${PYPY} ${CLAIR3} SelectCandidates \ + --pileup_vcf_fn ${OUTPUT_FOLDER}/pileup.vcf.gz \ + --split_folder ${CANDIDATE_BED_PATH} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --var_pct_full ${PRO} \ + --ref_pct_full ${REF_PRO} \ + --platform ${PLATFORM} \ + --ctgName {1}" ::: ${CHR[@]} |& tee ${LOG_PATH}/5_select_candidate.log + +echo $'' +echo "[INFO] 6/7 Call low-quality variants using full-alignment model" +cat ${CANDIDATE_BED_PATH}/FULL_ALN_FILE_* > ${CANDIDATE_BED_PATH}/FULL_ALN_FILES +\time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_6_call_var_bam_full_alignment.log -j ${THREADS_LOW} \ +"${PYTHON} ${CLAIR3} CallVarBam \ + --chkpnt_fn ${FULL_ALIGNMENT_CHECKPOINT_PATH} \ + --bam_fn ${PHASE_BAM_PATH}/{1/.}.bam \ + --call_fn ${FULL_ALIGNMENT_OUTPUT_PATH}/full_alignment_{1/}.vcf \ + --sampleName ${SAMPLE} \ + --vcf_fn ${VCF_FILE_PATH} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --full_aln_regions {1} \ + --ctgName {1/.} \ + --add_indel_length \ + --phasing_info_in_bam \ + --gvcf ${GVCF} \ + --enable_long_indel ${ENABLE_LONG_INDEL} \ + --python ${PYTHON} \ + --pypy ${PYPY} \ + --samtools ${SAMTOOLS} \ + --platform ${PLATFORM}" :::: ${CANDIDATE_BED_PATH}/FULL_ALN_FILES |& tee ${LOG_PATH}/6_call_var_bam_full_alignment.log + +${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${FULL_ALIGNMENT_OUTPUT_PATH} \ + --vcf_fn_prefix "full_alignment" \ + --output_fn ${OUTPUT_FOLDER}/full_alignment.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS + +if [ "$( gzip -fdc ${OUTPUT_FOLDER}/full_alignment.vcf.gz | grep -v '#' | wc -l )" -eq 0 ]; then echo "[INFO] Exit in full-alignment variant calling"; exit 0; fi +# Compress GVCF output using lz4 +if [ ${GVCF} == True ] +then + ${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${GVCF_TMP_PATH} \ + --vcf_fn_suffix ".tmp.gvcf" \ + --output_fn ${GVCF_TMP_PATH}/non_var.gvcf \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS +fi + +##Merge pileup and full alignment vcf +##----------------------------------------------------------------------------------------------------------------------- +echo $'' +echo "[INFO] 7/7 Merge pileup VCF and full-alignment VCF" +\time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_7_merge_vcf.log -j${THREADS} \ +"${PYPY} ${CLAIR3} MergeVcf \ + --pileup_vcf_fn ${OUTPUT_FOLDER}/pileup.vcf.gz \ + --bed_fn_prefix ${CANDIDATE_BED_PATH} \ + --full_alignment_vcf_fn ${OUTPUT_FOLDER}/full_alignment.vcf.gz \ + --output_fn ${TMP_FILE_PATH}/merge_output/merge_{1}.vcf \ + --platform ${PLATFORM} \ + --print_ref_calls ${SHOW_REF} \ + --gvcf ${GVCF} \ + --haploid_precise ${HAP_PRE} \ + --haploid_sensitive ${HAP_SEN} \ + --gvcf_fn ${TMP_FILE_PATH}/merge_output/merge_{1}.gvcf \ + --non_var_gvcf_fn ${GVCF_TMP_PATH}/non_var.gvcf \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --ctgName {1}" ::: ${CHR[@]} |& tee ${LOG_PATH}/7_merge_vcf.log + +${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${TMP_FILE_PATH}/merge_output \ + --vcf_fn_prefix "merge" \ + --output_fn ${OUTPUT_FOLDER}/merge_output.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS + +if [ "$( gzip -fdc ${OUTPUT_FOLDER}/merge_output.vcf.gz | grep -v '#' | wc -l )" -eq 0 ]; then echo "[INFO] Exit in variant merging"; exit 0; fi +if [ ${GVCF} == True ] +then + ${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${TMP_FILE_PATH}/merge_output \ + --vcf_fn_prefix "merge" \ + --vcf_fn_suffix ".gvcf" \ + --output_fn ${OUTPUT_FOLDER}/merge_output.gvcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS +fi + +if [ ${ENABLE_PHASING} == True ] +then + echo "[INFO] 7/7 Phasing VCF output in parallel using WhatsHap" + \time -v ${PARALLEL} --retries ${RETRIES} --joblog ${LOG_PATH}/parallel_8_phase_vcf_output.log -j${THREADS} \ + "${WHATSHAP} phase \ + --output ${TMP_FILE_PATH}/merge_output/phased_merge_{1}.vcf \ + --reference ${REFERENCE_FILE_PATH} \ + --ignore-read-groups \ + ${TMP_FILE_PATH}/merge_output/merge_{1}.vcf \ + ${BAM_FILE_PATH}" ::: ${CHR[@]} |& tee ${LOG_PATH}/8_phase_vcf_output.log + + ${PYPY} ${CLAIR3} SortVcf \ + --input_dir ${TMP_FILE_PATH}/merge_output \ + --vcf_fn_prefix "phased_merge" \ + --output_fn ${OUTPUT_FOLDER}/phased_merge_output.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --contigs_fn ${TMP_FILE_PATH}/CONTIGS +fi + +if [ ${RM_TMP_DIR} == True ]; then echo "[INFO] Removing intermediate files in ${OUTPUT_FOLDER}/tmp"; rm -rf ${OUTPUT_FOLDER}/tmp; fi + +echo $'' +echo "[INFO] Finish calling, output file: ${OUTPUT_FOLDER}/merge_output.vcf.gz" + +if [ ${ENABLE_PHASING} == True ]; then echo "[INFO] Finish calling, phased output file: ${OUTPUT_FOLDER}/phased_merge_output.vcf.gz"; fi diff --git a/benchmarks/nn-variant/Clair3/scripts/clair3_CallVar.sh b/benchmarks/nn-variant/Clair3/scripts/clair3_CallVar.sh new file mode 100755 index 0000000..1aa21b4 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/scripts/clair3_CallVar.sh @@ -0,0 +1,157 @@ +#!/bin/bash +SCRIPT_NAME=$(basename "$0") +Usage="Usage: ./${SCRIPT_NAME} --bam_fn=BAM --ref_fn=REF --output=OUTPUT_DIR --threads=THREADS --platform=PLATFORM --model_path=MODEL_PREFIX [--bed_fn=BED] [options]" +# INFO: whole calling workflow of clair3 + +set -e +ARGS=`getopt -o b:f:t:m:p:o:r::c::s::h::g \ +-l bam_fn:,ref_fn:,threads:,model_path:,platform:,output:,\ +bed_fn::,vcf_fn::,ctg_name::,sample_name::,help::,qual::,samtools::,python::,pypy::,parallel::,whatshap::,chunk_num::,chunk_size::,var_pct_full::,var_pct_phasing::,\ +snp_min_af::,indel_min_af::,ref_pct_full::,pileup_only::,fast_mode::,gvcf::,print_ref_calls::,haploid_precise::,haploid_sensitive::,include_all_ctgs::,\ +no_phasing_for_fa::,pileup_model_prefix::,fa_model_prefix::,call_snp_only::,remove_intermediate_dir::,enable_phasing::,enable_long_indel:: -n 'run_clair3.sh' -- "$@"` + +if [ $? != 0 ] ; then echo"No input. Terminating...">&2 ; exit 1 ; fi +eval set -- "${ARGS}" + +while true; do + case "$1" in + -b|--bam_fn ) BAM_FILE_PATH="$2"; shift 2 ;; + -f|--ref_fn ) REFERENCE_FILE_PATH="$2"; shift 2 ;; + -t|--threads ) THREADS="$2"; shift 2 ;; + -m|--model_path ) MODEL_PATH="$2"; shift 2 ;; + -p|--platform ) PLATFORM="$2"; shift 2 ;; + -o|--output ) OUTPUT_FOLDER="$2"; shift 2 ;; + --bed_fn ) BED_FILE_PATH="$2"; shift 2 ;; + --vcf_fn ) VCF_FILE_PATH="$2"; shift 2 ;; + --ctg_name ) CONTIGS="$2"; shift 2 ;; + --sample_name ) SAMPLE="$2"; shift 2 ;; + --chunk_num ) CHUNK_NUM="$2"; shift 2 ;; + --chunk_size ) CHUNK_SIZE="$2"; shift 2 ;; + --qual ) QUAL="$2"; shift 2 ;; + --samtools ) SAMTOOLS="$2"; shift 2 ;; + --python ) PYTHON="$2"; shift 2 ;; + --pypy ) PYPY="$2"; shift 2 ;; + --parallel ) PARALLEL="$2"; shift 2 ;; + --whatshap ) WHATSHAP="$2"; shift 2 ;; + --var_pct_full ) PRO="$2"; shift 2 ;; + --ref_pct_full ) REF_PRO="$2"; shift 2 ;; + --var_pct_phasing ) PHASING_PCT="$2"; shift 2 ;; + --pileup_only ) PILEUP_ONLY="$2"; shift 2 ;; + --fast_mode ) FAST_MODE="$2"; shift 2 ;; + --call_snp_only ) SNP_ONLY="$2"; shift 2 ;; + --print_ref_calls ) SHOW_REF="$2"; shift 2 ;; + --gvcf ) GVCF="$2"; shift 2 ;; + --snp_min_af ) SNP_AF="$2"; shift 2 ;; + --indel_min_af ) INDEL_AF="$2"; shift 2 ;; + --pileup_model_prefix ) PILEUP_PREFIX="$2"; shift 2 ;; + --fa_model_prefix ) FA_PREFIX="$2"; shift 2 ;; + --haploid_precise ) HAP_PRE="$2"; shift 2 ;; + --haploid_sensitive ) HAP_SEN="$2"; shift 2 ;; + --include_all_ctgs ) INCLUDE_ALL_CTGS="$2"; shift 2 ;; + --no_phasing_for_fa ) NO_PHASING="$2"; shift 2 ;; + --remove_intermediate_dir ) RM_TMP_DIR="$2"; shift 2 ;; + --enable_phasing ) ENABLE_PHASING="$2"; shift 2 ;; + --enable_long_indel ) ENABLE_LONG_INDEL="$2"; shift 2 ;; + + -- ) shift; break; ;; + -h|--help ) print_help_messages; break ;; + * ) print_help_messages; exit 0 ;; + esac +done + + +SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) +CLAIR3="${SHELL_FOLDER}/../clair3.py" + +if [ ${BED_FILE_PATH} = "EMPTY" ] ; then BED_FILE_PATH= ; fi +RETRIES=4 + +PILEUP_CHECKPOINT_PATH="${MODEL_PATH}/${PILEUP_PREFIX}" +echo "PILEUP_CHECKPOINT_PATH ${PILEUP_CHECKPOINT_PATH}" +FULL_ALIGNMENT_CHECKPOINT_PATH="${MODEL_PATH}/${FA_PREFIX}" +LOG_PATH="${OUTPUT_FOLDER}/log" +TMP_FILE_PATH="${OUTPUT_FOLDER}/tmp" +SPLIT_BED_PATH="${TMP_FILE_PATH}/split_beds" +PILEUP_VCF_PATH="${TMP_FILE_PATH}/pileup_output" +GVCF_TMP_PATH="${TMP_FILE_PATH}/gvcf_tmp_output" +PHASE_OUTPUT_PATH="${TMP_FILE_PATH}/phase_output" +FULL_ALIGNMENT_OUTPUT_PATH="${TMP_FILE_PATH}/full_alignment_output" +PHASE_VCF_PATH="${PHASE_OUTPUT_PATH}/phase_vcf" +PHASE_BAM_PATH="${PHASE_OUTPUT_PATH}/phase_bam" +CANDIDATE_BED_PATH="${FULL_ALIGNMENT_OUTPUT_PATH}/candidate_bed" +export OPENBLAS_NUM_THREADS=1 +export GOTO_NUM_THREADS=1 +export OMP_NUM_THREADS=1 + +echo "[INFO] Check environment variables" +${PYTHON} ${CLAIR3} CheckEnvs \ + --bam_fn ${BAM_FILE_PATH} \ + --bed_fn ${BED_FILE_PATH} \ + --output_fn_prefix ${OUTPUT_FOLDER} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --vcf_fn ${VCF_FILE_PATH} \ + --ctg_name ${CONTIGS} \ + --chunk_num ${CHUNK_NUM} \ + --chunk_size ${CHUNK_SIZE} \ + --include_all_ctgs ${INCLUDE_ALL_CTGS} \ + --threads ${THREADS} \ + --python ${PYTHON} \ + --pypy ${PYPY} \ + --samtools ${SAMTOOLS} \ + --whatshap ${WHATSHAP} \ + --parallel ${PARALLEL} \ + --qual ${QUAL} \ + --sampleName ${SAMPLE} \ + --var_pct_full ${PRO} \ + --ref_pct_full ${REF_PRO} \ + --snp_min_af ${SNP_AF} \ + --indel_min_af ${INDEL_AF} +readarray -t CHR < "${OUTPUT_FOLDER}/tmp/CONTIGS" +if [ ${#CHR[@]} -eq 0 ]; then echo "[INFO] Exit in environment checking"; exit 0; fi +# GenarchBench: Tensorflow sometimes uses more than 1 thread, so Clair3 uses +# ${THREADS}*3/4 simultaneous threads as an heuristic. We obtained better scalability +# using all threads. +# THREADS_LOW=$((${THREADS}*3/4)) +THREADS_LOW=${THREADS} +if [[ ${THREADS_LOW} < 1 ]]; then THREADS_LOW=1; fi + +cd ${OUTPUT_FOLDER} +# Pileup calling +#----------------------------------------------------------------------------------------------------------------------- +# GenarchBnech: --delay = 0 (no need to wait at startup) and +# --tensorflow_threads=1 (we want that each chunk is processed by a single physical CPU) + +export CUDA_VISIBLE_DEVICES="" +echo "[INFO] 1/1 Call variants using pileup model" +\time -v ${PARALLEL} --retries ${RETRIES} -C ' ' --joblog ${LOG_PATH}/parallel_1_call_var_bam_pileup.log -j ${THREADS_LOW} \ +"${PYTHON} ${CLAIR3} CallVarBam \ + --chkpnt_fn ${PILEUP_CHECKPOINT_PATH} \ + --bam_fn ${BAM_FILE_PATH} \ + --call_fn ${PILEUP_VCF_PATH}/pileup_{1}_{2}.vcf \ + --sampleName ${SAMPLE} \ + --ref_fn ${REFERENCE_FILE_PATH} \ + --extend_bed ${SPLIT_BED_PATH}/{1} \ + --bed_fn ${BED_FILE_PATH} \ + --vcf_fn ${VCF_FILE_PATH} \ + --ctgName {1} \ + --chunk_id {2} \ + --chunk_num {3} \ + --platform ${PLATFORM} \ + --fast_mode ${FAST_MODE} \ + --snp_min_af ${SNP_AF} \ + --indel_min_af ${INDEL_AF} \ + --call_snp_only ${SNP_ONLY} \ + --gvcf ${GVCF} \ + --enable_long_indel ${ENABLE_LONG_INDEL} \ + --python ${PYTHON} \ + --pypy ${PYPY} \ + --samtools ${SAMTOOLS} \ + --temp_file_dir ${GVCF_TMP_PATH} \ + --delay=0 \ + --tensorflow_threads=1 \ + --pileup" :::: ${OUTPUT_FOLDER}/tmp/CHUNK_LIST |& tee ${LOG_PATH}/1_call_var_bam_pileup.log + +# Write to prof_pipe to tell the c++ wrapper to stop profiling. +echo "ee" > "prof_pipe" + + diff --git a/benchmarks/nn-variant/Clair3/scripts/clair3_hifi_quick_demo.sh b/benchmarks/nn-variant/Clair3/scripts/clair3_hifi_quick_demo.sh new file mode 100755 index 0000000..ac7f5be --- /dev/null +++ b/benchmarks/nn-variant/Clair3/scripts/clair3_hifi_quick_demo.sh @@ -0,0 +1,61 @@ +PLATFORM='hifi' +INPUT_DIR="${HOME}/clair3_pacbio_hifi_quickDemo" +OUTPUT_DIR="${INPUT_DIR}/output" +THREADS=4 + +## Create local directory structure +mkdir -p ${INPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +# Download quick demo data +#GRCh38_no_alt Reference +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/GRCh38_no_alt_chr20.fa +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/GRCh38_no_alt_chr20.fa.fai +# BAM chr20:100000-300000 +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/HG003_chr20_demo.bam +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/HG003_chr20_demo.bam.bai +# GIAB Truth VCF and BED +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz.tbi +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/pacbio_hifi/HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed + +REF="GRCh38_no_alt_chr20.fa" +BAM="HG003_chr20_demo.bam" +BASELINE_VCF_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz" +BASELINE_BED_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed" +OUTPUT_VCF_FILE_PATH="merge_output.vcf.gz" + +CONTIGS="chr20" +START_POS=100000 +END_POS=300000 +echo -e "${CONTIGS}\t${START_POS}\t${END_POS}" > ${INPUT_DIR}/quick_demo.bed + +cd ${OUTPUT_DIR} +# Run Clair3 using one command +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/${BAM} \ + --ref_fn=${INPUT_DIR}/${REF} \ + --threads=${THREADS} \ + --platform=${PLATFORM} \ + --model_path="/opt/models/${PLATFORM}" \ + --output=${OUTPUT_DIR} \ + --bed_fn=${INPUT_DIR}/quick_demo.bed + +# Run hap.py +docker run \ +-v "${INPUT_DIR}":"${INPUT_DIR}" \ +-v "${OUTPUT_DIR}":"${OUTPUT_DIR}" \ +jmcdani20/hap.py:v0.3.12 /opt/hap.py/bin/hap.py \ +${INPUT_DIR}/${BASELINE_VCF_FILE_PATH} \ +${OUTPUT_DIR}/${OUTPUT_VCF_FILE_PATH} \ +-f "${INPUT_DIR}/${BASELINE_BED_FILE_PATH}" \ +-r "${INPUT_DIR}/${REF}" \ +-o "${OUTPUT_DIR}/happy" \ +-l ${CONTIGS}:${START_POS}-${END_POS} \ +--engine=vcfeval \ +--threads="${THREADS}" \ +--pass-only \ No newline at end of file diff --git a/benchmarks/nn-variant/Clair3/scripts/clair3_ilmn_quick_demo.sh b/benchmarks/nn-variant/Clair3/scripts/clair3_ilmn_quick_demo.sh new file mode 100755 index 0000000..66699db --- /dev/null +++ b/benchmarks/nn-variant/Clair3/scripts/clair3_ilmn_quick_demo.sh @@ -0,0 +1,62 @@ +# Parameters +PLATFORM='ilmn' +INPUT_DIR="${HOME}/clair3_illumina_quickDemo" +OUTPUT_DIR="${INPUT_DIR}/output" +THREADS=4 + +## Create local directory structure +mkdir -p ${INPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +# Download quick demo data +#GRCh38_no_alt Reference +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/GRCh38_chr20.fa +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/GRCh38_chr20.fa.fai +# BAM chr20:100000-300000 +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/HG003_chr20_demo.bam +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/HG003_chr20_demo.bam.bai +# GIAB Truth VCF and BED +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz.tbi +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/illumina/HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed + +REF="GRCh38_chr20.fa" +BAM="HG003_chr20_demo.bam" +BASELINE_VCF_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz" +BASELINE_BED_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed" +OUTPUT_VCF_FILE_PATH="merge_output.vcf.gz" + +CONTIGS="chr20" +START_POS=100000 +END_POS=300000 +echo -e "${CONTIGS}\t${START_POS}\t${END_POS}" > ${INPUT_DIR}/quick_demo.bed + +cd ${OUTPUT_DIR} +# Run Clair3 using one command +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/${BAM} \ + --ref_fn=${INPUT_DIR}/${REF} \ + --threads=${THREADS} \ + --platform=${PLATFORM} \ + --model_path="/opt/models/${PLATFORM}" \ + --output=${OUTPUT_DIR} \ + --bed_fn=${INPUT_DIR}/quick_demo.bed + +# Run hap.py +docker run \ +-v "${INPUT_DIR}":"${INPUT_DIR}" \ +-v "${OUTPUT_DIR}":"${OUTPUT_DIR}" \ +jmcdani20/hap.py:v0.3.12 /opt/hap.py/bin/hap.py \ +${INPUT_DIR}/${BASELINE_VCF_FILE_PATH} \ +${OUTPUT_DIR}/${OUTPUT_VCF_FILE_PATH} \ +-f "${INPUT_DIR}/${BASELINE_BED_FILE_PATH}" \ +-r "${INPUT_DIR}/${REF}" \ +-o "${OUTPUT_DIR}/happy" \ +-l ${CONTIGS}:${START_POS}-${END_POS} \ +--engine=vcfeval \ +--threads="${THREADS}" \ +--pass-only diff --git a/benchmarks/nn-variant/Clair3/scripts/clair3_ont_quick_demo.sh b/benchmarks/nn-variant/Clair3/scripts/clair3_ont_quick_demo.sh new file mode 100755 index 0000000..095c8f2 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/scripts/clair3_ont_quick_demo.sh @@ -0,0 +1,62 @@ +# Parameters +PLATFORM='ont' +INPUT_DIR="${HOME}/clair3_ont_quickDemo" +OUTPUT_DIR="${INPUT_DIR}/output" +THREADS=4 + +## Create local directory structure +mkdir -p ${INPUT_DIR} +mkdir -p ${OUTPUT_DIR} + +# Download quick demo data +#GRCh38_no_alt Reference +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/GRCh38_no_alt_chr20.fa +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/GRCh38_no_alt_chr20.fa.fai +# BAM chr20:100000-300000 +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/HG003_chr20_demo.bam +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/HG003_chr20_demo.bam.bai +# GIAB Truth VCF and BED +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz.tbi +wget -P ${INPUT_DIR} http://www.bio8.cs.hku.hk/clair3/demo/quick_demo/ont/HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed + +REF="GRCh38_no_alt_chr20.fa" +BAM="HG003_chr20_demo.bam" +BASELINE_VCF_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark.vcf.gz" +BASELINE_BED_FILE_PATH="HG003_GRCh38_chr20_v4.2.1_benchmark_noinconsistent.bed" +OUTPUT_VCF_FILE_PATH="merge_output.vcf.gz" + +CONTIGS="chr20" +START_POS=100000 +END_POS=300000 +echo -e "${CONTIGS}\t${START_POS}\t${END_POS}" > ${INPUT_DIR}/quick_demo.bed + +cd ${OUTPUT_DIR} +# Run Clair3 using one command +docker run -it \ + -v ${INPUT_DIR}:${INPUT_DIR} \ + -v ${OUTPUT_DIR}:${OUTPUT_DIR} \ + hkubal/clair3:latest \ + /opt/bin/run_clair3.sh \ + --bam_fn=${INPUT_DIR}/${BAM} \ + --ref_fn=${INPUT_DIR}/${REF} \ + --threads=${THREADS} \ + --platform=${PLATFORM} \ + --model_path="/opt/models/${PLATFORM}" \ + --output=${OUTPUT_DIR} \ + --bed_fn=${INPUT_DIR}/quick_demo.bed + +# Run hap.py +docker run \ +-v "${INPUT_DIR}":"${INPUT_DIR}" \ +-v "${OUTPUT_DIR}":"${OUTPUT_DIR}" \ +jmcdani20/hap.py:v0.3.12 /opt/hap.py/bin/hap.py \ +${INPUT_DIR}/${BASELINE_VCF_FILE_PATH} \ +${OUTPUT_DIR}/${OUTPUT_VCF_FILE_PATH} \ +-f "${INPUT_DIR}/${BASELINE_BED_FILE_PATH}" \ +-r "${INPUT_DIR}/${REF}" \ +-o "${OUTPUT_DIR}/happy" \ +-l ${CONTIGS}:${START_POS}-${END_POS} \ +--engine=vcfeval \ +--threads="${THREADS}" \ +--pass-only diff --git a/benchmarks/nn-variant/Clair3/shared/__init__.py b/benchmarks/nn-variant/Clair3/shared/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/benchmarks/nn-variant/Clair3/shared/__pycache__/__init__.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..7d25a35 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/__pycache__/__init__.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/__pycache__/interval_tree.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/__pycache__/interval_tree.cpython-39.pyc new file mode 100644 index 0000000..5143584 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/__pycache__/interval_tree.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/__pycache__/param_p.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/__pycache__/param_p.cpython-39.pyc new file mode 100644 index 0000000..a39e03d Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/__pycache__/param_p.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/__pycache__/utils.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..5934e38 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/__pycache__/utils.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/shared/command_options.py b/benchmarks/nn-variant/Clair3/shared/command_options.py similarity index 100% rename from benchmarks/nn-variant/shared/command_options.py rename to benchmarks/nn-variant/Clair3/shared/command_options.py diff --git a/benchmarks/nn-variant/Clair3/shared/interval_tree.py b/benchmarks/nn-variant/Clair3/shared/interval_tree.py new file mode 100644 index 0000000..3bb4ef3 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/interval_tree.py @@ -0,0 +1,67 @@ +import shlex +import sys +from shared.intervaltree.intervaltree import IntervalTree + +from shared.utils import subprocess_popen + + +def bed_tree_from(bed_file_path, expand_region=None, contig_name=None, bed_ctg_start=None, bed_ctg_end=None, + return_bed_region=False, padding=None): + """ + 0-based interval tree [start, end) + """ + + tree = {} + if bed_file_path is None: + if return_bed_region: + return tree, None, None + return tree + + bed_start, bed_end = float('inf'), 0 + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_file_path))) + for row_id, row in enumerate(unzip_process.stdout): + if row[0] == '#': + continue + columns = row.strip().split() + + ctg_name = columns[0] + if contig_name != None and ctg_name != contig_name: + continue + if ctg_name not in tree: + tree[ctg_name] = IntervalTree() + + ctg_start, ctg_end = int(columns[1]), int(columns[2]) + + if ctg_end < ctg_start or ctg_start < 0 or ctg_end < 0: + sys.exit("[ERROR] Invalid bed input in {}-th row {} {} {}".format(row_id+1, ctg_name, ctg_start, ctg_end)) + + if bed_ctg_start and bed_ctg_end: + if ctg_end < bed_ctg_start or ctg_start > bed_ctg_end: + continue + if padding: + ctg_start += padding + ctg_end -= padding + bed_start = min(ctg_start, bed_start) + bed_end = max(ctg_end, bed_end) + if ctg_start == ctg_end: + ctg_end += 1 + + tree[ctg_name].addi(ctg_start, ctg_end) + + unzip_process.stdout.close() + unzip_process.wait() + if return_bed_region: + return tree, bed_start, bed_end + return tree + + +def is_region_in(tree, contig_name, region_start=None, region_end=None): + if not tree or (contig_name is None) or (contig_name not in tree): + return False + + interval_tree = tree[contig_name] + return len( + interval_tree.at(region_start) + if region_end is None else + interval_tree.overlap(begin=region_start, end=region_end) + ) > 0 diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/__init__.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/__init__.py new file mode 100644 index 0000000..50dcc6a --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/__init__.py @@ -0,0 +1,22 @@ +""" +intervaltree: A mutable, self-balancing interval tree for Python 2 and 3. +Queries may be by point, by range overlap, or by range envelopment. + +Root package. + +Copyright 2013-2018 Chaim Leib Halbert + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from .interval import Interval +from .intervaltree import IntervalTree diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/__init__.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..6a29423 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/__init__.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/interval.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/interval.cpython-39.pyc new file mode 100644 index 0000000..8218789 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/interval.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/intervaltree.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/intervaltree.cpython-39.pyc new file mode 100644 index 0000000..8b7a288 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/intervaltree.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/node.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/node.cpython-39.pyc new file mode 100644 index 0000000..b50d93c Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/__pycache__/node.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/interval.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/interval.py new file mode 100644 index 0000000..d9d80d6 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/interval.py @@ -0,0 +1,302 @@ +""" +intervaltree: A mutable, self-balancing interval tree for Python 2 and 3. +Queries may be by point, by range overlap, or by range envelopment. + +Interval class + +Copyright 2013-2018 Chaim Leib Halbert +Modifications copyright 2014 Konstantin Tretyakov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from numbers import Number +from collections import namedtuple + + +# noinspection PyBroadException +class Interval(namedtuple('IntervalBase', ['begin', 'end', 'data'])): + __slots__ = () # Saves memory, avoiding the need to create __dict__ for each interval + + def __new__(cls, begin, end, data=None): + return super(Interval, cls).__new__(cls, begin, end, data) + + def overlaps(self, begin, end=None): + """ + Whether the interval overlaps the given point, range or Interval. + :param begin: beginning point of the range, or the point, or an Interval + :param end: end point of the range. Optional if not testing ranges. + :return: True or False + :rtype: bool + """ + if end is not None: + # An overlap means that some C exists that is inside both ranges: + # begin <= C < end + # and + # self.begin <= C < self.end + # See https://stackoverflow.com/questions/3269434/whats-the-most-efficient-way-to-test-two-integer-ranges-for-overlap/3269471#3269471 + return begin < self.end and end > self.begin + try: + return self.overlaps(begin.begin, begin.end) + except: + return self.contains_point(begin) + + def contains_point(self, p): + """ + Whether the Interval contains p. + :param p: a point + :return: True or False + :rtype: bool + """ + return self.begin <= p < self.end + + def range_matches(self, other): + """ + Whether the begins equal and the ends equal. Compare __eq__(). + :param other: Interval + :return: True or False + :rtype: bool + """ + return ( + self.begin == other.begin and + self.end == other.end + ) + + def contains_interval(self, other): + """ + Whether other is contained in this Interval. + :param other: Interval + :return: True or False + :rtype: bool + """ + return ( + self.begin <= other.begin and + self.end >= other.end + ) + + def distance_to(self, other): + """ + Returns the size of the gap between intervals, or 0 + if they touch or overlap. + :param other: Interval or point + :return: distance + :rtype: Number + """ + if self.overlaps(other): + return 0 + try: + if self.begin < other.begin: + return other.begin - self.end + else: + return self.begin - other.end + except: + if self.end <= other: + return other - self.end + else: + return self.begin - other + + def is_null(self): + """ + Whether this equals the null interval. + :return: True if end <= begin else False + :rtype: bool + """ + return self.begin >= self.end + + def length(self): + """ + The distance covered by this Interval. + :return: length + :type: Number + """ + if self.is_null(): + return 0 + return self.end - self.begin + + def __hash__(self): + """ + Depends on begin and end only. + :return: hash + :rtype: Number + """ + return hash((self.begin, self.end)) + + def __eq__(self, other): + """ + Whether the begins equal, the ends equal, and the data fields + equal. Compare range_matches(). + :param other: Interval + :return: True or False + :rtype: bool + """ + return ( + self.begin == other.begin and + self.end == other.end and + self.data == other.data + ) + + def __cmp__(self, other): + """ + Tells whether other sorts before, after or equal to this + Interval. + + Sorting is by begins, then by ends, then by data fields. + + If data fields are not both sortable types, data fields are + compared alphabetically by type name. + :param other: Interval + :return: -1, 0, 1 + :rtype: int + """ + s = self[0:2] + try: + o = other[0:2] + except: + o = (other,) + if s != o: + return -1 if s < o else 1 + try: + if self.data == other.data: + return 0 + return -1 if self.data < other.data else 1 + except TypeError: + s = type(self.data).__name__ + o = type(other.data).__name__ + if s == o: + return 0 + return -1 if s < o else 1 + + def __lt__(self, other): + """ + Less than operator. Parrots __cmp__() + :param other: Interval or point + :return: True or False + :rtype: bool + """ + return self.__cmp__(other) < 0 + + def __gt__(self, other): + """ + Greater than operator. Parrots __cmp__() + :param other: Interval or point + :return: True or False + :rtype: bool + """ + return self.__cmp__(other) > 0 + + def _raise_if_null(self, other): + """ + :raises ValueError: if either self or other is a null Interval + """ + if self.is_null(): + raise ValueError("Cannot compare null Intervals!") + if hasattr(other, 'is_null') and other.is_null(): + raise ValueError("Cannot compare null Intervals!") + + def lt(self, other): + """ + Strictly less than. Returns True if no part of this Interval + extends higher than or into other. + :raises ValueError: if either self or other is a null Interval + :param other: Interval or point + :return: True or False + :rtype: bool + """ + self._raise_if_null(other) + return self.end <= getattr(other, 'begin', other) + + def le(self, other): + """ + Less than or overlaps. Returns True if no part of this Interval + extends higher than other. + :raises ValueError: if either self or other is a null Interval + :param other: Interval or point + :return: True or False + :rtype: bool + """ + self._raise_if_null(other) + return self.end <= getattr(other, 'end', other) + + def gt(self, other): + """ + Strictly greater than. Returns True if no part of this Interval + extends lower than or into other. + :raises ValueError: if either self or other is a null Interval + :param other: Interval or point + :return: True or False + :rtype: bool + """ + self._raise_if_null(other) + if hasattr(other, 'end'): + return self.begin >= other.end + else: + return self.begin > other + + def ge(self, other): + """ + Greater than or overlaps. Returns True if no part of this Interval + extends lower than other. + :raises ValueError: if either self or other is a null Interval + :param other: Interval or point + :return: True or False + :rtype: bool + """ + self._raise_if_null(other) + return self.begin >= getattr(other, 'begin', other) + + def _get_fields(self): + """ + Used by str, unicode, repr and __reduce__. + + Returns only the fields necessary to reconstruct the Interval. + :return: reconstruction info + :rtype: tuple + """ + if self.data is not None: + return self.begin, self.end, self.data + else: + return self.begin, self.end + + def __repr__(self): + """ + Executable string representation of this Interval. + :return: string representation + :rtype: str + """ + if isinstance(self.begin, Number): + s_begin = str(self.begin) + s_end = str(self.end) + else: + s_begin = repr(self.begin) + s_end = repr(self.end) + if self.data is None: + return "Interval({0}, {1})".format(s_begin, s_end) + else: + return "Interval({0}, {1}, {2})".format(s_begin, s_end, repr(self.data)) + + __str__ = __repr__ + + def copy(self): + """ + Shallow copy. + :return: copy of self + :rtype: Interval + """ + return Interval(self.begin, self.end, self.data) + + def __reduce__(self): + """ + For pickle-ing. + :return: pickle data + :rtype: tuple + """ + return Interval, self._get_fields() diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/intervaltree.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/intervaltree.py new file mode 100644 index 0000000..b7f426e --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/intervaltree.py @@ -0,0 +1,1140 @@ + +""" +intervaltree: A mutable, self-balancing interval tree for Python 2 and 3. +Queries may be by point, by range overlap, or by range envelopment. + +Core logic. + +Copyright 2013-2018 Chaim Leib Halbert +Modifications Copyright 2014 Konstantin Tretyakov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from .interval import Interval +from .node import Node +from numbers import Number +from .sortedcontainers.sorteddict import SortedDict +from copy import copy +from warnings import warn + +try: + from collections.abc import MutableSet # Python 3? +except ImportError: + from collections import MutableSet + +try: + xrange # Python 2? +except NameError: # pragma: no cover + xrange = range + + +# noinspection PyBroadException +class IntervalTree(MutableSet): + """ + A binary lookup tree of intervals. + The intervals contained in the tree are represented using ``Interval(a, b, data)`` objects. + Each such object represents a half-open interval ``[a, b)`` with optional data. + + Examples: + --------- + + Initialize a blank tree:: + + >>> tree = IntervalTree() + >>> tree + IntervalTree() + + Initialize a tree from an iterable set of Intervals in O(n * log n):: + + >>> tree = IntervalTree([Interval(-10, 10), Interval(-20.0, -10.0)]) + >>> tree + IntervalTree([Interval(-20.0, -10.0), Interval(-10, 10)]) + >>> len(tree) + 2 + + Note that this is a set, i.e. repeated intervals are ignored. However, + Intervals with different data fields are regarded as different:: + + >>> tree = IntervalTree([Interval(-10, 10), Interval(-10, 10), Interval(-10, 10, "x")]) + >>> tree + IntervalTree([Interval(-10, 10), Interval(-10, 10, 'x')]) + >>> len(tree) + 2 + + Insertions:: + >>> tree = IntervalTree() + >>> tree[0:1] = "data" + >>> tree.add(Interval(10, 20)) + >>> tree.addi(19.9, 20) + >>> tree + IntervalTree([Interval(0, 1, 'data'), Interval(10, 20), Interval(19.9, 20)]) + >>> tree.update([Interval(19.9, 20.1), Interval(20.1, 30)]) + >>> len(tree) + 5 + + Inserting the same Interval twice does nothing:: + >>> tree = IntervalTree() + >>> tree[-10:20] = "arbitrary data" + >>> tree[-10:20] = None # Note that this is also an insertion + >>> tree + IntervalTree([Interval(-10, 20), Interval(-10, 20, 'arbitrary data')]) + >>> tree[-10:20] = None # This won't change anything + >>> tree[-10:20] = "arbitrary data" # Neither will this + >>> len(tree) + 2 + + Deletions:: + >>> tree = IntervalTree(Interval(b, e) for b, e in [(-10, 10), (-20, -10), (10, 20)]) + >>> tree + IntervalTree([Interval(-20, -10), Interval(-10, 10), Interval(10, 20)]) + >>> tree.remove(Interval(-10, 10)) + >>> tree + IntervalTree([Interval(-20, -10), Interval(10, 20)]) + >>> tree.remove(Interval(-10, 10)) + Traceback (most recent call last): + ... + ValueError + >>> tree.discard(Interval(-10, 10)) # Same as remove, but no exception on failure + >>> tree + IntervalTree([Interval(-20, -10), Interval(10, 20)]) + + Delete intervals, overlapping a given point:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> tree.remove_overlap(1.1) + >>> tree + IntervalTree([Interval(-1.1, 1.1)]) + + Delete intervals, overlapping an interval:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> tree.remove_overlap(0, 0.5) + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + >>> tree.remove_overlap(1.7, 1.8) + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + >>> tree.remove_overlap(1.6, 1.6) # Null interval does nothing + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + >>> tree.remove_overlap(1.6, 1.5) # Ditto + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + + Delete intervals, enveloped in the range:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> tree.remove_envelop(-1.0, 1.5) + >>> tree + IntervalTree([Interval(-1.1, 1.1), Interval(0.5, 1.7)]) + >>> tree.remove_envelop(-1.1, 1.5) + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + >>> tree.remove_envelop(0.5, 1.5) + >>> tree + IntervalTree([Interval(0.5, 1.7)]) + >>> tree.remove_envelop(0.5, 1.7) + >>> tree + IntervalTree() + + Point queries:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> assert tree[-1.1] == set([Interval(-1.1, 1.1)]) + >>> assert tree.at(1.1) == set([Interval(-0.5, 1.5), Interval(0.5, 1.7)]) # Same as tree[1.1] + >>> assert tree.at(1.5) == set([Interval(0.5, 1.7)]) # Same as tree[1.5] + + Interval overlap queries + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> assert tree.overlap(1.7, 1.8) == set() + >>> assert tree.overlap(1.5, 1.8) == set([Interval(0.5, 1.7)]) + >>> assert tree[1.5:1.8] == set([Interval(0.5, 1.7)]) # same as previous + >>> assert tree.overlap(1.1, 1.8) == set([Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> assert tree[1.1:1.8] == set([Interval(-0.5, 1.5), Interval(0.5, 1.7)]) # same as previous + + Interval envelop queries:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> assert tree.envelop(-0.5, 0.5) == set() + >>> assert tree.envelop(-0.5, 1.5) == set([Interval(-0.5, 1.5)]) + + Membership queries:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> Interval(-0.5, 0.5) in tree + False + >>> Interval(-1.1, 1.1) in tree + True + >>> Interval(-1.1, 1.1, "x") in tree + False + >>> tree.overlaps(-1.1) + True + >>> tree.overlaps(1.7) + False + >>> tree.overlaps(1.7, 1.8) + False + >>> tree.overlaps(-1.2, -1.1) + False + >>> tree.overlaps(-1.2, -1.0) + True + + Sizing:: + + >>> tree = IntervalTree([Interval(-1.1, 1.1), Interval(-0.5, 1.5), Interval(0.5, 1.7)]) + >>> len(tree) + 3 + >>> tree.is_empty() + False + >>> IntervalTree().is_empty() + True + >>> not tree + False + >>> not IntervalTree() + True + >>> print(tree.begin()) # using print() because of floats in Python 2.6 + -1.1 + >>> print(tree.end()) # ditto + 1.7 + + Iteration:: + + >>> tree = IntervalTree([Interval(-11, 11), Interval(-5, 15), Interval(5, 17)]) + >>> [iv.begin for iv in sorted(tree)] + [-11, -5, 5] + >>> assert tree.items() == set([Interval(-5, 15), Interval(-11, 11), Interval(5, 17)]) + + Copy- and typecasting, pickling:: + + >>> tree0 = IntervalTree([Interval(0, 1, "x"), Interval(1, 2, ["x"])]) + >>> tree1 = IntervalTree(tree0) # Shares Interval objects + >>> tree2 = tree0.copy() # Shallow copy (same as above, as Intervals are singletons) + >>> import pickle + >>> tree3 = pickle.loads(pickle.dumps(tree0)) # Deep copy + >>> list(tree0[1])[0].data[0] = "y" # affects shallow copies, but not deep copies + >>> tree0 + IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])]) + >>> tree1 + IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])]) + >>> tree2 + IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['y'])]) + >>> tree3 + IntervalTree([Interval(0, 1, 'x'), Interval(1, 2, ['x'])]) + + Equality testing:: + + >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1)]) + True + >>> IntervalTree([Interval(0, 1)]) == IntervalTree([Interval(0, 1, "x")]) + False + """ + @classmethod + def from_tuples(cls, tups): + """ + Create a new IntervalTree from an iterable of 2- or 3-tuples, + where the tuple lists begin, end, and optionally data. + """ + ivs = [Interval(*t) for t in tups] + return IntervalTree(ivs) + + def __init__(self, intervals=None): + """ + Set up a tree. If intervals is provided, add all the intervals + to the tree. + + Completes in O(n*log n) time. + """ + intervals = set(intervals) if intervals is not None else set() + for iv in intervals: + if iv.is_null(): + raise ValueError( + "IntervalTree: Null Interval objects not allowed in IntervalTree:" + " {0}".format(iv) + ) + self.all_intervals = intervals + self.top_node = Node.from_intervals(self.all_intervals) + self.boundary_table = SortedDict() + for iv in self.all_intervals: + self._add_boundaries(iv) + + def copy(self): + """ + Construct a new IntervalTree using shallow copies of the + intervals in the source tree. + + Completes in O(n*log n) time. + :rtype: IntervalTree + """ + return IntervalTree(iv.copy() for iv in self) + + def _add_boundaries(self, interval): + """ + Records the boundaries of the interval in the boundary table. + """ + begin = interval.begin + end = interval.end + if begin in self.boundary_table: + self.boundary_table[begin] += 1 + else: + self.boundary_table[begin] = 1 + + if end in self.boundary_table: + self.boundary_table[end] += 1 + else: + self.boundary_table[end] = 1 + + def _remove_boundaries(self, interval): + """ + Removes the boundaries of the interval from the boundary table. + """ + begin = interval.begin + end = interval.end + if self.boundary_table[begin] == 1: + del self.boundary_table[begin] + else: + self.boundary_table[begin] -= 1 + + if self.boundary_table[end] == 1: + del self.boundary_table[end] + else: + self.boundary_table[end] -= 1 + + def add(self, interval): + """ + Adds an interval to the tree, if not already present. + + Completes in O(log n) time. + """ + if interval in self: + return + + if interval.is_null(): + raise ValueError( + "IntervalTree: Null Interval objects not allowed in IntervalTree:" + " {0}".format(interval) + ) + + if not self.top_node: + self.top_node = Node.from_interval(interval) + else: + self.top_node = self.top_node.add(interval) + self.all_intervals.add(interval) + self._add_boundaries(interval) + append = add + + def addi(self, begin, end, data=None): + """ + Shortcut for add(Interval(begin, end, data)). + + Completes in O(log n) time. + """ + return self.add(Interval(begin, end, data)) + appendi = addi + + def update(self, intervals): + """ + Given an iterable of intervals, add them to the tree. + + Completes in O(m*log(n+m), where m = number of intervals to + add. + """ + for iv in intervals: + self.add(iv) + + def remove(self, interval): + """ + Removes an interval from the tree, if present. If not, raises + ValueError. + + Completes in O(log n) time. + """ + #self.verify() + if interval not in self: + #print(self.all_intervals) + raise ValueError + self.top_node = self.top_node.remove(interval) + self.all_intervals.remove(interval) + self._remove_boundaries(interval) + #self.verify() + + def removei(self, begin, end, data=None): + """ + Shortcut for remove(Interval(begin, end, data)). + + Completes in O(log n) time. + """ + return self.remove(Interval(begin, end, data)) + + def discard(self, interval): + """ + Removes an interval from the tree, if present. If not, does + nothing. + + Completes in O(log n) time. + """ + if interval not in self: + return + self.all_intervals.discard(interval) + self.top_node = self.top_node.discard(interval) + self._remove_boundaries(interval) + + def discardi(self, begin, end, data=None): + """ + Shortcut for discard(Interval(begin, end, data)). + + Completes in O(log n) time. + """ + return self.discard(Interval(begin, end, data)) + + def difference(self, other): + """ + Returns a new tree, comprising all intervals in self but not + in other. + """ + ivs = set() + for iv in self: + if iv not in other: + ivs.add(iv) + return IntervalTree(ivs) + + def difference_update(self, other): + """ + Removes all intervals in other from self. + """ + for iv in other: + self.discard(iv) + + def union(self, other): + """ + Returns a new tree, comprising all intervals from self + and other. + """ + return IntervalTree(set(self).union(other)) + + def intersection(self, other): + """ + Returns a new tree of all intervals common to both self and + other. + """ + ivs = set() + shorter, longer = sorted([self, other], key=len) + for iv in shorter: + if iv in longer: + ivs.add(iv) + return IntervalTree(ivs) + + def intersection_update(self, other): + """ + Removes intervals from self unless they also exist in other. + """ + ivs = list(self) + for iv in ivs: + if iv not in other: + self.remove(iv) + + def symmetric_difference(self, other): + """ + Return a tree with elements only in self or other but not + both. + """ + if not isinstance(other, set): other = set(other) + me = set(self) + ivs = me.difference(other).union(other.difference(me)) + return IntervalTree(ivs) + + def symmetric_difference_update(self, other): + """ + Throws out all intervals except those only in self or other, + not both. + """ + other = set(other) + ivs = list(self) + for iv in ivs: + if iv in other: + self.remove(iv) + other.remove(iv) + self.update(other) + + def remove_overlap(self, begin, end=None): + """ + Removes all intervals overlapping the given point or range. + + Completes in O((r+m)*log n) time, where: + * n = size of the tree + * m = number of matches + * r = size of the search range (this is 1 for a point) + """ + hitlist = self.at(begin) if end is None else self.overlap(begin, end) + for iv in hitlist: + self.remove(iv) + + def remove_envelop(self, begin, end): + """ + Removes all intervals completely enveloped in the given range. + + Completes in O((r+m)*log n) time, where: + * n = size of the tree + * m = number of matches + * r = size of the search range + """ + hitlist = self.envelop(begin, end) + for iv in hitlist: + self.remove(iv) + + def chop(self, begin, end, datafunc=None): + """ + Like remove_envelop(), but trims back Intervals hanging into + the chopped area so that nothing overlaps. + """ + insertions = set() + begin_hits = [iv for iv in self.at(begin) if iv.begin < begin] + end_hits = [iv for iv in self.at(end) if iv.end > end] + + if datafunc: + for iv in begin_hits: + insertions.add(Interval(iv.begin, begin, datafunc(iv, True))) + for iv in end_hits: + insertions.add(Interval(end, iv.end, datafunc(iv, False))) + else: + for iv in begin_hits: + insertions.add(Interval(iv.begin, begin, iv.data)) + for iv in end_hits: + insertions.add(Interval(end, iv.end, iv.data)) + + self.remove_envelop(begin, end) + self.difference_update(begin_hits) + self.difference_update(end_hits) + self.update(insertions) + + def slice(self, point, datafunc=None): + """ + Split Intervals that overlap point into two new Intervals. if + specified, uses datafunc(interval, islower=True/False) to + set the data field of the new Intervals. + :param point: where to slice + :param datafunc(interval, isupper): callable returning a new + value for the interval's data field + """ + hitlist = set(iv for iv in self.at(point) if iv.begin < point) + insertions = set() + if datafunc: + for iv in hitlist: + insertions.add(Interval(iv.begin, point, datafunc(iv, True))) + insertions.add(Interval(point, iv.end, datafunc(iv, False))) + else: + for iv in hitlist: + insertions.add(Interval(iv.begin, point, iv.data)) + insertions.add(Interval(point, iv.end, iv.data)) + self.difference_update(hitlist) + self.update(insertions) + + def clear(self): + """ + Empties the tree. + + Completes in O(1) tine. + """ + self.__init__() + + def find_nested(self): + """ + Returns a dictionary mapping parent intervals to sets of + intervals overlapped by and contained in the parent. + + Completes in O(n^2) time. + :rtype: dict of [Interval, set of Interval] + """ + result = {} + + def add_if_nested(): + if parent.contains_interval(child): + if parent not in result: + result[parent] = set() + result[parent].add(child) + + long_ivs = sorted(self.all_intervals, key=Interval.length, reverse=True) + for i, parent in enumerate(long_ivs): + for child in long_ivs[i + 1:]: + add_if_nested() + return result + + def overlaps(self, begin, end=None): + """ + Returns whether some interval in the tree overlaps the given + point or range. + + Completes in O(r*log n) time, where r is the size of the + search range. + :rtype: bool + """ + if end is not None: + return self.overlaps_range(begin, end) + elif isinstance(begin, Number): + return self.overlaps_point(begin) + else: + return self.overlaps_range(begin.begin, begin.end) + + def overlaps_point(self, p): + """ + Returns whether some interval in the tree overlaps p. + + Completes in O(log n) time. + :rtype: bool + """ + if self.is_empty(): + return False + return bool(self.top_node.contains_point(p)) + + def overlaps_range(self, begin, end): + """ + Returns whether some interval in the tree overlaps the given + range. Returns False if given a null interval over which to + test. + + Completes in O(r*log n) time, where r is the range length and n + is the table size. + :rtype: bool + """ + if self.is_empty(): + return False + elif begin >= end: + return False + elif self.overlaps_point(begin): + return True + return any( + self.overlaps_point(bound) + for bound in self.boundary_table + if begin < bound < end + ) + + def split_overlaps(self): + """ + Finds all intervals with overlapping ranges and splits them + along the range boundaries. + + Completes in worst-case O(n^2*log n) time (many interval + boundaries are inside many intervals), best-case O(n*log n) + time (small number of overlaps << n per interval). + """ + if not self: + return + if len(self.boundary_table) == 2: + return + + bounds = sorted(self.boundary_table) # get bound locations + + new_ivs = set() + for lbound, ubound in zip(bounds[:-1], bounds[1:]): + for iv in self[lbound]: + new_ivs.add(Interval(lbound, ubound, iv.data)) + + self.__init__(new_ivs) + + def merge_overlaps(self, data_reducer=None, data_initializer=None, strict=True): + """ + Finds all intervals with overlapping ranges and merges them + into a single interval. If provided, uses data_reducer and + data_initializer with similar semantics to Python's built-in + reduce(reducer_func[, initializer]), as follows: + + If data_reducer is set to a function, combines the data + fields of the Intervals with + current_reduced_data = data_reducer(current_reduced_data, new_data) + If data_reducer is None, the merged Interval's data + field will be set to None, ignoring all the data fields + of the merged Intervals. + + On encountering the first Interval to merge, if + data_initializer is None (default), uses the first + Interval's data field as the first value for + current_reduced_data. If data_initializer is not None, + current_reduced_data is set to a shallow copy of + data_initializer created with copy.copy(data_initializer). + + If strict is True (default), intervals are only merged if + their ranges actually overlap; adjacent, touching intervals + will not be merged. If strict is False, intervals are merged + even if they are only end-to-end adjacent. + + Completes in O(n*logn). + """ + if not self: + return + + sorted_intervals = sorted(self.all_intervals) # get sorted intervals + merged = [] + # use mutable object to allow new_series() to modify it + current_reduced = [None] + higher = None # iterating variable, which new_series() needs access to + + def new_series(): + if data_initializer is None: + current_reduced[0] = higher.data + merged.append(higher) + return + else: # data_initializer is not None + current_reduced[0] = copy(data_initializer) + current_reduced[0] = data_reducer(current_reduced[0], higher.data) + merged.append(Interval(higher.begin, higher.end, current_reduced[0])) + + for higher in sorted_intervals: + if merged: # series already begun + lower = merged[-1] + if (higher.begin < lower.end or + not strict and higher.begin == lower.end): # should merge + upper_bound = max(lower.end, higher.end) + if data_reducer is not None: + current_reduced[0] = data_reducer(current_reduced[0], higher.data) + else: # annihilate the data, since we don't know how to merge it + current_reduced[0] = None + merged[-1] = Interval(lower.begin, upper_bound, current_reduced[0]) + else: + new_series() + else: # not merged; is first of Intervals to merge + new_series() + + self.__init__(merged) + + def merge_equals(self, data_reducer=None, data_initializer=None): + """ + Finds all intervals with equal ranges and merges them + into a single interval. If provided, uses data_reducer and + data_initializer with similar semantics to Python's built-in + reduce(reducer_func[, initializer]), as follows: + + If data_reducer is set to a function, combines the data + fields of the Intervals with + current_reduced_data = data_reducer(current_reduced_data, new_data) + If data_reducer is None, the merged Interval's data + field will be set to None, ignoring all the data fields + of the merged Intervals. + + On encountering the first Interval to merge, if + data_initializer is None (default), uses the first + Interval's data field as the first value for + current_reduced_data. If data_initializer is not None, + current_reduced_data is set to a shallow copy of + data_initiazer created with + copy.copy(data_initializer). + + Completes in O(n*logn). + """ + if not self: + return + + sorted_intervals = sorted(self.all_intervals) # get sorted intervals + merged = [] + # use mutable object to allow new_series() to modify it + current_reduced = [None] + higher = None # iterating variable, which new_series() needs access to + + def new_series(): + if data_initializer is None: + current_reduced[0] = higher.data + merged.append(higher) + return + else: # data_initializer is not None + current_reduced[0] = copy(data_initializer) + current_reduced[0] = data_reducer(current_reduced[0], higher.data) + merged.append(Interval(higher.begin, higher.end, current_reduced[0])) + + for higher in sorted_intervals: + if merged: # series already begun + lower = merged[-1] + if higher.range_matches(lower): # should merge + upper_bound = max(lower.end, higher.end) + if data_reducer is not None: + current_reduced[0] = data_reducer(current_reduced[0], higher.data) + else: # annihilate the data, since we don't know how to merge it + current_reduced[0] = None + merged[-1] = Interval(lower.begin, upper_bound, current_reduced[0]) + else: + new_series() + else: # not merged; is first of Intervals to merge + new_series() + + self.__init__(merged) + + def items(self): + """ + Constructs and returns a set of all intervals in the tree. + + Completes in O(n) time. + :rtype: set of Interval + """ + return set(self.all_intervals) + + def is_empty(self): + """ + Returns whether the tree is empty. + + Completes in O(1) time. + :rtype: bool + """ + return 0 == len(self) + + def at(self, p): + """ + Returns the set of all intervals that contain p. + + Completes in O(m + log n) time, where: + * n = size of the tree + * m = number of matches + :rtype: set of Interval + """ + root = self.top_node + if not root: + return set() + return root.search_point(p, set()) + + def envelop(self, begin, end=None): + """ + Returns the set of all intervals fully contained in the range + [begin, end). + + Completes in O(m + k*log n) time, where: + * n = size of the tree + * m = number of matches + * k = size of the search range + :rtype: set of Interval + """ + root = self.top_node + if not root: + return set() + if end is None: + iv = begin + return self.envelop(iv.begin, iv.end) + elif begin >= end: + return set() + result = root.search_point(begin, set()) # bound_begin might be greater + boundary_table = self.boundary_table + bound_begin = boundary_table.bisect_left(begin) + bound_end = boundary_table.bisect_left(end) # up to, but not including end + result.update(root.search_overlap( + # slice notation is slightly slower + boundary_table.keys()[index] for index in xrange(bound_begin, bound_end) + )) + + # TODO: improve envelop() to use node info instead of less-efficient filtering + result = set( + iv for iv in result + if iv.begin >= begin and iv.end <= end + ) + return result + + def overlap(self, begin, end=None): + """ + Returns a set of all intervals overlapping the given range. + + Completes in O(m + k*log n) time, where: + * n = size of the tree + * m = number of matches + * k = size of the search range + :rtype: set of Interval + """ + root = self.top_node + if not root: + return set() + if end is None: + iv = begin + return self.overlap(iv.begin, iv.end) + elif begin >= end: + return set() + result = root.search_point(begin, set()) # bound_begin might be greater + boundary_table = self.boundary_table + bound_begin = boundary_table.bisect_left(begin) + bound_end = boundary_table.bisect_left(end) # up to, but not including end + result.update(root.search_overlap( + # slice notation is slightly slower + boundary_table.keys()[index] for index in xrange(bound_begin, bound_end) + )) + return result + + def begin(self): + """ + Returns the lower bound of the first interval in the tree. + + Completes in O(1) time. + """ + if not self.boundary_table: + return 0 + return self.boundary_table.keys()[0] + + def end(self): + """ + Returns the upper bound of the last interval in the tree. + + Completes in O(1) time. + """ + if not self.boundary_table: + return 0 + return self.boundary_table.keys()[-1] + + def range(self): + """ + Returns a minimum-spanning Interval that encloses all the + members of this IntervalTree. If the tree is empty, returns + null Interval. + :rtype: Interval + """ + return Interval(self.begin(), self.end()) + + def span(self): + """ + Returns the length of the minimum-spanning Interval that + encloses all the members of this IntervalTree. If the tree + is empty, return 0. + """ + if not self: + return 0 + return self.end() - self.begin() + + def print_structure(self, tostring=False): + """ + ## FOR DEBUGGING ONLY ## + Pretty-prints the structure of the tree. + If tostring is true, prints nothing and returns a string. + :rtype: None or str + """ + if self.top_node: + return self.top_node.print_structure(tostring=tostring) + else: + result = "" + if not tostring: + print(result) + else: + return result + + def verify(self): + """ + ## FOR DEBUGGING ONLY ## + Checks the table to ensure that the invariants are held. + """ + if self.all_intervals: + ## top_node.all_children() == self.all_intervals + try: + assert self.top_node.all_children() == self.all_intervals + except AssertionError as e: + print( + 'Error: the tree and the membership set are out of sync!' + ) + tivs = set(self.top_node.all_children()) + print('top_node.all_children() - all_intervals:') + try: + pprint + except NameError: + from pprint import pprint + pprint(tivs - self.all_intervals) + print('all_intervals - top_node.all_children():') + pprint(self.all_intervals - tivs) + raise e + + ## All members are Intervals + for iv in self: + assert isinstance(iv, Interval), ( + "Error: Only Interval objects allowed in IntervalTree:" + " {0}".format(iv) + ) + + ## No null intervals + for iv in self: + assert not iv.is_null(), ( + "Error: Null Interval objects not allowed in IntervalTree:" + " {0}".format(iv) + ) + + ## Reconstruct boundary_table + bound_check = {} + for iv in self: + if iv.begin in bound_check: + bound_check[iv.begin] += 1 + else: + bound_check[iv.begin] = 1 + if iv.end in bound_check: + bound_check[iv.end] += 1 + else: + bound_check[iv.end] = 1 + + ## Reconstructed boundary table (bound_check) ==? boundary_table + assert set(self.boundary_table.keys()) == set(bound_check.keys()),\ + 'Error: boundary_table is out of sync with ' \ + 'the intervals in the tree!' + + # For efficiency reasons this should be iteritems in Py2, but we + # don't care much for efficiency in debug methods anyway. + for key, val in self.boundary_table.items(): + assert bound_check[key] == val, \ + 'Error: boundary_table[{0}] should be {1},' \ + ' but is {2}!'.format( + key, bound_check[key], val) + + ## Internal tree structure + self.top_node.verify(set()) + else: + ## Verify empty tree + assert not self.boundary_table, \ + "Error: boundary table should be empty!" + assert self.top_node is None, \ + "Error: top_node isn't None!" + + def score(self, full_report=False): + """ + Returns a number between 0 and 1, indicating how suboptimal the tree + is. The lower, the better. Roughly, this number represents the + fraction of flawed Intervals in the tree. + :rtype: float + """ + if len(self) <= 2: + return 0.0 + + n = len(self) + m = self.top_node.count_nodes() + + def s_center_score(): + """ + Returns a normalized score, indicating roughly how many times + intervals share s_center with other intervals. Output is full-scale + from 0 to 1. + :rtype: float + """ + raw = n - m + maximum = n - 1 + return raw / float(maximum) + + report = { + "depth": self.top_node.depth_score(n, m), + "s_center": s_center_score(), + } + cumulative = max(report.values()) + report["_cumulative"] = cumulative + if full_report: + return report + return cumulative + + + def __getitem__(self, index): + """ + Returns a set of all intervals overlapping the given index or + slice. + + Completes in O(k * log(n) + m) time, where: + * n = size of the tree + * m = number of matches + * k = size of the search range (this is 1 for a point) + :rtype: set of Interval + """ + try: + start, stop = index.start, index.stop + if start is None: + start = self.begin() + if stop is None: + return set(self) + if stop is None: + stop = self.end() + return self.overlap(start, stop) + except AttributeError: + return self.at(index) + + def __setitem__(self, index, value): + """ + Adds a new interval to the tree. A shortcut for + add(Interval(index.start, index.stop, value)). + + If an identical Interval object with equal range and data + already exists, does nothing. + + Completes in O(log n) time. + """ + self.addi(index.start, index.stop, value) + + def __delitem__(self, point): + """ + Delete all items overlapping point. + """ + self.remove_overlap(point) + + def __contains__(self, item): + """ + Returns whether item exists as an Interval in the tree. + This method only returns True for exact matches; for + overlaps, see the overlaps() method. + + Completes in O(1) time. + :rtype: bool + """ + # Removed point-checking code; it might trick the user into + # thinking that this is O(1), which point-checking isn't. + #if isinstance(item, Interval): + return item in self.all_intervals + #else: + # return self.contains_point(item) + + def containsi(self, begin, end, data=None): + """ + Shortcut for (Interval(begin, end, data) in tree). + + Completes in O(1) time. + :rtype: bool + """ + return Interval(begin, end, data) in self + + def __iter__(self): + """ + Returns an iterator over all the intervals in the tree. + + Completes in O(1) time. + :rtype: collections.Iterable[Interval] + """ + return self.all_intervals.__iter__() + iter = __iter__ + + def __len__(self): + """ + Returns how many intervals are in the tree. + + Completes in O(1) time. + :rtype: int + """ + return len(self.all_intervals) + + def __eq__(self, other): + """ + Whether two IntervalTrees are equal. + + Completes in O(n) time if sizes are equal; O(1) time otherwise. + :rtype: bool + """ + return ( + isinstance(other, IntervalTree) and + self.all_intervals == other.all_intervals + ) + + def __repr__(self): + """ + :rtype: str + """ + ivs = sorted(self) + if not ivs: + return "IntervalTree()" + else: + return "IntervalTree({0})".format(ivs) + + __str__ = __repr__ + + def __reduce__(self): + """ + For pickle-ing. + :rtype: tuple + """ + return IntervalTree, (sorted(self.all_intervals),) + diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/node.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/node.py new file mode 100644 index 0000000..9cdbfbb --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/node.py @@ -0,0 +1,590 @@ +""" +intervaltree: A mutable, self-balancing interval tree for Python 2 and 3. +Queries may be by point, by range overlap, or by range envelopment. + +Core logic: internal tree nodes. + +Copyright 2013-2018 Chaim Leib Halbert +Modifications Copyright 2014 Konstantin Tretyakov + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +from operator import attrgetter +from math import floor, log + + +def l2(num): + """ + log base 2 + :rtype real + """ + return log(num, 2) + + +class Node(object): + def __init__(self, + x_center=None, + s_center=set(), + left_node=None, + right_node=None): + self.x_center = x_center + self.s_center = set(s_center) + self.left_node = left_node + self.right_node = right_node + self.depth = 0 # will be set when rotated + self.balance = 0 # ditto + self.rotate() + + @classmethod + def from_interval(cls, interval): + """ + :rtype : Node + """ + center = interval.begin + return Node(center, [interval]) + + @classmethod + def from_intervals(cls, intervals): + """ + :rtype : Node + """ + if not intervals: + return None + node = Node() + node = node.init_from_sorted(sorted(intervals)) + return node + + def init_from_sorted(self, intervals): + # assumes that intervals is a non-empty collection. + # Else, next line raises IndexError + center_iv = intervals[len(intervals) // 2] + self.x_center = center_iv.begin + self.s_center = set() + s_left = [] + s_right = [] + for k in intervals: + if k.end <= self.x_center: + s_left.append(k) + elif k.begin > self.x_center: + s_right.append(k) + else: + self.s_center.add(k) + self.left_node = Node.from_intervals(s_left) + self.right_node = Node.from_intervals(s_right) + return self.rotate() + + def center_hit(self, interval): + """Returns whether interval overlaps self.x_center.""" + return interval.contains_point(self.x_center) + + def hit_branch(self, interval): + """ + Assuming not center_hit(interval), return which branch + (left=0, right=1) interval is in. + """ + return interval.begin > self.x_center + + def refresh_balance(self): + """ + Recalculate self.balance and self.depth based on child node values. + """ + left_depth = self.left_node.depth if self.left_node else 0 + right_depth = self.right_node.depth if self.right_node else 0 + self.depth = 1 + max(left_depth, right_depth) + self.balance = right_depth - left_depth + + def compute_depth(self): + """ + Recursively computes true depth of the subtree. Should only + be needed for debugging. Unless something is wrong, the + depth field should reflect the correct depth of the subtree. + """ + left_depth = self.left_node.compute_depth() if self.left_node else 0 + right_depth = self.right_node.compute_depth() if self.right_node else 0 + return 1 + max(left_depth, right_depth) + + def rotate(self): + """ + Does rotating, if necessary, to balance this node, and + returns the new top node. + """ + self.refresh_balance() + if abs(self.balance) < 2: + return self + # balance > 0 is the heavy side + my_heavy = self.balance > 0 + child_heavy = self[my_heavy].balance > 0 + if my_heavy == child_heavy or self[my_heavy].balance == 0: + ## Heavy sides same + # self save + # save -> 1 self + # 1 + # + ## Heavy side balanced + # self save save + # save -> 1 self -> 1 self.rot() + # 1 2 2 + return self.srotate() + else: + return self.drotate() + + def srotate(self): + """Single rotation. Assumes that balance is +-2.""" + # self save save + # save 3 -> 1 self -> 1 self.rot() + # 1 2 2 3 + # + # self save save + # 3 save -> self 1 -> self.rot() 1 + # 2 1 3 2 + + #assert(self.balance != 0) + heavy = self.balance > 0 + light = not heavy + save = self[heavy] + #print("srotate: bal={},{}".format(self.balance, save.balance)) + #self.print_structure() + self[heavy] = save[light] # 2 + #assert(save[light]) + save[light] = self.rotate() # Needed to ensure the 2 and 3 are balanced under new subnode + + # Some intervals may overlap both self.x_center and save.x_center + # Promote those to the new tip of the tree + promotees = [iv for iv in save[light].s_center if save.center_hit(iv)] + if promotees: + for iv in promotees: + save[light] = save[light].remove(iv) # may trigger pruning + # TODO: Use Node.add() here, to simplify future balancing improvements. + # For now, this is the same as augmenting save.s_center, but that may + # change. + save.s_center.update(promotees) + save.refresh_balance() + return save + + def drotate(self): + # First rotation + my_heavy = self.balance > 0 + self[my_heavy] = self[my_heavy].srotate() + self.refresh_balance() + + # Second rotation + result = self.srotate() + + return result + + def add(self, interval): + """ + Returns self after adding the interval and balancing. + """ + if self.center_hit(interval): + self.s_center.add(interval) + return self + else: + direction = self.hit_branch(interval) + if not self[direction]: + self[direction] = Node.from_interval(interval) + self.refresh_balance() + return self + else: + self[direction] = self[direction].add(interval) + return self.rotate() + + def remove(self, interval): + """ + Returns self after removing the interval and balancing. + + If interval is not present, raise ValueError. + """ + # since this is a list, called methods can set this to [1], + # making it true + done = [] + return self.remove_interval_helper(interval, done, should_raise_error=True) + + def discard(self, interval): + """ + Returns self after removing interval and balancing. + + If interval is not present, do nothing. + """ + done = [] + return self.remove_interval_helper(interval, done, should_raise_error=False) + + def remove_interval_helper(self, interval, done, should_raise_error): + """ + Returns self after removing interval and balancing. + If interval doesn't exist, raise ValueError. + + This method may set done to [1] to tell all callers that + rebalancing has completed. + + See Eternally Confuzzled's jsw_remove_r function (lines 1-32) + in his AVL tree article for reference. + """ + #trace = interval.begin == 347 and interval.end == 353 + #if trace: print('\nRemoving from {} interval {}'.format( + # self.x_center, interval)) + if self.center_hit(interval): + #if trace: print('Hit at {}'.format(self.x_center)) + if not should_raise_error and interval not in self.s_center: + done.append(1) + #if trace: print('Doing nothing.') + return self + try: + # raises error if interval not present - this is + # desired. + self.s_center.remove(interval) + except: + self.print_structure() + raise KeyError(interval) + if self.s_center: # keep this node + done.append(1) # no rebalancing necessary + #if trace: print('Removed, no rebalancing.') + return self + + # If we reach here, no intervals are left in self.s_center. + # So, prune self. + return self.prune() + else: # interval not in s_center + direction = self.hit_branch(interval) + + if not self[direction]: + if should_raise_error: + raise ValueError + done.append(1) + return self + + #if trace: + # print('Descending to {} branch'.format( + # ['left', 'right'][direction] + # )) + self[direction] = self[direction].remove_interval_helper(interval, done, should_raise_error) + + # Clean up + if not done: + #if trace: + # print('Rotating {}'.format(self.x_center)) + # self.print_structure() + return self.rotate() + return self + + def search_overlap(self, point_list): + """ + Returns all intervals that overlap the point_list. + """ + result = set() + for j in point_list: + self.search_point(j, result) + return result + + def search_point(self, point, result): + """ + Returns all intervals that contain point. + """ + for k in self.s_center: + if k.begin <= point < k.end: + result.add(k) + if point < self.x_center and self[0]: + return self[0].search_point(point, result) + elif point > self.x_center and self[1]: + return self[1].search_point(point, result) + return result + + def prune(self): + """ + On a subtree where the root node's s_center is empty, + return a new subtree with no empty s_centers. + """ + if not self[0] or not self[1]: # if I have an empty branch + direction = not self[0] # graft the other branch here + #if trace: + # print('Grafting {} branch'.format( + # 'right' if direction else 'left')) + + result = self[direction] + #if result: result.verify() + return result + else: + # Replace the root node with the greatest predecessor. + heir, self[0] = self[0].pop_greatest_child() + #if trace: + # print('Replacing {} with {}.'.format( + # self.x_center, heir.x_center + # )) + # print('Removed greatest predecessor:') + # self.print_structure() + + #if self[0]: self[0].verify() + #if self[1]: self[1].verify() + + # Set up the heir as the new root node + (heir[0], heir[1]) = (self[0], self[1]) + #if trace: print('Setting up the heir:') + #if trace: heir.print_structure() + + # popping the predecessor may have unbalanced this node; + # fix it + heir.refresh_balance() + heir = heir.rotate() + #heir.verify() + #if trace: print('Rotated the heir:') + #if trace: heir.print_structure() + return heir + + def pop_greatest_child(self): + """ + Used when pruning a node with both a left and a right branch. + Returns (greatest_child, node), where: + * greatest_child is a new node to replace the removed node. + * node is the subtree after: + - removing the greatest child + - balancing + - moving overlapping nodes into greatest_child + + Assumes that self.s_center is not empty. + + See Eternally Confuzzled's jsw_remove_r function (lines 34-54) + in his AVL tree article for reference. + """ + #print('Popping from {}'.format(self.x_center)) + if not self.right_node: # This node is the greatest child. + # To reduce the chances of an overlap with a parent, return + # a child node containing the smallest possible number of + # intervals, as close as possible to the maximum bound. + ivs = sorted(self.s_center, key=attrgetter('end', 'begin')) + max_iv = ivs.pop() + new_x_center = self.x_center + while ivs: + next_max_iv = ivs.pop() + if next_max_iv.end == max_iv.end: continue + new_x_center = max(new_x_center, next_max_iv.end) + def get_new_s_center(): + for iv in self.s_center: + if iv.contains_point(new_x_center): yield iv + + # Create a new node with the largest x_center possible. + child = Node(new_x_center, get_new_s_center()) + self.s_center -= child.s_center + + #print('Pop hit! Returning child = {}'.format( + # child.print_structure(tostring=True) + # )) + #assert not child[0] + #assert not child[1] + + if self.s_center: + #print(' and returning newnode = {}'.format( self )) + #self.verify() + return child, self + else: + #print(' and returning newnode = {}'.format( self[0] )) + #if self[0]: self[0].verify() + return child, self[0] # Rotate left child up + + else: + #print('Pop descent to {}'.format(self[1].x_center)) + (greatest_child, self[1]) = self[1].pop_greatest_child() + + # Move any overlaps into greatest_child + for iv in set(self.s_center): + if iv.contains_point(greatest_child.x_center): + self.s_center.remove(iv) + greatest_child.add(iv) + + #print('Pop Returning child = {}'.format( + # greatest_child.print_structure(tostring=True) + # )) + if self.s_center: + #print('and returning newnode = {}'.format( + # new_self.print_structure(tostring=True) + # )) + #new_self.verify() + self.refresh_balance() + new_self = self.rotate() + return greatest_child, new_self + else: + new_self = self.prune() + #print('and returning prune = {}'.format( + # new_self.print_structure(tostring=True) + # )) + #if new_self: new_self.verify() + return greatest_child, new_self + + def contains_point(self, p): + """ + Returns whether this node or a child overlaps p. + """ + for iv in self.s_center: + if iv.contains_point(p): + return True + branch = self[p > self.x_center] + return branch and branch.contains_point(p) + + def all_children(self): + return self.all_children_helper(set()) + + def all_children_helper(self, result): + result.update(self.s_center) + if self[0]: + self[0].all_children_helper(result) + if self[1]: + self[1].all_children_helper(result) + return result + + def verify(self, parents=set()): + """ + ## DEBUG ONLY ## + Recursively ensures that the invariants of an interval subtree + hold. + """ + assert(isinstance(self.s_center, set)) + + bal = self.balance + assert abs(bal) < 2, \ + "Error: Rotation should have happened, but didn't! \n{}".format( + self.print_structure(tostring=True) + ) + self.refresh_balance() + assert bal == self.balance, \ + "Error: self.balance not set correctly! \n{}".format( + self.print_structure(tostring=True) + ) + + assert self.s_center, \ + "Error: s_center is empty! \n{}".format( + self.print_structure(tostring=True) + ) + for iv in self.s_center: + assert hasattr(iv, 'begin') + assert hasattr(iv, 'end') + assert iv.begin < iv.end + assert iv.overlaps(self.x_center) + for parent in sorted(parents): + assert not iv.contains_point(parent), \ + "Error: Overlaps ancestor ({})! \n{}\n\n{}".format( + parent, iv, self.print_structure(tostring=True) + ) + if self[0]: + assert self[0].x_center < self.x_center, \ + "Error: Out-of-order left child! {}".format(self.x_center) + self[0].verify(parents.union([self.x_center])) + if self[1]: + assert self[1].x_center > self.x_center, \ + "Error: Out-of-order right child! {}".format(self.x_center) + self[1].verify(parents.union([self.x_center])) + + def __getitem__(self, index): + """ + Returns the left child if input is equivalent to False, or + the right side otherwise. + """ + if index: + return self.right_node + else: + return self.left_node + + def __setitem__(self, key, value): + """Sets the left (0) or right (1) child.""" + if key: + self.right_node = value + else: + self.left_node = value + + def __str__(self): + """ + Shows info about this node. + + Since Nodes are internal data structures not revealed to the + user, I'm not bothering to make this copy-paste-executable as a + constructor. + """ + return "Node<{0}, depth={1}, balance={2}>".format( + self.x_center, + self.depth, + self.balance + ) + #fieldcount = 'c_count,has_l,has_r = <{}, {}, {}>'.format( + # len(self.s_center), + # bool(self.left_node), + # bool(self.right_node) + #) + #fields = [self.x_center, self.balance, fieldcount] + #return "Node({}, b={}, {})".format(*fields) + + def count_nodes(self): + """ + Count the number of Nodes in this subtree. + :rtype: int + """ + count = 1 + if self.left_node: + count += self.left_node.count_nodes() + if self.right_node: + count += self.right_node.count_nodes() + return count + + def depth_score(self, n, m): + """ + Calculates flaws in balancing the tree. + :param n: size of tree + :param m: number of Nodes in tree + :rtype: real + """ + if n == 0: + return 0.0 + + # dopt is the optimal maximum depth of the tree + dopt = 1 + int(floor(l2(m))) + f = 1 / float(1 + n - dopt) + return f * self.depth_score_helper(1, dopt) + + def depth_score_helper(self, d, dopt): + """ + Gets a weighted count of the number of Intervals deeper than dopt. + :param d: current depth, starting from 0 + :param dopt: optimal maximum depth of a leaf Node + :rtype: real + """ + # di is how may levels deeper than optimal d is + di = d - dopt + if di > 0: + count = di * len(self.s_center) + else: + count = 0 + if self.right_node: + count += self.right_node.depth_score_helper(d + 1, dopt) + if self.left_node: + count += self.left_node.depth_score_helper(d + 1, dopt) + return count + + def print_structure(self, indent=0, tostring=False): + """ + For debugging. + """ + nl = '\n' + sp = indent * ' ' + + rlist = [str(self) + nl] + if self.s_center: + for iv in sorted(self.s_center): + rlist.append(sp + ' ' + repr(iv) + nl) + if self.left_node: + rlist.append(sp + '<: ') # no CR + rlist.append(self.left_node.print_structure(indent + 1, True)) + if self.right_node: + rlist.append(sp + '>: ') # no CR + rlist.append(self.right_node.print_structure(indent + 1, True)) + result = ''.join(rlist) + if tostring: + return result + else: + print(result) diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__init__.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__init__.py new file mode 100644 index 0000000..a141dd1 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__init__.py @@ -0,0 +1,74 @@ +"""Sorted Containers -- Sorted List, Sorted Dict, Sorted Set + +Sorted Containers is an Apache2 licensed containers library, written in +pure-Python, and fast as C-extensions. + +Python's standard library is great until you need a sorted collections +type. Many will attest that you can get really far without one, but the moment +you **really need** a sorted list, dict, or set, you're faced with a dozen +different implementations, most using C-extensions without great documentation +and benchmarking. + +In Python, we can do better. And we can do it in pure-Python! + +:: + + >>> from sortedcontainers import SortedList + >>> sl = SortedList(['e', 'a', 'c', 'd', 'b']) + >>> sl + SortedList(['a', 'b', 'c', 'd', 'e']) + >>> sl *= 1000000 + >>> sl.count('c') + 1000000 + >>> sl[-3:] + ['e', 'e', 'e'] + >>> from sortedcontainers import SortedDict + >>> sd = SortedDict({'c': 3, 'a': 1, 'b': 2}) + >>> sd + SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> sd.popitem(index=-1) + ('c', 3) + >>> from sortedcontainers import SortedSet + >>> ss = SortedSet('abracadabra') + >>> ss + SortedSet(['a', 'b', 'c', 'd', 'r']) + >>> ss.bisect_left('c') + 2 + +Sorted Containers takes all of the work out of Python sorted types - making +your deployment and use of Python easy. There's no need to install a C compiler +or pre-build and distribute custom extensions. Performance is a feature and +testing has 100% coverage with unit tests and hours of stress. + +:copyright: (c) 2014-2019 by Grant Jenks. +:license: Apache 2.0, see LICENSE for more details. + +""" + + +from .sortedlist import SortedList, SortedKeyList, SortedListWithKey +from .sortedset import SortedSet +from .sorteddict import ( + SortedDict, + SortedKeysView, + SortedItemsView, + SortedValuesView, +) + +__all__ = [ + 'SortedList', + 'SortedKeyList', + 'SortedListWithKey', + 'SortedDict', + 'SortedKeysView', + 'SortedItemsView', + 'SortedValuesView', + 'SortedSet', +] + +__title__ = 'sortedcontainers' +__version__ = '2.4.0' +__build__ = 0x020400 +__author__ = 'Grant Jenks' +__license__ = 'Apache 2.0' +__copyright__ = '2014-2019, Grant Jenks' diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/__init__.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..c8088d3 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/__init__.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sorteddict.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sorteddict.cpython-39.pyc new file mode 100644 index 0000000..6abe457 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sorteddict.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedlist.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedlist.cpython-39.pyc new file mode 100644 index 0000000..6f33cc6 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedlist.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedset.cpython-39.pyc b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedset.cpython-39.pyc new file mode 100644 index 0000000..a4ab239 Binary files /dev/null and b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/__pycache__/sortedset.cpython-39.pyc differ diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sorteddict.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sorteddict.py new file mode 100644 index 0000000..910f260 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sorteddict.py @@ -0,0 +1,812 @@ +"""Sorted Dict +============== + +:doc:`Sorted Containers` is an Apache2 licensed Python sorted +collections library, written in pure-Python, and fast as C-extensions. The +:doc:`introduction` is the best way to get started. + +Sorted dict implementations: + +.. currentmodule:: sortedcontainers + +* :class:`SortedDict` +* :class:`SortedKeysView` +* :class:`SortedItemsView` +* :class:`SortedValuesView` + +""" + +import sys +import warnings + +from itertools import chain + +from .sortedlist import SortedList, recursive_repr +from .sortedset import SortedSet + +############################################################################### +# BEGIN Python 2/3 Shims +############################################################################### + +try: + from collections.abc import ( + ItemsView, KeysView, Mapping, ValuesView, Sequence + ) +except ImportError: + from collections import ItemsView, KeysView, Mapping, ValuesView, Sequence + +############################################################################### +# END Python 2/3 Shims +############################################################################### + + +class SortedDict(dict): + """Sorted dict is a sorted mutable mapping. + + Sorted dict keys are maintained in sorted order. The design of sorted dict + is simple: sorted dict inherits from dict to store items and maintains a + sorted list of keys. + + Sorted dict keys must be hashable and comparable. The hash and total + ordering of keys must not change while they are stored in the sorted dict. + + Mutable mapping methods: + + * :func:`SortedDict.__getitem__` (inherited from dict) + * :func:`SortedDict.__setitem__` + * :func:`SortedDict.__delitem__` + * :func:`SortedDict.__iter__` + * :func:`SortedDict.__len__` (inherited from dict) + + Methods for adding items: + + * :func:`SortedDict.setdefault` + * :func:`SortedDict.update` + + Methods for removing items: + + * :func:`SortedDict.clear` + * :func:`SortedDict.pop` + * :func:`SortedDict.popitem` + + Methods for looking up items: + + * :func:`SortedDict.__contains__` (inherited from dict) + * :func:`SortedDict.get` (inherited from dict) + * :func:`SortedDict.peekitem` + + Methods for views: + + * :func:`SortedDict.keys` + * :func:`SortedDict.items` + * :func:`SortedDict.values` + + Methods for miscellany: + + * :func:`SortedDict.copy` + * :func:`SortedDict.fromkeys` + * :func:`SortedDict.__reversed__` + * :func:`SortedDict.__eq__` (inherited from dict) + * :func:`SortedDict.__ne__` (inherited from dict) + * :func:`SortedDict.__repr__` + * :func:`SortedDict._check` + + Sorted list methods available (applies to keys): + + * :func:`SortedList.bisect_left` + * :func:`SortedList.bisect_right` + * :func:`SortedList.count` + * :func:`SortedList.index` + * :func:`SortedList.irange` + * :func:`SortedList.islice` + * :func:`SortedList._reset` + + Additional sorted list methods available, if key-function used: + + * :func:`SortedKeyList.bisect_key_left` + * :func:`SortedKeyList.bisect_key_right` + * :func:`SortedKeyList.irange_key` + + Sorted dicts may only be compared for equality and inequality. + + """ + def __init__(self, *args, **kwargs): + """Initialize sorted dict instance. + + Optional key-function argument defines a callable that, like the `key` + argument to the built-in `sorted` function, extracts a comparison key + from each dictionary key. If no function is specified, the default + compares the dictionary keys directly. The key-function argument must + be provided as a positional argument and must come before all other + arguments. + + Optional iterable argument provides an initial sequence of pairs to + initialize the sorted dict. Each pair in the sequence defines the key + and corresponding value. If a key is seen more than once, the last + value associated with it is stored in the new sorted dict. + + Optional mapping argument provides an initial mapping of items to + initialize the sorted dict. + + If keyword arguments are given, the keywords themselves, with their + associated values, are added as items to the dictionary. If a key is + specified both in the positional argument and as a keyword argument, + the value associated with the keyword is stored in the + sorted dict. + + Sorted dict keys must be hashable, per the requirement for Python's + dictionaries. Keys (or the result of the key-function) must also be + comparable, per the requirement for sorted lists. + + >>> d = {'alpha': 1, 'beta': 2} + >>> SortedDict([('alpha', 1), ('beta', 2)]) == d + True + >>> SortedDict({'alpha': 1, 'beta': 2}) == d + True + >>> SortedDict(alpha=1, beta=2) == d + True + + """ + if args and (args[0] is None or callable(args[0])): + _key = self._key = args[0] + args = args[1:] + else: + _key = self._key = None + + self._list = SortedList(key=_key) + + # Reaching through ``self._list`` repeatedly adds unnecessary overhead + # so cache references to sorted list methods. + + _list = self._list + self._list_add = _list.add + self._list_clear = _list.clear + self._list_iter = _list.__iter__ + self._list_reversed = _list.__reversed__ + self._list_pop = _list.pop + self._list_remove = _list.remove + self._list_update = _list.update + + # Expose some sorted list methods publicly. + + self.bisect_left = _list.bisect_left + self.bisect = _list.bisect_right + self.bisect_right = _list.bisect_right + self.index = _list.index + self.irange = _list.irange + self.islice = _list.islice + self._reset = _list._reset + + if _key is not None: + self.bisect_key_left = _list.bisect_key_left + self.bisect_key_right = _list.bisect_key_right + self.bisect_key = _list.bisect_key + self.irange_key = _list.irange_key + + self._update(*args, **kwargs) + + + @property + def key(self): + """Function used to extract comparison key from keys. + + Sorted dict compares keys directly when the key function is none. + + """ + return self._key + + + @property + def iloc(self): + """Cached reference of sorted keys view. + + Deprecated in version 2 of Sorted Containers. Use + :func:`SortedDict.keys` instead. + + """ + # pylint: disable=attribute-defined-outside-init + try: + return self._iloc + except AttributeError: + warnings.warn( + 'sorted_dict.iloc is deprecated.' + ' Use SortedDict.keys() instead.', + DeprecationWarning, + stacklevel=2, + ) + _iloc = self._iloc = SortedKeysView(self) + return _iloc + + + def clear(self): + + """Remove all items from sorted dict. + + Runtime complexity: `O(n)` + + """ + dict.clear(self) + self._list_clear() + + + def __delitem__(self, key): + """Remove item from sorted dict identified by `key`. + + ``sd.__delitem__(key)`` <==> ``del sd[key]`` + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> del sd['b'] + >>> sd + SortedDict({'a': 1, 'c': 3}) + >>> del sd['z'] + Traceback (most recent call last): + ... + KeyError: 'z' + + :param key: `key` for item lookup + :raises KeyError: if key not found + + """ + dict.__delitem__(self, key) + self._list_remove(key) + + + def __iter__(self): + """Return an iterator over the keys of the sorted dict. + + ``sd.__iter__()`` <==> ``iter(sd)`` + + Iterating the sorted dict while adding or deleting items may raise a + :exc:`RuntimeError` or fail to iterate over all keys. + + """ + return self._list_iter() + + + def __reversed__(self): + """Return a reverse iterator over the keys of the sorted dict. + + ``sd.__reversed__()`` <==> ``reversed(sd)`` + + Iterating the sorted dict while adding or deleting items may raise a + :exc:`RuntimeError` or fail to iterate over all keys. + + """ + return self._list_reversed() + + + def __setitem__(self, key, value): + """Store item in sorted dict with `key` and corresponding `value`. + + ``sd.__setitem__(key, value)`` <==> ``sd[key] = value`` + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict() + >>> sd['c'] = 3 + >>> sd['a'] = 1 + >>> sd['b'] = 2 + >>> sd + SortedDict({'a': 1, 'b': 2, 'c': 3}) + + :param key: key for item + :param value: value for item + + """ + if key not in self: + self._list_add(key) + dict.__setitem__(self, key, value) + + _setitem = __setitem__ + + + def __or__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + items = chain(self.items(), other.items()) + return self.__class__(self._key, items) + + + def __ror__(self, other): + if not isinstance(other, Mapping): + return NotImplemented + items = chain(other.items(), self.items()) + return self.__class__(self._key, items) + + + def __ior__(self, other): + self._update(other) + return self + + + def copy(self): + """Return a shallow copy of the sorted dict. + + Runtime complexity: `O(n)` + + :return: new sorted dict + + """ + return self.__class__(self._key, self.items()) + + __copy__ = copy + + + @classmethod + def fromkeys(cls, iterable, value=None): + """Return a new sorted dict initailized from `iterable` and `value`. + + Items in the sorted dict have keys from `iterable` and values equal to + `value`. + + Runtime complexity: `O(n*log(n))` + + :return: new sorted dict + + """ + return cls((key, value) for key in iterable) + + + def keys(self): + """Return new sorted keys view of the sorted dict's keys. + + See :class:`SortedKeysView` for details. + + :return: new sorted keys view + + """ + return SortedKeysView(self) + + + def items(self): + """Return new sorted items view of the sorted dict's items. + + See :class:`SortedItemsView` for details. + + :return: new sorted items view + + """ + return SortedItemsView(self) + + + def values(self): + """Return new sorted values view of the sorted dict's values. + + See :class:`SortedValuesView` for details. + + :return: new sorted values view + + """ + return SortedValuesView(self) + + + if sys.hexversion < 0x03000000: + def __make_raise_attributeerror(original, alternate): + # pylint: disable=no-self-argument + message = ( + 'SortedDict.{original}() is not implemented.' + ' Use SortedDict.{alternate}() instead.' + ).format(original=original, alternate=alternate) + def method(self): + # pylint: disable=missing-docstring,unused-argument + raise AttributeError(message) + method.__name__ = original # pylint: disable=non-str-assignment-to-dunder-name + method.__doc__ = message + return property(method) + + iteritems = __make_raise_attributeerror('iteritems', 'items') + iterkeys = __make_raise_attributeerror('iterkeys', 'keys') + itervalues = __make_raise_attributeerror('itervalues', 'values') + viewitems = __make_raise_attributeerror('viewitems', 'items') + viewkeys = __make_raise_attributeerror('viewkeys', 'keys') + viewvalues = __make_raise_attributeerror('viewvalues', 'values') + + + class _NotGiven(object): + # pylint: disable=too-few-public-methods + def __repr__(self): + return '' + + __not_given = _NotGiven() + + def pop(self, key, default=__not_given): + """Remove and return value for item identified by `key`. + + If the `key` is not found then return `default` if given. If `default` + is not given then raise :exc:`KeyError`. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> sd.pop('c') + 3 + >>> sd.pop('z', 26) + 26 + >>> sd.pop('y') + Traceback (most recent call last): + ... + KeyError: 'y' + + :param key: `key` for item + :param default: `default` value if key not found (optional) + :return: value for item + :raises KeyError: if `key` not found and `default` not given + + """ + if key in self: + self._list_remove(key) + return dict.pop(self, key) + else: + if default is self.__not_given: + raise KeyError(key) + return default + + + def popitem(self, index=-1): + """Remove and return ``(key, value)`` pair at `index` from sorted dict. + + Optional argument `index` defaults to -1, the last item in the sorted + dict. Specify ``index=0`` for the first item in the sorted dict. + + If the sorted dict is empty, raises :exc:`KeyError`. + + If the `index` is out of range, raises :exc:`IndexError`. + + Runtime complexity: `O(log(n))` + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> sd.popitem() + ('c', 3) + >>> sd.popitem(0) + ('a', 1) + >>> sd.popitem(100) + Traceback (most recent call last): + ... + IndexError: list index out of range + + :param int index: `index` of item (default -1) + :return: key and value pair + :raises KeyError: if sorted dict is empty + :raises IndexError: if `index` out of range + + """ + if not self: + raise KeyError('popitem(): dictionary is empty') + + key = self._list_pop(index) + value = dict.pop(self, key) + return (key, value) + + + def peekitem(self, index=-1): + """Return ``(key, value)`` pair at `index` in sorted dict. + + Optional argument `index` defaults to -1, the last item in the sorted + dict. Specify ``index=0`` for the first item in the sorted dict. + + Unlike :func:`SortedDict.popitem`, the sorted dict is not modified. + + If the `index` is out of range, raises :exc:`IndexError`. + + Runtime complexity: `O(log(n))` + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> sd.peekitem() + ('c', 3) + >>> sd.peekitem(0) + ('a', 1) + >>> sd.peekitem(100) + Traceback (most recent call last): + ... + IndexError: list index out of range + + :param int index: index of item (default -1) + :return: key and value pair + :raises IndexError: if `index` out of range + + """ + key = self._list[index] + return key, self[key] + + + def setdefault(self, key, default=None): + """Return value for item identified by `key` in sorted dict. + + If `key` is in the sorted dict then return its value. If `key` is not + in the sorted dict then insert `key` with value `default` and return + `default`. + + Optional argument `default` defaults to none. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict() + >>> sd.setdefault('a', 1) + 1 + >>> sd.setdefault('a', 10) + 1 + >>> sd + SortedDict({'a': 1}) + + :param key: key for item + :param default: value for item (default None) + :return: value for item identified by `key` + + """ + if key in self: + return self[key] + dict.__setitem__(self, key, default) + self._list_add(key) + return default + + + def update(self, *args, **kwargs): + """Update sorted dict with items from `args` and `kwargs`. + + Overwrites existing items. + + Optional arguments `args` and `kwargs` may be a mapping, an iterable of + pairs or keyword arguments. See :func:`SortedDict.__init__` for + details. + + :param args: mapping or iterable of pairs + :param kwargs: keyword arguments mapping + + """ + if not self: + dict.update(self, *args, **kwargs) + self._list_update(dict.__iter__(self)) + return + + if not kwargs and len(args) == 1 and isinstance(args[0], dict): + pairs = args[0] + else: + pairs = dict(*args, **kwargs) + + if (10 * len(pairs)) > len(self): + dict.update(self, pairs) + self._list_clear() + self._list_update(dict.__iter__(self)) + else: + for key in pairs: + self._setitem(key, pairs[key]) + + _update = update + + + def __reduce__(self): + """Support for pickle. + + The tricks played with caching references in + :func:`SortedDict.__init__` confuse pickle so customize the reducer. + + """ + items = dict.copy(self) + return (type(self), (self._key, items)) + + + @recursive_repr() + def __repr__(self): + """Return string representation of sorted dict. + + ``sd.__repr__()`` <==> ``repr(sd)`` + + :return: string representation + + """ + _key = self._key + type_name = type(self).__name__ + key_arg = '' if _key is None else '{0!r}, '.format(_key) + item_format = '{0!r}: {1!r}'.format + items = ', '.join(item_format(key, self[key]) for key in self._list) + return '{0}({1}{{{2}}})'.format(type_name, key_arg, items) + + + def _check(self): + """Check invariants of sorted dict. + + Runtime complexity: `O(n)` + + """ + _list = self._list + _list._check() + assert len(self) == len(_list) + assert all(key in self for key in _list) + + +def _view_delitem(self, index): + """Remove item at `index` from sorted dict. + + ``view.__delitem__(index)`` <==> ``del view[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> view = sd.keys() + >>> del view[0] + >>> sd + SortedDict({'b': 2, 'c': 3}) + >>> del view[-1] + >>> sd + SortedDict({'b': 2}) + >>> del view[:] + >>> sd + SortedDict({}) + + :param index: integer or slice for indexing + :raises IndexError: if index out of range + + """ + _mapping = self._mapping + _list = _mapping._list + dict_delitem = dict.__delitem__ + if isinstance(index, slice): + keys = _list[index] + del _list[index] + for key in keys: + dict_delitem(_mapping, key) + else: + key = _list.pop(index) + dict_delitem(_mapping, key) + + +class SortedKeysView(KeysView, Sequence): + """Sorted keys view is a dynamic view of the sorted dict's keys. + + When the sorted dict's keys change, the view reflects those changes. + + The keys view implements the set and sequence abstract base classes. + + """ + __slots__ = () + + + @classmethod + def _from_iterable(cls, it): + return SortedSet(it) + + + def __getitem__(self, index): + """Lookup key at `index` in sorted keys views. + + ``skv.__getitem__(index)`` <==> ``skv[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> skv = sd.keys() + >>> skv[0] + 'a' + >>> skv[-1] + 'c' + >>> skv[:] + ['a', 'b', 'c'] + >>> skv[100] + Traceback (most recent call last): + ... + IndexError: list index out of range + + :param index: integer or slice for indexing + :return: key or list of keys + :raises IndexError: if index out of range + + """ + return self._mapping._list[index] + + + __delitem__ = _view_delitem + + +class SortedItemsView(ItemsView, Sequence): + """Sorted items view is a dynamic view of the sorted dict's items. + + When the sorted dict's items change, the view reflects those changes. + + The items view implements the set and sequence abstract base classes. + + """ + __slots__ = () + + + @classmethod + def _from_iterable(cls, it): + return SortedSet(it) + + + def __getitem__(self, index): + """Lookup item at `index` in sorted items view. + + ``siv.__getitem__(index)`` <==> ``siv[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> siv = sd.items() + >>> siv[0] + ('a', 1) + >>> siv[-1] + ('c', 3) + >>> siv[:] + [('a', 1), ('b', 2), ('c', 3)] + >>> siv[100] + Traceback (most recent call last): + ... + IndexError: list index out of range + + :param index: integer or slice for indexing + :return: item or list of items + :raises IndexError: if index out of range + + """ + _mapping = self._mapping + _mapping_list = _mapping._list + + if isinstance(index, slice): + keys = _mapping_list[index] + return [(key, _mapping[key]) for key in keys] + + key = _mapping_list[index] + return key, _mapping[key] + + + __delitem__ = _view_delitem + + +class SortedValuesView(ValuesView, Sequence): + """Sorted values view is a dynamic view of the sorted dict's values. + + When the sorted dict's values change, the view reflects those changes. + + The values view implements the sequence abstract base class. + + """ + __slots__ = () + + + def __getitem__(self, index): + """Lookup value at `index` in sorted values view. + + ``siv.__getitem__(index)`` <==> ``siv[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sd = SortedDict({'a': 1, 'b': 2, 'c': 3}) + >>> svv = sd.values() + >>> svv[0] + 1 + >>> svv[-1] + 3 + >>> svv[:] + [1, 2, 3] + >>> svv[100] + Traceback (most recent call last): + ... + IndexError: list index out of range + + :param index: integer or slice for indexing + :return: value or list of values + :raises IndexError: if index out of range + + """ + _mapping = self._mapping + _mapping_list = _mapping._list + + if isinstance(index, slice): + keys = _mapping_list[index] + return [_mapping[key] for key in keys] + + key = _mapping_list[index] + return _mapping[key] + + + __delitem__ = _view_delitem diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedlist.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedlist.py new file mode 100644 index 0000000..e3b58eb --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedlist.py @@ -0,0 +1,2646 @@ +"""Sorted List +============== + +:doc:`Sorted Containers` is an Apache2 licensed Python sorted +collections library, written in pure-Python, and fast as C-extensions. The +:doc:`introduction` is the best way to get started. + +Sorted list implementations: + +.. currentmodule:: sortedcontainers + +* :class:`SortedList` +* :class:`SortedKeyList` + +""" +# pylint: disable=too-many-lines +from __future__ import print_function + +import sys +import traceback + +from bisect import bisect_left, bisect_right, insort +from itertools import chain, repeat, starmap +from math import log +from operator import add, eq, ne, gt, ge, lt, le, iadd +from textwrap import dedent + +############################################################################### +# BEGIN Python 2/3 Shims +############################################################################### + +try: + from collections.abc import Sequence, MutableSequence +except ImportError: + from collections import Sequence, MutableSequence + +from functools import wraps +from sys import hexversion + +if hexversion < 0x03000000: + from itertools import imap as map # pylint: disable=redefined-builtin + from itertools import izip as zip # pylint: disable=redefined-builtin + try: + from thread import get_ident + except ImportError: + from dummy_thread import get_ident +else: + from functools import reduce + try: + from _thread import get_ident + except ImportError: + from _dummy_thread import get_ident + + +def recursive_repr(fillvalue='...'): + "Decorator to make a repr function return fillvalue for a recursive call." + # pylint: disable=missing-docstring + # Copied from reprlib in Python 3 + # https://hg.python.org/cpython/file/3.6/Lib/reprlib.py + + def decorating_function(user_function): + repr_running = set() + + @wraps(user_function) + def wrapper(self): + key = id(self), get_ident() + if key in repr_running: + return fillvalue + repr_running.add(key) + try: + result = user_function(self) + finally: + repr_running.discard(key) + return result + + return wrapper + + return decorating_function + +############################################################################### +# END Python 2/3 Shims +############################################################################### + + +class SortedList(MutableSequence): + """Sorted list is a sorted mutable sequence. + + Sorted list values are maintained in sorted order. + + Sorted list values must be comparable. The total ordering of values must + not change while they are stored in the sorted list. + + Methods for adding values: + + * :func:`SortedList.add` + * :func:`SortedList.update` + * :func:`SortedList.__add__` + * :func:`SortedList.__iadd__` + * :func:`SortedList.__mul__` + * :func:`SortedList.__imul__` + + Methods for removing values: + + * :func:`SortedList.clear` + * :func:`SortedList.discard` + * :func:`SortedList.remove` + * :func:`SortedList.pop` + * :func:`SortedList.__delitem__` + + Methods for looking up values: + + * :func:`SortedList.bisect_left` + * :func:`SortedList.bisect_right` + * :func:`SortedList.count` + * :func:`SortedList.index` + * :func:`SortedList.__contains__` + * :func:`SortedList.__getitem__` + + Methods for iterating values: + + * :func:`SortedList.irange` + * :func:`SortedList.islice` + * :func:`SortedList.__iter__` + * :func:`SortedList.__reversed__` + + Methods for miscellany: + + * :func:`SortedList.copy` + * :func:`SortedList.__len__` + * :func:`SortedList.__repr__` + * :func:`SortedList._check` + * :func:`SortedList._reset` + + Sorted lists use lexicographical ordering semantics when compared to other + sequences. + + Some methods of mutable sequences are not supported and will raise + not-implemented error. + + """ + DEFAULT_LOAD_FACTOR = 1000 + + + def __init__(self, iterable=None, key=None): + """Initialize sorted list instance. + + Optional `iterable` argument provides an initial iterable of values to + initialize the sorted list. + + Runtime complexity: `O(n*log(n))` + + >>> sl = SortedList() + >>> sl + SortedList([]) + >>> sl = SortedList([3, 1, 2, 5, 4]) + >>> sl + SortedList([1, 2, 3, 4, 5]) + + :param iterable: initial values (optional) + + """ + assert key is None + self._len = 0 + self._load = self.DEFAULT_LOAD_FACTOR + self._lists = [] + self._maxes = [] + self._index = [] + self._offset = 0 + + if iterable is not None: + self._update(iterable) + + + def __new__(cls, iterable=None, key=None): + """Create new sorted list or sorted-key list instance. + + Optional `key`-function argument will return an instance of subtype + :class:`SortedKeyList`. + + >>> sl = SortedList() + >>> isinstance(sl, SortedList) + True + >>> sl = SortedList(key=lambda x: -x) + >>> isinstance(sl, SortedList) + True + >>> isinstance(sl, SortedKeyList) + True + + :param iterable: initial values (optional) + :param key: function used to extract comparison key (optional) + :return: sorted list or sorted-key list instance + + """ + # pylint: disable=unused-argument + if key is None: + return object.__new__(cls) + else: + if cls is SortedList: + return object.__new__(SortedKeyList) + else: + raise TypeError('inherit SortedKeyList for key argument') + + + @property + def key(self): # pylint: disable=useless-return + """Function used to extract comparison key from values. + + Sorted list compares values directly so the key function is none. + + """ + return None + + + def _reset(self, load): + """Reset sorted list load factor. + + The `load` specifies the load-factor of the list. The default load + factor of 1000 works well for lists from tens to tens-of-millions of + values. Good practice is to use a value that is the cube root of the + list size. With billions of elements, the best load factor depends on + your usage. It's best to leave the load factor at the default until you + start benchmarking. + + See :doc:`implementation` and :doc:`performance-scale` for more + information. + + Runtime complexity: `O(n)` + + :param int load: load-factor for sorted list sublists + + """ + values = reduce(iadd, self._lists, []) + self._clear() + self._load = load + self._update(values) + + + def clear(self): + """Remove all values from sorted list. + + Runtime complexity: `O(n)` + + """ + self._len = 0 + del self._lists[:] + del self._maxes[:] + del self._index[:] + self._offset = 0 + + _clear = clear + + + def add(self, value): + """Add `value` to sorted list. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList() + >>> sl.add(3) + >>> sl.add(1) + >>> sl.add(2) + >>> sl + SortedList([1, 2, 3]) + + :param value: value to add to sorted list + + """ + _lists = self._lists + _maxes = self._maxes + + if _maxes: + pos = bisect_right(_maxes, value) + + if pos == len(_maxes): + pos -= 1 + _lists[pos].append(value) + _maxes[pos] = value + else: + insort(_lists[pos], value) + + self._expand(pos) + else: + _lists.append([value]) + _maxes.append(value) + + self._len += 1 + + + def _expand(self, pos): + """Split sublists with length greater than double the load-factor. + + Updates the index when the sublist length is less than double the load + level. This requires incrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see + ``SortedList._loc``. + + """ + _load = self._load + _lists = self._lists + _index = self._index + + if len(_lists[pos]) > (_load << 1): + _maxes = self._maxes + + _lists_pos = _lists[pos] + half = _lists_pos[_load:] + del _lists_pos[_load:] + _maxes[pos] = _lists_pos[-1] + + _lists.insert(pos + 1, half) + _maxes.insert(pos + 1, half[-1]) + + del _index[:] + else: + if _index: + child = self._offset + pos + while child: + _index[child] += 1 + child = (child - 1) >> 1 + _index[0] += 1 + + + def update(self, iterable): + """Update sorted list by adding all values from `iterable`. + + Runtime complexity: `O(k*log(n))` -- approximate. + + >>> sl = SortedList() + >>> sl.update([3, 1, 2]) + >>> sl + SortedList([1, 2, 3]) + + :param iterable: iterable of values to add + + """ + _lists = self._lists + _maxes = self._maxes + values = sorted(iterable) + + if _maxes: + if len(values) * 4 >= self._len: + _lists.append(values) + values = reduce(iadd, _lists, []) + values.sort() + self._clear() + else: + _add = self.add + for val in values: + _add(val) + return + + _load = self._load + _lists.extend(values[pos:(pos + _load)] + for pos in range(0, len(values), _load)) + _maxes.extend(sublist[-1] for sublist in _lists) + self._len = len(values) + del self._index[:] + + _update = update + + + def __contains__(self, value): + """Return true if `value` is an element of the sorted list. + + ``sl.__contains__(value)`` <==> ``value in sl`` + + Runtime complexity: `O(log(n))` + + >>> sl = SortedList([1, 2, 3, 4, 5]) + >>> 3 in sl + True + + :param value: search for value in sorted list + :return: true if `value` in sorted list + + """ + _maxes = self._maxes + + if not _maxes: + return False + + pos = bisect_left(_maxes, value) + + if pos == len(_maxes): + return False + + _lists = self._lists + idx = bisect_left(_lists[pos], value) + + return _lists[pos][idx] == value + + + def discard(self, value): + """Remove `value` from sorted list if it is a member. + + If `value` is not a member, do nothing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList([1, 2, 3, 4, 5]) + >>> sl.discard(5) + >>> sl.discard(0) + >>> sl == [1, 2, 3, 4] + True + + :param value: `value` to discard from sorted list + + """ + _maxes = self._maxes + + if not _maxes: + return + + pos = bisect_left(_maxes, value) + + if pos == len(_maxes): + return + + _lists = self._lists + idx = bisect_left(_lists[pos], value) + + if _lists[pos][idx] == value: + self._delete(pos, idx) + + + def remove(self, value): + """Remove `value` from sorted list; `value` must be a member. + + If `value` is not a member, raise ValueError. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList([1, 2, 3, 4, 5]) + >>> sl.remove(5) + >>> sl == [1, 2, 3, 4] + True + >>> sl.remove(0) + Traceback (most recent call last): + ... + ValueError: 0 not in list + + :param value: `value` to remove from sorted list + :raises ValueError: if `value` is not in sorted list + + """ + _maxes = self._maxes + + if not _maxes: + raise ValueError('{0!r} not in list'.format(value)) + + pos = bisect_left(_maxes, value) + + if pos == len(_maxes): + raise ValueError('{0!r} not in list'.format(value)) + + _lists = self._lists + idx = bisect_left(_lists[pos], value) + + if _lists[pos][idx] == value: + self._delete(pos, idx) + else: + raise ValueError('{0!r} not in list'.format(value)) + + + def _delete(self, pos, idx): + """Delete value at the given `(pos, idx)`. + + Combines lists that are less than half the load level. + + Updates the index when the sublist length is more than half the load + level. This requires decrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see + ``SortedList._loc``. + + :param int pos: lists index + :param int idx: sublist index + + """ + _lists = self._lists + _maxes = self._maxes + _index = self._index + + _lists_pos = _lists[pos] + + del _lists_pos[idx] + self._len -= 1 + + len_lists_pos = len(_lists_pos) + + if len_lists_pos > (self._load >> 1): + _maxes[pos] = _lists_pos[-1] + + if _index: + child = self._offset + pos + while child > 0: + _index[child] -= 1 + child = (child - 1) >> 1 + _index[0] -= 1 + elif len(_lists) > 1: + if not pos: + pos += 1 + + prev = pos - 1 + _lists[prev].extend(_lists[pos]) + _maxes[prev] = _lists[prev][-1] + + del _lists[pos] + del _maxes[pos] + del _index[:] + + self._expand(prev) + elif len_lists_pos: + _maxes[pos] = _lists_pos[-1] + else: + del _lists[pos] + del _maxes[pos] + del _index[:] + + + def _loc(self, pos, idx): + """Convert an index pair (lists index, sublist index) into a single + index number that corresponds to the position of the value in the + sorted list. + + Many queries require the index be built. Details of the index are + described in ``SortedList._build_index``. + + Indexing requires traversing the tree from a leaf node to the root. The + parent of each node is easily computable at ``(pos - 1) // 2``. + + Left-child nodes are always at odd indices and right-child nodes are + always at even indices. + + When traversing up from a right-child node, increment the total by the + left-child node. + + The final index is the sum from traversal and the index in the sublist. + + For example, using the index from ``SortedList._build_index``:: + + _index = 14 5 9 3 2 4 5 + _offset = 3 + + Tree:: + + 14 + 5 9 + 3 2 4 5 + + Converting an index pair (2, 3) into a single index involves iterating + like so: + + 1. Starting at the leaf node: offset + alpha = 3 + 2 = 5. We identify + the node as a left-child node. At such nodes, we simply traverse to + the parent. + + 2. At node 9, position 2, we recognize the node as a right-child node + and accumulate the left-child in our total. Total is now 5 and we + traverse to the parent at position 0. + + 3. Iteration ends at the root. + + The index is then the sum of the total and sublist index: 5 + 3 = 8. + + :param int pos: lists index + :param int idx: sublist index + :return: index in sorted list + + """ + if not pos: + return idx + + _index = self._index + + if not _index: + self._build_index() + + total = 0 + + # Increment pos to point in the index to len(self._lists[pos]). + + pos += self._offset + + # Iterate until reaching the root of the index tree at pos = 0. + + while pos: + + # Right-child nodes are at odd indices. At such indices + # account the total below the left child node. + + if not pos & 1: + total += _index[pos - 1] + + # Advance pos to the parent node. + + pos = (pos - 1) >> 1 + + return total + idx + + + def _pos(self, idx): + """Convert an index into an index pair (lists index, sublist index) + that can be used to access the corresponding lists position. + + Many queries require the index be built. Details of the index are + described in ``SortedList._build_index``. + + Indexing requires traversing the tree to a leaf node. Each node has two + children which are easily computable. Given an index, pos, the + left-child is at ``pos * 2 + 1`` and the right-child is at ``pos * 2 + + 2``. + + When the index is less than the left-child, traversal moves to the + left sub-tree. Otherwise, the index is decremented by the left-child + and traversal moves to the right sub-tree. + + At a child node, the indexing pair is computed from the relative + position of the child node as compared with the offset and the remaining + index. + + For example, using the index from ``SortedList._build_index``:: + + _index = 14 5 9 3 2 4 5 + _offset = 3 + + Tree:: + + 14 + 5 9 + 3 2 4 5 + + Indexing position 8 involves iterating like so: + + 1. Starting at the root, position 0, 8 is compared with the left-child + node (5) which it is greater than. When greater the index is + decremented and the position is updated to the right child node. + + 2. At node 9 with index 3, we again compare the index to the left-child + node with value 4. Because the index is the less than the left-child + node, we simply traverse to the left. + + 3. At node 4 with index 3, we recognize that we are at a leaf node and + stop iterating. + + 4. To compute the sublist index, we subtract the offset from the index + of the leaf node: 5 - 3 = 2. To compute the index in the sublist, we + simply use the index remaining from iteration. In this case, 3. + + The final index pair from our example is (2, 3) which corresponds to + index 8 in the sorted list. + + :param int idx: index in sorted list + :return: (lists index, sublist index) pair + + """ + if idx < 0: + last_len = len(self._lists[-1]) + + if (-idx) <= last_len: + return len(self._lists) - 1, last_len + idx + + idx += self._len + + if idx < 0: + raise IndexError('list index out of range') + elif idx >= self._len: + raise IndexError('list index out of range') + + if idx < len(self._lists[0]): + return 0, idx + + _index = self._index + + if not _index: + self._build_index() + + pos = 0 + child = 1 + len_index = len(_index) + + while child < len_index: + index_child = _index[child] + + if idx < index_child: + pos = child + else: + idx -= index_child + pos = child + 1 + + child = (pos << 1) + 1 + + return (pos - self._offset, idx) + + + def _build_index(self): + """Build a positional index for indexing the sorted list. + + Indexes are represented as binary trees in a dense array notation + similar to a binary heap. + + For example, given a lists representation storing integers:: + + 0: [1, 2, 3] + 1: [4, 5] + 2: [6, 7, 8, 9] + 3: [10, 11, 12, 13, 14] + + The first transformation maps the sub-lists by their length. The + first row of the index is the length of the sub-lists:: + + 0: [3, 2, 4, 5] + + Each row after that is the sum of consecutive pairs of the previous + row:: + + 1: [5, 9] + 2: [14] + + Finally, the index is built by concatenating these lists together:: + + _index = [14, 5, 9, 3, 2, 4, 5] + + An offset storing the start of the first row is also stored:: + + _offset = 3 + + When built, the index can be used for efficient indexing into the list. + See the comment and notes on ``SortedList._pos`` for details. + + """ + row0 = list(map(len, self._lists)) + + if len(row0) == 1: + self._index[:] = row0 + self._offset = 0 + return + + head = iter(row0) + tail = iter(head) + row1 = list(starmap(add, zip(head, tail))) + + if len(row0) & 1: + row1.append(row0[-1]) + + if len(row1) == 1: + self._index[:] = row1 + row0 + self._offset = 1 + return + + size = 2 ** (int(log(len(row1) - 1, 2)) + 1) + row1.extend(repeat(0, size - len(row1))) + tree = [row0, row1] + + while len(tree[-1]) > 1: + head = iter(tree[-1]) + tail = iter(head) + row = list(starmap(add, zip(head, tail))) + tree.append(row) + + reduce(iadd, reversed(tree), self._index) + self._offset = size * 2 - 1 + + + def __delitem__(self, index): + """Remove value at `index` from sorted list. + + ``sl.__delitem__(index)`` <==> ``del sl[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList('abcde') + >>> del sl[2] + >>> sl + SortedList(['a', 'b', 'd', 'e']) + >>> del sl[:2] + >>> sl + SortedList(['d', 'e']) + + :param index: integer or slice for indexing + :raises IndexError: if index out of range + + """ + if isinstance(index, slice): + start, stop, step = index.indices(self._len) + + if step == 1 and start < stop: + if start == 0 and stop == self._len: + return self._clear() + elif self._len <= 8 * (stop - start): + values = self._getitem(slice(None, start)) + if stop < self._len: + values += self._getitem(slice(stop, None)) + self._clear() + return self._update(values) + + indices = range(start, stop, step) + + # Delete items from greatest index to least so + # that the indices remain valid throughout iteration. + + if step > 0: + indices = reversed(indices) + + _pos, _delete = self._pos, self._delete + + for index in indices: + pos, idx = _pos(index) + _delete(pos, idx) + else: + pos, idx = self._pos(index) + self._delete(pos, idx) + + + def __getitem__(self, index): + """Lookup value at `index` in sorted list. + + ``sl.__getitem__(index)`` <==> ``sl[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList('abcde') + >>> sl[1] + 'b' + >>> sl[-1] + 'e' + >>> sl[2:5] + ['c', 'd', 'e'] + + :param index: integer or slice for indexing + :return: value or list of values + :raises IndexError: if index out of range + + """ + _lists = self._lists + + if isinstance(index, slice): + start, stop, step = index.indices(self._len) + + if step == 1 and start < stop: + # Whole slice optimization: start to stop slices the whole + # sorted list. + + if start == 0 and stop == self._len: + return reduce(iadd, self._lists, []) + + start_pos, start_idx = self._pos(start) + start_list = _lists[start_pos] + stop_idx = start_idx + stop - start + + # Small slice optimization: start index and stop index are + # within the start list. + + if len(start_list) >= stop_idx: + return start_list[start_idx:stop_idx] + + if stop == self._len: + stop_pos = len(_lists) - 1 + stop_idx = len(_lists[stop_pos]) + else: + stop_pos, stop_idx = self._pos(stop) + + prefix = _lists[start_pos][start_idx:] + middle = _lists[(start_pos + 1):stop_pos] + result = reduce(iadd, middle, prefix) + result += _lists[stop_pos][:stop_idx] + + return result + + if step == -1 and start > stop: + result = self._getitem(slice(stop + 1, start + 1)) + result.reverse() + return result + + # Return a list because a negative step could + # reverse the order of the items and this could + # be the desired behavior. + + indices = range(start, stop, step) + return list(self._getitem(index) for index in indices) + else: + if self._len: + if index == 0: + return _lists[0][0] + elif index == -1: + return _lists[-1][-1] + else: + raise IndexError('list index out of range') + + if 0 <= index < len(_lists[0]): + return _lists[0][index] + + len_last = len(_lists[-1]) + + if -len_last < index < 0: + return _lists[-1][len_last + index] + + pos, idx = self._pos(index) + return _lists[pos][idx] + + _getitem = __getitem__ + + + def __setitem__(self, index, value): + """Raise not-implemented error. + + ``sl.__setitem__(index, value)`` <==> ``sl[index] = value`` + + :raises NotImplementedError: use ``del sl[index]`` and + ``sl.add(value)`` instead + + """ + message = 'use ``del sl[index]`` and ``sl.add(value)`` instead' + raise NotImplementedError(message) + + + def __iter__(self): + """Return an iterator over the sorted list. + + ``sl.__iter__()`` <==> ``iter(sl)`` + + Iterating the sorted list while adding or deleting values may raise a + :exc:`RuntimeError` or fail to iterate over all values. + + """ + return chain.from_iterable(self._lists) + + + def __reversed__(self): + """Return a reverse iterator over the sorted list. + + ``sl.__reversed__()`` <==> ``reversed(sl)`` + + Iterating the sorted list while adding or deleting values may raise a + :exc:`RuntimeError` or fail to iterate over all values. + + """ + return chain.from_iterable(map(reversed, reversed(self._lists))) + + + def reverse(self): + """Raise not-implemented error. + + Sorted list maintains values in ascending sort order. Values may not be + reversed in-place. + + Use ``reversed(sl)`` for an iterator over values in descending sort + order. + + Implemented to override `MutableSequence.reverse` which provides an + erroneous default implementation. + + :raises NotImplementedError: use ``reversed(sl)`` instead + + """ + raise NotImplementedError('use ``reversed(sl)`` instead') + + + def islice(self, start=None, stop=None, reverse=False): + """Return an iterator that slices sorted list from `start` to `stop`. + + The `start` and `stop` index are treated inclusive and exclusive, + respectively. + + Both `start` and `stop` default to `None` which is automatically + inclusive of the beginning and end of the sorted list. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + + >>> sl = SortedList('abcdefghij') + >>> it = sl.islice(2, 6) + >>> list(it) + ['c', 'd', 'e', 'f'] + + :param int start: start index (inclusive) + :param int stop: stop index (exclusive) + :param bool reverse: yield values in reverse order + :return: iterator + + """ + _len = self._len + + if not _len: + return iter(()) + + start, stop, _ = slice(start, stop).indices(self._len) + + if start >= stop: + return iter(()) + + _pos = self._pos + + min_pos, min_idx = _pos(start) + + if stop == _len: + max_pos = len(self._lists) - 1 + max_idx = len(self._lists[-1]) + else: + max_pos, max_idx = _pos(stop) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + + def _islice(self, min_pos, min_idx, max_pos, max_idx, reverse): + """Return an iterator that slices sorted list using two index pairs. + + The index pairs are (min_pos, min_idx) and (max_pos, max_idx), the + first inclusive and the latter exclusive. See `_pos` for details on how + an index is converted to an index pair. + + When `reverse` is `True`, values are yielded from the iterator in + reverse order. + + """ + _lists = self._lists + + if min_pos > max_pos: + return iter(()) + + if min_pos == max_pos: + if reverse: + indices = reversed(range(min_idx, max_idx)) + return map(_lists[min_pos].__getitem__, indices) + + indices = range(min_idx, max_idx) + return map(_lists[min_pos].__getitem__, indices) + + next_pos = min_pos + 1 + + if next_pos == max_pos: + if reverse: + min_indices = range(min_idx, len(_lists[min_pos])) + max_indices = range(max_idx) + return chain( + map(_lists[max_pos].__getitem__, reversed(max_indices)), + map(_lists[min_pos].__getitem__, reversed(min_indices)), + ) + + min_indices = range(min_idx, len(_lists[min_pos])) + max_indices = range(max_idx) + return chain( + map(_lists[min_pos].__getitem__, min_indices), + map(_lists[max_pos].__getitem__, max_indices), + ) + + if reverse: + min_indices = range(min_idx, len(_lists[min_pos])) + sublist_indices = range(next_pos, max_pos) + sublists = map(_lists.__getitem__, reversed(sublist_indices)) + max_indices = range(max_idx) + return chain( + map(_lists[max_pos].__getitem__, reversed(max_indices)), + chain.from_iterable(map(reversed, sublists)), + map(_lists[min_pos].__getitem__, reversed(min_indices)), + ) + + min_indices = range(min_idx, len(_lists[min_pos])) + sublist_indices = range(next_pos, max_pos) + sublists = map(_lists.__getitem__, sublist_indices) + max_indices = range(max_idx) + return chain( + map(_lists[min_pos].__getitem__, min_indices), + chain.from_iterable(sublists), + map(_lists[max_pos].__getitem__, max_indices), + ) + + + def irange(self, minimum=None, maximum=None, inclusive=(True, True), + reverse=False): + """Create an iterator of values between `minimum` and `maximum`. + + Both `minimum` and `maximum` default to `None` which is automatically + inclusive of the beginning and end of the sorted list. + + The argument `inclusive` is a pair of booleans that indicates whether + the minimum and maximum ought to be included in the range, + respectively. The default is ``(True, True)`` such that the range is + inclusive of both minimum and maximum. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + + >>> sl = SortedList('abcdefghij') + >>> it = sl.irange('c', 'f') + >>> list(it) + ['c', 'd', 'e', 'f'] + + :param minimum: minimum value to start iterating + :param maximum: maximum value to stop iterating + :param inclusive: pair of booleans + :param bool reverse: yield values in reverse order + :return: iterator + + """ + _maxes = self._maxes + + if not _maxes: + return iter(()) + + _lists = self._lists + + # Calculate the minimum (pos, idx) pair. By default this location + # will be inclusive in our calculation. + + if minimum is None: + min_pos = 0 + min_idx = 0 + else: + if inclusive[0]: + min_pos = bisect_left(_maxes, minimum) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_left(_lists[min_pos], minimum) + else: + min_pos = bisect_right(_maxes, minimum) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_right(_lists[min_pos], minimum) + + # Calculate the maximum (pos, idx) pair. By default this location + # will be exclusive in our calculation. + + if maximum is None: + max_pos = len(_maxes) - 1 + max_idx = len(_lists[max_pos]) + else: + if inclusive[1]: + max_pos = bisect_right(_maxes, maximum) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_lists[max_pos]) + else: + max_idx = bisect_right(_lists[max_pos], maximum) + else: + max_pos = bisect_left(_maxes, maximum) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_lists[max_pos]) + else: + max_idx = bisect_left(_lists[max_pos], maximum) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + + def __len__(self): + """Return the size of the sorted list. + + ``sl.__len__()`` <==> ``len(sl)`` + + :return: size of sorted list + + """ + return self._len + + + def bisect_left(self, value): + """Return an index to insert `value` in the sorted list. + + If the `value` is already present, the insertion point will be before + (to the left of) any existing values. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList([10, 11, 12, 13, 14]) + >>> sl.bisect_left(12) + 2 + + :param value: insertion index of value in sorted list + :return: index + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_left(_maxes, value) + + if pos == len(_maxes): + return self._len + + idx = bisect_left(self._lists[pos], value) + return self._loc(pos, idx) + + + def bisect_right(self, value): + """Return an index to insert `value` in the sorted list. + + Similar to `bisect_left`, but if `value` is already present, the + insertion point will be after (to the right of) any existing values. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList([10, 11, 12, 13, 14]) + >>> sl.bisect_right(12) + 3 + + :param value: insertion index of value in sorted list + :return: index + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_right(_maxes, value) + + if pos == len(_maxes): + return self._len + + idx = bisect_right(self._lists[pos], value) + return self._loc(pos, idx) + + bisect = bisect_right + _bisect_right = bisect_right + + + def count(self, value): + """Return number of occurrences of `value` in the sorted list. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList([1, 2, 2, 3, 3, 3, 4, 4, 4, 4]) + >>> sl.count(3) + 3 + + :param value: value to count in sorted list + :return: count + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos_left = bisect_left(_maxes, value) + + if pos_left == len(_maxes): + return 0 + + _lists = self._lists + idx_left = bisect_left(_lists[pos_left], value) + pos_right = bisect_right(_maxes, value) + + if pos_right == len(_maxes): + return self._len - self._loc(pos_left, idx_left) + + idx_right = bisect_right(_lists[pos_right], value) + + if pos_left == pos_right: + return idx_right - idx_left + + right = self._loc(pos_right, idx_right) + left = self._loc(pos_left, idx_left) + return right - left + + + def copy(self): + """Return a shallow copy of the sorted list. + + Runtime complexity: `O(n)` + + :return: new sorted list + + """ + return self.__class__(self) + + __copy__ = copy + + + def append(self, value): + """Raise not-implemented error. + + Implemented to override `MutableSequence.append` which provides an + erroneous default implementation. + + :raises NotImplementedError: use ``sl.add(value)`` instead + + """ + raise NotImplementedError('use ``sl.add(value)`` instead') + + + def extend(self, values): + """Raise not-implemented error. + + Implemented to override `MutableSequence.extend` which provides an + erroneous default implementation. + + :raises NotImplementedError: use ``sl.update(values)`` instead + + """ + raise NotImplementedError('use ``sl.update(values)`` instead') + + + def insert(self, index, value): + """Raise not-implemented error. + + :raises NotImplementedError: use ``sl.add(value)`` instead + + """ + raise NotImplementedError('use ``sl.add(value)`` instead') + + + def pop(self, index=-1): + """Remove and return value at `index` in sorted list. + + Raise :exc:`IndexError` if the sorted list is empty or index is out of + range. + + Negative indices are supported. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList('abcde') + >>> sl.pop() + 'e' + >>> sl.pop(2) + 'c' + >>> sl + SortedList(['a', 'b', 'd']) + + :param int index: index of value (default -1) + :return: value + :raises IndexError: if index is out of range + + """ + if not self._len: + raise IndexError('pop index out of range') + + _lists = self._lists + + if index == 0: + val = _lists[0][0] + self._delete(0, 0) + return val + + if index == -1: + pos = len(_lists) - 1 + loc = len(_lists[pos]) - 1 + val = _lists[pos][loc] + self._delete(pos, loc) + return val + + if 0 <= index < len(_lists[0]): + val = _lists[0][index] + self._delete(0, index) + return val + + len_last = len(_lists[-1]) + + if -len_last < index < 0: + pos = len(_lists) - 1 + loc = len_last + index + val = _lists[pos][loc] + self._delete(pos, loc) + return val + + pos, idx = self._pos(index) + val = _lists[pos][idx] + self._delete(pos, idx) + return val + + + def index(self, value, start=None, stop=None): + """Return first index of value in sorted list. + + Raise ValueError if `value` is not present. + + Index must be between `start` and `stop` for the `value` to be + considered present. The default value, None, for `start` and `stop` + indicate the beginning and end of the sorted list. + + Negative indices are supported. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> sl = SortedList('abcde') + >>> sl.index('d') + 3 + >>> sl.index('z') + Traceback (most recent call last): + ... + ValueError: 'z' is not in list + + :param value: value in sorted list + :param int start: start index (default None, start of sorted list) + :param int stop: stop index (default None, end of sorted list) + :return: index of value + :raises ValueError: if value is not present + + """ + _len = self._len + + if not _len: + raise ValueError('{0!r} is not in list'.format(value)) + + if start is None: + start = 0 + if start < 0: + start += _len + if start < 0: + start = 0 + + if stop is None: + stop = _len + if stop < 0: + stop += _len + if stop > _len: + stop = _len + + if stop <= start: + raise ValueError('{0!r} is not in list'.format(value)) + + _maxes = self._maxes + pos_left = bisect_left(_maxes, value) + + if pos_left == len(_maxes): + raise ValueError('{0!r} is not in list'.format(value)) + + _lists = self._lists + idx_left = bisect_left(_lists[pos_left], value) + + if _lists[pos_left][idx_left] != value: + raise ValueError('{0!r} is not in list'.format(value)) + + stop -= 1 + left = self._loc(pos_left, idx_left) + + if start <= left: + if left <= stop: + return left + else: + right = self._bisect_right(value) - 1 + + if start <= right: + return start + + raise ValueError('{0!r} is not in list'.format(value)) + + + def __add__(self, other): + """Return new sorted list containing all values in both sequences. + + ``sl.__add__(other)`` <==> ``sl + other`` + + Values in `other` do not need to be in sorted order. + + Runtime complexity: `O(n*log(n))` + + >>> sl1 = SortedList('bat') + >>> sl2 = SortedList('cat') + >>> sl1 + sl2 + SortedList(['a', 'a', 'b', 'c', 't', 't']) + + :param other: other iterable + :return: new sorted list + + """ + values = reduce(iadd, self._lists, []) + values.extend(other) + return self.__class__(values) + + __radd__ = __add__ + + + def __iadd__(self, other): + """Update sorted list with values from `other`. + + ``sl.__iadd__(other)`` <==> ``sl += other`` + + Values in `other` do not need to be in sorted order. + + Runtime complexity: `O(k*log(n))` -- approximate. + + >>> sl = SortedList('bat') + >>> sl += 'cat' + >>> sl + SortedList(['a', 'a', 'b', 'c', 't', 't']) + + :param other: other iterable + :return: existing sorted list + + """ + self._update(other) + return self + + + def __mul__(self, num): + """Return new sorted list with `num` shallow copies of values. + + ``sl.__mul__(num)`` <==> ``sl * num`` + + Runtime complexity: `O(n*log(n))` + + >>> sl = SortedList('abc') + >>> sl * 3 + SortedList(['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c']) + + :param int num: count of shallow copies + :return: new sorted list + + """ + values = reduce(iadd, self._lists, []) * num + return self.__class__(values) + + __rmul__ = __mul__ + + + def __imul__(self, num): + """Update the sorted list with `num` shallow copies of values. + + ``sl.__imul__(num)`` <==> ``sl *= num`` + + Runtime complexity: `O(n*log(n))` + + >>> sl = SortedList('abc') + >>> sl *= 3 + >>> sl + SortedList(['a', 'a', 'a', 'b', 'b', 'b', 'c', 'c', 'c']) + + :param int num: count of shallow copies + :return: existing sorted list + + """ + values = reduce(iadd, self._lists, []) * num + self._clear() + self._update(values) + return self + + + def __make_cmp(seq_op, symbol, doc): + "Make comparator method." + def comparer(self, other): + "Compare method for sorted list and sequence." + if not isinstance(other, Sequence): + return NotImplemented + + self_len = self._len + len_other = len(other) + + if self_len != len_other: + if seq_op is eq: + return False + if seq_op is ne: + return True + + for alpha, beta in zip(self, other): + if alpha != beta: + return seq_op(alpha, beta) + + return seq_op(self_len, len_other) + + seq_op_name = seq_op.__name__ + comparer.__name__ = '__{0}__'.format(seq_op_name) + doc_str = """Return true if and only if sorted list is {0} `other`. + + ``sl.__{1}__(other)`` <==> ``sl {2} other`` + + Comparisons use lexicographical order as with sequences. + + Runtime complexity: `O(n)` + + :param other: `other` sequence + :return: true if sorted list is {0} `other` + + """ + comparer.__doc__ = dedent(doc_str.format(doc, seq_op_name, symbol)) + return comparer + + + __eq__ = __make_cmp(eq, '==', 'equal to') + __ne__ = __make_cmp(ne, '!=', 'not equal to') + __lt__ = __make_cmp(lt, '<', 'less than') + __gt__ = __make_cmp(gt, '>', 'greater than') + __le__ = __make_cmp(le, '<=', 'less than or equal to') + __ge__ = __make_cmp(ge, '>=', 'greater than or equal to') + __make_cmp = staticmethod(__make_cmp) + + + def __reduce__(self): + values = reduce(iadd, self._lists, []) + return (type(self), (values,)) + + + @recursive_repr() + def __repr__(self): + """Return string representation of sorted list. + + ``sl.__repr__()`` <==> ``repr(sl)`` + + :return: string representation + + """ + return '{0}({1!r})'.format(type(self).__name__, list(self)) + + + def _check(self): + """Check invariants of sorted list. + + Runtime complexity: `O(n)` + + """ + try: + assert self._load >= 4 + assert len(self._maxes) == len(self._lists) + assert self._len == sum(len(sublist) for sublist in self._lists) + + # Check all sublists are sorted. + + for sublist in self._lists: + for pos in range(1, len(sublist)): + assert sublist[pos - 1] <= sublist[pos] + + # Check beginning/end of sublists are sorted. + + for pos in range(1, len(self._lists)): + assert self._lists[pos - 1][-1] <= self._lists[pos][0] + + # Check _maxes index is the last value of each sublist. + + for pos in range(len(self._maxes)): + assert self._maxes[pos] == self._lists[pos][-1] + + # Check sublist lengths are less than double load-factor. + + double = self._load << 1 + assert all(len(sublist) <= double for sublist in self._lists) + + # Check sublist lengths are greater than half load-factor for all + # but the last sublist. + + half = self._load >> 1 + for pos in range(0, len(self._lists) - 1): + assert len(self._lists[pos]) >= half + + if self._index: + assert self._len == self._index[0] + assert len(self._index) == self._offset + len(self._lists) + + # Check index leaf nodes equal length of sublists. + + for pos in range(len(self._lists)): + leaf = self._index[self._offset + pos] + assert leaf == len(self._lists[pos]) + + # Check index branch nodes are the sum of their children. + + for pos in range(self._offset): + child = (pos << 1) + 1 + if child >= len(self._index): + assert self._index[pos] == 0 + elif child + 1 == len(self._index): + assert self._index[pos] == self._index[child] + else: + child_sum = self._index[child] + self._index[child + 1] + assert child_sum == self._index[pos] + except: + traceback.print_exc(file=sys.stdout) + print('len', self._len) + print('load', self._load) + print('offset', self._offset) + print('len_index', len(self._index)) + print('index', self._index) + print('len_maxes', len(self._maxes)) + print('maxes', self._maxes) + print('len_lists', len(self._lists)) + print('lists', self._lists) + raise + + +def identity(value): + "Identity function." + return value + + +class SortedKeyList(SortedList): + """Sorted-key list is a subtype of sorted list. + + The sorted-key list maintains values in comparison order based on the + result of a key function applied to every value. + + All the same methods that are available in :class:`SortedList` are also + available in :class:`SortedKeyList`. + + Additional methods provided: + + * :attr:`SortedKeyList.key` + * :func:`SortedKeyList.bisect_key_left` + * :func:`SortedKeyList.bisect_key_right` + * :func:`SortedKeyList.irange_key` + + Some examples below use: + + >>> from operator import neg + >>> neg + + >>> neg(1) + -1 + + """ + def __init__(self, iterable=None, key=identity): + """Initialize sorted-key list instance. + + Optional `iterable` argument provides an initial iterable of values to + initialize the sorted-key list. + + Optional `key` argument defines a callable that, like the `key` + argument to Python's `sorted` function, extracts a comparison key from + each value. The default is the identity function. + + Runtime complexity: `O(n*log(n))` + + >>> from operator import neg + >>> skl = SortedKeyList(key=neg) + >>> skl + SortedKeyList([], key=) + >>> skl = SortedKeyList([3, 1, 2], key=neg) + >>> skl + SortedKeyList([3, 2, 1], key=) + + :param iterable: initial values (optional) + :param key: function used to extract comparison key (optional) + + """ + self._key = key + self._len = 0 + self._load = self.DEFAULT_LOAD_FACTOR + self._lists = [] + self._keys = [] + self._maxes = [] + self._index = [] + self._offset = 0 + + if iterable is not None: + self._update(iterable) + + + def __new__(cls, iterable=None, key=identity): + return object.__new__(cls) + + + @property + def key(self): + "Function used to extract comparison key from values." + return self._key + + + def clear(self): + """Remove all values from sorted-key list. + + Runtime complexity: `O(n)` + + """ + self._len = 0 + del self._lists[:] + del self._keys[:] + del self._maxes[:] + del self._index[:] + + _clear = clear + + + def add(self, value): + """Add `value` to sorted-key list. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList(key=neg) + >>> skl.add(3) + >>> skl.add(1) + >>> skl.add(2) + >>> skl + SortedKeyList([3, 2, 1], key=) + + :param value: value to add to sorted-key list + + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + + key = self._key(value) + + if _maxes: + pos = bisect_right(_maxes, key) + + if pos == len(_maxes): + pos -= 1 + _lists[pos].append(value) + _keys[pos].append(key) + _maxes[pos] = key + else: + idx = bisect_right(_keys[pos], key) + _lists[pos].insert(idx, value) + _keys[pos].insert(idx, key) + + self._expand(pos) + else: + _lists.append([value]) + _keys.append([key]) + _maxes.append(key) + + self._len += 1 + + + def _expand(self, pos): + """Split sublists with length greater than double the load-factor. + + Updates the index when the sublist length is less than double the load + level. This requires incrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see + ``SortedList._loc``. + + """ + _lists = self._lists + _keys = self._keys + _index = self._index + + if len(_keys[pos]) > (self._load << 1): + _maxes = self._maxes + _load = self._load + + _lists_pos = _lists[pos] + _keys_pos = _keys[pos] + half = _lists_pos[_load:] + half_keys = _keys_pos[_load:] + del _lists_pos[_load:] + del _keys_pos[_load:] + _maxes[pos] = _keys_pos[-1] + + _lists.insert(pos + 1, half) + _keys.insert(pos + 1, half_keys) + _maxes.insert(pos + 1, half_keys[-1]) + + del _index[:] + else: + if _index: + child = self._offset + pos + while child: + _index[child] += 1 + child = (child - 1) >> 1 + _index[0] += 1 + + + def update(self, iterable): + """Update sorted-key list by adding all values from `iterable`. + + Runtime complexity: `O(k*log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList(key=neg) + >>> skl.update([3, 1, 2]) + >>> skl + SortedKeyList([3, 2, 1], key=) + + :param iterable: iterable of values to add + + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + values = sorted(iterable, key=self._key) + + if _maxes: + if len(values) * 4 >= self._len: + _lists.append(values) + values = reduce(iadd, _lists, []) + values.sort(key=self._key) + self._clear() + else: + _add = self.add + for val in values: + _add(val) + return + + _load = self._load + _lists.extend(values[pos:(pos + _load)] + for pos in range(0, len(values), _load)) + _keys.extend(list(map(self._key, _list)) for _list in _lists) + _maxes.extend(sublist[-1] for sublist in _keys) + self._len = len(values) + del self._index[:] + + _update = update + + + def __contains__(self, value): + """Return true if `value` is an element of the sorted-key list. + + ``skl.__contains__(value)`` <==> ``value in skl`` + + Runtime complexity: `O(log(n))` + + >>> from operator import neg + >>> skl = SortedKeyList([1, 2, 3, 4, 5], key=neg) + >>> 3 in skl + True + + :param value: search for value in sorted-key list + :return: true if `value` in sorted-key list + + """ + _maxes = self._maxes + + if not _maxes: + return False + + key = self._key(value) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return False + + _lists = self._lists + _keys = self._keys + + idx = bisect_left(_keys[pos], key) + + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return False + if _lists[pos][idx] == value: + return True + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return False + len_sublist = len(_keys[pos]) + idx = 0 + + + def discard(self, value): + """Remove `value` from sorted-key list if it is a member. + + If `value` is not a member, do nothing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([5, 4, 3, 2, 1], key=neg) + >>> skl.discard(1) + >>> skl.discard(0) + >>> skl == [5, 4, 3, 2] + True + + :param value: `value` to discard from sorted-key list + + """ + _maxes = self._maxes + + if not _maxes: + return + + key = self._key(value) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return + if _lists[pos][idx] == value: + self._delete(pos, idx) + return + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return + len_sublist = len(_keys[pos]) + idx = 0 + + + def remove(self, value): + """Remove `value` from sorted-key list; `value` must be a member. + + If `value` is not a member, raise ValueError. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([1, 2, 3, 4, 5], key=neg) + >>> skl.remove(5) + >>> skl == [4, 3, 2, 1] + True + >>> skl.remove(0) + Traceback (most recent call last): + ... + ValueError: 0 not in list + + :param value: `value` to remove from sorted-key list + :raises ValueError: if `value` is not in sorted-key list + + """ + _maxes = self._maxes + + if not _maxes: + raise ValueError('{0!r} not in list'.format(value)) + + key = self._key(value) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + raise ValueError('{0!r} not in list'.format(value)) + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + raise ValueError('{0!r} not in list'.format(value)) + if _lists[pos][idx] == value: + self._delete(pos, idx) + return + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + raise ValueError('{0!r} not in list'.format(value)) + len_sublist = len(_keys[pos]) + idx = 0 + + + def _delete(self, pos, idx): + """Delete value at the given `(pos, idx)`. + + Combines lists that are less than half the load level. + + Updates the index when the sublist length is more than half the load + level. This requires decrementing the nodes in a traversal from the + leaf node to the root. For an example traversal see + ``SortedList._loc``. + + :param int pos: lists index + :param int idx: sublist index + + """ + _lists = self._lists + _keys = self._keys + _maxes = self._maxes + _index = self._index + keys_pos = _keys[pos] + lists_pos = _lists[pos] + + del keys_pos[idx] + del lists_pos[idx] + self._len -= 1 + + len_keys_pos = len(keys_pos) + + if len_keys_pos > (self._load >> 1): + _maxes[pos] = keys_pos[-1] + + if _index: + child = self._offset + pos + while child > 0: + _index[child] -= 1 + child = (child - 1) >> 1 + _index[0] -= 1 + elif len(_keys) > 1: + if not pos: + pos += 1 + + prev = pos - 1 + _keys[prev].extend(_keys[pos]) + _lists[prev].extend(_lists[pos]) + _maxes[prev] = _keys[prev][-1] + + del _lists[pos] + del _keys[pos] + del _maxes[pos] + del _index[:] + + self._expand(prev) + elif len_keys_pos: + _maxes[pos] = keys_pos[-1] + else: + del _lists[pos] + del _keys[pos] + del _maxes[pos] + del _index[:] + + + def irange(self, minimum=None, maximum=None, inclusive=(True, True), + reverse=False): + """Create an iterator of values between `minimum` and `maximum`. + + Both `minimum` and `maximum` default to `None` which is automatically + inclusive of the beginning and end of the sorted-key list. + + The argument `inclusive` is a pair of booleans that indicates whether + the minimum and maximum ought to be included in the range, + respectively. The default is ``(True, True)`` such that the range is + inclusive of both minimum and maximum. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + + >>> from operator import neg + >>> skl = SortedKeyList([11, 12, 13, 14, 15], key=neg) + >>> it = skl.irange(14.5, 11.5) + >>> list(it) + [14, 13, 12] + + :param minimum: minimum value to start iterating + :param maximum: maximum value to stop iterating + :param inclusive: pair of booleans + :param bool reverse: yield values in reverse order + :return: iterator + + """ + min_key = self._key(minimum) if minimum is not None else None + max_key = self._key(maximum) if maximum is not None else None + return self._irange_key( + min_key=min_key, max_key=max_key, + inclusive=inclusive, reverse=reverse, + ) + + + def irange_key(self, min_key=None, max_key=None, inclusive=(True, True), + reverse=False): + """Create an iterator of values between `min_key` and `max_key`. + + Both `min_key` and `max_key` default to `None` which is automatically + inclusive of the beginning and end of the sorted-key list. + + The argument `inclusive` is a pair of booleans that indicates whether + the minimum and maximum ought to be included in the range, + respectively. The default is ``(True, True)`` such that the range is + inclusive of both minimum and maximum. + + When `reverse` is `True` the values are yielded from the iterator in + reverse order; `reverse` defaults to `False`. + + >>> from operator import neg + >>> skl = SortedKeyList([11, 12, 13, 14, 15], key=neg) + >>> it = skl.irange_key(-14, -12) + >>> list(it) + [14, 13, 12] + + :param min_key: minimum key to start iterating + :param max_key: maximum key to stop iterating + :param inclusive: pair of booleans + :param bool reverse: yield values in reverse order + :return: iterator + + """ + _maxes = self._maxes + + if not _maxes: + return iter(()) + + _keys = self._keys + + # Calculate the minimum (pos, idx) pair. By default this location + # will be inclusive in our calculation. + + if min_key is None: + min_pos = 0 + min_idx = 0 + else: + if inclusive[0]: + min_pos = bisect_left(_maxes, min_key) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_left(_keys[min_pos], min_key) + else: + min_pos = bisect_right(_maxes, min_key) + + if min_pos == len(_maxes): + return iter(()) + + min_idx = bisect_right(_keys[min_pos], min_key) + + # Calculate the maximum (pos, idx) pair. By default this location + # will be exclusive in our calculation. + + if max_key is None: + max_pos = len(_maxes) - 1 + max_idx = len(_keys[max_pos]) + else: + if inclusive[1]: + max_pos = bisect_right(_maxes, max_key) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_keys[max_pos]) + else: + max_idx = bisect_right(_keys[max_pos], max_key) + else: + max_pos = bisect_left(_maxes, max_key) + + if max_pos == len(_maxes): + max_pos -= 1 + max_idx = len(_keys[max_pos]) + else: + max_idx = bisect_left(_keys[max_pos], max_key) + + return self._islice(min_pos, min_idx, max_pos, max_idx, reverse) + + _irange_key = irange_key + + + def bisect_left(self, value): + """Return an index to insert `value` in the sorted-key list. + + If the `value` is already present, the insertion point will be before + (to the left of) any existing values. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([5, 4, 3, 2, 1], key=neg) + >>> skl.bisect_left(1) + 4 + + :param value: insertion index of value in sorted-key list + :return: index + + """ + return self._bisect_key_left(self._key(value)) + + + def bisect_right(self, value): + """Return an index to insert `value` in the sorted-key list. + + Similar to `bisect_left`, but if `value` is already present, the + insertion point will be after (to the right of) any existing values. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedList([5, 4, 3, 2, 1], key=neg) + >>> skl.bisect_right(1) + 5 + + :param value: insertion index of value in sorted-key list + :return: index + + """ + return self._bisect_key_right(self._key(value)) + + bisect = bisect_right + + + def bisect_key_left(self, key): + """Return an index to insert `key` in the sorted-key list. + + If the `key` is already present, the insertion point will be before (to + the left of) any existing keys. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([5, 4, 3, 2, 1], key=neg) + >>> skl.bisect_key_left(-1) + 4 + + :param key: insertion index of key in sorted-key list + :return: index + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return self._len + + idx = bisect_left(self._keys[pos], key) + + return self._loc(pos, idx) + + _bisect_key_left = bisect_key_left + + + def bisect_key_right(self, key): + """Return an index to insert `key` in the sorted-key list. + + Similar to `bisect_key_left`, but if `key` is already present, the + insertion point will be after (to the right of) any existing keys. + + Similar to the `bisect` module in the standard library. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedList([5, 4, 3, 2, 1], key=neg) + >>> skl.bisect_key_right(-1) + 5 + + :param key: insertion index of key in sorted-key list + :return: index + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + pos = bisect_right(_maxes, key) + + if pos == len(_maxes): + return self._len + + idx = bisect_right(self._keys[pos], key) + + return self._loc(pos, idx) + + bisect_key = bisect_key_right + _bisect_key_right = bisect_key_right + + + def count(self, value): + """Return number of occurrences of `value` in the sorted-key list. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([4, 4, 4, 4, 3, 3, 3, 2, 2, 1], key=neg) + >>> skl.count(2) + 2 + + :param value: value to count in sorted-key list + :return: count + + """ + _maxes = self._maxes + + if not _maxes: + return 0 + + key = self._key(value) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + return 0 + + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + total = 0 + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + return total + if _lists[pos][idx] == value: + total += 1 + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + return total + len_sublist = len(_keys[pos]) + idx = 0 + + + def copy(self): + """Return a shallow copy of the sorted-key list. + + Runtime complexity: `O(n)` + + :return: new sorted-key list + + """ + return self.__class__(self, key=self._key) + + __copy__ = copy + + + def index(self, value, start=None, stop=None): + """Return first index of value in sorted-key list. + + Raise ValueError if `value` is not present. + + Index must be between `start` and `stop` for the `value` to be + considered present. The default value, None, for `start` and `stop` + indicate the beginning and end of the sorted-key list. + + Negative indices are supported. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> from operator import neg + >>> skl = SortedKeyList([5, 4, 3, 2, 1], key=neg) + >>> skl.index(2) + 3 + >>> skl.index(0) + Traceback (most recent call last): + ... + ValueError: 0 is not in list + + :param value: value in sorted-key list + :param int start: start index (default None, start of sorted-key list) + :param int stop: stop index (default None, end of sorted-key list) + :return: index of value + :raises ValueError: if value is not present + + """ + _len = self._len + + if not _len: + raise ValueError('{0!r} is not in list'.format(value)) + + if start is None: + start = 0 + if start < 0: + start += _len + if start < 0: + start = 0 + + if stop is None: + stop = _len + if stop < 0: + stop += _len + if stop > _len: + stop = _len + + if stop <= start: + raise ValueError('{0!r} is not in list'.format(value)) + + _maxes = self._maxes + key = self._key(value) + pos = bisect_left(_maxes, key) + + if pos == len(_maxes): + raise ValueError('{0!r} is not in list'.format(value)) + + stop -= 1 + _lists = self._lists + _keys = self._keys + idx = bisect_left(_keys[pos], key) + len_keys = len(_keys) + len_sublist = len(_keys[pos]) + + while True: + if _keys[pos][idx] != key: + raise ValueError('{0!r} is not in list'.format(value)) + if _lists[pos][idx] == value: + loc = self._loc(pos, idx) + if start <= loc <= stop: + return loc + elif loc > stop: + break + idx += 1 + if idx == len_sublist: + pos += 1 + if pos == len_keys: + raise ValueError('{0!r} is not in list'.format(value)) + len_sublist = len(_keys[pos]) + idx = 0 + + raise ValueError('{0!r} is not in list'.format(value)) + + + def __add__(self, other): + """Return new sorted-key list containing all values in both sequences. + + ``skl.__add__(other)`` <==> ``skl + other`` + + Values in `other` do not need to be in sorted-key order. + + Runtime complexity: `O(n*log(n))` + + >>> from operator import neg + >>> skl1 = SortedKeyList([5, 4, 3], key=neg) + >>> skl2 = SortedKeyList([2, 1, 0], key=neg) + >>> skl1 + skl2 + SortedKeyList([5, 4, 3, 2, 1, 0], key=) + + :param other: other iterable + :return: new sorted-key list + + """ + values = reduce(iadd, self._lists, []) + values.extend(other) + return self.__class__(values, key=self._key) + + __radd__ = __add__ + + + def __mul__(self, num): + """Return new sorted-key list with `num` shallow copies of values. + + ``skl.__mul__(num)`` <==> ``skl * num`` + + Runtime complexity: `O(n*log(n))` + + >>> from operator import neg + >>> skl = SortedKeyList([3, 2, 1], key=neg) + >>> skl * 2 + SortedKeyList([3, 3, 2, 2, 1, 1], key=) + + :param int num: count of shallow copies + :return: new sorted-key list + + """ + values = reduce(iadd, self._lists, []) * num + return self.__class__(values, key=self._key) + + + def __reduce__(self): + values = reduce(iadd, self._lists, []) + return (type(self), (values, self.key)) + + + @recursive_repr() + def __repr__(self): + """Return string representation of sorted-key list. + + ``skl.__repr__()`` <==> ``repr(skl)`` + + :return: string representation + + """ + type_name = type(self).__name__ + return '{0}({1!r}, key={2!r})'.format(type_name, list(self), self._key) + + + def _check(self): + """Check invariants of sorted-key list. + + Runtime complexity: `O(n)` + + """ + try: + assert self._load >= 4 + assert len(self._maxes) == len(self._lists) == len(self._keys) + assert self._len == sum(len(sublist) for sublist in self._lists) + + # Check all sublists are sorted. + + for sublist in self._keys: + for pos in range(1, len(sublist)): + assert sublist[pos - 1] <= sublist[pos] + + # Check beginning/end of sublists are sorted. + + for pos in range(1, len(self._keys)): + assert self._keys[pos - 1][-1] <= self._keys[pos][0] + + # Check _keys matches _key mapped to _lists. + + for val_sublist, key_sublist in zip(self._lists, self._keys): + assert len(val_sublist) == len(key_sublist) + for val, key in zip(val_sublist, key_sublist): + assert self._key(val) == key + + # Check _maxes index is the last value of each sublist. + + for pos in range(len(self._maxes)): + assert self._maxes[pos] == self._keys[pos][-1] + + # Check sublist lengths are less than double load-factor. + + double = self._load << 1 + assert all(len(sublist) <= double for sublist in self._lists) + + # Check sublist lengths are greater than half load-factor for all + # but the last sublist. + + half = self._load >> 1 + for pos in range(0, len(self._lists) - 1): + assert len(self._lists[pos]) >= half + + if self._index: + assert self._len == self._index[0] + assert len(self._index) == self._offset + len(self._lists) + + # Check index leaf nodes equal length of sublists. + + for pos in range(len(self._lists)): + leaf = self._index[self._offset + pos] + assert leaf == len(self._lists[pos]) + + # Check index branch nodes are the sum of their children. + + for pos in range(self._offset): + child = (pos << 1) + 1 + if child >= len(self._index): + assert self._index[pos] == 0 + elif child + 1 == len(self._index): + assert self._index[pos] == self._index[child] + else: + child_sum = self._index[child] + self._index[child + 1] + assert child_sum == self._index[pos] + except: + traceback.print_exc(file=sys.stdout) + print('len', self._len) + print('load', self._load) + print('offset', self._offset) + print('len_index', len(self._index)) + print('index', self._index) + print('len_maxes', len(self._maxes)) + print('maxes', self._maxes) + print('len_keys', len(self._keys)) + print('keys', self._keys) + print('len_lists', len(self._lists)) + print('lists', self._lists) + raise + + +SortedListWithKey = SortedKeyList diff --git a/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedset.py b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedset.py new file mode 100644 index 0000000..be2b899 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/intervaltree/sortedcontainers/sortedset.py @@ -0,0 +1,733 @@ +"""Sorted Set +============= + +:doc:`Sorted Containers` is an Apache2 licensed Python sorted +collections library, written in pure-Python, and fast as C-extensions. The +:doc:`introduction` is the best way to get started. + +Sorted set implementations: + +.. currentmodule:: sortedcontainers + +* :class:`SortedSet` + +""" + +from itertools import chain +from operator import eq, ne, gt, ge, lt, le +from textwrap import dedent + +from .sortedlist import SortedList, recursive_repr + +############################################################################### +# BEGIN Python 2/3 Shims +############################################################################### + +try: + from collections.abc import MutableSet, Sequence, Set +except ImportError: + from collections import MutableSet, Sequence, Set + +############################################################################### +# END Python 2/3 Shims +############################################################################### + + +class SortedSet(MutableSet, Sequence): + """Sorted set is a sorted mutable set. + + Sorted set values are maintained in sorted order. The design of sorted set + is simple: sorted set uses a set for set-operations and maintains a sorted + list of values. + + Sorted set values must be hashable and comparable. The hash and total + ordering of values must not change while they are stored in the sorted set. + + Mutable set methods: + + * :func:`SortedSet.__contains__` + * :func:`SortedSet.__iter__` + * :func:`SortedSet.__len__` + * :func:`SortedSet.add` + * :func:`SortedSet.discard` + + Sequence methods: + + * :func:`SortedSet.__getitem__` + * :func:`SortedSet.__delitem__` + * :func:`SortedSet.__reversed__` + + Methods for removing values: + + * :func:`SortedSet.clear` + * :func:`SortedSet.pop` + * :func:`SortedSet.remove` + + Set-operation methods: + + * :func:`SortedSet.difference` + * :func:`SortedSet.difference_update` + * :func:`SortedSet.intersection` + * :func:`SortedSet.intersection_update` + * :func:`SortedSet.symmetric_difference` + * :func:`SortedSet.symmetric_difference_update` + * :func:`SortedSet.union` + * :func:`SortedSet.update` + + Methods for miscellany: + + * :func:`SortedSet.copy` + * :func:`SortedSet.count` + * :func:`SortedSet.__repr__` + * :func:`SortedSet._check` + + Sorted list methods available: + + * :func:`SortedList.bisect_left` + * :func:`SortedList.bisect_right` + * :func:`SortedList.index` + * :func:`SortedList.irange` + * :func:`SortedList.islice` + * :func:`SortedList._reset` + + Additional sorted list methods available, if key-function used: + + * :func:`SortedKeyList.bisect_key_left` + * :func:`SortedKeyList.bisect_key_right` + * :func:`SortedKeyList.irange_key` + + Sorted set comparisons use subset and superset relations. Two sorted sets + are equal if and only if every element of each sorted set is contained in + the other (each is a subset of the other). A sorted set is less than + another sorted set if and only if the first sorted set is a proper subset + of the second sorted set (is a subset, but is not equal). A sorted set is + greater than another sorted set if and only if the first sorted set is a + proper superset of the second sorted set (is a superset, but is not equal). + + """ + def __init__(self, iterable=None, key=None): + """Initialize sorted set instance. + + Optional `iterable` argument provides an initial iterable of values to + initialize the sorted set. + + Optional `key` argument defines a callable that, like the `key` + argument to Python's `sorted` function, extracts a comparison key from + each value. The default, none, compares values directly. + + Runtime complexity: `O(n*log(n))` + + >>> ss = SortedSet([3, 1, 2, 5, 4]) + >>> ss + SortedSet([1, 2, 3, 4, 5]) + >>> from operator import neg + >>> ss = SortedSet([3, 1, 2, 5, 4], neg) + >>> ss + SortedSet([5, 4, 3, 2, 1], key=) + + :param iterable: initial values (optional) + :param key: function used to extract comparison key (optional) + + """ + self._key = key + + # SortedSet._fromset calls SortedSet.__init__ after initializing the + # _set attribute. So only create a new set if the _set attribute is not + # already present. + + if not hasattr(self, '_set'): + self._set = set() + + self._list = SortedList(self._set, key=key) + + # Expose some set methods publicly. + + _set = self._set + self.isdisjoint = _set.isdisjoint + self.issubset = _set.issubset + self.issuperset = _set.issuperset + + # Expose some sorted list methods publicly. + + _list = self._list + self.bisect_left = _list.bisect_left + self.bisect = _list.bisect + self.bisect_right = _list.bisect_right + self.index = _list.index + self.irange = _list.irange + self.islice = _list.islice + self._reset = _list._reset + + if key is not None: + self.bisect_key_left = _list.bisect_key_left + self.bisect_key_right = _list.bisect_key_right + self.bisect_key = _list.bisect_key + self.irange_key = _list.irange_key + + if iterable is not None: + self._update(iterable) + + + @classmethod + def _fromset(cls, values, key=None): + """Initialize sorted set from existing set. + + Used internally by set operations that return a new set. + + """ + sorted_set = object.__new__(cls) + sorted_set._set = values + sorted_set.__init__(key=key) + return sorted_set + + + @property + def key(self): + """Function used to extract comparison key from values. + + Sorted set compares values directly when the key function is none. + + """ + return self._key + + + def __contains__(self, value): + """Return true if `value` is an element of the sorted set. + + ``ss.__contains__(value)`` <==> ``value in ss`` + + Runtime complexity: `O(1)` + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> 3 in ss + True + + :param value: search for value in sorted set + :return: true if `value` in sorted set + + """ + return value in self._set + + + def __getitem__(self, index): + """Lookup value at `index` in sorted set. + + ``ss.__getitem__(index)`` <==> ``ss[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet('abcde') + >>> ss[2] + 'c' + >>> ss[-1] + 'e' + >>> ss[2:5] + ['c', 'd', 'e'] + + :param index: integer or slice for indexing + :return: value or list of values + :raises IndexError: if index out of range + + """ + return self._list[index] + + + def __delitem__(self, index): + """Remove value at `index` from sorted set. + + ``ss.__delitem__(index)`` <==> ``del ss[index]`` + + Supports slicing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet('abcde') + >>> del ss[2] + >>> ss + SortedSet(['a', 'b', 'd', 'e']) + >>> del ss[:2] + >>> ss + SortedSet(['d', 'e']) + + :param index: integer or slice for indexing + :raises IndexError: if index out of range + + """ + _set = self._set + _list = self._list + if isinstance(index, slice): + values = _list[index] + _set.difference_update(values) + else: + value = _list[index] + _set.remove(value) + del _list[index] + + + def __make_cmp(set_op, symbol, doc): + "Make comparator method." + def comparer(self, other): + "Compare method for sorted set and set." + if isinstance(other, SortedSet): + return set_op(self._set, other._set) + elif isinstance(other, Set): + return set_op(self._set, other) + return NotImplemented + + set_op_name = set_op.__name__ + comparer.__name__ = '__{0}__'.format(set_op_name) + doc_str = """Return true if and only if sorted set is {0} `other`. + + ``ss.__{1}__(other)`` <==> ``ss {2} other`` + + Comparisons use subset and superset semantics as with sets. + + Runtime complexity: `O(n)` + + :param other: `other` set + :return: true if sorted set is {0} `other` + + """ + comparer.__doc__ = dedent(doc_str.format(doc, set_op_name, symbol)) + return comparer + + + __eq__ = __make_cmp(eq, '==', 'equal to') + __ne__ = __make_cmp(ne, '!=', 'not equal to') + __lt__ = __make_cmp(lt, '<', 'a proper subset of') + __gt__ = __make_cmp(gt, '>', 'a proper superset of') + __le__ = __make_cmp(le, '<=', 'a subset of') + __ge__ = __make_cmp(ge, '>=', 'a superset of') + __make_cmp = staticmethod(__make_cmp) + + + def __len__(self): + """Return the size of the sorted set. + + ``ss.__len__()`` <==> ``len(ss)`` + + :return: size of sorted set + + """ + return len(self._set) + + + def __iter__(self): + """Return an iterator over the sorted set. + + ``ss.__iter__()`` <==> ``iter(ss)`` + + Iterating the sorted set while adding or deleting values may raise a + :exc:`RuntimeError` or fail to iterate over all values. + + """ + return iter(self._list) + + + def __reversed__(self): + """Return a reverse iterator over the sorted set. + + ``ss.__reversed__()`` <==> ``reversed(ss)`` + + Iterating the sorted set while adding or deleting values may raise a + :exc:`RuntimeError` or fail to iterate over all values. + + """ + return reversed(self._list) + + + def add(self, value): + """Add `value` to sorted set. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet() + >>> ss.add(3) + >>> ss.add(1) + >>> ss.add(2) + >>> ss + SortedSet([1, 2, 3]) + + :param value: value to add to sorted set + + """ + _set = self._set + if value not in _set: + _set.add(value) + self._list.add(value) + + _add = add + + + def clear(self): + """Remove all values from sorted set. + + Runtime complexity: `O(n)` + + """ + self._set.clear() + self._list.clear() + + + def copy(self): + """Return a shallow copy of the sorted set. + + Runtime complexity: `O(n)` + + :return: new sorted set + + """ + return self._fromset(set(self._set), key=self._key) + + __copy__ = copy + + + def count(self, value): + """Return number of occurrences of `value` in the sorted set. + + Runtime complexity: `O(1)` + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.count(3) + 1 + + :param value: value to count in sorted set + :return: count + + """ + return 1 if value in self._set else 0 + + + def discard(self, value): + """Remove `value` from sorted set if it is a member. + + If `value` is not a member, do nothing. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.discard(5) + >>> ss.discard(0) + >>> ss == set([1, 2, 3, 4]) + True + + :param value: `value` to discard from sorted set + + """ + _set = self._set + if value in _set: + _set.remove(value) + self._list.remove(value) + + _discard = discard + + + def pop(self, index=-1): + """Remove and return value at `index` in sorted set. + + Raise :exc:`IndexError` if the sorted set is empty or index is out of + range. + + Negative indices are supported. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet('abcde') + >>> ss.pop() + 'e' + >>> ss.pop(2) + 'c' + >>> ss + SortedSet(['a', 'b', 'd']) + + :param int index: index of value (default -1) + :return: value + :raises IndexError: if index is out of range + + """ + # pylint: disable=arguments-differ + value = self._list.pop(index) + self._set.remove(value) + return value + + + def remove(self, value): + """Remove `value` from sorted set; `value` must be a member. + + If `value` is not a member, raise :exc:`KeyError`. + + Runtime complexity: `O(log(n))` -- approximate. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.remove(5) + >>> ss == set([1, 2, 3, 4]) + True + >>> ss.remove(0) + Traceback (most recent call last): + ... + KeyError: 0 + + :param value: `value` to remove from sorted set + :raises KeyError: if `value` is not in sorted set + + """ + self._set.remove(value) + self._list.remove(value) + + + def difference(self, *iterables): + """Return the difference of two or more sets as a new sorted set. + + The `difference` method also corresponds to operator ``-``. + + ``ss.__sub__(iterable)`` <==> ``ss - iterable`` + + The difference is all values that are in this sorted set but not the + other `iterables`. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.difference([4, 5, 6, 7]) + SortedSet([1, 2, 3]) + + :param iterables: iterable arguments + :return: new sorted set + + """ + diff = self._set.difference(*iterables) + return self._fromset(diff, key=self._key) + + __sub__ = difference + + + def difference_update(self, *iterables): + """Remove all values of `iterables` from this sorted set. + + The `difference_update` method also corresponds to operator ``-=``. + + ``ss.__isub__(iterable)`` <==> ``ss -= iterable`` + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> _ = ss.difference_update([4, 5, 6, 7]) + >>> ss + SortedSet([1, 2, 3]) + + :param iterables: iterable arguments + :return: itself + + """ + _set = self._set + _list = self._list + values = set(chain(*iterables)) + if (4 * len(values)) > len(_set): + _set.difference_update(values) + _list.clear() + _list.update(_set) + else: + _discard = self._discard + for value in values: + _discard(value) + return self + + __isub__ = difference_update + + + def intersection(self, *iterables): + """Return the intersection of two or more sets as a new sorted set. + + The `intersection` method also corresponds to operator ``&``. + + ``ss.__and__(iterable)`` <==> ``ss & iterable`` + + The intersection is all values that are in this sorted set and each of + the other `iterables`. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.intersection([4, 5, 6, 7]) + SortedSet([4, 5]) + + :param iterables: iterable arguments + :return: new sorted set + + """ + intersect = self._set.intersection(*iterables) + return self._fromset(intersect, key=self._key) + + __and__ = intersection + __rand__ = __and__ + + + def intersection_update(self, *iterables): + """Update the sorted set with the intersection of `iterables`. + + The `intersection_update` method also corresponds to operator ``&=``. + + ``ss.__iand__(iterable)`` <==> ``ss &= iterable`` + + Keep only values found in itself and all `iterables`. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> _ = ss.intersection_update([4, 5, 6, 7]) + >>> ss + SortedSet([4, 5]) + + :param iterables: iterable arguments + :return: itself + + """ + _set = self._set + _list = self._list + _set.intersection_update(*iterables) + _list.clear() + _list.update(_set) + return self + + __iand__ = intersection_update + + + def symmetric_difference(self, other): + """Return the symmetric difference with `other` as a new sorted set. + + The `symmetric_difference` method also corresponds to operator ``^``. + + ``ss.__xor__(other)`` <==> ``ss ^ other`` + + The symmetric difference is all values tha are in exactly one of the + sets. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.symmetric_difference([4, 5, 6, 7]) + SortedSet([1, 2, 3, 6, 7]) + + :param other: `other` iterable + :return: new sorted set + + """ + diff = self._set.symmetric_difference(other) + return self._fromset(diff, key=self._key) + + __xor__ = symmetric_difference + __rxor__ = __xor__ + + + def symmetric_difference_update(self, other): + """Update the sorted set with the symmetric difference with `other`. + + The `symmetric_difference_update` method also corresponds to operator + ``^=``. + + ``ss.__ixor__(other)`` <==> ``ss ^= other`` + + Keep only values found in exactly one of itself and `other`. + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> _ = ss.symmetric_difference_update([4, 5, 6, 7]) + >>> ss + SortedSet([1, 2, 3, 6, 7]) + + :param other: `other` iterable + :return: itself + + """ + _set = self._set + _list = self._list + _set.symmetric_difference_update(other) + _list.clear() + _list.update(_set) + return self + + __ixor__ = symmetric_difference_update + + + def union(self, *iterables): + """Return new sorted set with values from itself and all `iterables`. + + The `union` method also corresponds to operator ``|``. + + ``ss.__or__(iterable)`` <==> ``ss | iterable`` + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> ss.union([4, 5, 6, 7]) + SortedSet([1, 2, 3, 4, 5, 6, 7]) + + :param iterables: iterable arguments + :return: new sorted set + + """ + return self.__class__(chain(iter(self), *iterables), key=self._key) + + __or__ = union + __ror__ = __or__ + + + def update(self, *iterables): + """Update the sorted set adding values from all `iterables`. + + The `update` method also corresponds to operator ``|=``. + + ``ss.__ior__(iterable)`` <==> ``ss |= iterable`` + + >>> ss = SortedSet([1, 2, 3, 4, 5]) + >>> _ = ss.update([4, 5, 6, 7]) + >>> ss + SortedSet([1, 2, 3, 4, 5, 6, 7]) + + :param iterables: iterable arguments + :return: itself + + """ + _set = self._set + _list = self._list + values = set(chain(*iterables)) + if (4 * len(values)) > len(_set): + _list = self._list + _set.update(values) + _list.clear() + _list.update(_set) + else: + _add = self._add + for value in values: + _add(value) + return self + + __ior__ = update + _update = update + + + def __reduce__(self): + """Support for pickle. + + The tricks played with exposing methods in :func:`SortedSet.__init__` + confuse pickle so customize the reducer. + + """ + return (type(self), (self._set, self._key)) + + + @recursive_repr() + def __repr__(self): + """Return string representation of sorted set. + + ``ss.__repr__()`` <==> ``repr(ss)`` + + :return: string representation + + """ + _key = self._key + key = '' if _key is None else ', key={0!r}'.format(_key) + type_name = type(self).__name__ + return '{0}({1!r}{2})'.format(type_name, list(self), key) + + + def _check(self): + """Check invariants of sorted set. + + Runtime complexity: `O(n)` + + """ + _set = self._set + _list = self._list + _list._check() + assert len(_set) == len(_list) + assert all(value in _set for value in _list) diff --git a/benchmarks/nn-variant/Clair3/shared/param_f.py b/benchmarks/nn-variant/Clair3/shared/param_f.py new file mode 100644 index 0000000..bfd56b2 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/param_f.py @@ -0,0 +1,55 @@ +# Clair3 full alignment parameters +REPO_NAME = "Clair3" +from itertools import accumulate + +zstd='zstd' +default_optimizer = "Radam" +default_loss_function = "FocalLoss" +min_af = 0.08 +min_af_dict = {'ont':0.15, 'hifi':min_af, 'ilmn':min_af } +matrix_depth_dict = {'ont': 89, 'hifi': 55, 'ilmn': 55} +max_depth = 144 +maximum_variant_length_that_need_infer = 50 +maximum_variant_length_that_need_infer_include_long_indel = 100000 +cal_precise_long_indel_af = False +long_indel_distance_proportion = 0.1 +min_mq = 5 +min_bq = 0 +min_coverage = 2 + +# Full alignment input feature list +channel = ( +'reference_base', 'alternative_base', 'mapping_quality', 'base_quality', 'strand_info', 'variant_type', 'insert_base', +'phasing_info') # phasing info if add_phasing +channel_size = len(channel) +flankingBaseNum = 16 +no_of_positions = 2 * flankingBaseNum + 1 +input_shape = [matrix_depth_dict['hifi'], no_of_positions, channel_size] +ont_input_shape = [matrix_depth_dict['ont'], no_of_positions, channel_size] +label_shape = [21, 3, no_of_positions, no_of_positions] +label_size = sum(label_shape) +label_shape_cum = list(accumulate(label_shape)) +expandReferenceRegion = 1000 +SAMTOOLS_VIEW_FILTER_FLAG = 2316 +NORMALIZE_NUM = 100 + +# Realignment parameters +partition_size = 500000 +realign_chunk_size = 5000 +phasing_window_size = 100000 +illumina_phasing_window_size = 10000 +max_phasing_depth = 15 +min_phasing_read_coverage = 2 +split_region_size = 1000 +extend_bp = 10 + +# Training hyperparameters +chunk_size = 200 +trainBatchSize = 2000 +predictBatchSize = 200 +initialLearningRate = 1e-3 +l2RegularizationLambda = 1e-7 +trainingDatasetPercentage = 0.9 +maxEpoch = 30 +OPERATION_SEED = None +RANDOM_SEED = None diff --git a/benchmarks/nn-variant/Clair3/shared/param_p.py b/benchmarks/nn-variant/Clair3/shared/param_p.py new file mode 100644 index 0000000..26bf3a5 --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/param_p.py @@ -0,0 +1,54 @@ +#Clair3 pileup parameters +REPO_NAME="Clair3" +import re +from itertools import accumulate + +zstd='zstd' +default_optimizer = "Radam" +default_loss_function = "FocalLoss" +support_platform = {'ont', 'hifi','ilmn'} +min_af = 0.08 +min_af_dict = {'ont':0.15, 'hifi':min_af, 'ilmn':min_af } +#as three platform training data vary in depth distribution, we recommend below max_depth base on max training data depth for calling +max_depth = 144 +max_depth_dict = {'ont':max_depth, 'hifi':max_depth, 'ilmn':max_depth} +maximum_variant_length_that_need_infer = 50 +maximum_variant_length_that_need_infer_include_long_indel = 100000 +cal_precise_long_indel_af = False +long_indel_distance_proportion = 0.1 +min_mq = 5 +min_bq = 0 +min_coverage = 2 +tensorflow_threads = 4 + +#GVCF parameters +base_err = 0.001 +gq_bin_size = 5 + +#Pileup input feature list +# 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 +channel = ('A', 'C', 'G', 'T', 'I', 'I1', 'D', 'D1', '*', 'a', 'c', 'g','t', 'i', 'i1','d', 'd1','#') +channel_size = len(channel) +flankingBaseNum = 16 +no_of_positions = 2 * flankingBaseNum + 1 +ont_input_shape = input_shape = [no_of_positions, channel_size] +label_shape = [21, 3, no_of_positions, no_of_positions] +label_size = sum(label_shape) +label_shape_cum = list(accumulate(label_shape)) +expandReferenceRegion = 1000 +SAMTOOLS_VIEW_FILTER_FLAG = 2316 +partition_size = 500000 +region_size =1000 +phasing_window_size = 30000 +extend_bp=10 + +#Training hyperparameters +chunk_size = 250 +trainBatchSize = 2000 +predictBatchSize = 200 +initialLearningRate = 1e-3 +trainingDatasetPercentage = 0.90 +l2RegularizationLambda = 0.0001 +maxEpoch = 30 +RANDOM_SEED = None +OPERATION_SEED = None diff --git a/benchmarks/nn-variant/Clair3/shared/utils.py b/benchmarks/nn-variant/Clair3/shared/utils.py new file mode 100644 index 0000000..730ef0f --- /dev/null +++ b/benchmarks/nn-variant/Clair3/shared/utils.py @@ -0,0 +1,227 @@ +import os +import sys +from os.path import isfile, abspath +from sys import exit, stderr +from subprocess import check_output, PIPE, Popen +import argparse +import shlex +from subprocess import PIPE +from os.path import isfile, isdir +# A->A +# C->C +# G->G +# T or U->T +# R->A or G +# Y->C or T +# S->G or C +# W->A or T +# K->G or T +# M->A or C +# B->C or G or T +# D->A or G or T +# H->A or C or T +# V->A or C or G +IUPAC_base_to_ACGT_base_dict = dict(zip( + "ACGTURYSWKMBDHVN", + ("A", "C", "G", "T", "T", "A", "C", "C", "A", "G", "A", "C", "A", "A", "A", "A") +)) + +IUPAC_base_to_num_dict = dict(zip( + "ACGTURYSWKMBDHVN", + (0, 1, 2, 3, 3, 0, 1, 1, 0, 2, 0, 1, 0, 0, 0, 0) +)) +BASIC_BASES = set("ACGTU") + +WARNING = '\033[93m' +ERROR = '\033[91m' +ENDC = '\033[0m' + +def log_error(log): + return ERROR + log + ENDC + +def log_warning(log): + return WARNING + log + ENDC + +def is_file_exists(file_name, suffix=""): + if not isinstance(file_name, str) or not isinstance(suffix, str): + return False + return isfile(file_name + suffix) + +def is_folder_exists(folder_name, suffix=""): + if not isinstance(folder_name, str) or not isinstance(suffix, str): + return False + return isdir(folder_name + suffix) + + +def legal_range_from(param_name, x, min_num=None, max_num=None, exit_out_of_range=False): + + if min_num is not None and x < min_num and exit_out_of_range: + exit(log_error("[ERROR] parameter --{}={} (minimum {}) out of range".format(param_name, x, min_num))) + if max_num is not None and x > max_num and exit_out_of_range: + exit(log_error("[ERROR] parameter --{}={} (maximum:{}) out of range".format(param_name, x, max_num))) + return + +def file_path_from(file_name, suffix="", exit_on_not_found=False, sep=""): + if is_file_exists(file_name, suffix): + return abspath(file_name + suffix) + #allow fn.bam.bai->fn.bai fn.fa.fai->fn.fai + elif sep != "" and len(sep) == 1: + file_name_remove_suffix = sep.join(file_name.split(sep)[:-1]) + if is_file_exists(file_name_remove_suffix, suffix): + return abspath(file_name_remove_suffix + suffix) + if exit_on_not_found: + exit(log_error("[ERROR] file %s not found" % (file_name + suffix))) + return None + +def folder_path_from(folder_name, create_not_found=True, exit_on_not_found=False): + if is_folder_exists(folder_name): + return abspath(folder_name) + if exit_on_not_found: + exit(log_error("[ERROR] folder %s not found" % (folder_name))) + if create_not_found: + if not os.path.exists(folder_name): + os.makedirs(abspath(folder_name)) + print("[INFO] Create folder %s" % (folder_name), file=stderr) + return abspath(folder_name) + return None + + +def is_command_exists(command): + if not isinstance(command, str): + return False + + try: + check_output("which %s" % (command), shell=True) + return True + except: + return False + + +def executable_command_string_from(command_to_execute, exit_on_not_found=False): + if is_command_exists(command_to_execute): + return command_to_execute + if exit_on_not_found: + exit(log_error("[ERROR] %s executable not found" % (command_to_execute))) + return None + + +def subprocess_popen(args, stdin=None, stdout=PIPE, stderr=stderr, bufsize=8388608): + return Popen(args, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, universal_newlines=True) + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def region_from(ctg_name, ctg_start=None, ctg_end=None): + """ + 1-based region string [start, end] + """ + if ctg_name is None: + return "" + if (ctg_start is None) != (ctg_end is None): + return "" + + if ctg_start is None and ctg_end is None: + return "{}".format(ctg_name) + return "{}:{}-{}".format(ctg_name, ctg_start, ctg_end) + +def reference_sequence_from(samtools_execute_command, fasta_file_path, regions): + refernce_sequences = [] + region_value_for_faidx = " ".join(regions) + + samtools_faidx_process = subprocess_popen( + shlex.split("{} faidx {} {}".format(samtools_execute_command, fasta_file_path, region_value_for_faidx)) + ) + while True: + row = samtools_faidx_process.stdout.readline() + is_finish_reading_output = row == '' and samtools_faidx_process.poll() is not None + if is_finish_reading_output: + break + if row: + refernce_sequences.append(row.rstrip()) + + # first line is reference name ">xxxx", need to be ignored + reference_sequence = "".join(refernce_sequences[1:]) + + # uppercase for masked sequences + reference_sequence = reference_sequence.upper() + + samtools_faidx_process.stdout.close() + samtools_faidx_process.wait() + if samtools_faidx_process.returncode != 0: + return None + + return reference_sequence + +def vcf_candidates_from(vcf_fn, contig_name=None): + + known_variants_set = set() + unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (vcf_fn))) + + start_pos, end_pos = float('inf'), 0 + for row in unzip_process.stdout: + if row[0] == '#': + continue + columns = row.strip().split(maxsplit=3) + ctg_name = columns[0] + + if contig_name and ctg_name != contig_name: + continue + center_pos = int(columns[1]) + known_variants_set.add(center_pos) + start_pos = min(start_pos, center_pos) + end_pos = max(center_pos, end_pos) + + known_variants_list = sorted(list(known_variants_set)) + return known_variants_list + +def candidate_position_generator_from( + candidate, + flanking_base_num, + begin_to_end +): + for position in candidate: + for i in range(position - (flanking_base_num + 1), position + (flanking_base_num + 1)): + if i not in begin_to_end: + begin_to_end[i] = [(position + (flanking_base_num + 1), position)] + else: + begin_to_end[i].append((position + (flanking_base_num + 1), position)) + yield position + yield -1 + + +def samtools_mpileup_generator_from( + candidate, + flanking_base_num, + begin_to_end +): + for position in candidate: + for i in range(position - (flanking_base_num + 1), position + (flanking_base_num + 1)): + if i not in begin_to_end: + begin_to_end[i] = [(position + (flanking_base_num + 1), position)] + else: + begin_to_end[i].append((position + (flanking_base_num + 1), position)) + yield position + yield -1 + +def samtools_view_process_from( + ctg_name, + ctg_start, + ctg_end, + samtools, + bam_file_path +): + have_start_and_end_position = ctg_start != None and ctg_end != None + region_str = ("%s:%d-%d" % (ctg_name, ctg_start, ctg_end)) if have_start_and_end_position else ctg_name + + return subprocess_popen( + shlex.split("%s view -F 2318 %s %s" % (samtools, bam_file_path, region_str)) + ) + diff --git a/benchmarks/nn-variant/README.md b/benchmarks/nn-variant/README.md index 8e60dc6..da5fe6e 100644 --- a/benchmarks/nn-variant/README.md +++ b/benchmarks/nn-variant/README.md @@ -1,17 +1,50 @@ -`nn-variant` uses the same license as [Clair](https://github.com/HKU-BAL/Clair). +# Neural Network-based Variant Calling (NN-VARIANT) -If you find `nn-variant` useful, please cite: +## Based on [Clair3](https://github.com/HKU-BAL/Clair3) +Symphonizing pileup and full-alignment for high-performance long-read variant calling + +If you find `NN-VARIANT` useful, please cite: ``` -@article{luo2020exploring, - title={Exploring the limit of using a deep neural network on pileup data for germline variant calling}, - author={Luo, Ruibang and Wong, Chak-Lim and Wong, Yat-Sing and Tang, Chi-Ian and Liu, Chi-Man and Leung, Chi-Ming and Lam, Tak-Wah}, - journal={Nature Machine Intelligence}, - volume={2}, - number={4}, - pages={220--227}, - year={2020}, - publisher={Nature Publishing Group} +@article{Zheng2021, + doi = {10.1101/2021.12.29.474431}, + url = {https://doi.org/10.1101/2021.12.29.474431}, + year = {2021}, + month = dec, + publisher = {Cold Spring Harbor Laboratory}, + author = {Zhenxian Zheng and Shumin Li and Junhao Su and Amy Wing-Sze Leung and Tak-Wah Lam and Ruibang Luo}, + title = {Symphonizing pileup and full-alignment for deep learning-based long-read variant calling} } ``` +## Installation + +NN-Variant requires the following Python3 packages (requirements.txt): + +| Package | Version | +| ----------------- | -------- | +| Python | 3.6.10 | +| TF | 2.1.0 | +| pypy | 3.6 | +| intervaltree | 3.0.2 | +| mpmath | 1.2.1 | +| tensorflow-addons | 0.11.2 | +| tables | 3.6.1 | +| pigz | 4.4 | +| parallel | 20191122 | +| zstd | 1.4.4 | +| samtools | 1.1.0 | +| whatshap | 1.0 | +| ensurepip | | + +## Execution + +``` +./Clair3/callVar.sh --bam_fn= --ref_fn= --threads= --platform={ont,hifi,ilmn}| \ + --model_path=\"r941_prom_hac_g360+g422\" \ + --bed_fn=region.bed \ + --output=. \ + --chunk_size=\$((${contig_len}/\${OMP_NUM_THREADS}+1)) \ + \$PYPY" +``` + diff --git a/benchmarks/nn-variant/clair/__pycache__/__init__.cpython-37.pyc b/benchmarks/nn-variant/clair/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 88260b1..0000000 Binary files a/benchmarks/nn-variant/clair/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/__pycache__/model.cpython-37.pyc b/benchmarks/nn-variant/clair/__pycache__/model.cpython-37.pyc deleted file mode 100644 index 1a9f63a..0000000 Binary files a/benchmarks/nn-variant/clair/__pycache__/model.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/__pycache__/selu.cpython-37.pyc b/benchmarks/nn-variant/clair/__pycache__/selu.cpython-37.pyc deleted file mode 100644 index 7b999fa..0000000 Binary files a/benchmarks/nn-variant/clair/__pycache__/selu.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/callVarBam.py b/benchmarks/nn-variant/clair/callVarBam.py deleted file mode 100644 index aa00c1f..0000000 --- a/benchmarks/nn-variant/clair/callVarBam.py +++ /dev/null @@ -1,332 +0,0 @@ -import sys -import shlex -import subprocess -import multiprocessing -import signal -import random -from os.path import dirname -from time import sleep -from argparse import ArgumentParser - -from shared.command_options import ( - CommandOption, - CommandOptionWithNoValue, - ExecuteCommand, - command_string_from, - command_option_from -) -from shared.utils import file_path_from, executable_command_string_from, subprocess_popen - - -class InstancesClass(object): - def __init__(self): - self.extract_variant_candidate = None - self.create_tensor = None - self.call_variant = None - - def poll(self): - self.extract_variant_candidate.poll() - self.create_tensor.poll() - self.call_variant.poll() - - -c = InstancesClass() - - -def check_return_code(signum, frame): - c.poll() - #print >> sys.stderr, c.extract_variant_candidate.returncode, c.create_tensor.returncode, c.call_variant.returncode - if c.extract_variant_candidate.returncode != None and c.extract_variant_candidate.returncode != 0: - c.create_tensor.kill() - c.call_variant.kill() - sys.exit("ExtractVariantCandidates.py or GetTruth.py exited with exceptions. Exiting...") - - if c.create_tensor.returncode != None and c.create_tensor.returncode != 0: - c.extract_variant_candidate.kill() - c.call_variant.kill() - sys.exit("CreateTensor.py exited with exceptions. Exiting...") - - if c.call_variant.returncode != None and c.call_variant.returncode != 0: - c.extract_variant_candidate.kill() - c.create_tensor.kill() - sys.exit("call_variant.py exited with exceptions. Exiting...") - - if ( - c.extract_variant_candidate.returncode == None or - c.create_tensor.returncode == None or - c.call_variant.returncode == None - ): - signal.alarm(5) - - -def Run(args): - basedir = dirname(__file__) - EVCBin = basedir + "/../clair.py ExtractVariantCandidates" - GTBin = basedir + "/../clair.py GetTruth" - CTBin = basedir + "/../clair.py CreateTensor" - CVBin = basedir + "/../clair.py call_var" - - pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True) - samtoolsBin = executable_command_string_from(args.samtools, exit_on_not_found=True) - - chkpnt_fn = file_path_from(args.chkpnt_fn, suffix=".meta", exit_on_not_found=True) - bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) - ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) - vcf_fn = file_path_from(args.vcf_fn) - bed_fn = file_path_from(args.bed_fn) - - dcov = args.dcov - call_fn = args.call_fn - af_threshold = args.threshold - minCoverage = int(args.minCoverage) - sampleName = args.sampleName - ctgName = args.ctgName - if ctgName is None: - sys.exit("--ctgName must be specified. You can call variants on multiple chromosomes simultaneously.") - - stop_consider_left_edge = command_option_from(args.stop_consider_left_edge, 'stop_consider_left_edge') - log_path = command_option_from(args.log_path, 'log_path', option_value=args.log_path) - pysam_for_all_indel_bases = command_option_from(args.pysam_for_all_indel_bases, 'pysam_for_all_indel_bases') - haploid_precision_mode = command_option_from(args.haploid_precision, 'haploid_precision') - haploid_sensitive_mode = command_option_from(args.haploid_sensitive, 'haploid_sensitive') - output_for_ensemble = command_option_from(args.output_for_ensemble, 'output_for_ensemble') - debug = command_option_from(args.debug, 'debug') - qual = command_option_from(args.qual, 'qual', option_value=args.qual) - fast_plotting = command_option_from(args.fast_plotting, 'fast_plotting') - - ctgStart = None - ctgEnd = None - if args.ctgStart is not None and args.ctgEnd is not None and int(args.ctgStart) <= int(args.ctgEnd): - ctgStart = CommandOption('ctgStart', args.ctgStart) - ctgEnd = CommandOption('ctgEnd', args.ctgEnd) - - if args.threads is None: - numCpus = multiprocessing.cpu_count() - else: - numCpus = args.threads if args.threads < multiprocessing.cpu_count() else multiprocessing.cpu_count() - - maxCpus = multiprocessing.cpu_count() - _cpuSet = ",".join(str(x) for x in random.sample(range(0, maxCpus), numCpus)) - - taskSet = "taskset -c %s" % (_cpuSet) - try: - subprocess.check_output("which %s" % ("taskset"), shell=True) - except: - taskSet = "" - - if args.delay > 0: - delay = random.randrange(0, args.delay) - print("Delay %d seconds before starting variant calling ..." % (delay), file=sys.stderr) - sleep(delay) - - extract_variant_candidate_command_options = [ - pypyBin, - EVCBin, - CommandOption('bam_fn', bam_fn), - CommandOption('ref_fn', ref_fn), - CommandOption('bed_fn', bed_fn), - CommandOption('ctgName', ctgName), - ctgStart, - ctgEnd, - CommandOption('threshold', af_threshold), - CommandOption('minCoverage', minCoverage), - CommandOption('samtools', samtoolsBin) - ] - get_truth_command_options = [ - pypyBin, - GTBin, - CommandOption('vcf_fn', vcf_fn), - CommandOption('ctgName', ctgName), - ctgStart, - ctgEnd - ] - - create_tensor_command_options = [ - pypyBin, - CTBin, - CommandOption('bam_fn', bam_fn), - CommandOption('ref_fn', ref_fn), - CommandOption('ctgName', ctgName), - ctgStart, - ctgEnd, - stop_consider_left_edge, - CommandOption('samtools', samtoolsBin), - CommandOption('dcov', dcov) - ] - - call_variant_command_options = [ - taskSet, - ExecuteCommand('python', CVBin), - CommandOption('chkpnt_fn', chkpnt_fn), - CommandOption('call_fn', call_fn), - CommandOption('bam_fn', bam_fn), - CommandOption('sampleName', sampleName), - CommandOption('threads', numCpus), - CommandOption('ref_fn', ref_fn), - pysam_for_all_indel_bases, - haploid_precision_mode, - haploid_sensitive_mode, - output_for_ensemble, - qual, - debug - ] - call_variant_with_activation_command_options = [ - CommandOptionWithNoValue('activation_only'), - log_path, - CommandOption('max_plot', args.max_plot), - CommandOption('parallel_level', args.parallel_level), - CommandOption('workers', args.workers), - fast_plotting, - ] if args.activation_only else [] - - is_true_variant_call = vcf_fn is not None - try: - c.extract_variant_candidate = subprocess_popen( - shlex.split(command_string_from( - get_truth_command_options if is_true_variant_call else extract_variant_candidate_command_options - )) - ) - - c.create_tensor = subprocess_popen( - shlex.split(command_string_from(create_tensor_command_options)), - stdin=c.extract_variant_candidate.stdout - ) - - c.call_variant = subprocess_popen( - shlex.split(command_string_from( - call_variant_command_options + call_variant_with_activation_command_options - )), - stdin=c.create_tensor.stdout, stdout=sys.stderr - ) - except Exception as e: - print(e, file=sys.stderr) - sys.exit("Failed to start required processes. Exiting...") - - signal.signal(signal.SIGALRM, check_return_code) - signal.alarm(2) - - try: - c.call_variant.wait() - c.create_tensor.stdout.close() - c.create_tensor.wait() - c.extract_variant_candidate.stdout.close() - c.extract_variant_candidate.wait() - except KeyboardInterrupt as e: - print("KeyboardInterrupt received when waiting at CallVarBam, terminating all scripts.") - try: - c.call_variant.terminate() - c.create_tensor.terminate() - c.extract_variant_candidate.terminate() - except Exception as e: - print(e) - - raise KeyboardInterrupt - except Exception as e: - print("Exception received when waiting at CallVarBam, terminating all scripts.") - print(e) - try: - c.call_variant.terminate() - c.create_tensor.terminate() - c.extract_variant_candidate.terminate() - except Exception as e: - print(e) - - raise e - - -def main(): - parser = ArgumentParser(description="Call variants using a trained model and a BAM file") - - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a model") - - parser.add_argument('--ref_fn', type=str, default="ref.fa", - help="Reference fasta file input, default: %(default)s") - - parser.add_argument('--bed_fn', type=str, default=None, - help="Call variant only in these regions, works in intersection with ctgName, ctgStart and ctgEnd, optional, default: as defined by ctgName, ctgStart and ctgEnd") - - parser.add_argument('--bam_fn', type=str, default="bam.bam", - help="BAM file input, default: %(default)s") - - parser.add_argument('--call_fn', type=str, default=None, - help="Output variant predictions") - - parser.add_argument('--vcf_fn', type=str, default=None, - help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") - - parser.add_argument('--threshold', type=float, default=0.125, - help="Minimum allele frequence of the 1st non-reference allele for a site to be considered as a condidate site, default: %(default)f") - - parser.add_argument('--minCoverage', type=float, default=4, - help="Minimum coverage required to call a variant, default: %(default)d") - - parser.add_argument('--qual', type=int, default=None, - help="If set, variant with equal or higher quality will be marked PASS, or LowQual otherwise, optional") - - parser.add_argument('--sampleName', type=str, default="SAMPLE", - help="Define the sample name to be shown in the VCF file") - - parser.add_argument('--ctgName', type=str, default=None, - help="The name of sequence to be processed, default: %(default)s") - parser.add_argument('--ctgStart', type=int, default=None, - help="The 1-based starting position of the sequence to be processed") - parser.add_argument('--ctgEnd', type=int, default=None, - help="The 1-based inclusive ending position of the sequence to be processed") - - parser.add_argument('--stop_consider_left_edge', action='store_true', - help="If not set, would consider left edge only. That is, count the left-most base-pairs of a read for coverage even if the starting position of a read is after the starting position of a tensor") - - parser.add_argument('--dcov', type=int, default=250, - help="Cap depth per position at %(default)s") - - parser.add_argument('--samtools', type=str, default="samtools", - help="Path to the 'samtools', default: %(default)s") - - parser.add_argument('--pypy', type=str, default="pypy3", - help="Path to the 'pypy', default: %(default)s") - - parser.add_argument('--threads', type=int, default=None, - help="Number of threads, optional") - - parser.add_argument('--delay', type=int, default=10, - help="Wait a short while for no more than %(default)s to start the job. This is to avoid starting multiple jobs simultaneously that might use up the maximum number of threads allowed, because Tensorflow will create more threads than needed at the beginning of running the program.") - - parser.add_argument('--debug', action='store_true', - help="Debug mode, optional") - - parser.add_argument('--pysam_for_all_indel_bases', action='store_true', - help="Always using pysam for outputting indel bases, optional") - - parser.add_argument('--haploid_precision', action='store_true', - help="call haploid instead of diploid (output homo-variant only)") - parser.add_argument('--haploid_sensitive', action='store_true', - help="call haploid instead of diploid (output non-multi-variant only)") - - parser.add_argument('--activation_only', action='store_true', - help="Output activation only, no prediction") - parser.add_argument('--max_plot', type=int, default=10, - help="The maximum number of plots output, negative number means no limit (plot all), default: %(default)s") - parser.add_argument('--log_path', type=str, nargs='?', default=None, - help="The path for tensorflow logging, default: %(default)s") - parser.add_argument('-p', '--parallel_level', type=int, default=2, - help="The level of parallelism in plotting (currently available: 0, 2), default: %(default)s") - parser.add_argument('--fast_plotting', action='store_true', - help="Enable fast plotting.") - parser.add_argument('-w', '--workers', type=int, default=8, - help="The number of workers in plotting, default: %(default)s") - - parser.add_argument('--output_for_ensemble', action='store_true', - help="Output for ensemble") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - Run(args) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/callVarBamParallel.py b/benchmarks/nn-variant/clair/callVarBamParallel.py deleted file mode 100644 index 19c3d6e..0000000 --- a/benchmarks/nn-variant/clair/callVarBamParallel.py +++ /dev/null @@ -1,219 +0,0 @@ -import os -import sys -import argparse - -from shared.command_options import ( - CommandOption, - CommandOptionWithNoValue, - ExecuteCommand, - command_string_from, - command_option_from -) -from shared.interval_tree import bed_tree_from, is_region_in -from shared.utils import file_path_from, executable_command_string_from - -major_contigs = {"chr"+str(a) for a in list(range(1, 23))+["X", "Y"]}.union({str(a) for a in list(range(1, 23))+["X", "Y"]}) - - -def Run(args): - basedir = os.path.dirname(__file__) - - callVarBamBin = basedir + "/../clair.py callVarBam" - pypyBin = executable_command_string_from(args.pypy, exit_on_not_found=True) - samtoolsBin = executable_command_string_from(args.samtools, exit_on_not_found=True) - - chkpnt_fn = file_path_from(args.chkpnt_fn, suffix=".meta", exit_on_not_found=True) - bam_fn = file_path_from(args.bam_fn, exit_on_not_found=True) - ref_fn = file_path_from(args.ref_fn, exit_on_not_found=True) - fai_fn = file_path_from(args.ref_fn + ".fai", exit_on_not_found=True) - bed_fn = file_path_from(args.bed_fn) - vcf_fn = file_path_from(args.vcf_fn) - - output_prefix = args.output_prefix - af_threshold = args.threshold - - tree = bed_tree_from(bed_file_path=bed_fn) - - minCoverage = args.minCoverage - sampleName = args.sampleName - delay = args.delay - threads = args.tensorflowThreads - qual = args.qual - is_include_all_contigs = args.includingAllContigs - region_chunk_size = args.refChunkSize - - stop_consider_left_edge = command_option_from(args.stop_consider_left_edge, 'stop_consider_left_edge') - log_path = command_option_from(args.log_path, 'log_path', option_value=args.log_path) - pysam_for_all_indel_bases = command_option_from(args.pysam_for_all_indel_bases, 'pysam_for_all_indel_bases') - haploid_precision_mode = command_option_from(args.haploid_precision, 'haploid_precision') - haploid_sensitive_mode = command_option_from(args.haploid_sensitive, 'haploid_sensitive') - output_for_ensemble = command_option_from(args.output_for_ensemble, 'output_for_ensemble') - debug = command_option_from(args.debug, 'debug') - qual = command_option_from(args.qual, 'qual', option_value=args.qual) - fast_plotting = command_option_from(args.fast_plotting, 'fast_plotting') - - call_var_bam_command_options = [ - ExecuteCommand('python', callVarBamBin), - CommandOption('chkpnt_fn', chkpnt_fn), - CommandOption('ref_fn', ref_fn), - CommandOption('bam_fn', bam_fn), - CommandOption('threshold', af_threshold), - CommandOption('minCoverage', minCoverage), - CommandOption('pypy', pypyBin), - CommandOption('samtools', samtoolsBin), - CommandOption('delay', delay), - CommandOption('threads', threads), - CommandOption('sampleName', sampleName), - # optional command options - CommandOption('vcf_fn', vcf_fn) if vcf_fn is not None else None, - qual, - stop_consider_left_edge, - debug, - pysam_for_all_indel_bases, - haploid_precision_mode, - haploid_sensitive_mode, - output_for_ensemble, - ] - - activation_only_command_options = [ - CommandOptionWithNoValue('activation_only'), - log_path, - CommandOption('max_plot', args.max_plot), - CommandOption('parallel_level', args.parallel_level), - CommandOption('workers', args.workers), - fast_plotting, - ] if args.activation_only else [] - - is_bed_file_provided = bed_fn is not None - command_string = command_string_from(call_var_bam_command_options + activation_only_command_options) - - with open(fai_fn, 'r') as fai_fp: - for row in fai_fp: - columns = row.strip().split("\t") - - contig_name = columns[0] - if not is_include_all_contigs and str(contig_name) not in major_contigs: - continue - - region_start, region_end = 0, 0 - contig_length = int(columns[1]) - while region_end < contig_length: - region_start = region_end - region_end = region_start + region_chunk_size - if region_end > contig_length: - region_end = contig_length - output_fn = "%s.%s_%d_%d.vcf" % (output_prefix, contig_name, region_start, region_end) - - is_region_in_bed = is_bed_file_provided and is_region_in(tree, contig_name, region_start, region_end) - need_output_command = not is_bed_file_provided or is_region_in_bed - if not need_output_command: - continue - - additional_command_options = [ - CommandOption('ctgName', contig_name), - CommandOption('ctgStart', region_start), - CommandOption('ctgEnd', region_end), - CommandOption('call_fn', output_fn), - CommandOption('bed_fn', bed_fn) if is_region_in_bed else None - ] - print(command_string + " " + command_string_from(additional_command_options)) - - -def main(): - parser = argparse.ArgumentParser( - description="Create commands for calling variants in parallel using a trained model and a BAM file") - - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a model") - - parser.add_argument('--ref_fn', type=str, default="ref.fa", - help="Reference fasta file input, default: %(default)s") - - parser.add_argument('--bed_fn', type=str, default=None, - help="Call variant only in these regions, optional, default: whole genome") - - parser.add_argument('--refChunkSize', type=int, default=10000000, - help="Divide job with smaller genome chunk size for parallelism, default: %(default)s") - - parser.add_argument('--bam_fn', type=str, default="bam.bam", - help="BAM file input, default: %(default)s") - - parser.add_argument('--vcf_fn', type=str, default=None, - help="Candidate sites VCF file input, if provided, variants will only be called at the sites in the VCF file, default: %(default)s") - - parser.add_argument('--output_prefix', type=str, default=None, - help="Output prefix") - - parser.add_argument('--includingAllContigs', action='store_true', - help="Call variants on all contigs, default: chr{1..22,X,Y,M,MT} and {1..22,X,Y,MT}") - - parser.add_argument('--tensorflowThreads', type=int, default=4, - help="Number of threads per tensorflow job, default: %(default)s") - - parser.add_argument('--threshold', type=float, default=0.2, - help="Minimum allele frequence of the 1st non-reference allele for a site to be considered as a condidate site, default: %(default)f") - - parser.add_argument('--minCoverage', type=float, default=4, - help="Minimum coverage required to call a variant, default: %(default)d") - - parser.add_argument('--qual', type=int, default=None, - help="If set, variant with equal or higher quality will be marked PASS, or LowQual otherwise, optional") - - parser.add_argument('--sampleName', type=str, default="SAMPLE", - help="Define the sample name to be shown in the VCF file") - - parser.add_argument('--stop_consider_left_edge', action='store_true', - help="If not set, would consider left edge only. That is, count the left-most base-pairs of a read for coverage even if the starting position of a read is after the starting position of a tensor") - - parser.add_argument('--samtools', type=str, default="samtools", - help="Path to the 'samtools', default: %(default)s") - - parser.add_argument('--pypy', type=str, default="pypy3", - help="Path to the 'pypy', default: %(default)s") - - parser.add_argument('--delay', type=int, default=10, - help="Wait a short while for no more than %(default)s to start the job. This is to avoid starting multiple jobs simultaneously that might use up the maximum number of threads allowed, because Tensorflow will create more threads than needed at the beginning of running the program.") - - parser.add_argument('--debug', action='store_true', - help="Debug mode, optional") - - parser.add_argument('--pysam_for_all_indel_bases', action='store_true', - help="Always using pysam for outputting indel bases, optional") - - parser.add_argument('--haploid_precision', action='store_true', - help="call haploid instead of diploid (output homo-variant only)") - parser.add_argument('--haploid_sensitive', action='store_true', - help="call haploid instead of diploid (output non-multi-variant only)") - - parser.add_argument('--activation_only', action='store_true', - help="Output activation only, no prediction") - parser.add_argument('--max_plot', type=int, default=10, - help="The maximum number of plots output, negative number means no limit (plot all), default: %(default)s") - parser.add_argument('--log_path', type=str, nargs='?', default=None, - help="The path for tensorflow logging, default: %(default)s") - parser.add_argument('-p', '--parallel_level', type=int, default=2, - help="The level of parallelism in plotting (currently available: 0, 2), default: %(default)s") - parser.add_argument('-w', '--workers', type=int, default=8, - help="The number of workers in plotting, default: %(default)s") - parser.add_argument('--fast_plotting', action='store_true', - help="Enable fast plotting.") - - parser.add_argument('--output_for_ensemble', action='store_true', - help="Output for ensemble") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - if not args.includingAllContigs: - print("echo \"[INFO] --includingAllContigs not enabled, use chr{1..22,X,Y,M,MT} and {1..22,X,Y,MT} by default\"\n") - else: - print("echo \"[INFO] --includingAllContigs enabled\"\n") - - Run(args) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/call_var.py b/benchmarks/nn-variant/clair/call_var.py deleted file mode 100644 index 36f466f..0000000 --- a/benchmarks/nn-variant/clair/call_var.py +++ /dev/null @@ -1,1439 +0,0 @@ -import sys -import os -import logging -import numpy as np -import pysam -from time import time -from argparse import ArgumentParser -from threading import Thread -from math import log, e -from enum import IntEnum -from collections import namedtuple, defaultdict - - -import clair.utils as utils -from clair.model import Clair -from clair.task.gt21 import ( - GT21_Type, gt21_enum_from_label, gt21_enum_from, - HOMO_SNP_GT21, HOMO_SNP_LABELS, - HETERO_SNP_GT21, HETERO_SNP_LABELS -) -from clair.task.genotype import Genotype, genotype_string_from, genotype_enum_from, genotype_enum_for_task -from clair.task.variant_length import VariantLength -from shared.utils import IUPAC_base_to_num_dict as BASE2NUM, IUPAC_base_to_ACGT_base_dict as BASE2ACGT, BASIC_BASES -import shared.param as param - - -logging.basicConfig(format='%(message)s', level=logging.INFO) -num2base = dict(zip((0, 1, 2, 3), "ACGT")) -minimum_variant_length_that_need_infer = VariantLength.max -maximum_variant_length_that_need_infer = 50 -inferred_indel_length_minimum_allele_frequency = 0.125 -flanking_base_number = param.flankingBaseNum - -OutputConfig = namedtuple('OutputConfig', [ - 'is_show_reference', - 'is_debug', - 'is_haploid_precision_mode_enabled', - 'is_haploid_sensitive_mode_enabled', - 'is_output_for_ensemble', - 'quality_score_for_pass', -]) -OutputUtilities = namedtuple('OutputUtilities', [ - 'print_debug_message', - 'insertion_bases_using', - 'deletion_bases_using', - 'insertion_bases_using_pysam_using', - 'output', - 'output_header', - 'close_opened_files', -]) - - -class Channel(IntEnum): - reference = 0 - insert = 1 - delete = 2 - SNP = 3 - - -def homo_SNP_bases_from(gt21_probabilities): - output_bases = HOMO_SNP_LABELS[np.argmax([gt21_probabilities[gt21_enum] for gt21_enum in HOMO_SNP_GT21])] - return output_bases[0], output_bases[1] - - -def hetero_SNP_bases_from(gt21_probabilities): - output_bases = HETERO_SNP_LABELS[np.argmax([gt21_probabilities[gt21_enum] for gt21_enum in HETERO_SNP_GT21])] - return output_bases[0], output_bases[1] - - -def filtration_value_from(quality_score_for_pass, quality_score): - if quality_score_for_pass is None: - return "." - if quality_score >= quality_score_for_pass: - return "PASS" - return "LowQual" - - -def pileup(sam_file, contig, position_start, position_end, func): - """ - Pileup using pysam - - sam_file: pysam.AlignmentFile for pileup - contig: chromosome name or contig name - position_start: start position. 0-based. Inclusive. - position_end: ending position. 0-based. Exclusive. - func: callback for pileup_column - """ - try: - for pileup_column in sam_file.pileup( - contig, - start=position_start, - stop=position_end, - flag_filter=param.SAMTOOLS_VIEW_FILTER_FLAG, - min_base_quality=0, - max_depth=250 - ): - func(pileup_column) - except AssertionError: - pass - - -def insertion_bases_using_pysam_from( - sam_file, - contig, - position, - minimum_insertion_length=1, - maximum_insertion_length=maximum_variant_length_that_need_infer, - insertion_bases_to_ignore="" -): - insertion_bases_dict = defaultdict(lambda: 0) - - def lambda_function(pileup_column): - if pileup_column.reference_pos != position - 1: - return - - for sequence in pileup_column.get_query_sequences(mark_matches=False, mark_ends=False, add_indels=True): - # minimum sequence needed: A+1A, and "+" for insertion - if len(sequence) < 4 or sequence[1] != "+": - continue - - no_of_insertion_bases = 0 - for (string_index, c) in enumerate(sequence[2:]): - if not c.isdigit(): - insertion_bases = sequence[string_index+2:].upper() - break - no_of_insertion_bases = no_of_insertion_bases * 10 + int(c) - - if ( - minimum_insertion_length <= no_of_insertion_bases <= maximum_insertion_length and - insertion_bases != insertion_bases_to_ignore - ): - insertion_bases_dict[insertion_bases] = insertion_bases_dict[insertion_bases] + 1 - pileup(sam_file, contig, position, position+1, func=lambda_function) - - return max(insertion_bases_dict, key=insertion_bases_dict.get) if len(insertion_bases_dict) > 0 else "" - - -def deletion_bases_using_pysam_from( - sam_file, - fasta_file, - contig, - position, - minimum_deletion_length=1, - maximum_deletion_length=maximum_variant_length_that_need_infer -): - deletion_bases_dict = defaultdict(lambda: 0) - - def lambda_function(pileup_column): - if pileup_column.reference_pos != position - 1: - return - - for sequence in pileup_column.get_query_sequences(mark_matches=False, mark_ends=False, add_indels=True): - # minimum sequence needed: A-1A, and "-" for deletion - if len(sequence) < 4 or sequence[1] != "-": - continue - - no_of_deletion_bases = 0 - for c in sequence[2:]: - if not c.isdigit(): - deletion_bases = fasta_file.fetch( - reference=contig, start=position, end=position + no_of_deletion_bases - ) - break - no_of_deletion_bases = no_of_deletion_bases * 10 + int(c) - - if minimum_deletion_length <= no_of_deletion_bases <= maximum_deletion_length: - deletion_bases_dict[deletion_bases] = deletion_bases_dict[deletion_bases] + 1 - pileup(sam_file, contig, position, position+1, func=lambda_function) - - return max(deletion_bases_dict, key=deletion_bases_dict.get) if len(deletion_bases_dict) > 0 else "" - - -def Run(args): - utils.setup_environment() - - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["OPENBLAS_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - os.environ["NUMEXPR_NUM_THREADS"] = "1" - - if args.threads == None: - if args.tensor_fn == "PIPE": - param.NUM_THREADS = 4 - else: - param.NUM_THREADS = args.threads - param.NUM_THREADS -= 1 - if param.NUM_THREADS < 1: - param.NUM_THREADS = 1 - - output_config = OutputConfig( - is_show_reference=args.showRef, - is_debug=args.debug, - is_haploid_precision_mode_enabled=args.haploid_precision, - is_haploid_sensitive_mode_enabled=args.haploid_sensitive, - is_output_for_ensemble=args.output_for_ensemble, - quality_score_for_pass=args.qual, - ) - output_utilities = output_utilties_from( - sample_name=args.sampleName, - is_debug=args.debug, - is_output_for_ensemble=args.output_for_ensemble, - is_using_pysam_for_all_indel_bases_output=args.pysam_for_all_indel_bases, - reference_file_path=args.ref_fn, - bam_file_path=args.bam_fn, - output_file_path=args.call_fn, - ) - - if args.input_probabilities: - call_variants_with_probabilities_input(args, output_config, output_utilities) - return - - m = Clair() - m.init() - m.restore_parameters(os.path.abspath(args.chkpnt_fn)) - - if args.activation_only: - log_activation(args, m) - else: - call_variants(args, m, output_config, output_utilities) - - -def output_utilties_from( - sample_name, - is_debug, - is_output_for_ensemble, - is_using_pysam_for_all_indel_bases_output, - bam_file_path, - reference_file_path, - output_file_path, -): - fasta_file = pysam.FastaFile(filename=reference_file_path) if reference_file_path else None - sam_file = pysam.AlignmentFile(bam_file_path, mode="rb") - output_file = open(output_file_path, "w") - - def output(string_value): - print(string_value, file=output_file) - - def print_debug_message( - chromosome, - position, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - extra_infomation_string="" - ): - if not is_debug: - return - - output("{}\t{}\t{}\t{}\t{}\t{}\t{}".format( - chromosome, - position, - ["{:0.8f}".format(x) for x in gt21_probabilities], - ["{:0.8f}".format(x) for x in genotype_probabilities], - ["{:0.8f}".format(x) for x in variant_length_probabilities_1], - ["{:0.8f}".format(x) for x in variant_length_probabilities_2], - extra_infomation_string - )) - - def insertion_bases_using(tensor_input, variant_length, contig, position): - return insertion_bases_from( - sam_file=sam_file, - tensor_input=tensor_input, - variant_length=variant_length, - contig=contig, - position=position, - is_using_pysam_for_all_indel_bases_output=is_using_pysam_for_all_indel_bases_output - ) - - def deletion_bases_using(tensor_input, variant_length, contig, position, reference_sequence): - return deletion_bases_from( - tensor_input=tensor_input, - variant_length=variant_length, - sam_file=sam_file, - fasta_file=fasta_file, - contig=contig, - position=position, - reference_sequence=reference_sequence, - is_using_pysam_for_all_indel_bases_output=is_using_pysam_for_all_indel_bases_output - ) - - def insertion_bases_using_pysam_using( - contig, - position, - minimum_insertion_length, - maximum_insertion_length, - insertion_bases_to_ignore - ): - return insertion_bases_using_pysam_from( - sam_file=sam_file, - contig=contig, - position=position, - minimum_insertion_length=minimum_insertion_length, - maximum_insertion_length=maximum_insertion_length, - insertion_bases_to_ignore=insertion_bases_to_ignore - ) - - def close_opened_files(): - sam_file.close() - fasta_file.close() - output_file.close() - - def output_header(): - if is_output_for_ensemble: - return - - from textwrap import dedent - output(dedent("""\ - ##fileformat=VCFv4.1 - ##FILTER= - ##FILTER= - ##ALT= - ##ALT= - ##INFO= - ##INFO= - ##FORMAT= - ##FORMAT= - ##FORMAT= - ##FORMAT=""" - )) - - if reference_file_path is not None: - reference_index_file_path = reference_file_path + ".fai" - with open(reference_index_file_path, "r") as fai_fp: - for row in fai_fp: - columns = row.strip().split("\t") - contig_name, contig_size = columns[0], columns[1] - output("##contig=" % (contig_name, contig_size)) - - output('#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\t%s' % (sample_name)) - - return OutputUtilities( - print_debug_message, - insertion_bases_using, - deletion_bases_using, - insertion_bases_using_pysam_using, - output, - output_header, - close_opened_files, - ) - - -def homo_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): - return [( - i, - variant_length_probabilities_1[i + VariantLength.index_offset] * - variant_length_probabilities_2[i + VariantLength.index_offset] * extra_probability - ) for i in range(1, VariantLength.max + 1)] - - -def hetero_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): - return [( - i, - max( - variant_length_probabilities_1[0 + VariantLength.index_offset] * - variant_length_probabilities_2[i + VariantLength.index_offset], - variant_length_probabilities_1[i + VariantLength.index_offset] * - variant_length_probabilities_2[0 + VariantLength.index_offset], - ) - ) for i in range(1, VariantLength.max + 1)] - - -def hetero_InsIns_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): - probabilities = [] - for i in range(1, VariantLength.max + 1): - for j in range(1, VariantLength.max + 1): - # note: one kind of InsIns is same # of insertion bases but different kind of ACGT - probabilities.append(( - (i, j) if i <= j else (j, i), - variant_length_probabilities_1[i + VariantLength.index_offset] * - variant_length_probabilities_2[j + VariantLength.index_offset] * extra_probability - )) - return probabilities - - -def homo_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): - return [( - i, - variant_length_probabilities_1[-i + VariantLength.index_offset] * - variant_length_probabilities_2[-i + VariantLength.index_offset] * extra_probability - ) for i in range(1, VariantLength.max + 1)] - - -def hetero_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): - return [( - i, - max( - variant_length_probabilities_1[0 + VariantLength.index_offset] * - variant_length_probabilities_2[-i + VariantLength.index_offset], - variant_length_probabilities_1[-i + VariantLength.index_offset] * - variant_length_probabilities_2[0 + VariantLength.index_offset], - ) - ) for i in range(1, VariantLength.max + 1)] - - -def hetero_DelDel_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): - probabilities = [] - for i in range(1, VariantLength.max + 1): - for j in range(1, VariantLength.max + 1): - if i == j: - continue - probabilities.append(( - (i, j) if i < j else (j, i), - variant_length_probabilities_1[-i + VariantLength.index_offset] * - variant_length_probabilities_2[-j + VariantLength.index_offset] * extra_probability - )) - return probabilities - - -def hetero_InsDel_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2, extra_probability): - probabilities = [] - for i in range(1, VariantLength.max + 1): - for j in range(1, VariantLength.max + 1): - probabilities.append(( - (j, i), - variant_length_probabilities_1[i + VariantLength.index_offset] * - variant_length_probabilities_2[-j + VariantLength.index_offset] * extra_probability - )) - probabilities.append(( - (i, j), - variant_length_probabilities_1[-i + VariantLength.index_offset] * - variant_length_probabilities_2[j + VariantLength.index_offset] * extra_probability - )) - return probabilities - - -def inferred_insertion_bases_from(tensor_input): - insertion_bases = "" - for position in range(flanking_base_number + 1, 2 * flanking_base_number + 1): - reference_tensor = tensor_input[position, :, Channel.reference] - insertion_tensor = np.copy(tensor_input[position, :, Channel.insert]) - for base_index in range(0, 4): - insertion_tensor[base_index] = insertion_tensor[base_index] + insertion_tensor[base_index + 4] - insertion_tensor[base_index + 4] = 0 - insertion_tensor[base_index] -= ( - tensor_input[position, base_index, Channel.SNP] + tensor_input[position, base_index + 4, Channel.SNP] - ) - - if ( - position < (flanking_base_number + minimum_variant_length_that_need_infer) or - sum(insertion_tensor) >= inferred_indel_length_minimum_allele_frequency * sum(reference_tensor) - ): - insertion_bases += num2base[np.argmax(insertion_tensor) % 4] - else: - break - return insertion_bases - - -def inferred_deletion_length_from(tensor_input): - deletion_length = 0 - for position in range(flanking_base_number + 1, 2 * flanking_base_number + 1): - reference_tensor = tensor_input[position, :, Channel.reference] - deletion_tensor = tensor_input[position, :, Channel.delete] - if ( - position < (flanking_base_number + minimum_variant_length_that_need_infer) or - sum(deletion_tensor) >= inferred_indel_length_minimum_allele_frequency * sum(reference_tensor) - ): - deletion_length += 1 - else: - break - return deletion_length - - -def insertion_bases_using_tensor(tensor_input, variant_length): - insertion_bases = "" - for position in range(flanking_base_number + 1, flanking_base_number + variant_length + 1): - insertion_tensor = np.copy(tensor_input[position, :, Channel.insert]) - for base_index in range(0, 4): - insertion_tensor[base_index] = insertion_tensor[base_index] + insertion_tensor[base_index + 4] - insertion_tensor[base_index + 4] = 0 - insertion_tensor[base_index] -= ( - tensor_input[position, base_index, Channel.SNP] + tensor_input[position, base_index + 4, Channel.SNP] - ) - - insertion_bases += num2base[np.argmax(insertion_tensor) % 4] - return insertion_bases - - -def maximum_variant_length_from(variant_length): - if variant_length >= minimum_variant_length_that_need_infer: - return maximum_variant_length_that_need_infer - else: - return variant_length - - -def insertion_bases_from( - tensor_input, - variant_length, - sam_file, - contig, - position, - is_using_pysam_for_all_indel_bases_output -): - """ - Return (insertion_bases, insertion bases length) tuple - """ - if is_using_pysam_for_all_indel_bases_output: - insertion_bases = insertion_bases_using_pysam_from( - sam_file=sam_file, - contig=contig, - position=position, - minimum_insertion_length=variant_length, - maximum_insertion_length=maximum_variant_length_from(variant_length) - ) - return insertion_bases, len(insertion_bases) - - need_inferred_variant_length = variant_length >= minimum_variant_length_that_need_infer - if not need_inferred_variant_length: - insertion_bases = insertion_bases_using_tensor(tensor_input, variant_length) - return insertion_bases, len(insertion_bases) - - insertion_bases = insertion_bases_using_pysam_from( - sam_file=sam_file, - contig=contig, - position=position, - minimum_insertion_length=minimum_variant_length_that_need_infer - ) - insertion_length = len(insertion_bases) - if insertion_length > 0: - return insertion_bases, insertion_length - else: - insertion_bases = inferred_insertion_bases_from(tensor_input) - return insertion_bases, len(insertion_bases) - - -def deletion_bases_from( - tensor_input, - variant_length, - sam_file, - fasta_file, - contig, - position, - reference_sequence, - is_using_pysam_for_all_indel_bases_output -): - """ - Return (deletion_bases, deletion bases length) tuple - """ - if is_using_pysam_for_all_indel_bases_output: - deletion_bases = deletion_bases_using_pysam_from( - sam_file=sam_file, - fasta_file=fasta_file, - contig=contig, - position=position, - minimum_deletion_length=variant_length, - maximum_deletion_length=maximum_variant_length_from(variant_length) - ) - return deletion_bases, len(deletion_bases) - - deletion_bases = "" - need_inferred_variant_length = variant_length >= minimum_variant_length_that_need_infer - if need_inferred_variant_length: - deletion_bases = deletion_bases_using_pysam_from( - sam_file=sam_file, - fasta_file=fasta_file, - contig=contig, - position=position, - minimum_deletion_length=minimum_variant_length_that_need_infer - ) - - have_long_deletion_bases = need_inferred_variant_length and len(deletion_bases) >= flanking_base_number - if not have_long_deletion_bases: - deletion_bases = reference_sequence[flanking_base_number + 1:flanking_base_number + variant_length + 1] - return deletion_bases, len(deletion_bases) - - -def quality_score_from( - reference, - alternate, - genotype_string, - gt21_probabilities, - genotype_probabilities, -): - genotype_1, genotype_2 = int(genotype_string[0]), int(genotype_string[2]) - - gt21 = gt21_enum_from(reference, alternate, genotype_1, genotype_2) - genotype = genotype_enum_for_task(genotype_enum_from(genotype_1, genotype_2)) - - p = gt21_probabilities[gt21] * genotype_probabilities[genotype] - tmp = max( - (-10 * log(e, 10)) * log(((1.0 - p) + 1e-300) / (p + 1e-300)) + 16, - 0 - ) - - return int(round(tmp * tmp)) - - -def possible_outcome_probabilites_from( - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - reference_base, -): - homo_reference_probability = genotype_probabilities[Genotype.homo_reference] - homo_variant_probability = genotype_probabilities[Genotype.homo_variant] - hetero_variant_probability = genotype_probabilities[Genotype.hetero_variant] - variant_length_0_probability = ( - variant_length_probabilities_1[0 + VariantLength.index_offset] * - variant_length_probabilities_2[0 + VariantLength.index_offset] - ) - - reference_gt21 = gt21_enum_from_label(reference_base + reference_base) - homo_Ref_probability = ( - variant_length_0_probability * homo_reference_probability * gt21_probabilities[reference_gt21] - ) - - homo_SNP_probabilities = [( - variant_length_0_probability * homo_variant_probability * gt21_probabilities[gt21] - ) for gt21 in HOMO_SNP_GT21] - hetero_SNP_probabilities = [( - variant_length_0_probability * hetero_variant_probability * gt21_probabilities[gt21] - ) for gt21 in HETERO_SNP_GT21] - - # Insertion - homo_Ins_lengths, homo_Ins_probabilities = zip(*homo_Ins_tuples_from( - variant_length_probabilities_1, variant_length_probabilities_2, - homo_variant_probability * gt21_probabilities[GT21_Type.InsIns] - )) - homo_Ins_lengths, homo_Ins_probabilities = list(homo_Ins_lengths), list(homo_Ins_probabilities) - hetero_InsIns_length_tuples, hetero_InsIns_probabilities = zip(*hetero_InsIns_tuples_from( - variant_length_probabilities_1, variant_length_probabilities_2, - hetero_variant_probability * gt21_probabilities[GT21_Type.InsIns] - )) - hetero_InsIns_length_tuples, hetero_InsIns_probabilities = ( - list(hetero_InsIns_length_tuples), list(hetero_InsIns_probabilities) - ) - hetero_ACGT_Ins_tuples = [] - gt21_base_tuples = [(GT21_Type.AIns, "A"), (GT21_Type.CIns, "C"), (GT21_Type.GIns, "G"), (GT21_Type.TIns, "T")] - for length_tuples, p in hetero_Ins_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): - for gt21, hetero_base in gt21_base_tuples: - hetero_ACGT_Ins_tuples.append(( - hetero_base, - length_tuples, - p * gt21_probabilities[gt21] * hetero_variant_probability - )) - hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = zip(*hetero_ACGT_Ins_tuples) - hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities = ( - list(hetero_ACGT_Ins_bases), list(hetero_ACGT_Ins_lengths), list(hetero_ACGT_Ins_probabilities) - ) - - # Deletion - homo_Del_lengths, homo_Del_probabilities = zip(*homo_Del_tuples_from( - variant_length_probabilities_1, variant_length_probabilities_2, - homo_variant_probability * gt21_probabilities[GT21_Type.DelDel] - )) - homo_Del_lengths, homo_Del_probabilities = list(homo_Del_lengths), list(homo_Del_probabilities) - hetero_DelDel_length_tuples, hetero_DelDel_probabilities = zip(*hetero_DelDel_tuples_from( - variant_length_probabilities_1, variant_length_probabilities_2, - hetero_variant_probability * gt21_probabilities[GT21_Type.DelDel] - )) - hetero_DelDel_length_tuples, hetero_DelDel_probabilities = ( - list(hetero_DelDel_length_tuples), list(hetero_DelDel_probabilities) - ) - hetero_ACGT_Del_tuples = [] - gt21_base_tuples = [(GT21_Type.ADel, "A"), (GT21_Type.CDel, "C"), (GT21_Type.GDel, "G"), (GT21_Type.TDel, "T")] - for length_tuples, p in hetero_Del_tuples_from(variant_length_probabilities_1, variant_length_probabilities_2): - for gt21, hetero_base in gt21_base_tuples: - hetero_ACGT_Del_tuples.append(( - hetero_base, - length_tuples, - p * gt21_probabilities[gt21] * hetero_variant_probability - )) - hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = zip(*hetero_ACGT_Del_tuples) - hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities = ( - list(hetero_ACGT_Del_bases), list(hetero_ACGT_Del_lengths), list(hetero_ACGT_Del_probabilities) - ) - - # InsDel - hetero_InsDel_length_tuples, hetero_InsDel_probabilities = zip(*hetero_InsDel_tuples_from( - variant_length_probabilities_1, variant_length_probabilities_2, - hetero_variant_probability * gt21_probabilities[GT21_Type.InsDel] - )) - hetero_InsDel_length_tuples, hetero_InsDel_probabilities = ( - list(hetero_InsDel_length_tuples), list(hetero_InsDel_probabilities) - ) - - return ( - homo_Ref_probability, - homo_SNP_probabilities, - hetero_SNP_probabilities, - homo_Ins_lengths, homo_Ins_probabilities, - hetero_InsIns_length_tuples, hetero_InsIns_probabilities, - hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities, - homo_Del_lengths, homo_Del_probabilities, - hetero_DelDel_length_tuples, hetero_DelDel_probabilities, - hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities, - hetero_InsDel_length_tuples, hetero_InsDel_probabilities, - ) - - -def output_from( - x, - reference_sequence, - contig, - position, - tensor_position_center, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - output_config, - output_utilities, -): - insertion_bases_using, deletion_bases_using, insertion_bases_using_pysam_using = ( - output_utilities.insertion_bases_using, - output_utilities.deletion_bases_using, - output_utilities.insertion_bases_using_pysam_using, - ) - - reference_base_ACGT = BASE2ACGT[reference_sequence[tensor_position_center]] - ( - homo_Ref_probability, - homo_SNP_probabilities, - hetero_SNP_probabilities, - homo_Ins_lengths, homo_Ins_probabilities, - hetero_InsIns_length_tuples, hetero_InsIns_probabilities, - hetero_ACGT_Ins_bases, hetero_ACGT_Ins_lengths, hetero_ACGT_Ins_probabilities, - homo_Del_lengths, homo_Del_probabilities, - hetero_DelDel_length_tuples, hetero_DelDel_probabilities, - hetero_ACGT_Del_bases, hetero_ACGT_Del_lengths, hetero_ACGT_Del_probabilities, - hetero_InsDel_length_tuples, hetero_InsDel_probabilities, - ) = possible_outcome_probabilites_from( - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - reference_base=reference_base_ACGT, - ) - - reference_base, alternate_base = None, None - while reference_base is None or alternate_base is None: - maximum_probability = max( - homo_Ref_probability, - max(homo_SNP_probabilities), - max(hetero_SNP_probabilities), - max(homo_Ins_probabilities) if len(homo_Ins_probabilities) else 0, - max(homo_Del_probabilities) if len(homo_Del_probabilities) else 0, - max(hetero_ACGT_Ins_probabilities) if len(hetero_ACGT_Ins_probabilities) else 0, - max(hetero_InsIns_probabilities) if len(hetero_InsIns_probabilities) else 0, - max(hetero_ACGT_Del_probabilities) if len(hetero_ACGT_Del_probabilities) else 0, - max(hetero_DelDel_probabilities) if len(hetero_DelDel_probabilities) else 0, - max(hetero_InsDel_probabilities) if len(hetero_InsDel_probabilities) else 0, - ) - - is_reference = maximum_probability == homo_Ref_probability - if is_reference: - return ( - (True, False, False, False, False, False, False, False, False, False), - (reference_base_ACGT, reference_base_ACGT) - ) - - is_homo_SNP = maximum_probability in homo_SNP_probabilities - is_hetero_SNP = maximum_probability in hetero_SNP_probabilities - is_homo_insertion = maximum_probability in homo_Ins_probabilities - is_hetero_ACGT_Ins = maximum_probability in hetero_ACGT_Ins_probabilities - is_hetero_InsIns = maximum_probability in hetero_InsIns_probabilities - is_homo_deletion = maximum_probability in homo_Del_probabilities - is_hetero_ACGT_Del = maximum_probability in hetero_ACGT_Del_probabilities - is_hetero_DelDel = maximum_probability in hetero_DelDel_probabilities - is_insertion_and_deletion = maximum_probability in hetero_InsDel_probabilities - - if is_homo_SNP: - base1, base2 = homo_SNP_bases_from(gt21_probabilities) - reference_base = reference_sequence[tensor_position_center] - alternate_base = base1 if base1 != reference_base else base2 - - elif is_hetero_SNP: - base1, base2 = hetero_SNP_bases_from(gt21_probabilities) - reference_base = reference_sequence[tensor_position_center] - is_multi = base1 != reference_base and base2 != reference_base - if is_multi: - alternate_base = "{},{}".format(base1, base2) - else: - alternate_base = base1 if base1 != reference_base else base2 - - elif is_homo_insertion: - idx = homo_Ins_probabilities.index(maximum_probability) - variant_length = homo_Ins_lengths[idx] - del homo_Ins_probabilities[idx] - del homo_Ins_lengths[idx] - - insertion_bases, insertion_length = insertion_bases_using( - tensor_input=x, variant_length=variant_length, contig=contig, position=position - ) - if insertion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] - alternate_base = reference_base + insertion_bases - - elif is_hetero_ACGT_Ins: - idx = hetero_ACGT_Ins_probabilities.index(maximum_probability) - variant_length = hetero_ACGT_Ins_lengths[idx] - hetero_Ins_base = hetero_ACGT_Ins_bases[idx] - del hetero_ACGT_Ins_probabilities[idx] - del hetero_ACGT_Ins_lengths[idx] - del hetero_ACGT_Ins_bases[idx] - - insertion_bases, insertion_length = insertion_bases_using( - tensor_input=x, variant_length=variant_length, contig=contig, position=position - ) - if insertion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] - alternate_base = reference_base + insertion_bases - - is_SNP_Ins_multi = hetero_Ins_base != reference_base - if is_SNP_Ins_multi: - alternate_base = "{},{}".format(hetero_Ins_base, alternate_base) - - elif is_hetero_InsIns: - idx = hetero_InsIns_probabilities.index(maximum_probability) - variant_length_1, variant_length_2 = hetero_InsIns_length_tuples[idx] - del hetero_InsIns_probabilities[idx] - del hetero_InsIns_length_tuples[idx] - - insertion_bases, insertion_length = insertion_bases_using( - tensor_input=x, variant_length=variant_length_2, contig=contig, position=position - ) - if insertion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] - alternate_base = reference_base + insertion_bases - - another_insertion_bases = ( - insertion_bases_using_pysam_using( - contig=contig, - position=position, - minimum_insertion_length=variant_length_1, - maximum_insertion_length=maximum_variant_length_from(variant_length_1), - insertion_bases_to_ignore=insertion_bases - ) or - insertion_bases[0:variant_length_1] - ) - alternate_base_1 = reference_base + another_insertion_bases - alternate_base_2 = alternate_base - if alternate_base_1 != alternate_base_2: - alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) - else: - reference_base, alternate_base = None, None - - elif is_homo_deletion: - idx = homo_Del_probabilities.index(maximum_probability) - variant_length = homo_Del_lengths[idx] - del homo_Del_probabilities[idx] - del homo_Del_lengths[idx] - - deletion_bases, deletion_length = deletion_bases_using( - tensor_input=x, - variant_length=variant_length, - contig=contig, - position=position, - reference_sequence=reference_sequence, - ) - if deletion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] + deletion_bases - alternate_base = reference_base[0] - - elif is_hetero_ACGT_Del: - idx = hetero_ACGT_Del_probabilities.index(maximum_probability) - variant_length = hetero_ACGT_Del_lengths[idx] - hetero_Del_base = hetero_ACGT_Del_bases[idx] - del hetero_ACGT_Del_probabilities[idx] - del hetero_ACGT_Del_lengths[idx] - del hetero_ACGT_Del_bases[idx] - - deletion_bases, deletion_length = deletion_bases_using( - tensor_input=x, - variant_length=variant_length, - contig=contig, - position=position, - reference_sequence=reference_sequence, - ) - if deletion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] + deletion_bases - alternate_base = reference_base[0] - - is_SNP_Del_multi = hetero_Del_base != reference_base[0] - if is_SNP_Del_multi: - alternate_base_1 = alternate_base - alternate_base_2 = hetero_Del_base + reference_base[1:] - alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) - - elif is_hetero_DelDel: - idx = hetero_DelDel_probabilities.index(maximum_probability) - variant_length_1, variant_length_2 = hetero_DelDel_length_tuples[idx] - del hetero_DelDel_probabilities[idx] - del hetero_DelDel_length_tuples[idx] - - deletion_bases, deletion_length = deletion_bases_using( - tensor_input=x, - variant_length=variant_length_2, - contig=contig, - position=position, - reference_sequence=reference_sequence, - ) - if deletion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] + deletion_bases - alternate_base = reference_base[0] - - alternate_base_1 = alternate_base - alternate_base_2 = reference_base[0] + reference_base[variant_length_1 + 1:] - if ( - alternate_base_1 != alternate_base_2 and - reference_base != alternate_base_1 and reference_base != alternate_base_2 - ): - alternate_base = "{},{}".format(alternate_base_1, alternate_base_2) - else: - reference_base, alternate_base = None, None - - elif is_insertion_and_deletion: - idx = hetero_InsDel_probabilities.index(maximum_probability) - variant_length_1, variant_length_2 = hetero_InsDel_length_tuples[idx] - del hetero_InsDel_probabilities[idx] - del hetero_InsDel_length_tuples[idx] - - insertion_bases, insertion_length = insertion_bases_using( - tensor_input=x, variant_length=variant_length_2, contig=contig, position=position - ) - deletion_bases, deletion_length = deletion_bases_using( - tensor_input=x, - variant_length=variant_length_1, - contig=contig, - position=position, - reference_sequence=reference_sequence, - ) - if insertion_length == 0 or deletion_length == 0: - continue - reference_base = reference_sequence[tensor_position_center] + deletion_bases - alternate_base = "{},{}".format( - reference_base[0], - reference_base[0] + insertion_bases + reference_base[1:] - ) - - return ( - ( - is_reference, is_homo_SNP, is_hetero_SNP, - is_homo_insertion, is_hetero_ACGT_Ins, is_hetero_InsIns, - is_homo_deletion, is_hetero_ACGT_Del, is_hetero_DelDel, - is_insertion_and_deletion - ), - (reference_base, alternate_base) - ) - - -def batch_output_for_ensemble(mini_batch, batch_Y, output_config, output_utilities): - X, batch_chr_pos_seq = mini_batch - batch_size = len(batch_chr_pos_seq) - - batch_gt21_probabilities, batch_genotype_probabilities, \ - batch_variant_length_probabilities_1, batch_variant_length_probabilities_2 = batch_Y - - if len(batch_gt21_probabilities) != batch_size: - sys.exit( - "Inconsistent shape between input tensor and output predictions %d/%d" % - (batch_size, len(batch_gt21_probabilities)) - ) - - tensor_position_center = flanking_base_number - - for ( - x, - chr_pos_seq, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2 - ) in zip( - X, - batch_chr_pos_seq, - batch_gt21_probabilities, - batch_genotype_probabilities, - batch_variant_length_probabilities_1, - batch_variant_length_probabilities_2 - ): - chromosome, position, reference_sequence = chr_pos_seq - - if reference_sequence[tensor_position_center] not in BASIC_BASES: - continue - - tensor = x.flatten().astype(int).astype(str) - - output_utilities.output( - "\t".join( - [ - chromosome, - position, - reference_sequence, - ] + - list(tensor) + - ["{:0.6f}".format(p) for p in list(gt21_probabilities)] + - ["{:0.6f}".format(p) for p in list(genotype_probabilities)] + - ["{:0.6f}".format(p) for p in list(variant_length_probabilities_1)] + - ["{:0.6f}".format(p) for p in list(variant_length_probabilities_2)] - ) - ) - -def output_with( - x, - chr_pos_seq, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - output_config, - output_utilities -): - chromosome, position, reference_sequence = chr_pos_seq - position = int(position) - - tensor_position_center = flanking_base_number - information_string = "." - - if reference_sequence[tensor_position_center] not in BASIC_BASES: - return - - # read depth - read_depth = sum( - x[tensor_position_center, :, Channel.delete] + x[tensor_position_center, :, Channel.reference] - ) - if read_depth == 0: - output_utilities.print_debug_message( - chromosome, - position, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - "Read Depth is zero" - ) - return - - ( - is_reference, is_homo_SNP, is_hetero_SNP, - is_homo_insertion, is_hetero_ACGT_Ins, is_hetero_InsIns, - is_homo_deletion, is_hetero_ACGT_Del, is_hetero_DelDel, - is_insertion_and_deletion - ), (reference_base, alternate_base) = output_from( - x, - reference_sequence, - chromosome, - position, - tensor_position_center, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - output_config, - output_utilities, - ) - - if not output_config.is_debug and ( - (not output_config.is_show_reference and is_reference) or - (not is_reference and reference_base == alternate_base) - ): - return - - if reference_base is None or alternate_base is None: - output_utilities.print_debug_message( - chromosome, - position, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - "no reference base / alternate base prediction" - ) - return - - is_multi = "," in str(alternate_base) - - # haploid (precision mode) - if output_config.is_haploid_precision_mode_enabled and ( - is_hetero_SNP or is_hetero_ACGT_Ins or is_hetero_InsIns or - is_hetero_ACGT_Del or is_hetero_DelDel or is_insertion_and_deletion - ): - return - # haploid (sensitive mode) - elif output_config.is_haploid_sensitive_mode_enabled and is_multi: - return - - # geno type string - if is_reference: - genotype_string = genotype_string_from(Genotype.homo_reference) - elif is_homo_SNP or is_homo_insertion or is_homo_deletion: - genotype_string = genotype_string_from(Genotype.homo_variant) - elif is_hetero_SNP or is_hetero_ACGT_Ins or is_hetero_InsIns or is_hetero_ACGT_Del or is_hetero_DelDel: - genotype_string = genotype_string_from(Genotype.hetero_variant) - if is_multi: - genotype_string = genotype_string_from(Genotype.hetero_variant_multi) - - # allele frequency / supported reads - supported_reads_count = 0 - if is_reference: - supported_reads_count = ( - x[tensor_position_center, BASE2NUM[reference_base], Channel.reference] + - x[tensor_position_center, BASE2NUM[reference_base]+4, Channel.reference] - ) - elif is_homo_SNP or is_hetero_SNP: - for base in str(alternate_base): - if base == ',': - continue - supported_reads_count += ( - x[tensor_position_center, BASE2NUM[base], Channel.SNP] + - x[tensor_position_center, BASE2NUM[base]+4, Channel.SNP] + - x[tensor_position_center, BASE2NUM[base], Channel.reference] + - x[tensor_position_center, BASE2NUM[base]+4, Channel.reference] - ) - elif is_homo_insertion or is_hetero_InsIns: - supported_reads_count = ( - sum(x[tensor_position_center+1, :, Channel.insert]) - - sum(x[tensor_position_center+1, :, Channel.SNP]) - ) - elif is_hetero_ACGT_Ins: - is_SNP_Ins_multi = is_multi - SNP_base = alternate_base.split(",")[0][0] if is_SNP_Ins_multi else None - supported_reads_for_SNP = ( - x[tensor_position_center, BASE2NUM[SNP_base], Channel.SNP] + - x[tensor_position_center, BASE2NUM[SNP_base]+4, Channel.SNP] + - x[tensor_position_center, BASE2NUM[SNP_base], Channel.reference] + - x[tensor_position_center, BASE2NUM[SNP_base]+4, Channel.reference] - ) if is_SNP_Ins_multi else 0 - - supported_reads_count = ( - sum(x[tensor_position_center+1, :, Channel.insert]) - - sum(x[tensor_position_center+1, :, Channel.SNP]) - ) + supported_reads_for_SNP - elif is_homo_deletion or is_hetero_DelDel: - supported_reads_count = sum(x[tensor_position_center+1, :, Channel.delete]) - elif is_hetero_ACGT_Del: - is_SNP_Del_multi = is_multi - SNP_base = alternate_base.split(",")[1][0] if is_SNP_Del_multi else None - supported_reads_for_SNP = ( - x[tensor_position_center, BASE2NUM[SNP_base], Channel.SNP] + - x[tensor_position_center, BASE2NUM[SNP_base]+4, Channel.SNP] + - x[tensor_position_center, BASE2NUM[SNP_base], Channel.reference] + - x[tensor_position_center, BASE2NUM[SNP_base]+4, Channel.reference] - ) if is_SNP_Del_multi else 0 - - supported_reads_count = sum(x[tensor_position_center+1, :, Channel.delete]) + supported_reads_for_SNP - elif is_insertion_and_deletion: - supported_reads_count = ( - sum(x[tensor_position_center+1, :, Channel.insert]) + - sum(x[tensor_position_center+1, :, Channel.delete]) - - sum(x[tensor_position_center+1, :, Channel.SNP]) - ) - allele_frequency = ((supported_reads_count + 0.0) / read_depth) if read_depth != 0 else 0.0 - if allele_frequency > 1: - allele_frequency = 1 - - # quality score - quality_score = quality_score_from( - reference_base, - alternate_base, - genotype_string, - gt21_probabilities, - genotype_probabilities, - ) - - # replace genotype string if any haploid mode enabled - if output_config.is_haploid_precision_mode_enabled or output_config.is_haploid_sensitive_mode_enabled: - genotype_string = "1" if "1" in genotype_string else "0" - - # filtration value - filtration_value = filtration_value_from( - quality_score_for_pass=output_config.quality_score_for_pass, quality_score=quality_score - ) - - if output_config.is_debug: - output_utilities.print_debug_message( - chromosome, - position, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - "Normal output" if not is_reference else "Reference" - ) - else: - output_utilities.output("%s\t%d\t.\t%s\t%s\t%d\t%s\t%s\tGT:GQ:DP:AF\t%s:%d:%d:%.4f" % ( - chromosome, - position, - reference_base, - alternate_base, - quality_score, - filtration_value, - information_string, - genotype_string, - quality_score, - read_depth, - allele_frequency - )) - - -def batch_output(mini_batch, batch_Y, output_config, output_utilities): - X, batch_chr_pos_seq = mini_batch - batch_size = len(batch_chr_pos_seq) - - batch_gt21_probabilities, batch_genotype_probabilities, \ - batch_variant_length_probabilities_1, batch_variant_length_probabilities_2 = batch_Y - - if len(batch_gt21_probabilities) != batch_size: - sys.exit( - "Inconsistent shape between input tensor and output predictions %d/%d" % - (batch_size, len(batch_gt21_probabilities)) - ) - - for ( - x, - chr_pos_seq, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2 - ) in zip( - X, - batch_chr_pos_seq, - batch_gt21_probabilities, - batch_genotype_probabilities, - batch_variant_length_probabilities_1, - batch_variant_length_probabilities_2 - ): - output_with( - x, - chr_pos_seq, - gt21_probabilities, - genotype_probabilities, - variant_length_probabilities_1, - variant_length_probabilities_2, - output_config, - output_utilities, - ) - - -def log_activation(args, m): - if args.log_path is None: - return - - summary_writer = m.get_summary_file_writer(args.log_path) - if summary_writer is None: - return - - tensor_generator = utils.tensor_generator_from(args.tensor_fn, param.predictBatchSize) - logging.info("Plotting activations ...") - - num_plotted = 0 - while num_plotted < args.max_plot or args.max_plot < 0: - print("Getting next batch") - try: - batch_X, batch_chr_pos_seq = next(tensor_generator) - except StopIteration: - break - batch_size = len(batch_chr_pos_seq) - print("Batch generation complete %d" % batch_size) - # strip away the reference string, keeping the chr and coor only - batch_chr_pos_seq = [chr+":"+pos for chr, pos, _ in batch_chr_pos_seq] - summaries = m.get_activation_summary( - batch_X, - operations=m.layers, - batch_item_suffixes=batch_chr_pos_seq, - max_plot_in_batch=args.max_plot - num_plotted if args.max_plot >= 0 else batch_size, - parallel_level=args.parallel_level, - num_workers=args.workers, - fast_plotting=args.fast_plotting - ) - for summary in summaries: - summary_writer.add_summary(summary) - num_plotted += min(batch_size, args.max_plot - num_plotted if args.max_plot >= 0 else batch_size) - print("Finished plotting %d" % num_plotted) - - -def call_variants_with_probabilities_input(args, output_config, output_utilities): - output_utilities.output_header() - logging.info("Output variants ...") - variant_call_start_time = time() - - tensor_dimensions = (2*param.flankingBaseNum+1, param.matrixRow, param.matrixNum) - no_of_tensor_values = tensor_dimensions[0] * tensor_dimensions[1] * tensor_dimensions[2] - - for row in sys.stdin: - columns = row.split("\t") - - chromosome = columns[0] - position = columns[1] - sequence = columns[2] - x = np.reshape(np.array(columns[3:3 + no_of_tensor_values], dtype=np.float32), tensor_dimensions) - probabilities = np.array(columns[3+no_of_tensor_values:], dtype=np.float32) - gt21_probabilities = probabilities[0:21] - genotype_probabilities = probabilities[21:21+3] - variant_length_1_probabilities = probabilities[21+3:21+3+tensor_dimensions[0]] - variant_length_2_probabilities = probabilities[21+3+tensor_dimensions[0]:] - - output_with( - x, - (chromosome, position, sequence), - gt21_probabilities, - genotype_probabilities, - variant_length_1_probabilities, - variant_length_2_probabilities, - output_config, - output_utilities, - ) - - logging.info("Total time elapsed: %.2f s" % (time() - variant_call_start_time)) - output_utilities.close_opened_files() - - -def call_variants(args, m, output_config, output_utilities): - output_utilities.output_header() - - tensor_generator = utils.tensor_generator_from(args.tensor_fn, param.predictBatchSize) - logging.info("Calling variants ...") - variant_call_start_time = time() - - is_finish_loaded_all_mini_batches = False - batch_output_method = batch_output_for_ensemble if output_config.is_output_for_ensemble else batch_output - mini_batches_loaded = [] - mini_batches_to_predict = [] - mini_batches_to_output = [] - - def load_mini_batch(): - try: - mini_batches_loaded.append(next(tensor_generator)) - except StopIteration: - return - - while True: - thread_pool = [] - - if len(mini_batches_to_output) > 0: - mini_batch = mini_batches_to_output.pop(0) - thread_pool.append(Thread( - target=batch_output_method, args=(mini_batch, m.prediction, output_config, output_utilities) - )) - - if len(mini_batches_to_predict) > 0: - mini_batch = mini_batches_to_predict.pop(0) - X, _ = mini_batch - thread_pool.append(Thread(target=m.predict, kwargs={"batchX":X})) - mini_batches_to_output.append(mini_batch) - - if not is_finish_loaded_all_mini_batches: - thread_pool.append(Thread(target=load_mini_batch)) - - for t in thread_pool: - t.start() - for t in thread_pool: - t.join() - - is_finish_loaded_all_mini_batches = len(mini_batches_loaded) == 0 - while len(mini_batches_loaded) > 0: - mini_batch = mini_batches_loaded.pop(0) - mini_batches_to_predict.append(mini_batch) - - is_nothing_to_predict_and_output = ( - len(thread_pool) <= 0 and len(mini_batches_to_predict) <= 0 and len(mini_batches_to_output) <= 0 - ) - if is_finish_loaded_all_mini_batches and is_nothing_to_predict_and_output: - break - - logging.info("Total time elapsed: %.2f s" % (time() - variant_call_start_time)) - - output_utilities.close_opened_files() - - -def main(): - parser = ArgumentParser(description="Call variants using a trained model and tensors of candididate variants") - - parser.add_argument('--tensor_fn', type=str, default="PIPE", - help="Tensor input, use PIPE for standard input") - - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing") - - parser.add_argument('--call_fn', type=str, default=None, - help="Output variant predictions") - - parser.add_argument('--bam_fn', type=str, default="bam.bam", - help="BAM file input, default: %(default)s") - - parser.add_argument('--qual', type=int, default=None, - help="If set, variant with equal or higher quality will be marked PASS, or LowQual otherwise, optional") - - parser.add_argument('--sampleName', type=str, default="SAMPLE", - help="Define the sample name to be shown in the VCF file") - - parser.add_argument('--showRef', action='store_true', - help="Show reference calls, optional") - - parser.add_argument('--debug', action='store_true', - help="Debug mode, optional") - - parser.add_argument('--ref_fn', type=str, default=None, - help="Reference fasta file input, optional, print contig tags in the VCF header if set") - - parser.add_argument('--threads', type=int, default=None, - help="Number of threads, optional") - - parser.add_argument('--activation_only', action='store_true', - help="Output activation only, no prediction") - parser.add_argument('--max_plot', type=int, default=10, - help="The maximum number of plots output, negative number means no limit (plot all), default: %(default)s") - parser.add_argument('--log_path', type=str, nargs='?', default=None, - help="The path for tensorflow logging, default: %(default)s") - parser.add_argument('-p', '--parallel_level', type=int, default=2, - help="The level of parallelism in plotting (currently available: 0, 2), default: %(default)s") - parser.add_argument('--fast_plotting', action='store_true', - help="Enable fast plotting.") - parser.add_argument('-w', '--workers', type=int, default=8, - help="The number of workers in plotting, default: %(default)s") - - parser.add_argument('--pysam_for_all_indel_bases', action='store_true', - help="Always using pysam for outputting indel bases, optional") - - parser.add_argument('--haploid_precision', action='store_true', - help="call haploid instead of diploid (output homo-variant only)") - parser.add_argument('--haploid_sensitive', action='store_true', - help="call haploid instead of diploid (output non-multi-variant only)") - - parser.add_argument('--input_probabilities', action='store_true', - help="Accept probabilities as input, using those probabilities to call variant") - parser.add_argument('--output_for_ensemble', action='store_true', - help="Output for ensemble") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - Run(args) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/evaluate.py b/benchmarks/nn-variant/clair/evaluate.py deleted file mode 100644 index ff0b909..0000000 --- a/benchmarks/nn-variant/clair/evaluate.py +++ /dev/null @@ -1,218 +0,0 @@ -import sys -import logging -import numpy as np -from os.path import abspath -from argparse import ArgumentParser -from time import time - - -from clair.model import Clair -import clair.utils as utils -from clair.task.main import GT21, GENOTYPE, VARIANT_LENGTH_1, VARIANT_LENGTH_2 -import shared.param as param - - -logging.basicConfig(format='%(message)s', level=logging.INFO) - - -def f1_score(confusion_matrix): - column_sum = confusion_matrix.sum(axis=0) - row_sum = confusion_matrix.sum(axis=1) - - f1_score_array = np.array([]) - matrix_size = confusion_matrix.shape[0] - epsilon = 1e-15 - for i in range(matrix_size): - TP = confusion_matrix[i][i] + 0.0 - precision = TP / (column_sum[i] + epsilon) - recall = TP / (row_sum[i] + epsilon) - f1_score_array = np.append(f1_score_array, (2.0 * precision * recall) / (precision + recall + epsilon)) - - return f1_score_array - - -def new_confusion_matrix_with_dimension(size): - return np.zeros((size, size), dtype=np.int) - - -def evaluate_model(m, dataset_info): - dataset_size = dataset_info.dataset_size - x_array_compressed = dataset_info.x_array_compressed - y_array_compressed = dataset_info.y_array_compressed - - logging.info("[INFO] Testing on the training and validation dataset ...") - prediction_start_time = time() - prediction_batch_size = param.predictBatchSize - - no_of_training_examples = ( - dataset_info.no_of_training_examples_from_train_binary or int(dataset_size * param.trainingDatasetPercentage) - ) - no_of_blosc_blocks = utils.no_of_blosc_blocks_from( - dataset_info=dataset_info, - no_of_training_examples=no_of_training_examples, - blosc_block_size=param.bloscBlockSize - ) - - blosc_index = 0 - first_blosc_block_data_index = 0 - - confusion_matrix_gt21 = new_confusion_matrix_with_dimension(GT21.output_label_count) - confusion_matrix_genotype = new_confusion_matrix_with_dimension(GENOTYPE.output_label_count) - confusion_matrix_indel_length_1 = new_confusion_matrix_with_dimension(VARIANT_LENGTH_1.output_label_count) - confusion_matrix_indel_length_2 = new_confusion_matrix_with_dimension(VARIANT_LENGTH_2.output_label_count) - - all_gt21_count = top_1_count = top_2_count = 0 - - while True: - x_batch, next_x_first_blosc_block_data_index, next_x_blosc_index = utils.decompress_array( - array=x_array_compressed, - blosc_start_index=blosc_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_data_rows_to_retrieve=prediction_batch_size, - no_of_blosc_blocks=no_of_blosc_blocks, - ) - y_batch, _next_y_first_blosc_block_data_index, _next_y_blosc_index = utils.decompress_array( - array=y_array_compressed, - blosc_start_index=blosc_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_data_rows_to_retrieve=prediction_batch_size, - no_of_blosc_blocks=no_of_blosc_blocks, - ) - minibatch_gt21_prediction, minibatch_genotype_prediction, \ - minibatch_indel_length_prediction_1, minibatch_indel_length_prediction_2 = m.predict(x_batch) - - blosc_index = next_x_blosc_index - first_blosc_block_data_index = next_x_first_blosc_block_data_index - - # update confusion matrix for gt21 prediction - for gt21_prediction, gt21_label in zip( - minibatch_gt21_prediction, - y_batch[:, GT21.y_start_index:GT21.y_end_index] - ): - true_label_index = np.argmax(gt21_label) - predict_label_index = np.argmax(gt21_prediction) - confusion_matrix_gt21[true_label_index][predict_label_index] += 1 - - all_gt21_count += 1 - indexes_with_sorted_prediction_probability = gt21_prediction.argsort()[::-1] - if true_label_index == indexes_with_sorted_prediction_probability[0]: - top_1_count += 1 - top_2_count += 1 - elif true_label_index == indexes_with_sorted_prediction_probability[1]: - top_2_count += 1 - - # update confusion matrix for genotype - for genotype_prediction, true_genotype_label in zip( - minibatch_genotype_prediction, - y_batch[:, GENOTYPE.y_start_index:GENOTYPE.y_end_index] - ): - confusion_matrix_genotype[np.argmax(true_genotype_label)][np.argmax(genotype_prediction)] += 1 - - # update confusion matrix for indel length 1 and 2 - for indel_length_prediction_1, true_indel_length_label_1, indel_length_prediction_2, true_indel_length_label_2 in zip( - minibatch_indel_length_prediction_1, - y_batch[:, VARIANT_LENGTH_1.y_start_index:VARIANT_LENGTH_1.y_end_index], - minibatch_indel_length_prediction_2, - y_batch[:, VARIANT_LENGTH_2.y_start_index:VARIANT_LENGTH_2.y_end_index] - ): - true_label_index_1 = np.argmax(true_indel_length_label_1) - true_label_index_2 = np.argmax(true_indel_length_label_2) - predict_label_index_1 = np.argmax(indel_length_prediction_1) - predict_label_index_2 = np.argmax(indel_length_prediction_2) - - if true_label_index_1 > true_label_index_2: - true_label_index_1, true_label_index_2 = true_label_index_2, true_label_index_1 - if predict_label_index_1 > predict_label_index_2: - predict_label_index_1, predict_label_index_2 = predict_label_index_2, predict_label_index_1 - - confusion_matrix_indel_length_1[true_label_index_1][predict_label_index_1] += 1 - confusion_matrix_indel_length_2[true_label_index_2][predict_label_index_2] += 1 - - if not (next_x_first_blosc_block_data_index >= 0 and next_x_blosc_index >= 0): - break - - logging.info("[INFO] Prediciton time elapsed: %.2f s" % (time() - prediction_start_time)) - - print("[INFO] Evaluation on gt21:") - print("[INFO] all/top1/top2/top1p/top2p: %d/%d/%d/%.2f/%.2f" % - (all_gt21_count, top_1_count, top_2_count, - float(top_1_count)/all_gt21_count*100, float(top_2_count)/all_gt21_count*100)) - for i in range(GT21.output_label_count): - print("\t".join([str(confusion_matrix_gt21[i][j]) for j in range(GT21.output_label_count)])) - gt21_f_measure = f1_score(confusion_matrix_gt21) - print("[INFO] f-measure: ", gt21_f_measure) - - print("\n[INFO] Evaluation on Genotype:") - for i in range(GENOTYPE.output_label_count): - print("\t".join([str(confusion_matrix_genotype[i][j]) for j in range(GENOTYPE.output_label_count)])) - genotype_f_measure = f1_score(confusion_matrix_genotype) - print("[INFO] f-measure: ", genotype_f_measure) - - print("\n[INFO] evaluation on indel length 1:") - for i in range(VARIANT_LENGTH_1.output_label_count): - print("\t".join([str(confusion_matrix_indel_length_1[i][j]) - for j in range(VARIANT_LENGTH_1.output_label_count)])) - indel_length_f_measure_1 = f1_score(confusion_matrix_indel_length_1) - print("[INFO] f-measure: ", indel_length_f_measure_1) - - print("\n[INFO] evaluation on indel length 2:") - for i in range(VARIANT_LENGTH_2.output_label_count): - print("\t".join([str(confusion_matrix_indel_length_2[i][j]) - for j in range(VARIANT_LENGTH_2.output_label_count)])) - indel_length_f_measure_2 = f1_score(confusion_matrix_indel_length_2) - print("[INFO] f-measure: ", indel_length_f_measure_2) - - -def main(): - parser = ArgumentParser(description="Evaluate trained model") - - parser.add_argument('--bin_fn', type=str, default=None, - help="Binary tensor input generated by tensor2Bin.py, tensor_fn, var_fn and bed_fn will be ignored") - parser.add_argument('--train_bin_fn', type=str, default=None, - help="Train Binary, used together with --validation_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - parser.add_argument('--validation_bin_fn', type=str, default=None, - help="Validation Binary, used together with --train_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - - parser.add_argument('--tensor_fn', type=str, default="vartensors", - help="Tensor input") - - parser.add_argument('--var_fn', type=str, default="truthvars", - help="Truth variants list input") - - parser.add_argument('--bed_fn', type=str, default=None, - help="High confident genome regions input in the BED format") - - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing, REQUIRED") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - # initialize - logging.info("[INFO] Loading model ...") - utils.setup_environment() - - m = Clair() - m.init() - - dataset_info = utils.dataset_info_from( - binary_file_path=args.bin_fn, - tensor_file_path=args.tensor_fn, - variant_file_path=args.var_fn, - bed_file_path=args.bed_fn, - train_binary_file_path=args.train_bin_fn, - validation_binary_file_path=args.validation_bin_fn, - ) - - model_initalization_file_path = args.chkpnt_fn - m.restore_parameters(abspath(model_initalization_file_path)) - - # start evaluation - evaluate_model(m, dataset_info) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/learning_rate_finder.py b/benchmarks/nn-variant/clair/learning_rate_finder.py deleted file mode 100644 index fed4a17..0000000 --- a/benchmarks/nn-variant/clair/learning_rate_finder.py +++ /dev/null @@ -1,348 +0,0 @@ -import sys -import logging -import random -import numpy as np -import pandas as pd -from os.path import abspath -from time import time -from argparse import ArgumentParser -from threading import Thread - - -import clair.evaluate as evaluate -from clair.model import Clair -import clair.utils as utils -from clair.task.main import GT21, GENOTYPE, VARIANT_LENGTH_1, VARIANT_LENGTH_2 -import shared.param as param - -logging.basicConfig(format='%(message)s', level=logging.INFO) - - -def accuracy(y_pred, y_true): - gt21, genotype, indel_length_1, indel_length_2 = y_pred - batch_size = len(gt21) + 0.0 - gt21_TP = 0 - genotype_TP = 0 - indel1_TP = 0 - indel2_TP = 0 - - for gt21_prediction, gt21_true_label in zip( - gt21, - y_true[:, GT21.y_start_index:GT21.y_end_index] - ): - true_label_index = np.argmax(gt21_true_label) - predict_label_index = np.argmax(gt21_prediction) - if true_label_index == predict_label_index: - gt21_TP += 1 - - for genotype_prediction, true_genotype_label in zip( - genotype, - y_true[:, GENOTYPE.y_start_index:GENOTYPE.y_end_index] - ): - true_label_index = np.argmax(true_genotype_label) - predict_label_index = np.argmax(genotype_prediction) - if true_label_index == predict_label_index: - genotype_TP += 1 - - for indel_length_prediction_1, true_indel_length_label_1, indel_length_prediction_2, true_indel_length_label_2 in zip( - indel_length_1, - y_true[:, VARIANT_LENGTH_1.y_start_index:VARIANT_LENGTH_1.y_end_index], - indel_length_2, - y_true[:, VARIANT_LENGTH_2.y_start_index:VARIANT_LENGTH_2.y_end_index] - ): - true_label_index_1 = np.argmax(true_indel_length_label_1) - true_label_index_2 = np.argmax(true_indel_length_label_2) - predict_label_index_1 = np.argmax(indel_length_prediction_1) - predict_label_index_2 = np.argmax(indel_length_prediction_2) - - if true_label_index_1 > true_label_index_2: - true_label_index_1, true_label_index_2 = true_label_index_2, true_label_index_1 - if predict_label_index_1 > predict_label_index_2: - predict_label_index_1, predict_label_index_2 = predict_label_index_2, predict_label_index_1 - - if true_label_index_1 == predict_label_index_1: - indel1_TP += 1 - if true_label_index_2 == predict_label_index_2: - indel2_TP += 1 - - gt21_acc = gt21_TP / batch_size - genotype_acc = genotype_TP / batch_size - indel1_acc = indel1_TP / batch_size - indel2_acc = indel2_TP / batch_size - acc = (gt21_acc + genotype_acc + indel1_acc + indel2_acc) / 4 - return acc - - -def lr_finder(lr_accuracy): - df = pd.DataFrame(lr_accuracy, columns=["lr", "accuracy", "loss"]) - df['diff'] = df['accuracy'].diff() - df = df.dropna().reset_index(drop=True) - minimum_lr = df[df['diff'] == max(df['diff'])]['lr'].sort_values(ascending=False).item() - maximum_lr = df[df['diff'] == min(df['diff'])]['lr'].sort_values(ascending=True).item() - if minimum_lr > maximum_lr: - minimum_lr, maximum_lr = maximum_lr, minimum_lr - return minimum_lr, maximum_lr, df - - -logging.basicConfig(format='%(message)s', level=logging.INFO) - - -def shuffle_first_n_items(array, n): - if len(array) <= n: - np.random.shuffle(array) - return array - # pylint: disable=unbalanced-tuple-unpacking - a1, a2 = np.split(array, [n]) - np.random.shuffle(a1) - return np.append(a1, a2) - - -def train_model(m, training_config): - learning_rate = param.min_lr - l2_regularization_lambda = training_config.l2_regularization_lambda - output_file_path_prefix = training_config.output_file_path_prefix - summary_writer = training_config.summary_writer - model_initalization_file_path = training_config.model_initalization_file_path - - dataset_info = training_config.dataset_info - dataset_size = dataset_info.dataset_size - - training_losses = [] - validation_losses = [] - lr_accuracy = [] - - if model_initalization_file_path is not None: - m.restore_parameters(abspath(model_initalization_file_path)) - - logging.info("[INFO] Start training...") - logging.info("[INFO] Learning rate: %.2e" % m.set_learning_rate(learning_rate)) - logging.info("[INFO] L2 regularization lambda: %.2e" % m.set_l2_regularization_lambda(l2_regularization_lambda)) - - # Model Constants - training_start_time = time() - no_of_training_examples = ( - dataset_info.no_of_training_examples_from_train_binary or int(dataset_size * param.trainingDatasetPercentage) - ) - no_of_validation_examples = dataset_info.dataset_size - no_of_training_examples - no_of_blosc_blocks = utils.no_of_blosc_blocks_from( - dataset_info=dataset_info, - no_of_training_examples=no_of_training_examples, - blosc_block_size=param.bloscBlockSize - ) - no_of_training_blosc_blocks = int(no_of_training_examples / param.bloscBlockSize) - tensor_block_index_list = np.arange(no_of_blosc_blocks, dtype=int) - - total_numbers_of_iterations = np.ceil(no_of_training_examples / param.trainBatchSize+1) - step_size = param.stepsizeConstant * total_numbers_of_iterations - - # Initialize variables - epoch_count = 1 - if model_initalization_file_path is not None: - epoch_count = int(model_initalization_file_path[-param.parameterOutputPlaceHolder:])+1 - - global_step = 0 - - mini_batches_loaded = [] - - def load_mini_batch(data_index, blosc_index, first_blosc_block_data_index, tensor_block_index_list): - mini_batch = utils.new_mini_batch( - data_index=data_index, - blosc_start_index=blosc_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_training_examples=no_of_training_examples, - no_of_blosc_blocks=no_of_blosc_blocks, - dataset_info=dataset_info, - tensor_block_index_list=tensor_block_index_list, - ) - _, _, next_first_blosc_block_data_index, next_blosc_start_index = mini_batch - if next_first_blosc_block_data_index < 0 or next_blosc_start_index < 0: - return - mini_batches_loaded.append(mini_batch) - - while epoch_count <= param.lr_finder_max_epoch: - # init variables for process one epoch - epoch_start_time = time() - training_loss_sum = 0 - validation_loss_sum = 0 - data_index = 0 - blosc_index = 0 - first_blosc_block_data_index = 0 - x_batch, y_batch = None, None - - gt21_loss_sum = 0 - genotype_loss_sum = 0 - indel_length_loss_sum_1 = 0 - indel_length_loss_sum_2 = 0 - l2_loss_sum = 0 - - while True: - is_with_batch_data = x_batch is not None and y_batch is not None - is_training = is_with_batch_data and data_index < no_of_training_examples - is_validation = is_with_batch_data and not is_training - - thread_pool = [] - if is_training: - thread_pool.append(Thread(target=m.train, args=(x_batch, y_batch))) - elif is_validation: - thread_pool.append(Thread(target=m.validate, args=(x_batch, y_batch))) - thread_pool.append( - Thread( - target=load_mini_batch, - args=(data_index, blosc_index, first_blosc_block_data_index, tensor_block_index_list) - ) - ) - - for t in thread_pool: - t.start() - for t in thread_pool: - t.join() - - # add training loss or validation loss - if is_training: - training_loss_sum += m.training_loss_on_one_batch - batch_acc = accuracy(y_pred=m.prediction, y_true=y_batch) - lr_accuracy.append((learning_rate, batch_acc, m.training_loss_on_one_batch)) - if summary_writer is not None: - summary = m.training_summary_on_one_batch - summary_writer.add_summary(summary, epoch_count) - elif is_validation: - validation_loss_sum += m.validation_loss_on_one_batch - - gt21_loss_sum += m.gt21_loss - genotype_loss_sum += m.genotype_loss - indel_length_loss_sum_1 += m.indel_length_loss_1 - indel_length_loss_sum_2 += m.indel_length_loss_2 - l2_loss_sum += m.l2_loss - - if is_with_batch_data: - data_index += np.shape(x_batch)[0] - - have_next_mini_batch = len(mini_batches_loaded) > 0 - is_processed_a_mini_batch = len(thread_pool) > 0 - - if have_next_mini_batch: - x_batch, y_batch, first_blosc_block_data_index, blosc_index = mini_batches_loaded.pop(0) - learning_rate, global_step, _max_learning_rate = m.clr( - global_step, step_size, param.max_lr, "tri" - ) - if not have_next_mini_batch and not is_processed_a_mini_batch: - break - - logging.info( - " ".join([str(epoch_count), "Training loss:", str(training_loss_sum/no_of_training_examples)]) - ) - logging.info( - "\t".join([ - "{} Validation loss (Total/Base/Genotype/Indel_1_2):".format(epoch_count), - str(validation_loss_sum/no_of_validation_examples), - str(gt21_loss_sum/no_of_validation_examples), - str(genotype_loss_sum/no_of_validation_examples), - str(indel_length_loss_sum_1/no_of_validation_examples), - str(indel_length_loss_sum_2/no_of_validation_examples) - ]) - ) - - logging.info("[INFO] Epoch time elapsed: %.2f s" % (time() - epoch_start_time)) - training_losses.append((training_loss_sum, epoch_count)) - validation_losses.append((validation_loss_sum, epoch_count)) - - # Output the model - if output_file_path_prefix != None: - parameter_output_path = "%s-%%0%dd" % (output_file_path_prefix, param.parameterOutputPlaceHolder) - m.save_parameters(abspath(parameter_output_path % epoch_count)) - - # variables update per epoch - epoch_count += 1 - minimum_lr, maximum_lr, df = lr_finder(lr_accuracy) - logging.info("[INFO] min_lr: %g, max_lr: %g" % (minimum_lr, maximum_lr)) - df.to_csv("lr_finder.txt", sep=',', index=False) - - # shuffle data on each epoch - tensor_block_index_list = shuffle_first_n_items(tensor_block_index_list, no_of_training_blosc_blocks) - logging.info("[INFO] Shuffled: " + ' '.join( - [str(x) for x in np.append(tensor_block_index_list[:5], tensor_block_index_list[-5:])] - )) - - logging.info("[INFO] Training time elapsed: %.2f s" % (time() - training_start_time)) - return training_losses, validation_losses - - -if __name__ == "__main__": - - random.seed(param.RANDOM_SEED) - np.random.seed(param.RANDOM_SEED) - - parser = ArgumentParser(description="Learning rate finder") - - # binary file path - parser.add_argument('--bin_fn', type=str, default=None, - help="Binary tensor input generated by tensor2Bin.py, tensor_fn, var_fn and bed_fn will be ignored") - - # tensor file path - parser.add_argument('--tensor_fn', type=str, default="vartensors", help="Tensor input") - - # variant file path - parser.add_argument('--var_fn', type=str, default="truthvars", help="Truth variants list input") - - # bed file path - parser.add_argument('--bed_fn', type=str, default=None, - help="High confident genome regions input in the BED format") - - # checkpoint file path - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing or continue training") - - # learning rate, with default value stated in param - parser.add_argument('--learning_rate', type=float, default=param.initialLearningRate, - help="Set the initial learning rate, default: %(default)s") - - # l2 regularization - parser.add_argument('--lambd', type=float, default=param.l2RegularizationLambda, - help="Set the l2 regularization lambda, default: %(default)s") - - # output checkpint file path prefix - parser.add_argument('--ochk_prefix', type=str, default=None, - help="Prefix for checkpoint outputs at each learning rate change, REQUIRED") - - parser.add_argument('--olog_dir', type=str, default=None, - help="Directory for tensorboard log outputs, optional") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - # initialize - logging.info("[INFO] Initializing") - utils.setup_environment() - m = Clair() - m.init() - - dataset_info = utils.dataset_info_from( - binary_file_path=args.bin_fn, - tensor_file_path=args.tensor_fn, - variant_file_path=args.var_fn, - bed_file_path=args.bed_fn - ) - training_config = utils.TrainingConfig( - dataset_info=dataset_info, - learning_rate=args.learning_rate, - l2_regularization_lambda=args.lambd, - output_file_path_prefix=args.ochk_prefix, - model_initalization_file_path=args.chkpnt_fn, - summary_writer=m.get_summary_file_writer(args.olog_dir) if args.olog_dir != None else None, - ) - - _training_losses, validation_losses = train_model(m, training_config) - - # show the parameter set with the smallest validation loss - validation_losses.sort() - best_validation_epoch = validation_losses[0][1] - logging.info("[INFO] Best validation loss at epoch: %d" % best_validation_epoch) - - # load best validation model and evaluate it - model_file_path = "%s-%%0%dd" % (training_config.output_file_path_prefix, param.parameterOutputPlaceHolder) - best_validation_model_file_path = model_file_path % best_validation_epoch - m.restore_parameters(abspath(best_validation_model_file_path)) - evaluate.evaluate_model(m, dataset_info) diff --git a/benchmarks/nn-variant/clair/model.py b/benchmarks/nn-variant/clair/model.py deleted file mode 100644 index 4f777e8..0000000 --- a/benchmarks/nn-variant/clair/model.py +++ /dev/null @@ -1,1225 +0,0 @@ -import warnings -with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - from tensorflow.python.util import deprecation - deprecation._PRINT_DEPRECATION_WARNINGS = False - import tensorflow as tf - from tensorflow.python.client import device_lib - from tensorflow.python.ops import array_ops - -import numpy as np -import re -import multiprocessing -from sys import exit -from os.path import abspath -from argparse import ArgumentParser -from collections import defaultdict - -from clair.task.main import GT21, GENOTYPE, VARIANT_LENGTH_1, VARIANT_LENGTH_2 -import clair.selu as selu -import shared.param as param - - -class Clair(object): - """ - Keywords arguments: - float_type: The type of float to be used for tensorflow, default tf.float64 - input_shape: Shpae of the input tensor, a tuple or list of 3 integers - task_loss_weights: - The weights of different tasks in the calculation of total loss, list of 5 integers in order - (gt21, genotype, indel length, L2 regularization) - structure: The name of the structure, supporting "FC_L3_narrow_legacy_0.1, 2BiLST, CNN1D_L6, Res1D_L9M" - output_gt21_shape: The number of classes in the output of gt21 (alternate base) prediction - output_genotype_shape: The number of classes in the output of genotype prediction - output_indel_length_shape_1: The number of output values in the output of indel length prediction 1 - output_indel_length_shape_2: The number of output values in the output of indel length prediction 2 - output_weight_enabled: True enables per class weights speficied in output_*_entropy_weights (Slower) - output_gt21_entropy_weights: - A list of (output_gt21_shape) integers specifying the weights of different classes in - the calculation of entropy loss (Only used when output_weight_enabled is set to True) - output_genotype_entropy_weights: similar to output_gt21_entropy_weights - L1_num_units: Number of units in L1 - tensor_transform_function: - the function (callable) for transforming the input tensors to match the model, takes in - X_tensor, Y_tensor, and stage text ("train" or "predict") and - return the pair (transformed_X, transformed_Y) - i.e. type: tensor -> tensor -> str -> (tensor, tensor) - default: lambda X, Y, phase: (X, Y) (identity ignoring stage text), which is equivalent to def f(X, Y, phase): return (X, Y) - """ - COLORS_RGB = dict( - RED=[1.0, 0.0, 0.0], - GREEN=[0.0, 1.0, 0.0], - BLUE=[0.0, 0.0, 1.0], - WHITE=[1.0, 1.0, 1.0], - BLACK=[0.0, 0.0, 0.0] - ) - - def __init__(self, **kwargs): - - # Default params dictionary - params = dict( - float_type=tf.float64, - input_shape=(2 * param.flankingBaseNum + 1, param.matrixRow, param.matrixNum), - task_loss_weights=[ - 1, # gt21 - 1, # genotype - 1, # variant/indel length 0 - 1, # variant/indel length 1 - 1 # l2 loss - ], - structure="2BiLSTM", - output_gt21_shape=GT21.output_label_count, - output_genotype_shape=GENOTYPE.output_label_count, - output_indel_length_shape_1=VARIANT_LENGTH_1.output_label_count, - output_indel_length_shape_2=VARIANT_LENGTH_2.output_label_count, - output_gt21_entropy_weights=[1] * GT21.output_label_count, - output_genotype_entropy_weights=[1] * GENOTYPE.output_label_count, - output_indel_length_entropy_weights_1=[1] * VARIANT_LENGTH_1.output_label_count, - output_indel_length_entropy_weights_2=[1] * VARIANT_LENGTH_2.output_label_count, - L1_num_units=30, - L2_num_units=30, - L4_num_units=192, - L4_dropout_rate=0.5, - L5_1_num_units=96, - L5_1_dropout_rate=0.2, - L5_2_num_units=96, - L5_2_dropout_rate=0.2, - L5_3_num_units=96, - L5_3_dropout_rate=0.2, - L5_4_num_units=96, - L5_4_dropout_rate=0.2, - LSTM1_num_units=128, - LSTM2_num_units=128, - LSTM3_num_units=128, - LSTM1_dropout_rate=0, - LSTM2_dropout_rate=0.5, - LSTM3_dropout_rate=0.5, - initial_learning_rate=param.initialLearningRate, - learning_rate_decay=param.learningRateDecay, - l2_regularization_lambda=param.l2RegularizationLambda, - l2_regularization_lambda_decay_rate=param.l2RegularizationLambdaDecay, - tensor_transform_function=lambda X, Y, phase: (X, Y), - optimizer_name=param.default_optimizer, - loss_function=param.default_loss_function, - ) - - # Update params dictionary from the param.py file - params_from_file = param.get_model_parameters() - params.update(params_from_file) - - # Update params dictionary from kwargs - for key, value in kwargs.items(): - if key in params.keys(): - params[key] = value - else: - print("Info: the parameter %s, with value %s is not supported" % (key, value)) - - # Extract the values from the params dictionary - self.input_shape = params['input_shape'] - self.tensor_transform_function = params['tensor_transform_function'] - self.output_gt21_shape = params['output_gt21_shape'] - self.output_genotype_shape = params['output_genotype_shape'] - self.output_indel_length_shape_1 = params['output_indel_length_shape_1'] - self.output_indel_length_shape_2 = params['output_indel_length_shape_2'] - - self.task_loss_weights = np.array(params['task_loss_weights'], dtype=float) - - self.output_gt21_entropy_weights = np.array(params['output_gt21_entropy_weights'], dtype=float) - self.output_genotype_entropy_weights = np.array(params['output_genotype_entropy_weights'], dtype=float) - self.output_indel_length_entropy_weights_1 = np.array( - params['output_indel_length_entropy_weights_1'], dtype=float - ) - self.output_indel_length_entropy_weights_2 = np.array( - params['output_indel_length_entropy_weights_2'], dtype=float - ) - - self.L1_num_units = params['L1_num_units'] - self.L2_num_units = params['L2_num_units'] - self.L4_num_units = params['L4_num_units'] - self.L4_dropout_rate = params['L4_dropout_rate'] - self.L5_1_num_units = params['L5_1_num_units'] - self.L5_1_dropout_rate = params['L5_1_dropout_rate'] - self.L5_2_num_units = params['L5_2_num_units'] - self.L5_2_dropout_rate = params['L5_2_dropout_rate'] - self.L5_3_num_units = params['L5_3_num_units'] - self.L5_3_dropout_rate = params['L5_3_dropout_rate'] - self.L5_4_num_units = params['L5_4_num_units'] - self.L5_4_dropout_rate = params['L5_4_dropout_rate'] - - self.LSTM1_num_units = params['LSTM1_num_units'] - self.LSTM2_num_units = params['LSTM2_num_units'] - self.LSTM3_num_units = params['LSTM3_num_units'] - self.LSTM1_dropout_rate = params['LSTM1_dropout_rate'] - self.LSTM2_dropout_rate = params['LSTM2_dropout_rate'] - self.LSTM3_dropout_rate = params['LSTM3_dropout_rate'] - - self.learning_rate_value = params['initial_learning_rate'] - self.learning_rate_decay_rate = params['learning_rate_decay'] - self.l2_regularization_lambda_value = params['l2_regularization_lambda'] - self.l2_regularization_lambda_decay_rate = params['l2_regularization_lambda_decay_rate'] - self.structure = params['structure'] - self.optimizer_name = params['optimizer_name'] - self.loss_function = params['loss_function'] - - # Ensure the appropriate float datatype is used for Convolutional / Recurrent networks, - # which does not support tf.float64 - if 'CNN' in self.structure or 'Res' in self.structure or 'LSTM' in self.structure or 'GRU' in self.structure: - self.float_type = tf.float32 - else: - self.float_type = params['float_type'] - - # Specify the way to split the output ground truth label - self.output_label_split = [ - self.output_gt21_shape, - self.output_genotype_shape, - self.output_indel_length_shape_1, - self.output_indel_length_shape_2 - ] - - tf.set_random_seed(param.RANDOM_SEED) - self.g = tf.Graph() - self._build_graph() - - print("[INFO] Using %d CPU threads" % (param.NUM_THREADS)) - self.netcfg = tf.ConfigProto() - self.netcfg.intra_op_parallelism_threads = param.NUM_THREADS - self.netcfg.inter_op_parallelism_threads = param.NUM_THREADS - - self.session = tf.Session( - graph=self.g, - config=self.netcfg - ) - - @staticmethod - def get_available_gpus(): - """ - Return the names of gpu units available on the system - """ - local_device_protos = device_lib.list_local_devices() - return [x.name for x in local_device_protos if x.device_type == 'GPU'] - - def get_structure_dict(self, phase='train'): - """ - A function for getting the appropriate values for placeholders, based on whether the phase is "train" or not - Return: - A dictionary containing values for the placeholders - """ - if phase == 'train': - return { - self.L4_dropout_rate_placeholder: self.L4_dropout_rate, - self.L5_1_dropout_rate_placeholder: self.L5_1_dropout_rate, - self.L5_2_dropout_rate_placeholder: self.L5_2_dropout_rate, - self.L5_3_dropout_rate_placeholder: self.L5_3_dropout_rate, - self.L5_4_dropout_rate_placeholder: self.L5_4_dropout_rate - } - else: - return { - self.L4_dropout_rate_placeholder: 0.0, - self.L5_1_dropout_rate_placeholder: 0.0, - self.L5_2_dropout_rate_placeholder: 0.0, - self.L5_3_dropout_rate_placeholder: 0.0, - self.L5_4_dropout_rate_placeholder: 0.0 - } - - @staticmethod - def slice_dense_layer(inputs, units, slice_dimension, name="slice_dense", **kwargs): - """ - Specify a slice dense layer, which unpacks along the specified dimension and connects each position to another layer by full connections - e.g. A tensor of shape [4, 5] would be unpacked to 4 tensors with shape [5], and each of the tensor with shape [5] is fully connected - to another tensor with [units], and restacked to give a tensor with output shape [4, units] - inputs: The input tensor - units: The number of units for each position - slice_dimension: The index of the dimension to be sliced, following the order of tensor.shape - name: The name of the operation (variable scope) - **kwargs: Other parameters to be passed to the tf.layers.dense() function - """ - with tf.variable_scope(name): - sliced = tf.unstack(inputs, axis=slice_dimension, name=name + "Unstack") - slice_dense = tf.stack( - [tf.layers.dense(v, units=units, name="Unit_" + str(i), **kwargs) for i, v in enumerate(sliced)], - axis=slice_dimension, - name="Stacked" - ) - return slice_dense - - @staticmethod - def weighted_cross_entropy(softmax_prediction, labels, weights, epsilon, name): - """ - Compute cross entropy with per class weights - softmax_prediction: The softmaxed tensor produced by the model, should have shape (batch, number of output classes) - labels: The output labels in one-hot encoding - weights: The weights for each class, must have same shape as the number of classes in the output, i.e. the output shape - Return: - Tensor representing the weighted cross entropy, having shape of (batch size, ) - """ - return -tf.reduce_sum( - tf.multiply( - labels * tf.log(softmax_prediction + epsilon), - weights - ), - reduction_indices=[1], - name=name - ) - - @staticmethod - def adaptive_LSTM_layer(inputs, num_units, name="adaptive_LSTM", direction="bidirectional", num_layers=1, cudnn_gpu_available=False): - """ - A wrapper function for selecting the appropriate LSTM layer to use depending on whether cudnn compatible gpu is available - Args: - inputs: Tensor, The input tensor to the LSTM layer, time-major (i.e. in shape (time-steps, batch, sequence)) - num_units: int, The number of units in each direction (i.e. will have a total of 2 * num_units outputs for bidirectional LSTM) - direction: str, "bidirectional" for bidirectional LSTM, unidirectional otherwise - num_layers: int, the number of layers stacked together, each having the same number of units - cudnn_gpu_available: bool, if True, the Cudnn enabled version will be used, otherwise, a compatible version is used - Return: (outputs, output_states) - outputs: Tensor, containing the output of the LSTM - output_states: A tuple of two Tensors for bidirectional LSTM, the first one being the final state for the forward LSTM, and the second one is backward - If unidirectional, contains only a single Tensor for the final state of the LSTM - """ - with tf.variable_scope(name): - if cudnn_gpu_available: - lstm = tf.contrib.cudnn_rnn.CudnnLSTM( - num_layers=num_layers, - num_units=num_units, - direction=direction, - dtype=tf.float32, - kernel_initializer=tf.contrib.layers.variance_scaling_initializer( - factor=1.0, - mode='FAN_IN', - seed=param.OPERATION_SEED - ), - seed=param.OPERATION_SEED - ) - lstm.build(inputs.get_shape()) - outputs, output_states = lstm(inputs) - return outputs, output_states - - # print("[INFO] GPU not available") - if direction == "bidirectional": - def single_cell_generator(): - return tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) - # , reuse=tf.get_variable_scope().reuse - lstm_fw_cells = [single_cell_generator() for _ in range(num_layers)] - lstm_bw_cells = [single_cell_generator() for _ in range(num_layers)] - (outputs, output_state_fw, output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( - lstm_fw_cells, - lstm_bw_cells, - inputs, - dtype=tf.float32, - time_major=True - ) - return outputs, (output_state_fw, output_state_bw) - else: - def single_cell_generator(): - return tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(num_units) - # NOTE: Even if there's only one layer, the cell needs to be wrapped in MultiRNNCell. - cell = tf.nn.rnn_cell.MultiRNNCell([single_cell_generator() for _ in range(num_layers)]) - # Leave the scope arg unset. - outputs, final_state = tf.nn.dynamic_rnn( - cell, - inputs, - dtype=tf.float32, - time_major=True - ) - return outputs, final_state - - def _build_graph(self): - """ - Build the computation graph for the model - """ - - self.graph = self.g - self.layers = [] # A list used to contain meaningful intermediate layers - with self.graph.as_default(): - tf.set_random_seed(param.RANDOM_SEED) - - # Conversion to tensors for some values - self.epsilon = tf.constant(value=1e-10, dtype=self.float_type) - - # dimensions: batch size, # of bases (33), ACGTacgt (8), # of Channels (4) (reference, insertion, deletion, SNP) - self.input_shape_tf = (None, self.input_shape[0], self.input_shape[1], self.input_shape[2]) - self.output_shape_tf = ( - None, - self.output_gt21_shape + - self.output_genotype_shape + - self.output_indel_length_shape_1 + - self.output_indel_length_shape_2 - ) - - # Place holders - self.X_placeholder = tf.placeholder( - dtype=self.float_type, shape=self.input_shape_tf, name='X_placeholder' - ) - self.Y_placeholder = tf.placeholder( - dtype=self.float_type, shape=self.output_shape_tf, name='Y_placeholder' - ) - - # first layer, X_placeholder - self.layers.append(self.X_placeholder) - - self.learning_rate_placeholder = tf.placeholder( - dtype=self.float_type, shape=[], name='learning_rate_placeholder' - ) - self.phase_placeholder = tf.placeholder( - dtype=tf.bool, shape=[], name='phase_placeholder' - ) - self.regularization_L2_lambda_placeholder = tf.placeholder( - dtype=self.float_type, shape=[], name='regularization_L2_lambda_placeholder' - ) - self.task_loss_weights_placeholder = tf.placeholder( - dtype=self.float_type, shape=self.task_loss_weights.shape, name='task_loss_weights_placeholder' - ) - self.output_gt21_entropy_weights_placeholder = tf.placeholder( - dtype=self.float_type, - shape=self.output_gt21_entropy_weights.shape, - name='output_gt21_entropy_weights_placeholder' - ) - self.output_genotype_entropy_weights_placeholder = tf.placeholder( - dtype=self.float_type, - shape=self.output_genotype_entropy_weights.shape, - name='output_genotype_entropy_weights_placeholder' - ) - self.output_indel_length_entropy_weights_placeholder_1 = tf.placeholder( - dtype=self.float_type, - shape=self.output_indel_length_entropy_weights_1.shape, - name='output_indel_length_entropy_weights_placeholder_1' - ) - self.output_indel_length_entropy_weights_placeholder_2 = tf.placeholder( - dtype=self.float_type, - shape=self.output_indel_length_entropy_weights_2.shape, - name='output_indel_length_entropy_weights_placeholder_2' - ) - - he_initializer = tf.contrib.layers.variance_scaling_initializer( - factor=1.0, - mode='FAN_IN', - seed=param.OPERATION_SEED - ) - - if self.structure == "2BiLSTM": - # Flatten 2nd layer ACGTacgt (8), - # and 3rd layer # of Channels (4) (reference, insertion, deletion, SNP) - self.X_flattened_2D = tf.reshape( - tensor=self.X_placeholder, - shape=( - tf.shape(self.X_placeholder)[0], - self.input_shape_tf[1], - self.input_shape_tf[2] * self.input_shape_tf[3] - ), - name="X_flattened_2D" - ) - self.layers.append(self.X_flattened_2D) - - # the input shape in adaptive LSTM layer should be in shape (time-steps, batch_size, sequence) - # that is: (# of bases, batch_size, (# of ACGTacgt) * (# of channels)) - self.X_flattened_2D_transposed = tf.transpose( - self.X_flattened_2D, perm=[1, 0, 2], name="X_flattened_2D_transposed" - ) - - is_gpu_available = len(Clair.get_available_gpus()) > 0 - - # LSTM Layer (Layer 1) - self.LSTM1, self.LSTM1_state = Clair.adaptive_LSTM_layer( - inputs=self.X_flattened_2D_transposed, - num_units=self.LSTM1_num_units, - name="LSTM1", - direction="bidirectional", - num_layers=1, - cudnn_gpu_available=is_gpu_available - ) - self.layers.append(self.LSTM1) - - # print(self.LSTM1, self.LSTM1_state) - self.LSTM1_dropout = tf.layers.dropout( - inputs=self.LSTM1, - rate=self.LSTM1_dropout_rate, - training=self.phase_placeholder, - name="LSTM1_dropout", - seed=param.OPERATION_SEED - ) - - # LSTM Layer (Layer 2) - self.LSTM2, _ = Clair.adaptive_LSTM_layer( - inputs=self.LSTM1_dropout, - num_units=self.LSTM2_num_units, - name="LSTM2", - direction="bidirectional", - num_layers=1, - cudnn_gpu_available=is_gpu_available - ) - self.layers.append(self.LSTM2) - - self.LSTM2_dropout = tf.layers.dropout( - inputs=self.LSTM2, - rate=self.LSTM2_dropout_rate, - training=self.phase_placeholder, - name="LSTM2_dropout", - seed=param.OPERATION_SEED - ) - # revert the shape to (batch_size, # of bases, self.LSTM2_num_units * 2) - self.LSTM2_transposed = tf.transpose(self.LSTM2_dropout, [1, 0, 2], name="LSTM2_transposed") - - # Slice dense layer (Layer 3) - self.L3 = Clair.slice_dense_layer( - inputs=self.LSTM2_transposed, - units=self.L2_num_units, - slice_dimension=2, - name="L3", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.layers.append(self.L3) - - self.L3_flattened = tf.reshape( - self.L3, - shape=(tf.shape(self.L3)[0], self.L2_num_units * self.LSTM2_num_units * 2), - name="L3_flattened" - ) - self.layers.append(self.L3_flattened) - - # Dense layer (Layer 4) - self.L4 = tf.layers.dense( - inputs=self.L3_flattened, - units=self.L4_num_units, - name="L4", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.layers.append(self.L4) - - self.L4_dropout_rate_placeholder = tf.placeholder( - self.float_type, shape=[], name='L4_dropout_rate_placeholder' - ) - - self.L4_dropout = selu.dropout_selu( - x=self.L4, - rate=self.L4_dropout_rate_placeholder, - training=self.phase_placeholder, - name="L4_dropout", - seed=param.OPERATION_SEED - ) - self.layers.append(self.L4_dropout) - - self.L5_1_dropout_rate_placeholder = tf.placeholder( - self.float_type, shape=[], name='L5_1_dropout_rate_placeholder' - ) - self.L5_1 = tf.layers.dense( - inputs=self.L4_dropout, - units=self.L5_1_num_units, - name="L5_1", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.L5_1_dropout = selu.dropout_selu( - x=self.L5_1, - rate=self.L5_1_dropout_rate_placeholder, - training=self.phase_placeholder, - name="L5_1_dropout", - seed=param.OPERATION_SEED - ) - self.layers.append(self.L5_1_dropout) - - self.L5_2_dropout_rate_placeholder = tf.placeholder( - self.float_type, shape=[], name='L5_2_dropout_rate_placeholder' - ) - self.L5_2 = tf.layers.dense( - inputs=self.L4_dropout, - units=self.L5_2_num_units, - name="L5_2", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.L5_2_dropout = selu.dropout_selu( - x=self.L5_2, - rate=self.L5_2_dropout_rate_placeholder, - training=self.phase_placeholder, - name="L5_2_dropout", - seed=param.OPERATION_SEED - ) - self.layers.append(self.L5_2_dropout) - - self.L5_3_dropout_rate_placeholder = tf.placeholder( - self.float_type, shape=[], name='L5_3_dropout_rate_placeholder' - ) - self.L5_3 = tf.layers.dense( - inputs=self.L4_dropout, - units=self.L5_3_num_units, - name="L5_3", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.L5_3_dropout = selu.dropout_selu( - x=self.L5_3, - rate=self.L5_3_dropout_rate_placeholder, - training=self.phase_placeholder, - name="L5_3_dropout", - seed=param.OPERATION_SEED - ) - self.layers.append(self.L5_3_dropout) - - self.L5_4_dropout_rate_placeholder = tf.placeholder( - self.float_type, shape=[], name='L5_4_dropout_rate_placeholder' - ) - self.L5_4 = tf.layers.dense( - inputs=self.L4_dropout, - units=self.L5_4_num_units, - name="L5_4", - activation=selu.selu, - kernel_initializer=he_initializer - ) - self.L5_4_dropout = selu.dropout_selu( - x=self.L5_4, - rate=self.L5_4_dropout_rate_placeholder, - training=self.phase_placeholder, - name="L5_4_dropout", - seed=param.OPERATION_SEED - ) - self.layers.append(self.L5_4_dropout) - - # Output layer - with tf.variable_scope("Prediction"): - self.Y_gt21_logits = tf.layers.dense( - inputs=self.L5_1_dropout, - units=self.output_gt21_shape, - kernel_initializer=he_initializer, - activation=selu.selu, - name='Y_base_change_logits' - ) - self.Y_gt21 = tf.nn.softmax(self.Y_gt21_logits, name='Y_base_change') - self.layers.append(self.Y_gt21) - - self.Y_genotype_logits = tf.layers.dense( - inputs=self.L5_2_dropout, - units=self.output_genotype_shape, - kernel_initializer=he_initializer, - activation=selu.selu, - name='Y_genotype_logits' - ) - self.Y_genotype = tf.nn.softmax(self.Y_genotype_logits, name='Y_genotype') - self.layers.append(self.Y_genotype) - - self.Y_indel_length_logits_1 = tf.layers.dense( - inputs=self.L5_3_dropout, - units=self.output_indel_length_shape_1, - kernel_initializer=he_initializer, - activation=selu.selu, - name='Y_indel_length_logits_1' - ) - self.Y_indel_length_1 = tf.nn.softmax(self.Y_indel_length_logits_1, name='Y_indel_length_1') - self.layers.append(self.Y_indel_length_logits_1) - - self.Y_indel_length_logits_2 = tf.layers.dense( - inputs=self.L5_4_dropout, - units=self.output_indel_length_shape_2, - kernel_initializer=he_initializer, - activation=selu.selu, - name='Y_indel_length_logits_2' - ) - self.Y_indel_length_2 = tf.nn.softmax(self.Y_indel_length_logits_2, name='Y_indel_length_2') - self.layers.append(self.Y_indel_length_logits_2) - - self.Y = [self.Y_gt21, self.Y_genotype, self.Y_indel_length_1, self.Y_indel_length_2] - - # Extract the truth labels by output ratios - with tf.variable_scope("Loss"): - Y_gt21_label, Y_genotype_label, Y_indel_length_label_1, Y_indel_length_label_2 = tf.split( - self.Y_placeholder, self.output_label_split, axis=1, name="label_split" - ) - - if self.loss_function == "CrossEntropy": - self.Y_gt21_cross_entropy = Clair.weighted_cross_entropy( - softmax_prediction=self.Y_gt21, - labels=Y_gt21_label, - weights=self.output_gt21_entropy_weights_placeholder, - epsilon=self.epsilon, - name="Y_base_change_cross_entropy" - ) - self.Y_gt21_loss = tf.reduce_sum(self.Y_gt21_cross_entropy, name="Y_gt21_loss") - - self.Y_genotype_cross_entropy = Clair.weighted_cross_entropy( - softmax_prediction=self.Y_genotype, - labels=Y_genotype_label, - weights=self.output_genotype_entropy_weights_placeholder, - epsilon=self.epsilon, - name="Y_genotype_cross_entropy" - ) - self.Y_genotype_loss = tf.reduce_sum(self.Y_genotype_cross_entropy, name="Y_genotype_loss") - - self.Y_indel_length_cross_entropy_1 = Clair.weighted_cross_entropy( - softmax_prediction=self.Y_indel_length_1, - labels=Y_indel_length_label_1, - weights=self.output_indel_length_entropy_weights_placeholder_1, - epsilon=self.epsilon, - name="Y_indel_length_cross_entropy_1" - ) - self.Y_indel_length_loss_1 = tf.reduce_sum( - self.Y_indel_length_cross_entropy_1, name="Y_indel_length_loss_1" - ) - - self.Y_indel_length_cross_entropy_2 = Clair.weighted_cross_entropy( - softmax_prediction=self.Y_indel_length_2, - labels=Y_indel_length_label_2, - weights=self.output_indel_length_entropy_weights_placeholder_2, - epsilon=self.epsilon, - name="Y_indel_length_cross_entropy_2" - ) - self.Y_indel_length_loss_2 = tf.reduce_sum( - self.Y_indel_length_cross_entropy_2, name="Y_indel_length_loss_2" - ) - - else: - self.Y_gt21_loss = Clair.focal_loss( - prediction_tensor=self.Y_gt21_logits, - target_tensor=Y_gt21_label, - ) - self.Y_genotype_loss = Clair.focal_loss( - prediction_tensor=self.Y_genotype_logits, - target_tensor=Y_genotype_label, - ) - self.Y_indel_length_loss_1 = Clair.focal_loss( - prediction_tensor=self.Y_indel_length_logits_1, - target_tensor=Y_indel_length_label_1, - ) - self.Y_indel_length_loss_2 = Clair.focal_loss( - prediction_tensor=self.Y_indel_length_logits_2, - target_tensor=Y_indel_length_label_2, - ) - - self.regularization_L2_loss_without_lambda = tf.add_n([ - tf.nn.l2_loss(v) for v in tf.trainable_variables() if 'bias' not in v.name - ]) - self.regularization_L2_loss = ( - self.regularization_L2_loss_without_lambda * self.regularization_L2_lambda_placeholder - ) - - # Weighted average of losses - self.total_loss = tf.reduce_sum( - tf.multiply( - self.task_loss_weights_placeholder, - tf.stack([ - self.Y_gt21_loss, - self.Y_genotype_loss, - self.Y_indel_length_loss_1, - self.Y_indel_length_loss_2, - self.regularization_L2_loss - ]) - ), - name="Total_loss" - ) - - # Create the saver for the model - self.saver = tf.train.Saver(max_to_keep=1000000,) - - # Include gradient clipping if RNN architectures are used - if "RNN" in self.structure or "LSTM" in self.structure: - with tf.variable_scope("Training_Operation"): - if self.optimizer_name == "Adam": - self.optimizer = tf.train.AdamOptimizer( - learning_rate=self.learning_rate_placeholder - ) - elif self.optimizer_name == "SGDM": - self.optimizer = tf.train.MomentumOptimizer( - learning_rate=self.learning_rate_placeholder, - momentum=param.momentum - ) - gradients, variables = list(zip(*self.optimizer.compute_gradients(self.total_loss))) - gradients, _ = tf.clip_by_global_norm(gradients, 5.0) - self.training_op = self.optimizer.apply_gradients(list(zip(gradients, variables))) - else: - if self.optimizer_name == "Adam": - self.training_op = tf.train.AdamOptimizer( - learning_rate=self.learning_rate_placeholder - ).minimize(self.total_loss) - elif self.optimizer_name == "SGDM": - self.optimizer = tf.train.MomentumOptimizer( - learning_rate=self.learning_rate_placeholder, - momentum=param.momentum - ).minimize(self.total_loss) - - self.init_op = tf.global_variables_initializer() - - # Summary logging - self.training_summary_op = tf.summary.merge([ - tf.summary.scalar('learning_rate', self.learning_rate_placeholder), - tf.summary.scalar('l2_Lambda', self.regularization_L2_lambda_placeholder), - tf.summary.scalar("Y_gt21_loss", self.Y_gt21_loss), - tf.summary.scalar("Y_genotype_loss", self.Y_genotype_loss), - tf.summary.scalar("Y_indel_length_loss_1", self.Y_indel_length_loss_1), - tf.summary.scalar("Y_indel_length_loss_2", self.Y_indel_length_loss_2), - tf.summary.scalar("Regularization_loss", self.regularization_L2_loss), - tf.summary.scalar("Total_loss", self.total_loss) - ]) - - # For report or debug. Fetching histogram summary is slow, GPU utilization will be low if enabled. - # for var in tf.trainable_variables(): - # tf.summary.histogram(var.op.name, var) - # self.merged_summary_op = tf.summary.merge_all() - - # Aliasing - self.loss = self.total_loss - - # Getting the total number of traininable parameters - # total_parameters = 0 - # for variable in tf.trainable_variables(): - # shape is an array of tf.Dimension - # shape = variable.get_shape() - # print(variable.name, shape) - # print(len(shape)) - # variable_parameters = 1 - # try: - # for dim in shape: - # print(dim) - # variable_parameters *= dim.value - # total_parameters += variable_parameters - # except ValueError as ve: - # if the shape cannot be obtained, (e.g. opaque operators) - # print("Variable {:s} has unknown shape.".format(variable.name)) - # print(ve.message) - # print(variable_parameters) - # - # print("Total Trainable Parameters: " + str(total_parameters)) - - @staticmethod - def focal_loss(prediction_tensor, target_tensor, alpha=0.25, gamma=2): - softmax_p = tf.nn.softmax(prediction_tensor) - - # array_ops.zeros_like(tensor, dtype): - # create a tensor with all elements set to zero, with the same shape as tensor - zeros = array_ops.zeros_like(softmax_p, dtype=softmax_p.dtype) - - # For positive prediction, only need consider front part loss, back part is 0; - # target_tensor > zeros <=> z=1, so positive coefficient = z - p. - # - # array_ops.where(condition, x, y): - # return the elements, either from x or y, depending on the condition - pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - softmax_p, zeros) - - # For negative prediction, only need consider back part loss, front part is 0; - # target_tensor > zeros <=> z=1, so negative coefficient = 0. - neg_p_sub = array_ops.where(target_tensor > zeros, zeros, softmax_p) - per_entry_cross_ent = -( - (pos_p_sub ** gamma) * tf.log(tf.clip_by_value(0.0 + softmax_p, 1e-8, 1.0)) + - (neg_p_sub ** gamma) * tf.log(tf.clip_by_value(1.0 - softmax_p, 1e-8, 1.0)) - ) - return tf.reduce_sum(per_entry_cross_ent) - - def init(self): - """ - Initialize the model by running the init_op and create the summary writer - """ - # self.current_summary_writer = tf.summary.FileWriter('logs', self.session.graph) - # print("Preparing to run init") - self.session.run(self.init_op) - - def get_summary_op_factory(self, render_function, name="Render", *args_factory, **kwargs_factory): - """ - (Experimental, unstable when using with matplotlib) - Wrap the rendering function as a tensorflow operation - """ - - def _get_tensor_render_op(in_tensor, *args_func, **kwargs_func): - - def _render_function_wrap(in_tensor, *args): - img_arrays = [render_function(matrix, *args, **kwargs_func) for matrix in in_tensor] - return np.stack(img_arrays, axis=0) - # img_array = render_function(*args, **kwargs_func) - # return np.expand_dims(img_array, axis=0) - return tf.py_func(_render_function_wrap, [in_tensor] + list(args_func), Tout=tf.uint8, name="Plot") - - def _summary_op_factory(summary_name, in_tensor, *args_call, **kwargs_call): - tf_render_op = _get_tensor_render_op(in_tensor, *args_call, **kwargs_call) - - # with tf.name_scope("Batch_Plot"): - # unstack_layer = tf.unstack(in_tensor, name='unstack') - # tensor_render_ops = [_get_tensor_render_op(slice_tensor, *args_call, **kwargs_call) for slice_tensor in unstack_layer] - # image_stacked = tf.stack(tensor_render_ops, name='stack_images') - - return tf.summary.image(summary_name, tf_render_op, - max_outputs=kwargs_call.pop('max_outputs', 3), - collections=kwargs_call.pop('collections', None), - ) - return _summary_op_factory - - @staticmethod - def recursive_process_tensor(tensor, apply_function, recursion_text="", target_ndim=2, last_first=False, sparator="-", *args, **kwargs): - """ - A general function processing tensors if they have larger dimension than the target ndim, calling the apply_function for each sliced tensor and - group the output in a list - Arguments: - tensor: Numpy Array, the tensor to be processed - apply_function: a function to be called for a tensor with the target_ndim - recursion_text: str, used internally, where in each round, the position of the corresponding matrix is appending to this string, together with a separator - e.g. a seed of "ABC" will become "ABC-2" in the next layer for position 2 and separator - - target_ndim: int, the target number of dimensions to stop the recursion and call the function - last_first: bool, expand the last dimension first - sparator: str, for appending when each dimension is processed - *args, **kwargs: other arguments to be passed to the function "apply_function" - Returns: - A list containing all the results from apply_function(sliced_tensor) - """ - - if tensor.ndim <= target_ndim: - return [apply_function(tensor, recursion_text, *args, **kwargs)] - else: - if last_first: - rolled_tensor = np.rollaxis(tensor, -1) - recursion_text += sparator - processed = [Clair.recursive_process_tensor(subtensor, apply_function, recursion_text + str(i), target_ndim=target_ndim, last_first=last_first, *args, **kwargs) - for i, subtensor in enumerate(rolled_tensor)] - return [item for sublist in processed for item in sublist] - - def close(self): - """ - Closes the current tf session - """ - self.session.close() - - def lr_train(self, batchX, batchY): - """ - Train the model in batch with input tensor batchX and truth tensor batchY - The tensor transform function is applied prior to training - Returns: - prediction: predictions from the model in batch - training_loss: training loss from the batch - summary: tf.summary of the training - """ - transformed_batch_X, transformed_batch_Y = self.tensor_transform_function(batchX, batchY, "train") - - input_dictionary = { - self.X_placeholder: transformed_batch_X, - self.Y_placeholder: transformed_batch_Y, - self.learning_rate_placeholder: self.learning_rate_value, - self.phase_placeholder: True, - self.regularization_L2_lambda_placeholder: self.l2_regularization_lambda_value, - self.task_loss_weights_placeholder: self.task_loss_weights, - self.output_gt21_entropy_weights_placeholder: self.output_gt21_entropy_weights, - self.output_genotype_entropy_weights_placeholder: self.output_genotype_entropy_weights, - self.output_indel_length_entropy_weights_placeholder_1: self.output_indel_length_entropy_weights_1, - self.output_indel_length_entropy_weights_placeholder_2: self.output_indel_length_entropy_weights_2, - } - input_dictionary.update(self.get_structure_dict(phase='train')) - - prediction, training_loss, _, summary = self.session.run( - (self.Y, self.loss, self.training_op, self.training_summary_op), - feed_dict=input_dictionary - ) - self.prediction = prediction - self.training_loss_on_one_batch = training_loss - self.training_summary_on_one_batch = summary - - return prediction, training_loss, summary - - def train(self, batchX, batchY): - """ - Train the model in batch with input tensor batchX and truth tensor batchY - The tensor transform function is applied prior to training - Returns: - training_loss: training loss value from the batch - summary: tf.summary of the training - """ - transformed_batch_X, transformed_batch_Y = self.tensor_transform_function(batchX, batchY, "train") - - input_dictionary = { - self.X_placeholder: transformed_batch_X, - self.Y_placeholder: transformed_batch_Y, - self.learning_rate_placeholder: self.learning_rate_value, - self.phase_placeholder: True, - self.regularization_L2_lambda_placeholder: self.l2_regularization_lambda_value, - self.task_loss_weights_placeholder: self.task_loss_weights, - self.output_gt21_entropy_weights_placeholder: self.output_gt21_entropy_weights, - self.output_genotype_entropy_weights_placeholder: self.output_genotype_entropy_weights, - self.output_indel_length_entropy_weights_placeholder_1: self.output_indel_length_entropy_weights_1, - self.output_indel_length_entropy_weights_placeholder_2: self.output_indel_length_entropy_weights_2, - } - input_dictionary.update(self.get_structure_dict(phase='train')) - - training_loss, _, summary = self.session.run( - (self.loss, self.training_op, self.training_summary_op), - feed_dict=input_dictionary - ) - self.training_loss_on_one_batch = training_loss - self.training_summary_on_one_batch = summary - - return training_loss, summary - - def predict(self, batchX): - """ - Predict using model in batch with input tensor batchX, - The tensor transform function is applied prior to prediction - Returns: - prediction: predictions from the model in batch - """ - transformed_batch_X, _ = self.tensor_transform_function(batchX, None, "predict") - - input_dictionary = { - self.X_placeholder: transformed_batch_X, - self.learning_rate_placeholder: 0.0, - self.phase_placeholder: False, - self.regularization_L2_lambda_placeholder: 0.0 - } - input_dictionary.update(self.get_structure_dict(phase='predict')) - - prediction = self.session.run(self.Y, feed_dict=input_dictionary) - self.prediction = prediction - - return prediction - - def validate(self, batchX, batchY): - """ - Getting the loss using model in batch with input tensor batchX and truth tensor batchY - The tensor transform function is applied prior to getting loss - Returns: - loss: The loss value for this batch - """ - transformed_batch_X, transformed_batch_Y = self.tensor_transform_function(batchX, batchY, "predict") - input_dictionary = { - self.X_placeholder: transformed_batch_X, - self.Y_placeholder: transformed_batch_Y, - self.learning_rate_placeholder: 0.0, - self.phase_placeholder: False, - self.regularization_L2_lambda_placeholder: 0.0, - self.task_loss_weights_placeholder: self.task_loss_weights, - self.output_gt21_entropy_weights_placeholder: self.output_gt21_entropy_weights, - self.output_genotype_entropy_weights_placeholder: self.output_genotype_entropy_weights, - self.output_indel_length_entropy_weights_placeholder_1: self.output_indel_length_entropy_weights_1, - self.output_indel_length_entropy_weights_placeholder_2: self.output_indel_length_entropy_weights_2, - } - input_dictionary.update(self.get_structure_dict(phase='predict')) - - loss, gt21_loss, genotype_loss, indel_length_loss_1, indel_length_loss_2, l2_loss = self.session.run([ - self.loss, - self.Y_gt21_loss, - self.Y_genotype_loss, - self.Y_indel_length_loss_1, - self.Y_indel_length_loss_2, - self.regularization_L2_loss_without_lambda - ], feed_dict=input_dictionary) - - self.validation_loss_on_one_batch = loss - - self.gt21_loss = gt21_loss - self.genotype_loss = genotype_loss - self.indel_length_loss = indel_length_loss_1 + indel_length_loss_2 - self.indel_length_loss_1 = indel_length_loss_1 - self.indel_length_loss_2 = indel_length_loss_2 - self.l2_loss = l2_loss * param.l2RegularizationLambda - - return loss - - def save_parameters(self, file_name): - """ - Save the parameters (weights) to the specific file (file_name) - """ - self.saver.save(self.session, file_name) - - def restore_parameters(self, file_name): - """ - Restore the parameters (weights) from the specific file (file_name) - """ - self.saver.restore(self.session, file_name) - - def get_variable_objects(self, regular_expression): - """ - Get all variable objects from the graph matching the regular expression - Returns: - variable_list: list of tf variable objects - """ - regex = re.compile(regular_expression) - variable_list = [] - with self.graph.as_default(): - tf.set_random_seed(param.RANDOM_SEED) - for variable in tf.trainable_variables(): - if regex.match(variable.name): - variable_list.append(variable) - return variable_list - - def get_operation_objects(self, regular_expression, exclude_expression=".*(grad|tags|Adam).*"): - """ - Get all operation objects from the graph matching the regular expression, but not the exclude_expression - Returns: - operation_list: list of tf operation objects - """ - regex = re.compile(regular_expression) - regex_exclude = re.compile(exclude_expression) - operation_list = [] - - for op in self.graph.get_operations(): - if regex.match(op.name) and not regex_exclude.match(op.name): - print(op.name) - operation_list.append(op) - return operation_list - - def get_summary_file_writer(self, logs_path): - """ - Generate a new tf summary File writer with the specified log path - returns: A tf.summary.FileWriter object - """ - # if hasattr(self, "current_summary_writer"): - # self.current_summary_writer.close() - # self.current_summary_writer = tf.summary.FileWriter(logs_path, graph=self.graph) - # return self.current_summary_writer - return None - - def set_task_loss_weights(self, task_loss_weights=[1, 1, 1, 1, 1]): - """ - Assign a set new task loss weights for training - Arguments: - task_loss_weights: A list of numbers specifying the weights to the tasks - """ - self.task_loss_weights = np.array(task_loss_weights, dtype=float) - - def set_learning_rate(self, learning_rate): - """ - Assign a new learning rate - """ - self.learning_rate_value = learning_rate - return self.learning_rate_value - - def decay_learning_rate(self): - """ - Decay the learning rate by the predefined decay rate - """ - self.learning_rate_value = self.learning_rate_value * self.learning_rate_decay_rate - return self.learning_rate_value - - def clr(self, global_step, step_size, max_lr, mode="tri"): - """ - Cyclical Learning Rate - """ - global_step += 1 - cycle = 1 + global_step / (2 * step_size) - if cycle > 2: - global_step = 0 - if mode == "exp": - max_lr = max_lr * param.clrGamma ** (1) - elif mode == "tri2": - max_lr = max_lr / 2 - x = global_step / step_size - if x <= 1: - self.learning_rate_value = param.clr_min_lr + (max_lr - param.clr_min_lr) * np.maximum(0, x) - else: - self.learning_rate_value = param.clr_min_lr + (max_lr - param.clr_min_lr) * np.maximum(0, (2 - x)) - return self.learning_rate_value, global_step, max_lr - - def set_l2_regularization_lambda(self, l2_regularization_lambda): - """ - Assign a new l2_regularization_lambda value - """ - self.l2_regularization_lambda_value = l2_regularization_lambda - return self.l2_regularization_lambda_value - - def decay_l2_regularization_lambda(self): - """ - Decay the l2_regularization_lambda value by the predefined decay rate - """ - self.l2_regularization_lambda_value = self.l2_regularization_lambda_value * self.l2_regularization_lambda_decay_rate - return self.l2_regularization_lambda_value - - def pretty_print_variables(self, regular_expression): - variable_list = self.get_variable_objects(regular_expression) - result_string_list = [] - for v in variable_list: - variable_value = self.session.run(v) - result_string_list.append(v.name) - result_string_list.append(Clair.pretty_print_np_tensor(variable_value) + '\n') - return '\n'.join(result_string_list) - - @staticmethod - def pretty_print_np_tensor(tensor, element_separator='\t'): - """ - Print a numpy array (tensor) formatted with [], new lines and the element_separator - Returns: - A string containing the formatted tensor - """ - if tensor.ndim == 1: - return element_separator.join(('%.16f') % value for value in tensor) - elif tensor.ndim == 2: - return_list = [] - for row in tensor: - return_list.append(Clair.pretty_print_np_tensor(row, element_separator=element_separator)) - return '\n'.join(return_list) - else: - return_list = [] - for sub_tensor in tensor: - return_list.append('[\n' + Clair.pretty_print_np_tensor(sub_tensor, - element_separator=element_separator) + '\n]') - return '\n'.join(return_list) - - def __del__(self): - # if hasattr(self, "current_summary_writer"): - # self.current_summary_writer.close() - self.session.close() - - -class FunctionCallConsumer(multiprocessing.Process): - """ - A class implementing thread safe consumer which does a function call for each task - Init Arguments: - target_function: callable, when a task is obtained from the task_queue, the fucntion is called in the args and kwargs from the queue - task_queue: the task queue, recommend using multiprocessing.JoinableQueue(), each object put into this queue should be a tuple of size 3: - (identity, args, kwargs). The identity is only used for identifying the result of the task, and won't be passed to the function - result_dict: The result dictionary, where the result is put as result_dict[identity] = f(*args, **kwargs) - name: name of the consumer, for printing - verbose: printing out message if True - """ - - def __init__(self, target_function, task_queue, result_dict, name="c", verbose=False): - multiprocessing.Process.__init__(self) - self.target_function = target_function - self.task_queue = task_queue - self.result_dict = result_dict - self.name = name - self.verbose = verbose - - def run(self): - """ - Start the consumer, the consumer stops whenever a None value is put into the queue - """ - if self.verbose: - print("Consumer {:s} is starting.".format(self.name)) - while True: - next_task = self.task_queue.get() - if next_task is None: - self.task_queue.task_done() - break - - # identity, f, args, kwargs = next_task["identity"], next_task["f"], next_task["args"], next_task["kwargs"] - # identity, f, args, kwargs = next_task - identity, args, kwargs = next_task - # answer = f(*args, **kwargs) - answer = self.target_function(*args, **kwargs) - self.task_queue.task_done() - # self.result_queue.put((identity, answer)) - self.result_dict[identity] = (identity, answer) - if self.verbose: - print("Consumer {:s} finished".format(self.name), identity) - if self.verbose: - print("Consumer {:s} is terminating.".format(self.name)) - return - - -if __name__ == "__main__": - parser = ArgumentParser(description="Model") - - parser.add_argument('-v', '--variables', type=str, default=None, - help="Print variables matching the regular expression. default: %(default)s") - - parser.add_argument('-r', '--restore_model', type=str, default=None, - help="The path to the model to be restored. default: %(default)s") - - parser.add_argument('-s', '--save_model', type=str, default=None, - help="The path where the model is to be saved. default: %(default)s") - - parser.add_argument('-l', '--log_dir', type=str, default="logs", - help="The path to the log directory. default: %(default)s") - - args = parser.parse_args() - - if args.variables is not None: - q = Clair() - q.init() - if args.restore_model is not None: - q.restore_parameters(abspath(args.restore_model)) - print(q.pretty_print_variables(args.variables)) - exit(0) diff --git a/benchmarks/nn-variant/clair/plot_tensor.py b/benchmarks/nn-variant/clair/plot_tensor.py deleted file mode 100644 index 1011532..0000000 --- a/benchmarks/nn-variant/clair/plot_tensor.py +++ /dev/null @@ -1,108 +0,0 @@ -import sys -import os -import numpy as np -import matplotlib -matplotlib.use('Agg') -from matplotlib import pyplot as plt -from argparse import ArgumentParser - -from clair.utils import setup_environment - -def plot_tensor(ofn, XArray): - plot = plt.figure(figsize=(15, 8)) - - plot_min = -30 - plot_max = 30 - plot_arr = ["A+", "C+", "G+", "T+", "A-", "C-", "G-", "T-"] - - plt.subplot(4, 1, 1) - plt.xticks(np.arange(0, 33, 1)) - plt.yticks(np.arange(0, 8, 1), plot_arr) - plt.imshow(XArray[0, :, :, 0].transpose(), vmin=0, vmax=plot_max, interpolation="nearest", cmap=plt.cm.hot) - plt.colorbar() - - plt.subplot(4, 1, 2) - plt.xticks(np.arange(0, 33, 1)) - plt.yticks(np.arange(0, 8, 1), plot_arr) - plt.imshow(XArray[0, :, :, 1].transpose(), vmin=plot_min, vmax=plot_max, interpolation="nearest", cmap=plt.cm.bwr) - plt.colorbar() - - plt.subplot(4, 1, 3) - plt.xticks(np.arange(0, 33, 1)) - plt.yticks(np.arange(0, 8, 1), plot_arr) - plt.imshow(XArray[0, :, :, 2].transpose(), vmin=plot_min, vmax=plot_max, interpolation="nearest", cmap=plt.cm.bwr) - plt.colorbar() - - plt.subplot(4, 1, 4) - plt.xticks(np.arange(0, 33, 1)) - plt.yticks(np.arange(0, 8, 1), plot_arr) - plt.imshow(XArray[0, :, :, 3].transpose(), vmin=plot_min, vmax=plot_max, interpolation="nearest", cmap=plt.cm.bwr) - plt.colorbar() - - plot.savefig(ofn, dpi=300, transparent=True, bbox_inches='tight') - plt.close(plot) - - -def create_png(args): - f = open(args.array_fn, 'r') - array = f.read() - f.close() - import re - array = re.split("\n", array) - array = [x for x in array if x] - print(array) - - splitted_array = [] - for i in range(len(array)): - splitted_array += re.split(",", array[i]) - - print("splitted array length") - print(len(splitted_array)) - print(splitted_array[0]) - # for i in range(len(splitted_array)): - # splitted_array[i] = int(splitted_array[i]) - - XArray = np.array(splitted_array, dtype=np.float32).reshape((-1, 33, 8, 4)) - XArray[0, :, :, 1] -= XArray[0, :, :, 0] - XArray[0, :, :, 2] -= XArray[0, :, :, 0] - XArray[0, :, :, 3] -= XArray[0, :, :, 0] - - _YArray = np.zeros((1, 16)) - varName = args.name - print("Plotting %s..." % (varName), file=sys.stderr) - - # Create folder - if not os.path.exists(varName): - os.makedirs(varName) - - # Plot tensors - plot_tensor(varName+"/tensor.png", XArray) - - -def ParseArgs(): - parser = ArgumentParser( - description="Visualize tensors and hidden layers in PNG") - - parser.add_argument('--array_fn', type=str, default="vartensors", - help="Array input") - - parser.add_argument('--name', type=str, default=None, - help="output name") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - return args - - -def main(): - args = ParseArgs() - setup_environment() - create_png(args) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/post_processing/ensemble.py b/benchmarks/nn-variant/clair/post_processing/ensemble.py deleted file mode 100644 index 7870f58..0000000 --- a/benchmarks/nn-variant/clair/post_processing/ensemble.py +++ /dev/null @@ -1,109 +0,0 @@ -from sys import stdin, stderr, argv, exit -from collections import namedtuple, defaultdict -from argparse import ArgumentParser - -EnsembleConfig = namedtuple('EnsembleConfig', [ - 'minimum_count_to_output', -]) - - -def dicts_from_stdin(): - counter = defaultdict(lambda: 0) - - sequence_dict = {} - tensor_dict = {} - probabilities_dict = {} - - for row in stdin.readlines(): - columns = row.split(sep="\t") - - chromosome, position, sequence = columns[0], columns[1], columns[2] - - key = (chromosome, position) - - counter[key] = counter[key] + 1 - - if not key in sequence_dict: - sequence_dict[key] = sequence - - if not key in tensor_dict: - tensor = [int(str_value) for str_value in columns[3:3 + 33*8*4]] - tensor_dict[key] = tensor - - if not key in probabilities_dict: - probabilities = [float(no) for no in columns[3+ 33*8*4:]] - probabilities_dict[key] = probabilities - else: - probabilities_from_input = [float(no) for no in columns[3 + 33*8*4:]] - - probabilities = list.copy(probabilities_dict[key]) - for index, probability in enumerate(probabilities): - probabilities[index] = probability + probabilities_from_input[index] - - probabilities_dict[key] = probabilities - - return counter, sequence_dict, tensor_dict, probabilities_dict - - -def output_with( - counter, - sequence_dict, - tensor_dict, - probabilities_dict, - ensemble_config, -): - minimum_count_to_output = ensemble_config.minimum_count_to_output - - for key, count in counter.items(): - if count < minimum_count_to_output: - continue - - chromosome, position = key - sequence = sequence_dict[key] - tensor = tensor_dict[key] - probabilities = probabilities_dict[key] - - tensor_str = "\t".join([str(int_value) for int_value in tensor]) - probabilities_str = "\t".join(["{:.6f}".format(probability / count) for probability in probabilities]) - - print("\t".join([ - chromosome, - position, - sequence, - tensor_str, - probabilities_str, - ])) - - -def run_pipeline(ensemble_config): - counter, sequence_dict, tensor_dict, probabilities_dict = dicts_from_stdin() - - output_with( - counter, - sequence_dict, - tensor_dict, - probabilities_dict, - ensemble_config, - ) - - -def main(): - parser = ArgumentParser(description="Call variants using a trained model and tensors of candididate variants") - - parser.add_argument('--minimum_count_to_output', type=int, default=0, - help="minimum # of calls to output the probabilities") - - args = parser.parse_args() - - if len(argv[1:]) == 0: - parser.print_help() - exit(1) - - ensemble_config = EnsembleConfig( - minimum_count_to_output=args.minimum_count_to_output - ) - run_pipeline(ensemble_config=ensemble_config) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/post_processing/overlap_variant.py b/benchmarks/nn-variant/clair/post_processing/overlap_variant.py deleted file mode 100644 index 4cfd2fc..0000000 --- a/benchmarks/nn-variant/clair/post_processing/overlap_variant.py +++ /dev/null @@ -1,285 +0,0 @@ -from sys import stdin, stderr -from collections import namedtuple - -Variant = namedtuple('Variant', [ - 'chromosome', - 'position', - 'reference_base', - 'alternate_base', - 'alternate_base_multi', - 'quality_score', - 'genotype', - 'depth', - 'allele_frequency', -]) - -VariantIntervals = namedtuple('VariantIntervals', [ - 'snp_interval', - 'deletion_interval', - 'insertion_intervals', -]) - - -EMPTY_INTERVAL = (-1, -1) - - -DEBUG_OVERLAPPED_VARIANT = False - - -def maximum_deletion_length_of(variant): - return len(variant.reference_base) - min( - len(variant.alternate_base), - 1024 if variant.alternate_base_multi is None else len(variant.alternate_base_multi), - ) - - -def snp_interval_from(variant): - # need to handle the case like [ACGT]Del / [ACGT]Ins - is_snp = ( - len(variant.reference_base) == len(variant.alternate_base) or - ( - False if variant.alternate_base_multi is None else len( - variant.reference_base) == len(variant.alternate_base_multi) - ) - ) - return EMPTY_INTERVAL if not is_snp else (variant.position - 1, variant.position) - - -def deletion_interval_from(variant): - maximum_deletion_length = maximum_deletion_length_of(variant) - is_deletion = maximum_deletion_length > 0 - - return EMPTY_INTERVAL if not is_deletion else (variant.position - 1, variant.position + maximum_deletion_length) - - -def insertion_intervals_from(variant): - insertion_intervals = [] - - if len(variant.alternate_base) > len(variant.reference_base): - insertion_intervals.append( - ( - variant.position - 1, - variant.position + len(variant.alternate_base) - len(variant.reference_base) - ) - ) - else: - insertion_intervals.append(EMPTY_INTERVAL) - - if ( - variant.alternate_base_multi is not None and - len(variant.alternate_base_multi) > len(variant.reference_base) - ): - insertion_intervals.append( - ( - variant.position - 1, - variant.position + len(variant.alternate_base_multi) - len(variant.reference_base) - ) - ) - else: - insertion_intervals.append(EMPTY_INTERVAL) - - return insertion_intervals - - -# all intervals is suppose to be zero-base and [start, end) half open interval -def variant_intervals_from(variant): - return VariantIntervals( - snp_interval=snp_interval_from(variant), - deletion_interval=deletion_interval_from(variant), - insertion_intervals=insertion_intervals_from(variant), - ) - - -def is_two_intervals_overlap(interval1, interval2): - if interval1 is EMPTY_INTERVAL or interval2 is EMPTY_INTERVAL: - return False - - begin1, end1 = interval1 - begin2, _ = interval2 - # return begin1 <= begin2 <= end1 or begin2 <= end1 <= end2 - return begin1 <= begin2 < end1 - - -def is_two_intervals_overlap_for_ins_snp(insertion_interval, snp_interval): - if insertion_interval is EMPTY_INTERVAL or snp_interval is EMPTY_INTERVAL: - return False - - insert_begin, insert_end = insertion_interval - _, snp_end = snp_interval - return insert_end - insert_begin == 2 and insert_end == snp_end - - -# for insertion intervals overlap, current implementation needs with the same ending position -def is_two_intervals_overlap_for_ins_ins(interval1, interval2): - if interval1 is EMPTY_INTERVAL or interval2 is EMPTY_INTERVAL: - return False - - _, end1 = interval1 - _, end2 = interval2 - return end1 == end2 - - -def is_two_variants_overlap(variant1, variant2): - if variant1.chromosome != variant2.chromosome: - return False - if variant1.position > variant2.position: - return is_two_variants_overlap(variant2, variant1) - - intervals_1 = variant_intervals_from(variant1) - intervals_2 = variant_intervals_from(variant2) - - # return ( - # is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.snp_interval) or - # is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.deletion_interval) or - # is_two_intervals_overlap_for_ins_snp(intervals_1.insertion_intervals[0], intervals_2.snp_interval) or - # is_two_intervals_overlap_for_ins_snp(intervals_1.insertion_intervals[1], intervals_2.snp_interval) or - # is_two_intervals_overlap_for_ins_ins(intervals_1.insertion_intervals[0], intervals_2.insertion_intervals[0]) or - # is_two_intervals_overlap_for_ins_ins(intervals_1.insertion_intervals[0], intervals_2.insertion_intervals[1]) or - # is_two_intervals_overlap_for_ins_ins(intervals_1.insertion_intervals[1], intervals_2.insertion_intervals[0]) or - # is_two_intervals_overlap_for_ins_ins(intervals_1.insertion_intervals[1], intervals_2.insertion_intervals[1]) - # ) - - # return ( - # is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.snp_interval) or - # is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.deletion_interval) or - # is_two_intervals_overlap_for_ins_snp(intervals_1.insertion_intervals[0], intervals_2.snp_interval) or - # is_two_intervals_overlap_for_ins_snp(intervals_1.insertion_intervals[1], intervals_2.snp_interval) - # ) - - return ( - is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.snp_interval) or - is_two_intervals_overlap(intervals_1.deletion_interval, intervals_2.deletion_interval) - ) - - -def variant_from(variant_row): - if variant_row[0] == "#": - return - - columns = str(variant_row).split("\t") - chromosome = columns[0] - position = int(columns[1]) - - reference_base = columns[3] - alternates = columns[4].split(",") - alternate_base = alternates[0] - alternate_base_multi = None if len(alternates) == 1 else alternates[1] - - quality_score = int(float(columns[5])) - - last_column = columns[-1] - last_columns = last_column.split(":") - genotype = last_columns[0] - depth = last_columns[2] - allele_frequency = last_columns[3] - - return Variant( - chromosome=chromosome, - position=position, - reference_base=reference_base, - alternate_base=alternate_base, - alternate_base_multi=alternate_base_multi, - quality_score=quality_score, - genotype=genotype, - depth=depth, - allele_frequency=allele_frequency, - ) - - -def variant_row_from(variant): - alternates = ",".join( - [variant.alternate_base] + - ([] if variant.alternate_base_multi is None else [variant.alternate_base_multi]) - ) - quality_score_str = str(variant.quality_score) - last_column = ":".join([ - variant.genotype, - quality_score_str, - variant.depth, - variant.allele_frequency, - ]) - - return "\t".join([ - variant.chromosome, - str(variant.position), - ".", - variant.reference_base, - alternates, - str(variant.quality_score), - ".", - ".", - "GT:GQ:DP:AF", - last_column, - ]) - - -def header_and_variant_rows_from_stdin(): - header_rows = [] - variant_rows = [] - for row in stdin.readlines(): - if row[0] == "#": - header_rows.append(row[:-1]) - else: - variant_rows.append(row[:-1]) - - return header_rows, variant_rows - - -def variant_to_output_for(variant1, variant2): - # return variant1 if variant1.quality_score > variant2.quality_score else variant2 - score1 = variant1.quality_score - score2 = variant2.quality_score - # score1 = variant1.quality_score * float(variant1.allele_frequency) - # score2 = variant2.quality_score * float(variant2.allele_frequency) - return variant1 if score1 > score2 else variant2 - - -def filter_variants_with(variants): - filtered_variants = [] - - overlapped_variants_count = 0 - - for variant in variants: - if len(filtered_variants) == 0: - filtered_variants.append(variant) - continue - - last_variant = filtered_variants[-1] - if not is_two_variants_overlap(last_variant, variant): - filtered_variants.append(variant) - continue - - if DEBUG_OVERLAPPED_VARIANT: - overlapped_variants_count += 1 - print("\n[INFO] variants overlapped.", file=stderr) - print(variant_row_from(last_variant), file=stderr) - print(variant_row_from(variant), file=stderr) - - # variant_to_append = last_variant if last_variant.quality_score >= variant.quality_score else variant - variant_to_append = variant_to_output_for(last_variant, variant) - if variant_to_append != last_variant: - filtered_variants.pop() - filtered_variants.append(variant) - - if DEBUG_OVERLAPPED_VARIANT: - print("[INFO] {} variants overlapped.".format(overlapped_variants_count), file=stderr) - - return filtered_variants - - -def output(header_rows, variants): - for header_row in header_rows: - print(header_row) - for variant in variants: - print(variant_row_from(variant)) - - -def main(): - header_rows, variant_rows = header_and_variant_rows_from_stdin() - variants = [variant_from(variant_row) for variant_row in variant_rows] - filtered_variants = filter_variants_with(variants) - output(header_rows, filtered_variants) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/selu.py b/benchmarks/nn-variant/clair/selu.py deleted file mode 100644 index 3e6d696..0000000 --- a/benchmarks/nn-variant/clair/selu.py +++ /dev/null @@ -1,74 +0,0 @@ -''' -Tensorflow Implementation of the Scaled ELU function and Dropout -''' -import warnings -with warnings.catch_warnings(): - warnings.filterwarnings('ignore', category=DeprecationWarning) - warnings.filterwarnings("ignore", category=FutureWarning) - from tensorflow.python.util import deprecation - deprecation._PRINT_DEPRECATION_WARNINGS = False - import tensorflow as tf - from tensorflow.contrib import layers - from tensorflow.python.framework import ops - from tensorflow.python.framework import tensor_shape - from tensorflow.python.framework import tensor_util - from tensorflow.python.ops import math_ops - from tensorflow.python.ops import random_ops - from tensorflow.python.ops import array_ops - from tensorflow.contrib.layers.python.layers import utils -import numbers - - -# (1) scale inputs to zero mean and unit variance - - -# (2) use SELUs -def selu(x): - with ops.name_scope('elu') as scope: - alpha = 1.6732632423543772848170429916717 - scale = 1.0507009873554804934193349852946 - return scale*tf.where(x>=0.0, x, alpha*tf.nn.elu(x)) - - -# (3) initialize weights with stddev sqrt(1/n) -# e.g. use: -initializer = layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN') - - -# (4) use this dropout -def dropout_selu(x, rate, alpha= -1.7580993408473766, fixedPointMean=0.0, fixedPointVar=1.0, - noise_shape=None, seed=None, name=None, training=False): - """Dropout to a value with rescaling.""" - - def dropout_selu_impl(x, rate, alpha, noise_shape, seed, name): - keep_prob = 1.0 - rate - x = ops.convert_to_tensor(x, name="x") - if isinstance(keep_prob, numbers.Real) and not 0 < keep_prob <= 1: - raise ValueError("keep_prob must be a scalar tensor or a float in the " - "range (0, 1], got %g" % keep_prob) - keep_prob = ops.convert_to_tensor(keep_prob, dtype=x.dtype, name="keep_prob") - keep_prob.get_shape().assert_is_compatible_with(tensor_shape.scalar()) - - alpha = ops.convert_to_tensor(alpha, dtype=x.dtype, name="alpha") - alpha.get_shape().assert_is_compatible_with(tensor_shape.scalar()) - - if tensor_util.constant_value(keep_prob) == 1: - return x - - noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) - random_tensor = keep_prob - random_tensor += random_ops.random_uniform(noise_shape, seed=seed, dtype=x.dtype) - binary_tensor = math_ops.floor(random_tensor) - ret = x * binary_tensor + alpha * (1-binary_tensor) - - a = math_ops.sqrt(fixedPointVar / (keep_prob *((1-keep_prob) * math_ops.pow(alpha-fixedPointMean,2) + fixedPointVar))) - - b = fixedPointMean - a * (keep_prob * fixedPointMean + (1 - keep_prob) * alpha) - ret = a * ret + b - ret.set_shape(x.get_shape()) - return ret - - with ops.name_scope(name, "dropout", [x]) as name: - return utils.smart_cond(training, - lambda: dropout_selu_impl(x, rate, alpha, noise_shape, seed, name), - lambda: array_ops.identity(x)) diff --git a/benchmarks/nn-variant/clair/task/__pycache__/__init__.cpython-37.pyc b/benchmarks/nn-variant/clair/task/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index c78079f..0000000 Binary files a/benchmarks/nn-variant/clair/task/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/task/__pycache__/genotype.cpython-37.pyc b/benchmarks/nn-variant/clair/task/__pycache__/genotype.cpython-37.pyc deleted file mode 100644 index 1dbdc86..0000000 Binary files a/benchmarks/nn-variant/clair/task/__pycache__/genotype.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/task/__pycache__/gt21.cpython-37.pyc b/benchmarks/nn-variant/clair/task/__pycache__/gt21.cpython-37.pyc deleted file mode 100644 index c58fe9a..0000000 Binary files a/benchmarks/nn-variant/clair/task/__pycache__/gt21.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/task/__pycache__/main.cpython-37.pyc b/benchmarks/nn-variant/clair/task/__pycache__/main.cpython-37.pyc deleted file mode 100644 index e5d0898..0000000 Binary files a/benchmarks/nn-variant/clair/task/__pycache__/main.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/task/__pycache__/variant_length.cpython-37.pyc b/benchmarks/nn-variant/clair/task/__pycache__/variant_length.cpython-37.pyc deleted file mode 100644 index d76741e..0000000 Binary files a/benchmarks/nn-variant/clair/task/__pycache__/variant_length.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/clair/train.py b/benchmarks/nn-variant/clair/train.py deleted file mode 100644 index 123ea58..0000000 --- a/benchmarks/nn-variant/clair/train.py +++ /dev/null @@ -1,376 +0,0 @@ -import sys -import os -import logging -import random -import numpy as np -from time import time -from argparse import ArgumentParser -from threading import Thread - -from clair.model import Clair -import clair.utils as utils -import clair.evaluate as evaluate -import shared.param as param - -logging.basicConfig(format='%(message)s', level=logging.INFO) - - -def is_last_five_epoch_approaches_minimum(validation_losses): - if len(validation_losses) <= 5: - return True - - minimum_validation_loss = min(np.asarray(validation_losses)[:, 0]) - return ( - validation_losses[-5][0] == minimum_validation_loss or - validation_losses[-4][0] == minimum_validation_loss or - validation_losses[-3][0] == minimum_validation_loss or - validation_losses[-2][0] == minimum_validation_loss or - validation_losses[-1][0] == minimum_validation_loss - ) - - -def is_validation_loss_goes_up_and_down(validation_losses): - if len(validation_losses) <= 6: - return False - - return ( - validation_losses[-6][0] > validation_losses[-5][0] and - validation_losses[-5][0] < validation_losses[-4][0] and - validation_losses[-4][0] > validation_losses[-3][0] and - validation_losses[-3][0] < validation_losses[-2][0] and - validation_losses[-2][0] > validation_losses[-1][0] - ) or ( - validation_losses[-6][0] < validation_losses[-5][0] and - validation_losses[-5][0] > validation_losses[-4][0] and - validation_losses[-4][0] < validation_losses[-3][0] and - validation_losses[-3][0] > validation_losses[-2][0] and - validation_losses[-2][0] < validation_losses[-1][0] - ) - - -def is_validation_losses_keep_increasing(validation_losses): - if len(validation_losses) <= 6: - return False - - minimum_validation_loss = min(np.asarray(validation_losses)[:, 0]) - return ( - validation_losses[-5][0] > minimum_validation_loss and - validation_losses[-4][0] > minimum_validation_loss and - validation_losses[-3][0] > minimum_validation_loss and - validation_losses[-2][0] > minimum_validation_loss and - validation_losses[-1][0] > minimum_validation_loss - ) - - -def shuffle_first_n_items(array, n): - """ - Shuffle first n items on given array. - """ - if len(array) <= n: - np.random.shuffle(array) - return array - # pylint: disable=unbalanced-tuple-unpacking - a1, a2 = np.split(array, [n]) - np.random.shuffle(a1) - return np.append(a1, a2) - - -def train_model(m, training_config): - learning_rate = training_config.learning_rate - l2_regularization_lambda = training_config.l2_regularization_lambda - output_file_path_prefix = training_config.output_file_path_prefix - summary_writer = training_config.summary_writer - model_initalization_file_path = training_config.model_initalization_file_path - - dataset_info = training_config.dataset_info - dataset_size = dataset_info.dataset_size - - training_losses = [] - validation_losses = [] - - if model_initalization_file_path is not None: - m.restore_parameters(os.path.abspath(model_initalization_file_path)) - - logging.info("[INFO] Start training...") - logging.info("[INFO] Learning rate: %.2e" % m.set_learning_rate(learning_rate)) - logging.info("[INFO] L2 regularization lambda: %.2e" % m.set_l2_regularization_lambda(l2_regularization_lambda)) - - # Model Constants - training_start_time = time() - learning_rate_switch_count = param.maxLearningRateSwitch - no_of_training_examples = ( - dataset_info.no_of_training_examples_from_train_binary or int(dataset_size * param.trainingDatasetPercentage) - ) - no_of_validation_examples = dataset_info.dataset_size - no_of_training_examples - no_of_blosc_blocks = utils.no_of_blosc_blocks_from( - dataset_info=dataset_info, - no_of_training_examples=no_of_training_examples, - blosc_block_size=param.bloscBlockSize - ) - no_of_training_blosc_blocks = int(no_of_training_examples / param.bloscBlockSize) - tensor_block_index_list = np.arange(no_of_blosc_blocks, dtype=int) - - # Initialize variables - epoch_count = 1 - if model_initalization_file_path is not None: - epoch_count = int(model_initalization_file_path[-param.parameterOutputPlaceHolder:]) + 1 - - epoch_start_time = time() - training_loss_sum = 0 - validation_loss_sum = 0 - no_of_epochs_with_current_learning_rate = 0 # Variables for learning rate decay - data_index = 0 - blosc_index = 0 - first_blosc_block_data_index = 0 - x_batch = None - y_batch = None - - gt21_loss_sum = 0 - genotype_loss_sum = 0 - indel_length_loss_sum_1 = 0 - indel_length_loss_sum_2 = 0 - l2_loss_sum = 0 - - while True: - is_training = data_index < no_of_training_examples - is_validation = not is_training - is_with_batch_data = x_batch is not None and y_batch is not None - # logging.info("{} {} {} {} {}".format("TRAIN" if is_training else "VALID", data_index, first_blosc_block_data_index, blosc_index, no_of_training_examples)) - - # threads for either train or validation - thread_pool = [] - if is_with_batch_data and is_training: - thread_pool.append(Thread(target=m.train, args=(x_batch, y_batch))) - elif is_with_batch_data and is_validation: - thread_pool.append(Thread(target=m.validate, args=(x_batch, y_batch))) - for t in thread_pool: - t.start() - - next_x_batch, next_y_batch, next_first_blosc_block_data_index, next_blosc_start_index = utils.new_mini_batch( - data_index=data_index, - blosc_start_index=blosc_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_training_examples=no_of_training_examples, - no_of_blosc_blocks=no_of_blosc_blocks, - dataset_info=dataset_info, - tensor_block_index_list=tensor_block_index_list, - ) - - # wait until loaded next mini batch & finished training/validation with current mini batch - for t in thread_pool: - t.join() - - # add training loss or validation loss - if is_with_batch_data and is_training: - training_loss_sum += m.training_loss_on_one_batch - if summary_writer is not None: - summary = m.training_summary_on_one_batch - summary_writer.add_summary(summary, epoch_count) - elif is_with_batch_data and is_validation: - validation_loss_sum += m.validation_loss_on_one_batch - - gt21_loss_sum += m.gt21_loss - genotype_loss_sum += m.genotype_loss - indel_length_loss_sum_1 += m.indel_length_loss_1 - indel_length_loss_sum_2 += m.indel_length_loss_2 - l2_loss_sum += m.l2_loss - - batch_size = np.shape(next_x_batch)[0] - data_index += batch_size - blosc_index = next_blosc_start_index - first_blosc_block_data_index = next_first_blosc_block_data_index - - # if not go through whole dataset yet, continue the process - if next_first_blosc_block_data_index >= 0 and next_blosc_start_index >= 0: - x_batch = next_x_batch - y_batch = next_y_batch - continue - - # logging.info("{} {} {} {} {}".format("END", data_index, first_blosc_block_data_index, blosc_index, no_of_training_examples)) - - logging.info( - " ".join([str(epoch_count), "Training loss:", str(training_loss_sum/no_of_training_examples)]) - ) - logging.info( - "\t".join([ - "{} Validation loss (Total/Base/Genotype/Indel_1_2):".format(epoch_count), - str(validation_loss_sum/no_of_validation_examples), - str(gt21_loss_sum/no_of_validation_examples), - str(genotype_loss_sum/no_of_validation_examples), - str(indel_length_loss_sum_1/no_of_validation_examples), - str(indel_length_loss_sum_2/no_of_validation_examples) - ]) - ) - - logging.info("[INFO] Epoch time elapsed: %.2f s" % (time() - epoch_start_time)) - training_losses.append((training_loss_sum, epoch_count)) - validation_losses.append((validation_loss_sum, epoch_count)) - - # Output the model - if output_file_path_prefix is not None: - parameter_output_path = "%s-%%0%dd" % (output_file_path_prefix, param.parameterOutputPlaceHolder) - m.save_parameters(os.path.abspath(parameter_output_path % epoch_count)) - - # Adaptive learning rate decay - no_of_epochs_with_current_learning_rate += 1 - - need_learning_rate_update = ( - ( - no_of_epochs_with_current_learning_rate >= 6 and - not is_last_five_epoch_approaches_minimum(validation_losses) and - is_validation_loss_goes_up_and_down(validation_losses) - ) or - ( - no_of_epochs_with_current_learning_rate >= 8 and - is_validation_losses_keep_increasing(validation_losses) - ) - ) - - if need_learning_rate_update: - learning_rate_switch_count -= 1 - if learning_rate_switch_count == 0: - break - logging.info("[INFO] New learning rate: %.2e" % m.decay_learning_rate()) - logging.info("[INFO] New L2 regularization lambda: %.2e" % m.decay_l2_regularization_lambda()) - no_of_epochs_with_current_learning_rate = 0 - - # variables update per epoch - epoch_count += 1 - - epoch_start_time = time() - training_loss_sum = 0 - validation_loss_sum = 0 - data_index = 0 - blosc_index = 0 - first_blosc_block_data_index = 0 - x_batch = None - y_batch = None - - gt21_loss_sum = 0 - genotype_loss_sum = 0 - indel_length_loss_sum_1 = 0 - indel_length_loss_sum_2 = 0 - l2_loss_sum = 0 - - # shuffle data on each epoch - tensor_block_index_list = shuffle_first_n_items(tensor_block_index_list, no_of_training_blosc_blocks) - logging.info("[INFO] Shuffled: " + ' '.join( - [str(x) for x in np.append(tensor_block_index_list[:5], tensor_block_index_list[-5:])] - )) - - logging.info("[INFO] Training time elapsed: %.2f s" % (time() - training_start_time)) - - return training_losses, validation_losses - - -def main(): - random.seed(param.RANDOM_SEED) - np.random.seed(param.RANDOM_SEED) - - parser = ArgumentParser(description="Train model") - - # optimizer - parser.add_argument('--SGDM', action='store_true', - help="Use Stochastic Gradient Descent with momentum as optimizer") - parser.add_argument('--Adam', action='store_true', - help="Use Adam as optimizer") - - # loss function - parser.add_argument('--cross_entropy', action='store_true', - help="Use Cross Entropy as loss function") - parser.add_argument('--focal_loss', action='store_true', - help="Use Focal Loss as loss function") - - # binary file path - parser.add_argument('--bin_fn', type=str, default=None, - help="Binary tensor input generated by tensor2Bin.py, tensor_fn, var_fn and bed_fn will be ignored") - parser.add_argument('--train_bin_fn', type=str, default=None, - help="Train Binary, used together with --validation_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - parser.add_argument('--validation_bin_fn', type=str, default=None, - help="Validation Binary, used together with --train_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - - # tensor file path - parser.add_argument('--tensor_fn', type=str, default="vartensors", help="Tensor input") - - # variant file path - parser.add_argument('--var_fn', type=str, default="truthvars", help="Truth variants list input") - - # bed file path - parser.add_argument('--bed_fn', type=str, default=None, - help="High confident genome regions input in the BED format") - - # checkpoint file path - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing or continue training") - - # learning rate, with default value stated in param - parser.add_argument('--learning_rate', type=float, default=param.initialLearningRate, - help="Set the initial learning rate, default: %(default)s") - - # l2 regularization - parser.add_argument('--lambd', type=float, default=param.l2RegularizationLambda, - help="Set the l2 regularization lambda, default: %(default)s") - - # output checkpint file path prefix - parser.add_argument('--ochk_prefix', type=str, default=None, - help="Prefix for checkpoint outputs at each learning rate change, REQUIRED") - - parser.add_argument('--olog_dir', type=str, default=None, - help="Directory for tensorboard log outputs, optional") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - # initialize - logging.info("[INFO] Initializing") - utils.setup_environment() - - optimizer = "SGDM" if args.SGDM else ("Adam" if args.Adam else param.default_optimizer) - loss_function = ( - "FocalLoss" if args.focal_loss else ("CrossEntropy" if args.cross_entropy else param.default_loss_function) - ) - logging.info("[INFO] Optimizer: {}".format(optimizer)) - logging.info("[INFO] Loss Function: {}".format(loss_function)) - - m = Clair( - optimizer_name=optimizer, - loss_function=loss_function - ) - m.init() - - dataset_info = utils.dataset_info_from( - binary_file_path=args.bin_fn, - tensor_file_path=args.tensor_fn, - variant_file_path=args.var_fn, - bed_file_path=args.bed_fn, - train_binary_file_path=args.train_bin_fn, - validation_binary_file_path=args.validation_bin_fn, - ) - training_config = utils.TrainingConfig( - dataset_info=dataset_info, - learning_rate=args.learning_rate, - l2_regularization_lambda=args.lambd, - output_file_path_prefix=args.ochk_prefix, - model_initalization_file_path=args.chkpnt_fn, - summary_writer=m.get_summary_file_writer(args.olog_dir) if args.olog_dir != None else None, - ) - - _training_losses, validation_losses = train_model(m, training_config) - - # show the parameter set with the smallest validation loss - validation_losses.sort() - best_validation_epoch = validation_losses[0][1] - logging.info("[INFO] Best validation loss at epoch: %d" % best_validation_epoch) - - # load best validation model and evaluate it - model_file_path = "%s-%%0%dd" % (training_config.output_file_path_prefix, param.parameterOutputPlaceHolder) - best_validation_model_file_path = model_file_path % best_validation_epoch - m.restore_parameters(os.path.abspath(best_validation_model_file_path)) - evaluate.evaluate_model(m, dataset_info) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/train_clr.py b/benchmarks/nn-variant/clair/train_clr.py deleted file mode 100644 index 05e64ee..0000000 --- a/benchmarks/nn-variant/clair/train_clr.py +++ /dev/null @@ -1,312 +0,0 @@ -import sys -import os -import logging -import random -import numpy as np -from time import time -from argparse import ArgumentParser -from threading import Thread - -from clair.model import Clair -import clair.utils as utils -import clair.evaluate as evaluate -import shared.param as param - -logging.basicConfig(format='%(message)s', level=logging.INFO) - - -def shuffle_first_n_items(array, n): - if len(array) <= n: - np.random.shuffle(array) - return array - # pylint: disable=unbalanced-tuple-unpacking - a1, a2 = np.split(array, [n]) - np.random.shuffle(a1) - return np.append(a1, a2) - - -def train_model(m, training_config, clr_mode): - learning_rate = training_config.learning_rate - max_learning_rate = param.clr_max_lr - l2_regularization_lambda = training_config.l2_regularization_lambda - output_file_path_prefix = training_config.output_file_path_prefix - summary_writer = training_config.summary_writer - model_initalization_file_path = training_config.model_initalization_file_path - - dataset_info = training_config.dataset_info - dataset_size = dataset_info.dataset_size - - training_losses = [] - validation_losses = [] - - if model_initalization_file_path is not None: - m.restore_parameters(os.path.abspath(model_initalization_file_path)) - - logging.info("[INFO] Start training...") - logging.info("[INFO] Learning rate: %.2e" % m.set_learning_rate(learning_rate)) - logging.info("[INFO] L2 regularization lambda: %.2e" % m.set_l2_regularization_lambda(l2_regularization_lambda)) - - # Model Constants - training_start_time = time() - no_of_training_examples = ( - dataset_info.no_of_training_examples_from_train_binary or int(dataset_size * param.trainingDatasetPercentage) - ) - no_of_validation_examples = dataset_info.dataset_size - no_of_training_examples - no_of_blosc_blocks = utils.no_of_blosc_blocks_from( - dataset_info=dataset_info, - no_of_training_examples=no_of_training_examples, - blosc_block_size=param.bloscBlockSize - ) - no_of_training_blosc_blocks = int(no_of_training_examples / param.bloscBlockSize) - tensor_block_index_list = np.arange(no_of_blosc_blocks, dtype=int) - - total_numbers_of_iterations = np.ceil(no_of_training_examples / param.trainBatchSize+1) + \ - np.ceil(no_of_validation_examples/param.predictBatchSize+1) - step_size = param.stepsizeConstant * total_numbers_of_iterations - - # Initialize variables - epoch_count = 1 - if model_initalization_file_path != None: - epoch_count = int(model_initalization_file_path[-param.parameterOutputPlaceHolder:])+1 - - epoch_start_time = time() - training_loss_sum = 0 - validation_loss_sum = 0 - data_index = 0 - blosc_index = 0 - first_blosc_block_data_index = 0 - x_batch = None - y_batch = None - global_step = 0 - - gt21_loss_sum = 0 - genotype_loss_sum = 0 - indel_length_loss_sum_1 = 0 - indel_length_loss_sum_2 = 0 - l2_loss_sum = 0 - - while epoch_count <= param.maxEpoch: - is_training = data_index < no_of_training_examples - is_validation = not is_training - is_with_batch_data = x_batch is not None and y_batch is not None - # logging.info("{} {} {} {} {}".format("TRAIN" if is_training else "VALID", data_index, first_blosc_block_data_index, blosc_index, no_of_training_examples)) - - # threads for either train or validation - thread_pool = [] - if is_with_batch_data and is_training: - thread_pool.append(Thread(target=m.train, args=(x_batch, y_batch))) - elif is_with_batch_data and is_validation: - thread_pool.append(Thread(target=m.validate, args=(x_batch, y_batch))) - for t in thread_pool: - t.start() - - next_x_batch, next_y_batch, next_first_blosc_block_data_index, next_blosc_start_index = utils.new_mini_batch( - data_index=data_index, - blosc_start_index=blosc_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_training_examples=no_of_training_examples, - no_of_blosc_blocks=no_of_blosc_blocks, - dataset_info=dataset_info, - tensor_block_index_list=tensor_block_index_list, - ) - - # wait until loaded next mini batch & finished training/validation with current mini batch - for t in thread_pool: - t.join() - - # add training loss or validation loss - if is_with_batch_data and is_training: - training_loss_sum += m.training_loss_on_one_batch - if summary_writer is not None: - summary = m.training_summary_on_one_batch - summary_writer.add_summary(summary, epoch_count) - elif is_with_batch_data and is_validation: - validation_loss_sum += m.validation_loss_on_one_batch - - gt21_loss_sum += m.gt21_loss - genotype_loss_sum += m.genotype_loss - indel_length_loss_sum_1 += m.indel_length_loss_1 - indel_length_loss_sum_2 += m.indel_length_loss_2 - l2_loss_sum += m.l2_loss - - batch_size = np.shape(next_x_batch)[0] - data_index += batch_size - blosc_index = next_blosc_start_index - first_blosc_block_data_index = next_first_blosc_block_data_index - - # if not go through whole dataset yet, continue the process - if next_first_blosc_block_data_index >= 0 and next_blosc_start_index >= 0: - x_batch = next_x_batch - y_batch = next_y_batch - learning_rate, global_step, max_learning_rate = m.clr( - global_step, step_size, max_learning_rate, clr_mode - ) - continue - - # logging.info("{} {} {} {} {}".format("END", data_index, first_blosc_block_data_index, blosc_index, no_of_training_examples)) - logging.info( - " ".join([str(epoch_count), "Training loss:", str(training_loss_sum/no_of_training_examples)]) - ) - logging.info( - "\t".join([ - "{} Validation loss (Total/Base/Genotype/Indel_1_2):".format(epoch_count), - str(validation_loss_sum/no_of_validation_examples), - str(gt21_loss_sum/no_of_validation_examples), - str(genotype_loss_sum/no_of_validation_examples), - str(indel_length_loss_sum_1/no_of_validation_examples), - str(indel_length_loss_sum_2/no_of_validation_examples) - ]) - ) - - logging.info("[INFO] Epoch time elapsed: %.2f s" % (time() - epoch_start_time)) - training_losses.append((training_loss_sum, epoch_count)) - validation_losses.append((validation_loss_sum, epoch_count)) - - # Output the model - if output_file_path_prefix != None: - parameter_output_path = "%s-%%0%dd" % (output_file_path_prefix, param.parameterOutputPlaceHolder) - m.save_parameters(os.path.abspath(parameter_output_path % epoch_count)) - - # variables update per epoch - epoch_count += 1 - - epoch_start_time = time() - training_loss_sum = 0 - validation_loss_sum = 0 - data_index = 0 - blosc_index = 0 - first_blosc_block_data_index = 0 - x_batch = None - y_batch = None - - gt21_loss_sum = 0 - genotype_loss_sum = 0 - indel_length_loss_sum_1 = 0 - indel_length_loss_sum_2 = 0 - l2_loss_sum = 0 - - # shuffle data on each epoch - tensor_block_index_list = shuffle_first_n_items(tensor_block_index_list, no_of_training_blosc_blocks) - logging.info("[INFO] Shuffled: " + ' '.join( - [str(x) for x in np.append(tensor_block_index_list[:5], tensor_block_index_list[-5:])] - )) - - logging.info("[INFO] Training time elapsed: %.2f s" % (time() - training_start_time)) - return training_losses, validation_losses - - -def main(): - random.seed(param.RANDOM_SEED) - np.random.seed(param.RANDOM_SEED) - - parser = ArgumentParser(description="Train model") - - # clr mode - parser.add_argument('--clr_mode', type=str, default="exp", - help="clr modes: tri, tri2, exp") - - # optimizer - parser.add_argument('--SGDM', action='store_true', - help="Use Stochastic Gradient Descent with momentum as optimizer") - parser.add_argument('--Adam', action='store_true', - help="Use Adam as optimizer") - - # loss function - parser.add_argument('--cross_entropy', action='store_true', - help="Use Cross Entropy as loss function") - parser.add_argument('--focal_loss', action='store_true', - help="Use Focal Loss as loss function") - - # binary file path - parser.add_argument('--bin_fn', type=str, default=None, - help="Binary tensor input generated by tensor2Bin.py, tensor_fn, var_fn and bed_fn will be ignored") - parser.add_argument('--train_bin_fn', type=str, default=None, - help="Train Binary, used together with --validation_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - parser.add_argument('--validation_bin_fn', type=str, default=None, - help="Validation Binary, used together with --train_bin_fn (would ignore: bin_fn, tensor_fn, var_fn, bed_fn)") - - # tensor file path - parser.add_argument('--tensor_fn', type=str, default="vartensors", help="Tensor input") - - # variant file path - parser.add_argument('--var_fn', type=str, default="truthvars", help="Truth variants list input") - - # bed file path - parser.add_argument('--bed_fn', type=str, default=None, - help="High confident genome regions input in the BED format") - - # checkpoint file path - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing or continue training") - - # learning rate, with default value stated in param - parser.add_argument('--learning_rate', type=float, default=param.clr_min_lr, - help="Set the initial learning rate, default: %(default)s") - - # l2 regularization - parser.add_argument('--lambd', type=float, default=param.l2RegularizationLambda, - help="Set the l2 regularization lambda, default: %(default)s") - - # output checkpint file path prefix - parser.add_argument('--ochk_prefix', type=str, default=None, - help="Prefix for checkpoint outputs at each learning rate change, REQUIRED") - - parser.add_argument('--olog_dir', type=str, default=None, - help="Directory for tensorboard log outputs, optional") - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - # initialize - logging.info("[INFO] Initializing") - utils.setup_environment() - - optimizer = "SGDM" if args.SGDM else ("Adam" if args.Adam else param.default_optimizer) - loss_function = ( - "FocalLoss" if args.focal_loss else ("CrossEntropy" if args.cross_entropy else param.default_loss_function) - ) - logging.info("[INFO] Optimizer: {}".format(optimizer)) - logging.info("[INFO] Loss Function: {}".format(loss_function)) - - m = Clair( - optimizer_name=optimizer, - loss_function=loss_function - ) - m.init() - - dataset_info = utils.dataset_info_from( - binary_file_path=args.bin_fn, - tensor_file_path=args.tensor_fn, - variant_file_path=args.var_fn, - bed_file_path=args.bed_fn, - train_binary_file_path=args.train_bin_fn, - validation_binary_file_path=args.validation_bin_fn, - ) - training_config = utils.TrainingConfig( - dataset_info=dataset_info, - learning_rate=args.learning_rate, - l2_regularization_lambda=args.lambd, - output_file_path_prefix=args.ochk_prefix, - model_initalization_file_path=args.chkpnt_fn, - summary_writer=m.get_summary_file_writer(args.olog_dir) if args.olog_dir != None else None, - ) - - _training_losses, validation_losses = train_model(m, training_config, clr_mode=args.clr_mode) - - # show the parameter set with the smallest validation loss - validation_losses.sort() - best_validation_epoch = validation_losses[0][1] - logging.info("[INFO] Best validation loss at epoch: %d" % best_validation_epoch) - - # load best validation model and evaluate it - model_file_path = "%s-%%0%dd" % (training_config.output_file_path_prefix, param.parameterOutputPlaceHolder) - best_validation_model_file_path = model_file_path % best_validation_epoch - m.restore_parameters(os.path.abspath(best_validation_model_file_path)) - evaluate.evaluate_model(m, dataset_info) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/clair/utils.py b/benchmarks/nn-variant/clair/utils.py deleted file mode 100644 index e6560e3..0000000 --- a/benchmarks/nn-variant/clair/utils.py +++ /dev/null @@ -1,377 +0,0 @@ -from __future__ import print_function - -import sys -import gc -import shlex -import logging -import pickle -import numpy as np -import blosc -from os import environ -from enum import IntEnum -from collections import namedtuple - -from clair.task.main import output_labels_from_reference, output_labels_from_vcf_columns -import shared.param as param -from shared.interval_tree import bed_tree_from, is_region_in -from shared.utils import subprocess_popen, IUPAC_base_to_num_dict as BASE2NUM, IUPAC_base_to_ACGT_base_dict as BASE2ACGT, BASIC_BASES - -PREFIX_CHAR_STR = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" - -DatasetInfo = namedtuple('DatasetInfo', [ - 'dataset_size', - 'x_array_compressed', - 'y_array_compressed', - 'position_array_compressed', - 'no_of_training_examples_from_train_binary', - 'is_separated_train_and_validation_binary', -]) -TrainingConfig = namedtuple('TrainingConfig', [ - 'dataset_info', - 'learning_rate', - 'l2_regularization_lambda', - 'output_file_path_prefix', - 'model_initalization_file_path', - 'summary_writer' -]) - - -def setup_environment(): - environ["CXX"] = "g++" - environ['TF_CPP_MIN_LOG_LEVEL'] = '3' - - blosc.set_nthreads(4) - gc.enable() - - -def blosc_pack_array(array): - return blosc.pack_array(array, cname='lz4hc', clevel=9, shuffle=blosc.NOSHUFFLE) - - -def unpack_a_tensor_record(a, b, c, *d): - return a, b, c, np.array(d, dtype=np.float32) - - -def batches_from(iterable, item_from, batch_size=1): - iterable = iter(iterable) - while True: - chunk = [] - for _ in range(batch_size): - try: - chunk.append(item_from(next(iterable))) - except StopIteration: - yield chunk - return - yield chunk - - -no_of_positions, matrix_row, matrix_num = 2 * param.flankingBaseNum + 1, param.matrixRow, param.matrixNum -input_tensor_size = no_of_positions * matrix_row * matrix_num - - -def tensor_generator_from(tensor_file_path, batch_size): - if tensor_file_path != "PIPE": - f = subprocess_popen(shlex.split("gzip -fdc %s" % (tensor_file_path))) - fo = f.stdout - else: - fo = sys.stdin - - processed_tensors = 0 - - def item_from(row): - columns = row.split() - return (columns[:-input_tensor_size], np.array(columns[-input_tensor_size:], dtype=np.float32)) - - for batch in batches_from(fo, item_from=item_from, batch_size=batch_size): - tensors = np.empty((batch_size, input_tensor_size), dtype=np.float32) - non_tensor_infos = [] - for non_tensor_info, tensor in batch: - _, _, sequence = non_tensor_info - if sequence[param.flankingBaseNum] not in BASE2NUM: - continue - tensors[len(non_tensor_infos)] = tensor - non_tensor_infos.append(non_tensor_info) - - current_batch_size = len(non_tensor_infos) - X = np.reshape(tensors, (batch_size, no_of_positions, matrix_row, matrix_num)) - for i in range(1, matrix_num): - X[:current_batch_size, :, :, i] -= X[:current_batch_size, :, :, 0] - - processed_tensors += current_batch_size - print("Processed %d tensors" % processed_tensors, file=sys.stderr) - - if current_batch_size <= 0: - continue - yield X[:current_batch_size], non_tensor_infos[:current_batch_size] - - if tensor_file_path != "PIPE": - fo.close() - f.wait() - - -def variant_map_from(var_fn, tree, is_tree_empty): - Y = {} - if var_fn is None: - return Y - - f = subprocess_popen(shlex.split("gzip -fdc %s" % (var_fn))) - for row in f.stdout: - columns = row.split() - ctg_name, position_str = columns[0], columns[1] - - if not (is_tree_empty or is_region_in(tree, ctg_name, int(position_str))): - continue - - key = ctg_name + ":" + position_str - Y[key] = output_labels_from_vcf_columns(columns) - - f.stdout.close() - f.wait() - return Y - - -def get_training_array(tensor_fn, var_fn, bed_fn, shuffle=True, is_allow_duplicate_chr_pos=False): - tree = bed_tree_from(bed_file_path=bed_fn) - is_tree_empty = len(tree.keys()) == 0 - - Y = variant_map_from(var_fn, tree, is_tree_empty) - - X = {} - f = subprocess_popen(shlex.split("gzip -fdc %s" % (tensor_fn))) - total = 0 - mat = np.empty(input_tensor_size, dtype=np.float32) - for row in f.stdout: - chrom, coord, seq, mat = unpack_a_tensor_record(*(row.split())) - if not (is_tree_empty or is_region_in(tree, chrom, int(coord))): - continue - seq = seq.upper() - if seq[param.flankingBaseNum] not in BASIC_BASES: - continue - key = chrom + ":" + coord - - x = np.reshape(mat, (no_of_positions, matrix_row, matrix_num)) - for i in range(1, matrix_num): - x[:, :, i] -= x[:, :, 0] - - if key not in X: - X[key] = np.copy(x) - elif is_allow_duplicate_chr_pos: - new_key = "" - for character in PREFIX_CHAR_STR: - tmp_key = character + key - if tmp_key not in X: - new_key = tmp_key - break - if len(new_key) > 0: - X[new_key] = np.copy(x) - - is_reference = key not in Y - if is_reference: - Y[key] = output_labels_from_reference(BASE2ACGT[seq[param.flankingBaseNum]]) - - total += 1 - if total % 100000 == 0: - print("Processed %d tensors" % total, file=sys.stderr) - f.stdout.close() - f.wait() - - # print "[INFO] size of X: {}, size of Y: {}".format(len(X), len(Y)) - - all_chr_pos = sorted(X.keys()) - if shuffle == True: - np.random.shuffle(all_chr_pos) - - X_compressed, Y_compressed, pos_compressed = [], [], [] - X_array, Y_array, pos_array = [], [], [] - count = 0 - total = 0 - for key in all_chr_pos: - total += 1 - - X_array.append(X[key]) - del X[key] - - if key in Y: - Y_array.append(Y[key]) - pos_array.append(key) - if not is_allow_duplicate_chr_pos: - del Y[key] - elif is_allow_duplicate_chr_pos: - tmp_key = key[1:] - Y_array.append(Y[tmp_key]) - pos_array.append(tmp_key) - - count += 1 - if count == param.bloscBlockSize: - X_compressed.append(blosc_pack_array(np.array(X_array))) - Y_compressed.append(blosc_pack_array(np.array(Y_array))) - pos_compressed.append(blosc_pack_array(np.array(pos_array))) - X_array, Y_array, pos_array = [], [], [] - count = 0 - - if total % 50000 == 0: - print("Compressed %d/%d tensor" % (total, len(all_chr_pos)), file=sys.stderr) - - if count > 0: - X_compressed.append(blosc_pack_array(np.array(X_array))) - Y_compressed.append(blosc_pack_array(np.array(Y_array))) - pos_compressed.append(blosc_pack_array(np.array(pos_array))) - - return total, X_compressed, Y_compressed, pos_compressed - - -def decompress_array( - array, - blosc_start_index, - first_blosc_block_data_index, - no_of_data_rows_to_retrieve, - no_of_blosc_blocks, - read_index_list=None -): - """ - Return: - data_rows, next_first_blosc_block_data_index and next_blosc_start_index - - Note: - blosc_start_index, next_first_blosc_block_data_index and next_blosc_start_index is inclusive. - """ - data_rows = [] - no_of_data_rows = 0 - for i in range(blosc_start_index, no_of_blosc_blocks): - new_data_rows = blosc.unpack_array(array[i if read_index_list is None else read_index_list[i]]) - data_rows.append(new_data_rows) - no_of_data_rows += len(new_data_rows) - - if i == blosc_start_index and first_blosc_block_data_index > 0: - return np.concatenate(data_rows[:])[first_blosc_block_data_index:], 0, i+1 - - if no_of_data_rows >= no_of_data_rows_to_retrieve: - extra_no_of_data_rows = no_of_data_rows % no_of_data_rows_to_retrieve - next_blosc_start_index = i+1 if extra_no_of_data_rows == 0 else i - next_first_blosc_block_data_index = ( - 0 if extra_no_of_data_rows == 0 else (len(new_data_rows) - extra_no_of_data_rows) - ) - return ( - np.concatenate(data_rows[:])[0:no_of_data_rows_to_retrieve], - next_first_blosc_block_data_index if next_blosc_start_index < no_of_blosc_blocks else -1, - next_blosc_start_index if next_blosc_start_index < no_of_blosc_blocks else -1 - ) - - if no_of_data_rows <= 0: - return None, -1, -1 - return np.concatenate(data_rows[:]), -1, -1 - - -def dataset_info_from( - binary_file_path, - tensor_file_path=None, - variant_file_path=None, - bed_file_path=None, - train_binary_file_path=None, - validation_binary_file_path=None, -): - logging.info("[INFO] Loading dataset...") - no_of_training_examples_from_train_binary = None - - if train_binary_file_path is not None and validation_binary_file_path is not None: - logging.info("[INFO] Loading compressed data from train and validation binary file path") - with open(train_binary_file_path, "rb") as fh: - dataset_size = pickle.load(fh) - x_array_compressed = pickle.load(fh) - y_array_compressed = pickle.load(fh) - position_array_compressed = pickle.load(fh) - no_of_training_examples_from_train_binary = dataset_size - with open(validation_binary_file_path, "rb") as fh: - dataset_size += pickle.load(fh) - x_array_compressed += pickle.load(fh) - y_array_compressed += pickle.load(fh) - position_array_compressed += pickle.load(fh) - - elif binary_file_path != None: - logging.info("[INFO] Loading compressed data from binary file path") - with open(binary_file_path, "rb") as fh: - dataset_size = pickle.load(fh) - x_array_compressed = pickle.load(fh) - y_array_compressed = pickle.load(fh) - position_array_compressed = pickle.load(fh) - else: - logging.info("[INFO] Loading compressed data from utils get training array") - dataset_size, x_array_compressed, y_array_compressed, position_array_compressed = \ - get_training_array(tensor_file_path, variant_file_path, bed_file_path) - - logging.info("[INFO] The size of dataset: {}".format(dataset_size)) - - return DatasetInfo( - dataset_size=dataset_size, - x_array_compressed=x_array_compressed, - y_array_compressed=y_array_compressed, - position_array_compressed=position_array_compressed, - no_of_training_examples_from_train_binary=no_of_training_examples_from_train_binary, - is_separated_train_and_validation_binary=no_of_training_examples_from_train_binary is not None, - ) - - -def new_mini_batch( - data_index, - blosc_start_index, - first_blosc_block_data_index, - no_of_training_examples, - no_of_blosc_blocks, - dataset_info, - tensor_block_index_list -): - """ - Return: - x_batch, y_batch, next_first_blosc_block_data_index, next_blosc_index - """ - if blosc_start_index >= no_of_blosc_blocks: - return None, None, -1, -1 - - x_array_compressed = dataset_info.x_array_compressed - y_array_compressed = dataset_info.y_array_compressed - training_batch_size = param.trainBatchSize - validation_batch_size = param.predictBatchSize - is_training = data_index < no_of_training_examples - is_validation = not is_training - - # calculate new batch size according to dataset index - # train: 0 - validation_data_start_index - 1, validation: validation_data_start_index - dataset_size - if is_training and (no_of_training_examples - data_index) < training_batch_size: - batch_size = no_of_training_examples - data_index - elif is_training: - batch_size = training_batch_size - elif is_validation: - batch_size = validation_batch_size - - def decompress_array_from(array): - return decompress_array( - array=array, - blosc_start_index=blosc_start_index, - first_blosc_block_data_index=first_blosc_block_data_index, - no_of_data_rows_to_retrieve=batch_size, - no_of_blosc_blocks=no_of_blosc_blocks, - read_index_list=tensor_block_index_list - ) - x_batch, next_x_first_blosc_block_data_index, next_x_blosc_index = decompress_array_from(x_array_compressed) - y_batch, _next_y_first_blosc_block_data_index, next_y_blosc_index = decompress_array_from(y_array_compressed) - - x_batch_size, y_batch_size = np.shape(x_batch)[0], np.shape(y_batch)[0] - x_end_flag, y_end_flag = next_x_blosc_index == -1, next_y_blosc_index == -1 - if x_batch_size != y_batch_size or x_end_flag != y_end_flag: - sys.exit("[ERROR] Inconsistency between decompressed arrays: %d/%d" % (x_batch_size, y_batch_size)) - - return x_batch, y_batch, next_x_first_blosc_block_data_index, next_x_blosc_index - - -def no_of_blosc_blocks_from( - dataset_info, - no_of_training_examples, - blosc_block_size, -): - if dataset_info.is_separated_train_and_validation_binary: - no_of_validation_examples = dataset_info.dataset_size - no_of_training_examples - no_of_training_blocks = int(np.ceil(float(no_of_training_examples) / blosc_block_size)) - no_of_validation_blocks = int(np.ceil(float(no_of_validation_examples) / blosc_block_size)) - return no_of_training_blocks + no_of_validation_blocks - - return int(np.ceil(float(dataset_info.dataset_size) / blosc_block_size)) diff --git a/benchmarks/nn-variant/prediction.py b/benchmarks/nn-variant/prediction.py deleted file mode 100644 index 5f6c294..0000000 --- a/benchmarks/nn-variant/prediction.py +++ /dev/null @@ -1,121 +0,0 @@ -import os -import sys -from time import time -import numpy as np -import deepdish as dd -import shared.param as param -from clair.model import Clair -from argparse import ArgumentParser - - -def prediction(args, m): - - print("Begin predicting...") - prediction_output = [] - input_mini_match = dd.io.load(args.input_fn) - output_mini_match = dd.io.load(args.output_fn) - time_counter = {"Load_mini_batch": [], - "Model_prediction": [], - "Write_batch_to_output": []} - - begin_time = time() - for i in range(len(input_mini_match)): - mini_batch = input_mini_match[i] - X, _ = mini_batch - tmp_time = time() - m.predict(X) - cost_time = time() - tmp_time - #print(cost_time) - time_counter["Model_prediction"].append(round(cost_time, 4)) - prediction_output.append(m.prediction) - - end_time = time() - begin_time - - comp = [] - #for i in range(len(input_mini_match)): - # print(prediction_output[i][0], output_mini_match[i][0]) - # comp.append(np.all(np.round(prediction_output[i][0], 3) == np.round(output_mini_match[i][0], 3))) - - #print(comp) - #if False not in comp: - # print("My_prediction function is correct, which takes %.4f s" % end_time) - #else: - # print("My_prediction function is wrong, which takes %.4f s" % end_time) - #dd.io.save("time_counter_my_prediction.h5", time_counter) - print("Time taken: %.4f s" % end_time) - -def Run(args): - - os.environ["OMP_NUM_THREADS"] = "1" - os.environ["OPENBLAS_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - os.environ["MKL_NUM_THREADS"] = "1" - os.environ["NUMEXPR_NUM_THREADS"] = "1" - - if args.threads is None: - if args.tensor_fn == "PIPE": - param.NUM_THREADS = 4 - else: - param.NUM_THREADS = args.threads - param.NUM_THREADS -= 1 - if param.NUM_THREADS < 1: - param.NUM_THREADS = 1 - - m = Clair() - m.init() - m.restore_parameters(os.path.abspath(args.chkpnt_fn)) - - prediction(args, m) - - -def main(): - parser = ArgumentParser(description="Call variants using a trained model and tensors of candididate variants") - - parser.add_argument('--input_fn', type=str, default="prediction_input.h5", - help="input file") - - parser.add_argument('--output_fn', type=str, default="prediction_output.h5", - help="output file") - - parser.add_argument('--tensor_fn', type=str, default="PIPE", - help="Tensor input, use PIPE for standard input") - - parser.add_argument('--chkpnt_fn', type=str, default=None, - help="Input a checkpoint for testing") - - parser.add_argument('--call_fn', type=str, default=None, - help="Output variant predictions") - - parser.add_argument('--bam_fn', type=str, default="bam.bam", - help="BAM file input, default: %(default)s") - - parser.add_argument('--qual', type=int, default=None, - help="If set, variant with equal or higher quality will be marked PASS, or LowQual otherwise, optional") - - parser.add_argument('--sampleName', type=str, default="SAMPLE", - help="Define the sample name to be shown in the VCF file") - - parser.add_argument('--showRef', action='store_true', - help="Show reference calls, optional") - - parser.add_argument('--debug', action='store_true', - help="Debug mode, optional") - - parser.add_argument('--ref_fn', type=str, default=None, - help="Reference fasta file input, optional, print contig tags in the VCF header if set") - - parser.add_argument('--threads', type=int, default=None, - help="Number of threads, optional") - - - args = parser.parse_args() - - if len(sys.argv[1:]) == 0: - parser.print_help() - sys.exit(1) - - Run(args) - - -if __name__ == "__main__": - main() diff --git a/benchmarks/nn-variant/shared/__pycache__/__init__.cpython-37.pyc b/benchmarks/nn-variant/shared/__pycache__/__init__.cpython-37.pyc deleted file mode 100644 index 6c5bc65..0000000 Binary files a/benchmarks/nn-variant/shared/__pycache__/__init__.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/shared/__pycache__/param.cpython-37.pyc b/benchmarks/nn-variant/shared/__pycache__/param.cpython-37.pyc deleted file mode 100644 index f18aaa0..0000000 Binary files a/benchmarks/nn-variant/shared/__pycache__/param.cpython-37.pyc and /dev/null differ diff --git a/benchmarks/nn-variant/shared/interval_tree.py b/benchmarks/nn-variant/shared/interval_tree.py deleted file mode 100644 index 4a9c2fd..0000000 --- a/benchmarks/nn-variant/shared/interval_tree.py +++ /dev/null @@ -1,56 +0,0 @@ -import shlex -from intervaltree import IntervalTree - -from shared.utils import subprocess_popen - - -def bed_tree_from(bed_file_path): - """ - 0-based interval tree [start, end) - """ - - tree = {} - if bed_file_path is None: - return tree - - unzip_process = subprocess_popen(shlex.split("gzip -fdc %s" % (bed_file_path))) - while True: - row = unzip_process.stdout.readline() - is_finish_reading_output = row == '' and unzip_process.poll() is not None - if is_finish_reading_output: - break - - if row: - columns = row.strip().split() - - ctg_name = columns[0] - if ctg_name not in tree: - tree[ctg_name] = IntervalTree() - - ctg_start, ctg_end = int(columns[1]), int(columns[2]) - if ctg_start == ctg_end: - ctg_end += 1 - - tree[ctg_name].addi(ctg_start, ctg_end) - - unzip_process.stdout.close() - unzip_process.wait() - - return tree - - -def is_region_in(tree, contig_name, region_start=None, region_end=None): - if (contig_name is None) or (contig_name not in tree): - return False - - interval_tree = tree[contig_name] - is_interval_tree_version_3 = hasattr(interval_tree, 'at') - if is_interval_tree_version_3: - return len( - interval_tree.at(region_start) - if region_end is None else - interval_tree.overlap(begin=region_start, end=region_end) - ) > 0 - - # interval tree version 2 - return len(interval_tree.search(begin=region_start, end=region_end, strict=False)) > 0 diff --git a/benchmarks/nn-variant/shared/param.py b/benchmarks/nn-variant/shared/param.py deleted file mode 100644 index 329a1c0..0000000 --- a/benchmarks/nn-variant/shared/param.py +++ /dev/null @@ -1,56 +0,0 @@ -REPO_NAME="Clair" - -NUM_THREADS = 12 -parameterOutputPlaceHolder = 6 -expandReferenceRegion = 1000000 -SAMTOOLS_VIEW_FILTER_FLAG = 2316 - -# Tensor related parameters, please use the same values for creating tensor, model training and variant calling -flankingBaseNum = 16 -matrixRow = 8 -matrixNum = 4 -bloscBlockSize = 500 - -# Model hyperparameters -trainBatchSize = 10000 -predictBatchSize = 1000 -initialLearningRate = 1e-3 -learningRateDecay = 0.1 -maxLearningRateSwitch = 3 -trainingDatasetPercentage = 0.9 - -# other hyperparameters -l2RegularizationLambda = 0.005 -l2RegularizationLambdaDecay = 1 -dropoutRateFC4 = 0.5 -dropoutRateFC5 = 0.0 -dropoutRate = 0.05 -default_optimizer = "Adam" # Adam / SGDM -default_loss_function = "FocalLoss" # CrossEntropy / FocalLoss - -# Cyclical learning rate param(s) -clr_max_lr = 3e-2 -clr_min_lr = 1e-4 -stepsizeConstant = 1 -clrGamma = 0.95 -momentum = 0.9 -maxEpoch = 30 - -# Cyclical learning rate finder param(s) -min_lr = 1e-6 -max_lr = 1e-1 -lr_finder_max_epoch = 1 - -# random seed (None to make it random for every run) -# set to None because cuDNN may introduce additional sources of randomness -# https://machinelearningmastery.com/reproducible-results-neural-networks-keras/ -RANDOM_SEED = None -OPERATION_SEED = None - - -def get_model_parameters(): - return dict( - flankingBaseNum=flankingBaseNum, - matrixNum=matrixNum, - expandReferenceRegion=expandReferenceRegion, - ) diff --git a/benchmarks/nn-variant/shared/utils.py b/benchmarks/nn-variant/shared/utils.py deleted file mode 100644 index 12f6139..0000000 --- a/benchmarks/nn-variant/shared/utils.py +++ /dev/null @@ -1,65 +0,0 @@ -from os.path import isfile, abspath -from sys import exit, stderr -from subprocess import check_output, PIPE, Popen - -# A->A -# C->C -# G->G -# T or U->T -# R->A or G -# Y->C or T -# S->G or C -# W->A or T -# K->G or T -# M->A or C -# B->C or G or T -# D->A or G or T -# H->A or C or T -# V->A or C or G -IUPAC_base_to_ACGT_base_dict = dict(zip( - "ACGTURYSWKMBDHVN", - ("A", "C", "G", "T", "T", "A", "C", "C", "A", "G", "A", "C", "A", "A", "A", "A") -)) - -IUPAC_base_to_num_dict = dict(zip( - "ACGTURYSWKMBDHVN", - (0, 1, 2, 3, 3, 0, 1, 1, 0, 2, 0, 1, 0, 0, 0, 0) -)) - -BASIC_BASES = set("ACGTU") - -def is_file_exists(file_name, suffix=""): - if not isinstance(file_name, str) or not isinstance(suffix, str): - return False - return isfile(file_name + suffix) - - -def file_path_from(file_name, suffix="", exit_on_not_found=False): - if is_file_exists(file_name, suffix): - return abspath(file_name) - if exit_on_not_found: - exit("[ERROR] file %s not found" % (file_name + suffix)) - return None - - -def is_command_exists(command): - if not isinstance(command, str): - return False - - try: - check_output("which %s" % (command), shell=True) - return True - except: - return False - - -def executable_command_string_from(command_to_execute, exit_on_not_found=False): - if is_command_exists(command_to_execute): - return command_to_execute - if exit_on_not_found: - exit("[ERROR] %s executable not found" % (command_to_execute)) - return None - - -def subprocess_popen(args, stdin=None, stdout=PIPE, stderr=stderr, bufsize=8388608): - return Popen(args, stdin=stdin, stdout=stdout, stderr=stderr, bufsize=bufsize, universal_newlines=True) diff --git a/benchmarks/wfa/LICENSE b/benchmarks/wfa/LICENSE new file mode 100644 index 0000000..69b13f7 --- /dev/null +++ b/benchmarks/wfa/LICENSE @@ -0,0 +1,23 @@ +MIT License + +Copyright (c) 2017 Santiago Marco-Sola + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + +AUTHOR/CONTACT: Santiago Marco-Sola diff --git a/benchmarks/wfa/Makefile b/benchmarks/wfa/Makefile new file mode 100644 index 0000000..f3a2c29 --- /dev/null +++ b/benchmarks/wfa/Makefile @@ -0,0 +1,80 @@ +############################################################################### +# Flags & Folders +############################################################################### +FOLDER_BIN:=bin +FOLDER_BUILD:=build + +# UNAME=$(shell uname) + +CC=gcc +CXX=g++ + +ifeq ($(arch),sse41) + ARCH_FLAGS=-msse4.1 +else ifeq ($(arch),avx2) + ifeq ($(CXX), icpc) + ARCH_FLAGS=-march=core-avx2 #-xCORE-AVX2 + else + ARCH_FLAGS=-mavx2 + endif +else ifeq ($(arch),avx512) + ifeq ($(CXX), icpc) + ARCH_FLAGS=-xCORE-AVX512 + else + ARCH_FLAGS=-mavx512bw + endif +else ifeq ($(arch),native) + ARCH_FLAGS=-march=native +else ifneq ($(arch),) + ## To provide a different architecture flag like -march=core-avx2. + ARCH_FLAGS=$(arch) +endif + +LD_FLAGS=-lm +CC_FLAGS=-Wall -fopenmp $(ARCH_FLAGS) # -g +# ifeq ($(UNAME), Linux) +# LD_FLAGS+=-lrt +# endif + +AR=ar +AR_FLAGS=-rsc + +############################################################################### +# Compile rules +############################################################################### +SUBDIRS=gap_affine \ + utils + +LIB_WFA=$(FOLDER_BUILD)/libwfa.a + +all: CC_FLAGS+=-O3 +all: MODE=all +all: setup +all: $(SUBDIRS) tools $(LIB_WFA) + +debug: setup +debug: MODE=all +debug: $(SUBDIRS) tools $(LIB_WFA) + +$(LIB_WFA): .FORCE + $(AR) $(AR_FLAGS) $(LIB_WFA) $(FOLDER_BUILD)/*.o 2> /dev/null + +setup: + @mkdir -p $(FOLDER_BIN) $(FOLDER_BUILD) + +clean: + rm -rf $(FOLDER_BIN) $(FOLDER_BUILD) + +############################################################################### +# Subdir rule +############################################################################### +export +$(SUBDIRS): + $(MAKE) --directory=$@ all + +tools: + $(MAKE) --directory=$@ $(MODE) + +.PHONY: $(SUBDIRS) tools +.FORCE: + diff --git a/benchmarks/wfa/README.md b/benchmarks/wfa/README.md new file mode 100644 index 0000000..887bd39 --- /dev/null +++ b/benchmarks/wfa/README.md @@ -0,0 +1,80 @@ +# Wavefront Alignment (WFA) + +## 1. INTRODUCTION + +### 1.1 What is WFA? + +The wavefront alignment (WFA) algorithm is an exact gap-affine algorithm that takes advantage of +homologous regions between the sequences to accelerate the alignment process. As opposed to +traditional dynamic programming algorithms that run in quadratic time, the WFA runs in time O(ns), +proportional to the read length n and the alignment score s, using O(s^2) memory. Moreover, the WFA +exhibits simple data dependencies that can be easily vectorized, even by the automatic features of +modern compilers, for different architectures, without the need to adapt the code. + +This library implements the WFA and the WFA-Adapt algorithms for gap-affine penalties. It also +provides support functions to display and verify the results. Moreover, it implements a benchmarking +tool that serves to evaluate the performance of these two algorithms, together with other +high-performance alignment methods (checkout branch `benchmark`). The library can be executed +through the benchmarking tool for evaluation purposes or can be integrated into your code by calling +the WFA functions. + +If you are interested in benchmarking WFA with other algorithms implemented or integrated into the +WFA library, checkout branch `benchmark`. + +### 1.2 Introduction to benchmarking WFA. Simple tests + +The WFA includes the benchmarking tool *align-benchmark* to test performance of. This tool takes as +input a dataset containing pairs of sequences (i.e., pattern and text) to align. Patterns are +preceded by the '>' symbol and texts by the '<' symbol. Example: + +``` +>ATTGGAAAATAGGATTGGGGTTTGTTTATATTTGGGTTGAGGGATGTCCCACCTTCGTCGTCCTTACGTTTCCGGAAGGGAGTGGTTAGCTCGAAGCCCA +CCGTAGAGTTAGACACTCGACCGTGGTGAATCCGCGACCACCGCTTTGACGGGCGCTCTACGGTATCCCGCGATTTGTGTACGTGAAGCAGTGATTAAAC + ./bin/generate_dataset -n 5000000 -l 100 -e 0.05 -o sample.dataset.seq +``` + +Once you have the dataset ready, you can run the *align-benchmark* tool to benchmark the performance +of a specific pairwise alignment method. For example, the WFA algorithm: + +``` +$> ./bin/align_benchmark -i sample.dataset.seq -a gap-affine-wfa +...processed 10000 reads (benchmark=125804.398 reads/s;alignment=188049.469 reads/s) +...processed 20000 reads (benchmark=117722.406 reads/s;alignment=180925.031 reads/s) +[...] +...processed 5000000 reads (benchmark=113844.039 reads/s;alignment=177325.281 reads/s) +[Benchmark] +=> Total.reads 5000000 +=> Time.Benchmark 43.92 s ( 1 call, 43.92 s/call {min43.92s,Max43.92s}) + => Time.Alignment 28.20 s ( 64.20 %) ( 5 Mcalls, 5.64 us/call {min438ns,Max47.05ms}) +``` + +The *align-benchmark* tool will finish and report overall benchmark time (including reading the +input, setup, checking, etc.) and the time taken by the algorithm (i.e., *Time.Alignment*). + +## 2. AUTHORS + + Santiago Marco-Sola \- santiagomsola@gmail.com + +## 3. REPORTING BUGS + +Feedback and bug reporting it's highly appreciated. +Please report any issue or suggestion on github, or by email to the main developer (santiagomsola@gmail.com). + +## 4. LICENSE + +`WFA` uses the same license as [WFA](https://github.com/smarco/WFA). + +## 5. CITATION + +**Santiago Marco-Sola, Juan Carlos Moure, Miquel Moreto, Antonio Espinosa**. ["Fast gap-affine pairwise alignment using the wavefront algorithm."](https://doi.org/10.1093/bioinformatics/btaa777) Bioinformatics, 2020. diff --git a/benchmarks/wfa/VERSION b/benchmarks/wfa/VERSION new file mode 100644 index 0000000..6b3126c --- /dev/null +++ b/benchmarks/wfa/VERSION @@ -0,0 +1 @@ +v1.0 diff --git a/benchmarks/wfa/gap_affine/Makefile b/benchmarks/wfa/gap_affine/Makefile new file mode 100644 index 0000000..22b9619 --- /dev/null +++ b/benchmarks/wfa/gap_affine/Makefile @@ -0,0 +1,43 @@ +############################################################################### +# Definitions +############################################################################### +FOLDER_ROOT:=.. +FOLDER_BUILD_PATH:=$(FOLDER_ROOT)/$(FOLDER_BUILD) + +############################################################################### +# Modules +############################################################################### +MODULES=affine_wavefront \ + affine_wavefront_align \ + affine_wavefront_backtrace \ + affine_wavefront_extend \ + affine_wavefront_penalties \ + affine_wavefront_reduction \ + affine_wavefront_utils \ + edit_cigar + +SRCS=$(addsuffix .c, $(MODULES)) +OBJS=$(addprefix $(FOLDER_BUILD_PATH)/, $(SRCS:.c=.o)) + +CC_XFLAGS=$(ARCH_FLAGS) + +############################################################################### +# Rules +############################################################################### + +all: $(OBJS) + +$(FOLDER_BUILD_PATH)/affine_wavefront.o : affine_wavefront.c + $(CC) $(CC_FLAGS) $(CC_XFLAGS) -I$(FOLDER_ROOT) -c $< -o $@ + +$(FOLDER_BUILD_PATH)/affine_wavefront_align.o : affine_wavefront_align.c + $(CC) $(CC_FLAGS) $(CC_XFLAGS) -I$(FOLDER_ROOT) -c $< -o $@ + +$(FOLDER_BUILD_PATH)/affine_wavefront_extend.o : affine_wavefront_extend.c + $(CC) $(CC_FLAGS) $(CC_XFLAGS) -I$(FOLDER_ROOT) -c $< -o $@ + +# General building rule +$(FOLDER_BUILD_PATH)/%.o : %.c + $(CC) $(CC_FLAGS) -I$(FOLDER_ROOT) -c $< -o $@ + + diff --git a/benchmarks/wfa/gap_affine/affine_wavefront.c b/benchmarks/wfa/gap_affine/affine_wavefront.c new file mode 100644 index 0000000..1ce68ae --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront.c @@ -0,0 +1,208 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Wavefront alignment algorithm for pairwise gap-affine + * alignment (Main module) + */ + +#include "affine_wavefront.h" +#include "affine_wavefront_backtrace.h" +#include "affine_wavefront_utils.h" + +/* + * Setup + */ +void affine_wavefronts_allocate_wavefront_null( + affine_wavefronts_t* const affine_wavefronts) { + // Allocate null wavefront + const int wavefront_length = affine_wavefronts->pattern_length + affine_wavefronts->text_length + 1; + awf_offset_t* const offsets_null = mm_allocator_calloc( + affine_wavefronts->mm_allocator,wavefront_length,awf_offset_t,false); + // Initialize + affine_wavefronts->wavefront_null.null = true; + affine_wavefronts->wavefront_null.lo = 1; + affine_wavefronts->wavefront_null.hi = -1; + affine_wavefronts->wavefront_null.lo_base = 1; + affine_wavefronts->wavefront_null.hi_base = -1; + affine_wavefronts->wavefront_null.offsets = offsets_null + affine_wavefronts->pattern_length; // Center at k=0 + int i; + for (i=0;imm_allocator; + // Initialize wavefronts + const int num_wavefronts = affine_wavefronts->num_wavefronts; + affine_wavefronts->mwavefronts = + mm_allocator_calloc(mm_allocator,num_wavefronts,affine_wavefront_t*,true); + affine_wavefronts->iwavefronts = + mm_allocator_calloc(mm_allocator,num_wavefronts,affine_wavefront_t*,true); + affine_wavefronts->dwavefronts = + mm_allocator_calloc(mm_allocator,num_wavefronts,affine_wavefront_t*,true); + // Allocate bulk-memory (for all wavefronts) + affine_wavefront_t* const wavefronts_mem = + mm_allocator_calloc(mm_allocator,3*num_wavefronts,affine_wavefront_t,false); + affine_wavefronts->wavefronts_mem = wavefronts_mem; + affine_wavefronts->wavefronts_current = wavefronts_mem; +} +affine_wavefronts_t* affine_wavefronts_new( + const int pattern_length, + const int text_length, + affine_penalties_t* const penalties, + const wavefronts_penalties_strategy penalties_strategy, + mm_allocator_t* const mm_allocator) { + // Handler + affine_wavefronts_t* const affine_wavefronts = mm_allocator_alloc(mm_allocator,affine_wavefronts_t); + // MM + affine_wavefronts->mm_allocator = mm_allocator; + affine_wavefronts->mm_stack = mm_stack_new(BUFFER_SIZE_8M); + // Dimensions + const int max_score_misms = MIN(pattern_length,text_length) * penalties->mismatch; + const int max_score_indel = penalties->gap_opening + ABS(pattern_length-text_length) * penalties->gap_extension; + const int num_wavefronts = max_score_misms + max_score_indel; + affine_wavefronts->pattern_length = pattern_length; + affine_wavefronts->text_length = text_length; + affine_wavefronts->num_wavefronts = num_wavefronts; + affine_wavefronts->max_allocated_wavefront = 0; + // Limits + const int single_gap_penalty = penalties->gap_opening + penalties->gap_extension; + const int max_penalty = MAX(penalties->mismatch,single_gap_penalty); + affine_wavefronts->max_penalty = max_penalty; + // Penalties + affine_wavefronts_penalties_init(&affine_wavefronts->penalties,penalties,penalties_strategy); + // Allocate wavefronts + affine_wavefronts_allocate_wavefront_components(affine_wavefronts); + affine_wavefronts_allocate_wavefront_null(affine_wavefronts); + // CIGAR + edit_cigar_allocate(&affine_wavefronts->edit_cigar,pattern_length,text_length,mm_allocator); + // Return + return affine_wavefronts; +} +void affine_wavefronts_clear( + affine_wavefronts_t* const affine_wavefronts) { + // Clear wavefronts memory + const int num_wavefronts = MIN(affine_wavefronts->max_allocated_wavefront,affine_wavefronts->num_wavefronts); + memset(affine_wavefronts->mwavefronts,0,num_wavefronts*sizeof(affine_wavefront_t*)); + memset(affine_wavefronts->iwavefronts,0,num_wavefronts*sizeof(affine_wavefront_t*)); + memset(affine_wavefronts->dwavefronts,0,num_wavefronts*sizeof(affine_wavefront_t*)); + mm_stack_clear(affine_wavefronts->mm_stack); + // Clear CIGAR + edit_cigar_clear(&affine_wavefronts->edit_cigar); + // Clear wavefronts-ptr + affine_wavefronts->wavefronts_current = affine_wavefronts->wavefronts_mem; +} +void affine_wavefronts_delete( + affine_wavefronts_t* const affine_wavefronts) { + // Parameters + mm_allocator_t* const mm_allocator = affine_wavefronts->mm_allocator; + // Free MID-Wavefronts + mm_allocator_free(mm_allocator,affine_wavefronts->mwavefronts); + mm_allocator_free(mm_allocator,affine_wavefronts->iwavefronts); + mm_allocator_free(mm_allocator,affine_wavefronts->dwavefronts); + mm_allocator_free(mm_allocator,affine_wavefronts->wavefront_null.offsets - affine_wavefronts->pattern_length); + // Free bulk memory + mm_allocator_free(mm_allocator,affine_wavefronts->wavefronts_mem); + // CIGAR + edit_cigar_free(&affine_wavefronts->edit_cigar,mm_allocator); + // MM + mm_stack_delete(affine_wavefronts->mm_stack); + // Handler + mm_allocator_free(mm_allocator,affine_wavefronts); +} +/* + * Setup WF-modes + */ +affine_wavefronts_t* affine_wavefronts_new_complete( + const int pattern_length, + const int text_length, + affine_penalties_t* const penalties, + mm_allocator_t* const mm_allocator) { + // Create new + affine_wavefronts_t* const affine_wavefronts = + affine_wavefronts_new( + pattern_length,text_length, + penalties,wavefronts_penalties_force_zero_match,mm_allocator); + // Limits + affine_wavefronts->max_k = text_length; + affine_wavefronts->min_k = -pattern_length; + // Reduction + affine_wavefronts_reduction_set_none(&affine_wavefronts->reduction); + // Return + return affine_wavefronts; +} +affine_wavefronts_t* affine_wavefronts_new_reduced( + const int pattern_length, + const int text_length, + affine_penalties_t* const penalties, + const int min_wavefront_length, + const int max_distance_threshold, + mm_allocator_t* const mm_allocator) { + // Create new + affine_wavefronts_t* const affine_wavefronts = + affine_wavefronts_new( + pattern_length,text_length, + penalties,wavefronts_penalties_force_zero_match,mm_allocator); + // Limits + affine_wavefronts->max_k = text_length; + affine_wavefronts->min_k = -pattern_length; + // Reduction + affine_wavefronts_reduction_set_dynamic( + &affine_wavefronts->reduction,min_wavefront_length,max_distance_threshold); + // Return + return affine_wavefronts; +} +/* + * Allocate individual wavefront + */ +affine_wavefront_t* affine_wavefronts_allocate_wavefront( + affine_wavefronts_t* const affine_wavefronts, + const int lo_base, + const int hi_base) { + // Compute limits + const int wavefront_length = hi_base - lo_base + 2; // (+1) for k=0 + // Allocate wavefront + affine_wavefront_t* const wavefront = affine_wavefronts->wavefronts_current; + ++(affine_wavefronts->wavefronts_current); // Next + // Configure offsets + wavefront->null = false; + wavefront->lo = lo_base; + wavefront->hi = hi_base; + wavefront->lo_base = lo_base; + wavefront->hi_base = hi_base; + // Allocate offsets + awf_offset_t* const offsets_mem = mm_stack_calloc( + affine_wavefronts->mm_stack,wavefront_length,awf_offset_t,false); + awf_offset_t* const offsets = offsets_mem - lo_base; // Center at k=0 + wavefront->offsets = offsets; + // Return + return wavefront; +} + diff --git a/benchmarks/wfa/gap_affine/affine_wavefront.h b/benchmarks/wfa/gap_affine/affine_wavefront.h new file mode 100644 index 0000000..07178f0 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront.h @@ -0,0 +1,169 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Wavefront alignment algorithm for pairwise gap-affine + * alignment (Main module) + */ + +#ifndef AFFINE_WAVEFRONT_H_ +#define AFFINE_WAVEFRONT_H_ + + +#include "utils/commons.h" +#include "utils/mm_allocator.h" +#include "utils/mm_stack.h" + +#include "affine_wavefront_penalties.h" +#include "affine_wavefront_reduction.h" +#include "edit_cigar.h" + +/* + * Constants + */ +#define AFFINE_WAVEFRONT_OFFSET_NULL (-10) +#define AFFINE_WAVEFRONT_K_NULL (INT_MAX/2) + +/* + * Translate k and offset to coordinates h,v + */ +#define AFFINE_WAVEFRONT_V(k,offset) ((offset)-(k)) +#define AFFINE_WAVEFRONT_H(k,offset) (offset) + +#define AFFINE_WAVEFRONT_DIAGONAL(h,v) ((h)-(v)) +#define AFFINE_WAVEFRONT_OFFSET(h,v) (h) + +/* + * Offset size + */ +//#define AFFINE_WAVEFRONT_W8 +//#define AFFINE_WAVEFRONT_W16 +#define AFFINE_WAVEFRONT_W32 + +#ifdef AFFINE_WAVEFRONT_W8 + typedef int8_t awf_offset_t; +#else + #ifdef AFFINE_WAVEFRONT_W16 + typedef int16_t awf_offset_t; + #else // AFFINE_WAVEFRONT_W32 + typedef int32_t awf_offset_t; + #endif +#endif + +/* + * Wavefront + */ +typedef struct { + // Range + bool null; // Is null interval? + int lo; // Effective lowest diagonal (inclusive) + int hi; // Effective highest diagonal (inclusive) + int lo_base; // Lowest diagonal before reduction (inclusive) + int hi_base; // Highest diagonal before reduction (inclusive) + // Offsets + awf_offset_t* offsets; // Offsets +} affine_wavefront_t; + +/* + * Gap-Affine Wavefronts + */ +typedef struct { + // Dimensions + int pattern_length; // Pattern length + int text_length; // Text length + int num_wavefronts; // Total number of allocatable wavefronts + int max_allocated_wavefront; // Maximum index/score of allocated wavefront + // Limits + int max_penalty; // MAX(mismatch_penalty,single_gap_penalty) + int max_k; // Maximum diagonal k (used for null-wf, display, and banding) + int min_k; // Maximum diagonal k (used for null-wf, display, and banding) + // Wavefronts + affine_wavefront_t** mwavefronts; // M-wavefronts + affine_wavefront_t** iwavefronts; // I-wavefronts + affine_wavefront_t** dwavefronts; // D-wavefronts + affine_wavefront_t wavefront_null; // Null wavefront (used to gain orthogonality) + // Reduction + affine_wavefronts_reduction_t reduction; // Reduction parameters + // Penalties + affine_wavefronts_penalties_t penalties; // Penalties parameters + // CIGAR + edit_cigar_t edit_cigar; // Alignment CIGAR + // MM + mm_allocator_t* mm_allocator; // MM-Allocator (General memory allocator) + mm_stack_t* mm_stack; // MM-Stack (Specific fast malloc/free wavefronts' memory) + affine_wavefront_t* wavefronts_mem; // MM-Slab (Specific fast malloc/free wavefronts-ptr => base) + affine_wavefront_t* wavefronts_current; // MM-Slab (Specific fast malloc/free wavefronts-ptr => next) +} affine_wavefronts_t; + +/* + * SWF Wavefront Computation Set + */ +typedef struct { + /* In Wavefronts*/ + affine_wavefront_t* in_mwavefront_sub; + affine_wavefront_t* in_mwavefront_gap; + affine_wavefront_t* in_iwavefront_ext; + affine_wavefront_t* in_dwavefront_ext; + /* Out Wavefronts */ + affine_wavefront_t* out_mwavefront; + affine_wavefront_t* out_iwavefront; + affine_wavefront_t* out_dwavefront; +} affine_wavefront_set; + +/* + * Setup + */ +void affine_wavefronts_clear( + affine_wavefronts_t* const affine_wavefronts); +void affine_wavefronts_delete( + affine_wavefronts_t* const affine_wavefronts); + +/* + * Setup WF-modes + */ +affine_wavefronts_t* affine_wavefronts_new_complete( + const int pattern_length, + const int text_length, + affine_penalties_t* const penalties, + mm_allocator_t* const mm_allocator); +affine_wavefronts_t* affine_wavefronts_new_reduced( + const int pattern_length, + const int text_length, + affine_penalties_t* const penalties, + const int min_wavefront_length, + const int max_distance_threshold, + mm_allocator_t* const mm_allocator); + +/* + * Allocate individual wavefront + */ +affine_wavefront_t* affine_wavefronts_allocate_wavefront( + affine_wavefronts_t* const affine_wavefronts, + const int lo_base, + const int hi_base); + +#endif /* AFFINE_WAVEFRONT_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_align.c b/benchmarks/wfa/gap_affine/affine_wavefront_align.c new file mode 100644 index 0000000..430c28c --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_align.c @@ -0,0 +1,362 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA main algorithm + */ + +#include "affine_wavefront_align.h" +#include "gap_affine/affine_wavefront_backtrace.h" +#include "gap_affine/affine_wavefront_extend.h" +#include "gap_affine/affine_wavefront_utils.h" +#include "utils/string_padded.h" + +/* + * Fetch & allocate wavefronts + */ +void affine_wavefronts_fetch_wavefronts( + affine_wavefronts_t* const affine_wavefronts, + affine_wavefront_set* const wavefront_set, + const int score) { + // Compute scores + const affine_penalties_t* const wavefront_penalties = &(affine_wavefronts->penalties.wavefront_penalties); + const int mismatch_score = score - wavefront_penalties->mismatch; + const int gap_open_score = score - wavefront_penalties->gap_opening - wavefront_penalties->gap_extension; + const int gap_extend_score = score - wavefront_penalties->gap_extension; + // Fetch wavefronts + wavefront_set->in_mwavefront_sub = affine_wavefronts_get_source_mwavefront(affine_wavefronts,mismatch_score); + wavefront_set->in_mwavefront_gap = affine_wavefronts_get_source_mwavefront(affine_wavefronts,gap_open_score); + wavefront_set->in_iwavefront_ext = affine_wavefronts_get_source_iwavefront(affine_wavefronts,gap_extend_score); + wavefront_set->in_dwavefront_ext = affine_wavefronts_get_source_dwavefront(affine_wavefronts,gap_extend_score); +} +void affine_wavefronts_allocate_wavefronts( + affine_wavefronts_t* const affine_wavefronts, + affine_wavefront_set* const wavefront_set, + const int score, + const int lo_effective, + const int hi_effective) { + // Allocate M-Wavefront + wavefront_set->out_mwavefront = + affine_wavefronts_allocate_wavefront(affine_wavefronts,lo_effective,hi_effective); + affine_wavefronts->mwavefronts[score] = wavefront_set->out_mwavefront; + // Allocate I-Wavefront + if (!wavefront_set->in_mwavefront_gap->null || !wavefront_set->in_iwavefront_ext->null) { + wavefront_set->out_iwavefront = + affine_wavefronts_allocate_wavefront(affine_wavefronts,lo_effective,hi_effective); + affine_wavefronts->iwavefronts[score] = wavefront_set->out_iwavefront; + } else { + wavefront_set->out_iwavefront = NULL; + } + // Allocate D-Wavefront + if (!wavefront_set->in_mwavefront_gap->null || !wavefront_set->in_dwavefront_ext->null) { + wavefront_set->out_dwavefront = + affine_wavefronts_allocate_wavefront(affine_wavefronts,lo_effective,hi_effective); + affine_wavefronts->dwavefronts[score] = wavefront_set->out_dwavefront; + } else { + wavefront_set->out_dwavefront = NULL; + } + // Increase max-wavefront + affine_wavefronts->max_allocated_wavefront = MAX(affine_wavefronts->max_allocated_wavefront,score); +} +void affine_wavefronts_compute_limits( + affine_wavefronts_t* const affine_wavefronts, + const affine_wavefront_set* const wavefront_set, + const int score, + int* const lo_effective, + int* const hi_effective) { + // Set limits (min_lo) + int lo = wavefront_set->in_mwavefront_sub->lo; + if (lo > wavefront_set->in_mwavefront_gap->lo) lo = wavefront_set->in_mwavefront_gap->lo; + if (lo > wavefront_set->in_iwavefront_ext->lo) lo = wavefront_set->in_iwavefront_ext->lo; + if (lo > wavefront_set->in_dwavefront_ext->lo) lo = wavefront_set->in_dwavefront_ext->lo; + --lo; + // Set limits (max_hi) + int hi = wavefront_set->in_mwavefront_sub->hi; + if (hi < wavefront_set->in_mwavefront_gap->hi) hi = wavefront_set->in_mwavefront_gap->hi; + if (hi < wavefront_set->in_iwavefront_ext->hi) hi = wavefront_set->in_iwavefront_ext->hi; + if (hi < wavefront_set->in_dwavefront_ext->hi) hi = wavefront_set->in_dwavefront_ext->hi; + ++hi; + // Set effective limits values + *hi_effective = hi; + *lo_effective = lo; +} +/* + * Compute wavefront offsets + */ +#define AFFINE_WAVEFRONT_DECLARE(wavefront,prefix) \ + awf_offset_t* prefix ## _offsets = wavefront->offsets; \ + const int prefix ## _hi = wavefront->hi; \ + const int prefix ## _lo = wavefront->lo +#define AFFINE_WAVEFRONT_COND_FETCH(prefix,index,value) \ + (prefix ## _lo <= (index) && (index) <= prefix ## _hi) ? (value) : AFFINE_WAVEFRONT_OFFSET_NULL +/* + * Compute wavefront offsets + */ +void affine_wavefronts_compute_next( + affine_wavefronts_t* const affine_wavefronts, + const affine_wavefront_set* const wavefront_set, + const int lo, + const int hi) { + // Parameters + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_sub,m_sub); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_gap,m_gap); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_iwavefront_ext,i_ext); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_dwavefront_ext,d_ext); + awf_offset_t* const out_ioffsets = wavefront_set->out_iwavefront->offsets; + awf_offset_t* const out_doffsets = wavefront_set->out_dwavefront->offsets; + awf_offset_t* const out_moffsets = wavefront_set->out_mwavefront->offsets; + // Compute loop peeling offset (min_hi) + int min_hi = wavefront_set->in_mwavefront_sub->hi; + if (!wavefront_set->in_mwavefront_gap->null && min_hi > wavefront_set->in_mwavefront_gap->hi-1) min_hi = wavefront_set->in_mwavefront_gap->hi-1; + if (!wavefront_set->in_iwavefront_ext->null && min_hi > wavefront_set->in_iwavefront_ext->hi+1) min_hi = wavefront_set->in_iwavefront_ext->hi+1; + if (!wavefront_set->in_dwavefront_ext->null && min_hi > wavefront_set->in_dwavefront_ext->hi-1) min_hi = wavefront_set->in_dwavefront_ext->hi-1; + // Compute loop peeling offset (max_lo) + int max_lo = wavefront_set->in_mwavefront_sub->lo; + if (!wavefront_set->in_mwavefront_gap->null && max_lo < wavefront_set->in_mwavefront_gap->lo+1) max_lo = wavefront_set->in_mwavefront_gap->lo+1; + if (!wavefront_set->in_iwavefront_ext->null && max_lo < wavefront_set->in_iwavefront_ext->lo+1) max_lo = wavefront_set->in_iwavefront_ext->lo+1; + if (!wavefront_set->in_dwavefront_ext->null && max_lo < wavefront_set->in_dwavefront_ext->lo-1) max_lo = wavefront_set->in_dwavefront_ext->lo-1; + // Compute score wavefronts (prologue) + int k; + for (k=lo;kin_mwavefront_sub,m_sub); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_gap,m_gap); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_iwavefront_ext,i_ext); + awf_offset_t* const out_ioffsets = wavefront_set->out_iwavefront->offsets; + awf_offset_t* const out_moffsets = wavefront_set->out_mwavefront->offsets; + // Compute score wavefronts + int k; +#if defined(__GNUC__) || defined(__GNUG__) + #pragma GCC ivdep +#else + #pragma ivdep +#endif + for (k=lo;k<=hi;++k) { + // Update I + const awf_offset_t ins_g = AFFINE_WAVEFRONT_COND_FETCH(m_gap,k-1,m_gap_offsets[k-1]); + const awf_offset_t ins_i = AFFINE_WAVEFRONT_COND_FETCH(i_ext,k-1,i_ext_offsets[k-1]); + const awf_offset_t ins = MAX(ins_g,ins_i) + 1; + out_ioffsets[k] = ins; + // Update M + const awf_offset_t sub = AFFINE_WAVEFRONT_COND_FETCH(m_sub,k,m_sub_offsets[k]+1); + out_moffsets[k] = MAX(ins,sub); + } +} +void affine_wavefronts_compute_offsets_dm( + affine_wavefronts_t* const affine_wavefronts, + const affine_wavefront_set* const wavefront_set, + const int lo, + const int hi) { + // Parameters + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_sub,m_sub); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_gap,m_gap); + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_dwavefront_ext,d_ext); + awf_offset_t* const out_doffsets = wavefront_set->out_dwavefront->offsets; + awf_offset_t* const out_moffsets = wavefront_set->out_mwavefront->offsets; + // Compute score wavefronts + int k; +#if defined(__GNUC__) || defined(__GNUG__) + #pragma GCC ivdep +#else + #pragma ivdep +#endif + for (k=lo;k<=hi;++k) { + // Update D + const awf_offset_t del_g = AFFINE_WAVEFRONT_COND_FETCH(m_gap,k+1,m_gap_offsets[k+1]); + const awf_offset_t del_d = AFFINE_WAVEFRONT_COND_FETCH(d_ext,k+1,d_ext_offsets[k+1]); + const awf_offset_t del = MAX(del_g,del_d); + out_doffsets[k] = del; + // Update M + const awf_offset_t sub = AFFINE_WAVEFRONT_COND_FETCH(m_sub,k,m_sub_offsets[k]+1); + out_moffsets[k] = MAX(del,sub); + } +} +void affine_wavefronts_compute_offsets_m( + affine_wavefronts_t* const affine_wavefronts, + const affine_wavefront_set* const wavefront_set, + const int lo, + const int hi) { + // Parameters + AFFINE_WAVEFRONT_DECLARE(wavefront_set->in_mwavefront_sub,m_sub); + awf_offset_t* const out_moffsets = wavefront_set->out_mwavefront->offsets; + // Compute score wavefronts + int k; +#if defined(__GNUC__) || defined(__GNUG__) + #pragma GCC ivdep +#else + #pragma ivdep +#endif + for (k=lo;k<=hi;++k) { + // Update M + out_moffsets[k] = AFFINE_WAVEFRONT_COND_FETCH(m_sub,k,m_sub_offsets[k]+1); + } +} +/* + * Compute wavefront + */ +void affine_wavefronts_compute_wavefront( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int score) { + // Select wavefronts + affine_wavefront_set wavefront_set; + affine_wavefronts_fetch_wavefronts(affine_wavefronts,&wavefront_set,score); + // Check null wavefronts + if (wavefront_set.in_mwavefront_sub->null && + wavefront_set.in_mwavefront_gap->null && + wavefront_set.in_iwavefront_ext->null && + wavefront_set.in_dwavefront_ext->null) { + return; + } + // Set limits + int hi, lo; + affine_wavefronts_compute_limits(affine_wavefronts,&wavefront_set,score,&lo,&hi); + // Allocate score-wavefronts + affine_wavefronts_allocate_wavefronts(affine_wavefronts,&wavefront_set,score,lo,hi); + // Compute WF + const int kernel = ((wavefront_set.out_iwavefront!=NULL) << 1) | (wavefront_set.out_dwavefront!=NULL); + switch (kernel) { + case 3: // 11b + affine_wavefronts_compute_next(affine_wavefronts,&wavefront_set,lo,hi); + break; + case 2: // 10b + affine_wavefronts_compute_offsets_im(affine_wavefronts,&wavefront_set,lo,hi); + break; + case 1: // 01b + affine_wavefronts_compute_offsets_dm(affine_wavefronts,&wavefront_set,lo,hi); + break; + case 0: // 00b + affine_wavefronts_compute_offsets_m(affine_wavefronts,&wavefront_set,lo,hi); + break; + } +} +/* + * Computation using Wavefronts + */ +void affine_wavefronts_align( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length) { + // Init padded strings + strings_padded_t* const strings_padded = + strings_padded_new_rhomb( + pattern,pattern_length,text,text_length, + AFFINE_WAVEFRONT_PADDING,affine_wavefronts->mm_allocator); + // Initialize wavefront + affine_wavefront_initialize(affine_wavefronts); + // Compute wavefronts for increasing score + int score = 0; + while (true) { + // Exact extend s-wavefront + affine_wavefronts_extend_wavefront_packed( + affine_wavefronts,strings_padded->pattern_padded,pattern_length, + strings_padded->text_padded,text_length,score); + // Exit condition + if (affine_wavefront_end_reached(affine_wavefronts,pattern_length,text_length,score)) { + // Backtrace & check alignment reached + affine_wavefronts_backtrace( + affine_wavefronts,strings_padded->pattern_padded,pattern_length, + strings_padded->text_padded,text_length,score); + break; + } + // Update all wavefronts + ++score; // Increase score + affine_wavefronts_compute_wavefront( + affine_wavefronts,strings_padded->pattern_padded,pattern_length, + strings_padded->text_padded,text_length,score); + } + // Free + strings_padded_delete(strings_padded); +} + diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_align.h b/benchmarks/wfa/gap_affine/affine_wavefront_align.h new file mode 100644 index 0000000..4dc26a3 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_align.h @@ -0,0 +1,48 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA main algorithm + */ + +#ifndef AFFINE_WAVEFRONT_ALIGN_H_ +#define AFFINE_WAVEFRONT_ALIGN_H_ + +#include "gap_affine/affine_wavefront.h" +#include "utils/commons.h" + +/* + * Computation using Wavefronts + */ +void affine_wavefronts_align( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length); + +#endif /* AFFINE_WAVEFRONT_ALIGN_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.c b/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.c new file mode 100644 index 0000000..e866a6f --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.c @@ -0,0 +1,387 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA extend backtrace component + */ + +#include "gap_affine/affine_wavefront_backtrace.h" + +/* + * Backtrace Detect Limits + */ +bool affine_wavefronts_valid_location( + const int k, + const awf_offset_t offset, + const int pattern_length, + const int text_length) { + // Locate offset (remember that backtrace is always +1 offset ahead) + const int v = AFFINE_WAVEFRONT_V(k,offset); + const int h = AFFINE_WAVEFRONT_H(k,offset); + return (v > 0 && v <= pattern_length && + h > 0 && h <= text_length); +} +void affine_wavefronts_offset_add_trailing_gap( + edit_cigar_t* const edit_cigar, + const int k, + const int alignment_k) { + // Parameters + char* const operations = edit_cigar->operations; + int op_sentinel = edit_cigar->begin_offset; + // Add trailing gap + int i; + if (k < alignment_k) { + for (i=k;i alignment_k) { + for (i=alignment_k;ibegin_offset = op_sentinel; +} +/* + * Backtrace Paths Offsets + */ +awf_offset_t backtrace_wavefront_trace_deletion_open_offset( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return AFFINE_WAVEFRONT_OFFSET_NULL; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront != NULL && + mwavefront->lo_base <= k+1 && + k+1 <= mwavefront->hi_base) { + return mwavefront->offsets[k+1]; + } else { + return AFFINE_WAVEFRONT_OFFSET_NULL; + } +} +awf_offset_t backtrace_wavefront_trace_deletion_extend_offset( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return AFFINE_WAVEFRONT_OFFSET_NULL; + affine_wavefront_t* const dwavefront = affine_wavefronts->dwavefronts[score]; + if (dwavefront != NULL && + dwavefront->lo_base <= k+1 && + k+1 <= dwavefront->hi_base) { + return dwavefront->offsets[k+1]; + } else { + return AFFINE_WAVEFRONT_OFFSET_NULL; + } +} +awf_offset_t backtrace_wavefront_trace_insertion_open_offset( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return AFFINE_WAVEFRONT_OFFSET_NULL; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront != NULL && + mwavefront->lo_base <= k-1 && + k-1 <= mwavefront->hi_base) { + return mwavefront->offsets[k-1] + 1; + } else { + return AFFINE_WAVEFRONT_OFFSET_NULL; + } +} +awf_offset_t backtrace_wavefront_trace_insertion_extend_offset( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return AFFINE_WAVEFRONT_OFFSET_NULL; + affine_wavefront_t* const iwavefront = affine_wavefronts->iwavefronts[score]; + if (iwavefront != NULL && + iwavefront->lo_base <= k-1 && + k-1 <= iwavefront->hi_base) { + return iwavefront->offsets[k-1] + 1; + } else { + return AFFINE_WAVEFRONT_OFFSET_NULL; + } +} +awf_offset_t backtrace_wavefront_trace_mismatch_offset( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return AFFINE_WAVEFRONT_OFFSET_NULL; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront != NULL && + mwavefront->lo_base <= k && + k <= mwavefront->hi_base) { + return mwavefront->offsets[k] + 1; + } else { + return AFFINE_WAVEFRONT_OFFSET_NULL; + } +} +/* + * Backtrace Paths Conditions + */ +bool backtrace_wavefront_trace_deletion( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const dwavefront = affine_wavefronts->dwavefronts[score]; + return (dwavefront != NULL && + dwavefront->lo_base <= k && + k <= dwavefront->hi_base && + offset == dwavefront->offsets[k]); +} +bool backtrace_wavefront_trace_deletion_open( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + return (mwavefront != NULL && + mwavefront->lo_base <= k+1 && + k+1 <= mwavefront->hi_base && + offset == mwavefront->offsets[k+1]); +} +bool backtrace_wavefront_trace_deletion_extend( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const dwavefront = affine_wavefronts->dwavefronts[score]; + return (dwavefront != NULL && + dwavefront->lo_base <= k+1 && + k+1 <= dwavefront->hi_base && + offset == dwavefront->offsets[k+1]); +} +bool backtrace_wavefront_trace_insertion( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const iwavefront = affine_wavefronts->iwavefronts[score]; + return (iwavefront != NULL && + iwavefront->lo_base <= k && + k <= iwavefront->hi_base && + offset == iwavefront->offsets[k]); +} +bool backtrace_wavefront_trace_insertion_open( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + return (mwavefront != NULL && + mwavefront->lo_base <= k-1 && + k-1 <= mwavefront->hi_base && + offset == mwavefront->offsets[k-1]+1); +} +bool backtrace_wavefront_trace_insertion_extend( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const iwavefront = affine_wavefronts->iwavefronts[score]; + return (iwavefront != NULL && + iwavefront->lo_base <= k-1 && + k-1 <= iwavefront->hi_base && + offset == iwavefront->offsets[k-1]+1); +} +bool backtrace_wavefront_trace_mismatch( + affine_wavefronts_t* const affine_wavefronts, + const int score, + const int k, + const awf_offset_t offset) { + if (score < 0) return false; + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + return (mwavefront != NULL && + mwavefront->lo_base <= k && + k <= mwavefront->hi_base && + offset == mwavefront->offsets[k]+1); +} +/* + * Backtrace Operations + */ +int affine_wavefronts_backtrace_compute_max_matches( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const char* const text, + const int k, + awf_offset_t offset) { + // Locate position + int v = AFFINE_WAVEFRONT_V(k,offset); + int h = AFFINE_WAVEFRONT_H(k,offset); + // Check matches + int num_matches = 0; + while (v>0 && h>0 && pattern[--v] == text[--h]) ++num_matches; + // Return max-matches + return num_matches; +} +void affine_wavefronts_backtrace_matches__check( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const char* const text, + const int k, + awf_offset_t offset, + const bool valid_location, + const int num_matches, + edit_cigar_t* const edit_cigar) { + int i; + for (i=0;ioperations[(edit_cigar->begin_offset)--] = 'M'; + // Update state + --offset; + } +} +void affine_wavefronts_backtrace_matches( + edit_cigar_t* const edit_cigar, + const int num_matches) { + // Set Matches + int i; + for (i=0;ioperations[(edit_cigar->begin_offset)--] = 'M'; + } +} +/* + * Backtrace (single solution) + */ +void affine_wavefronts_backtrace( + affine_wavefronts_t* const affine_wavefronts, + char* const pattern, + const int pattern_length, + char* const text, + const int text_length, + const int alignment_score) { + // Parameters + const affine_penalties_t* const wavefront_penalties = + &(affine_wavefronts->penalties.wavefront_penalties); + edit_cigar_t* const cigar = &affine_wavefronts->edit_cigar; + const int alignment_k = AFFINE_WAVEFRONT_DIAGONAL(text_length,pattern_length); + // Compute starting location + int score = alignment_score; + int k = alignment_k; + awf_offset_t offset = affine_wavefronts->mwavefronts[alignment_score]->offsets[k]; + bool valid_location = affine_wavefronts_valid_location(k,offset,pattern_length,text_length); + // Trace the alignment back + backtrace_wavefront_type backtrace_type = backtrace_wavefront_M; + int v = AFFINE_WAVEFRONT_V(k,offset); + int h = AFFINE_WAVEFRONT_H(k,offset); + while (v > 0 && h > 0 && score > 0) { + // Check location + if (!valid_location) { + valid_location = affine_wavefronts_valid_location(k,offset,pattern_length,text_length); + if (valid_location) { + affine_wavefronts_offset_add_trailing_gap(cigar,k,alignment_k); + } + } + // Compute scores + const int gap_open_score = score - wavefront_penalties->gap_opening - wavefront_penalties->gap_extension; + const int gap_extend_score = score - wavefront_penalties->gap_extension; + const int mismatch_score = score - wavefront_penalties->mismatch; + // Compute source offsets + const awf_offset_t del_ext = (backtrace_type == backtrace_wavefront_I) ? AFFINE_WAVEFRONT_OFFSET_NULL: + backtrace_wavefront_trace_deletion_extend_offset(affine_wavefronts,gap_extend_score,k,offset); + const awf_offset_t del_open = (backtrace_type == backtrace_wavefront_I) ? AFFINE_WAVEFRONT_OFFSET_NULL: + backtrace_wavefront_trace_deletion_open_offset(affine_wavefronts,gap_open_score,k,offset); + const awf_offset_t ins_ext = (backtrace_type == backtrace_wavefront_D) ? AFFINE_WAVEFRONT_OFFSET_NULL: + backtrace_wavefront_trace_insertion_extend_offset(affine_wavefronts,gap_extend_score,k,offset); + const awf_offset_t ins_open = (backtrace_type == backtrace_wavefront_D) ? AFFINE_WAVEFRONT_OFFSET_NULL: + backtrace_wavefront_trace_insertion_open_offset(affine_wavefronts,gap_open_score,k,offset); + const awf_offset_t misms = (backtrace_type != backtrace_wavefront_M) ? AFFINE_WAVEFRONT_OFFSET_NULL: + backtrace_wavefront_trace_mismatch_offset(affine_wavefronts,mismatch_score,k,offset); + // Compute maximum offset + const awf_offset_t max_del = MAX(del_ext,del_open); + const awf_offset_t max_ins = MAX(ins_ext,ins_open); + const awf_offset_t max_all = MAX(misms,MAX(max_ins,max_del)); + // Traceback Matches + if (backtrace_type == backtrace_wavefront_M) { + const int num_matches = offset - max_all; + affine_wavefronts_backtrace_matches__check(affine_wavefronts, + pattern,text,k,offset,valid_location,num_matches,cigar); + offset = max_all; + } + // Traceback Operation + if (max_all == del_ext) { + // Add Deletion + if (valid_location) cigar->operations[(cigar->begin_offset)--] = 'D'; + // Update state + score = gap_extend_score; + ++k; + backtrace_type = backtrace_wavefront_D; + } else if (max_all == del_open) { + // Add Deletion + if (valid_location) cigar->operations[(cigar->begin_offset)--] = 'D'; + // Update state + score = gap_open_score; + ++k; + backtrace_type = backtrace_wavefront_M; + } else if (max_all == ins_ext) { + // Add Insertion + if (valid_location) cigar->operations[(cigar->begin_offset)--] = 'I'; + // Update state + score = gap_extend_score; + --k; + --offset; + backtrace_type = backtrace_wavefront_I; + } else if (max_all == ins_open) { + // Add Insertion + if (valid_location) cigar->operations[(cigar->begin_offset)--] = 'I'; + // Update state + score = gap_open_score; + --k; + --offset; + backtrace_type = backtrace_wavefront_M; + } else if (max_all == misms) { + // Add Mismatch + if (valid_location) cigar->operations[(cigar->begin_offset)--] = 'X'; + // Update state + score = mismatch_score; + --offset; + } else { + fprintf(stderr,"Backtrace error: No link found during backtrace\n"); + exit(1); + } + // Update coordinates + v = AFFINE_WAVEFRONT_V(k,offset); + h = AFFINE_WAVEFRONT_H(k,offset); + } + // Account for last operations + if (score == 0) { + // Account for last stroke of matches + affine_wavefronts_backtrace_matches__check(affine_wavefronts, + pattern,text,k,offset,valid_location,offset,cigar); + } else { + // Account for last stroke of insertion/deletion + while (v > 0) {cigar->operations[(cigar->begin_offset)--] = 'D'; --v;}; + while (h > 0) {cigar->operations[(cigar->begin_offset)--] = 'I'; --h;}; + } + ++(cigar->begin_offset); // Set CIGAR length +} diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.h b/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.h new file mode 100644 index 0000000..f62e07f --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_backtrace.h @@ -0,0 +1,67 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA extend backtrace component + */ + +#ifndef AFFINE_WAVEFRONT_BACKTRACE_H_ +#define AFFINE_WAVEFRONT_BACKTRACE_H_ + +#include "gap_affine/affine_wavefront.h" + +/* + * Sequences DTO + */ +typedef struct { + char* pattern; + int pattern_length; + char* text; + int text_length; +} alignment_sequences_t; + +/* + * WF type + */ +typedef enum { + backtrace_wavefront_M = 0, + backtrace_wavefront_I = 1, + backtrace_wavefront_D = 2 +} backtrace_wavefront_type; + +/* + * Backtrace + */ +void affine_wavefronts_backtrace( + affine_wavefronts_t* const affine_wavefronts, + char* const pattern, + const int pattern_length, + char* const text, + const int text_length, + const int alignment_score); + +#endif /* AFFINE_WAVEFRONT_BACKTRACE_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_extend.c b/benchmarks/wfa/gap_affine/affine_wavefront_extend.c new file mode 100644 index 0000000..cf80be7 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_extend.c @@ -0,0 +1,276 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA extend exact-matches component + */ + +#include "gap_affine/affine_wavefront_extend.h" +#include "gap_affine/affine_wavefront_reduction.h" +#include "gap_affine/affine_wavefront_utils.h" +#include "utils/string_padded.h" + +#ifdef __ARM_FEATURE_SVE + #include +#endif + +#if 0 +void print_mask (svbool_t mask) { + // Create a "1" vector with inactive lanes to zero + svuint32_t res = svdup_u32_z(mask, 1U); + + uint32_t v[16] = {0}; + svst1_u32(svptrue_b32(), v, res); + for (int i=0; i<16; i++) { + fprintf(stderr, "%u ", v[i]); + } + fprintf(stderr, "\n"); +} + +void print_svint32 (svint32_t _v) { + int32_t v[16] = {0}; + svst1_s32(svptrue_b32(), v, _v); + for (int i=0; i<16; i++) { + fprintf(stderr, "%d ", v[i]); + } + fprintf(stderr, "\n"); +} + +void print_svuint32 (svuint32_t _v) { + uint32_t v[16] = {0}; + svst1_u32(svptrue_b32(), v, _v); + for (int i=0; i<16; i++) { + fprintf(stderr, "%u ", v[i]); + } + fprintf(stderr, "\n"); +} + +void print_svuint32_hex (svuint32_t _v) { + uint32_t v[16] = {0}; + svst1_u32(svptrue_b32(), v, _v); + for (int i=0; i<16; i++) { + fprintf(stderr, "0x%x ", v[i]); + } + fprintf(stderr, "\n"); +} +#endif + +/* + * Reduce wavefront + */ +void affine_wavefronts_reduce_wavefront_offsets( + affine_wavefronts_t* const affine_wavefronts, + affine_wavefront_t* const wavefront, + const int pattern_length, + const int text_length, + const int min_distance, + const int max_distance_threshold, + const int alignment_k) { + // Parameters + const awf_offset_t* const offsets = wavefront->offsets; + int k; + // Reduce from bottom + const int top_limit = MIN(alignment_k-1,wavefront->hi); + for (k=wavefront->lo;klo); + } + // Reduce from top + const int botton_limit = MAX(alignment_k+1,wavefront->lo); + for (k=wavefront->hi;k>botton_limit;--k) { + const int distance = affine_wavefronts_compute_distance(pattern_length,text_length,offsets[k],k); + if (distance - min_distance <= max_distance_threshold) break; + --(wavefront->hi); + } + // Check hi/lo range + if (wavefront->lo > wavefront->hi) { + wavefront->null = true; + } +} +void affine_wavefronts_reduce_wavefronts( + affine_wavefronts_t* const affine_wavefronts, + const int pattern_length, + const int text_length, + const int score) { + // Parameters + const int min_wavefront_length = affine_wavefronts->reduction.min_wavefront_length; + const int max_distance_threshold = affine_wavefronts->reduction.max_distance_threshold; + const int alignment_k = AFFINE_WAVEFRONT_DIAGONAL(text_length,pattern_length); + // Fetch m-wavefront + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront==NULL) return; + if ((mwavefront->hi - mwavefront->lo + 1) < min_wavefront_length) return; + // Compute min-distance + const awf_offset_t* const offsets = mwavefront->offsets; + int min_distance = MAX(pattern_length,text_length); + int k; + for (k=mwavefront->lo;k<=mwavefront->hi;++k) { + const int distance = affine_wavefronts_compute_distance(pattern_length,text_length,offsets[k],k); + min_distance = MIN(min_distance,distance); + } + // Reduce m-wavefront + affine_wavefronts_reduce_wavefront_offsets( + affine_wavefronts,mwavefront,pattern_length,text_length, + min_distance,max_distance_threshold,alignment_k); + // Reduce i-wavefront + affine_wavefront_t* const iwavefront = affine_wavefronts->iwavefronts[score]; + if (iwavefront!=NULL) { + if (mwavefront->lo > iwavefront->lo) iwavefront->lo = mwavefront->lo; + if (mwavefront->hi < iwavefront->hi) iwavefront->hi = mwavefront->hi; + if (iwavefront->lo > iwavefront->hi) iwavefront->null = true; + } + // Reduce d-wavefront + affine_wavefront_t* const dwavefront = affine_wavefronts->dwavefronts[score]; + if (dwavefront!=NULL) { + if (mwavefront->lo > dwavefront->lo) dwavefront->lo = mwavefront->lo; + if (mwavefront->hi < dwavefront->hi) dwavefront->hi = mwavefront->hi; + if (dwavefront->lo > dwavefront->hi) dwavefront->null = true; + } +} + +/* + * Wavefront offset extension comparing characters + */ +void affine_wavefronts_extend( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int score) { + + // Fetch m-wavefront + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront==NULL) return; + // Extend diagonally each wavefront point + awf_offset_t* const offsets = mwavefront->offsets; + +#if defined(__ARM_FEATURE_SVE) + #pragma message("affine_wavefront_extend: ARM-SVE version") + + const int k_min = mwavefront->lo; + const int k_max = mwavefront->hi; + + uint64_t num_elems = svcntw(); + + svbool_t mask; + + // Extend diagonally each wavefront point + for (int k=k_min;k<=k_max;k+=num_elems) { + // Get number of elements that will be computed in this iteration + //int active_elements = wf_length - k; + svint32_t ks = svindex_s32(k, 1); + mask = svcmple_s32(svptrue_b32(), ks, svdup_s32(k_max)); + svbool_t original_mask = mask; + + svint32_t sv_offsets = svld1(mask, &offsets[k]); + + // h = o + svint32_t h = sv_offsets; + // v = o - k + svint32_t v = svsub_z(mask, sv_offsets, ks); + + bool svtest = svptest_any(svptrue_b32(), mask); + + while (svtest) { + svuint32_t bases_pattern = svld1_gather_s32offset_u32(mask, (uint32_t*)pattern, v); + bases_pattern = svrevb_u32_z(mask, bases_pattern); + svuint32_t bases_text = svld1_gather_s32offset_u32(mask, (uint32_t*)text, h); + bases_text = svrevb_u32_z(mask, bases_text); + + svuint32_t xor_result = sveor_u32_z(mask, bases_pattern, bases_text); + svuint32_t clz_res = svclz_u32_z(mask, xor_result); + svint32_t Eq = svreinterpret_s32(svlsr_u32_z(mask, clz_res, svdup_u32(3U))); + + // Make sure we don't count beyond the sequence + svint32_t remaining_v = svsub_s32_z(mask, svdup_s32(pattern_length), v); + svint32_t remaining_h = svsub_s32_z(mask, svdup_s32(text_length), h); + Eq = svmin_s32_z(mask, Eq, remaining_v); + Eq = svmin_s32_z(mask, Eq, remaining_h); + + sv_offsets = svadd_s32_m(mask, sv_offsets, Eq); + + // Only diagonals that have 4 elements equal (so they have not finished) will continue + mask = svcmpgt_n_s32(mask, Eq, 3U); + + // v < pattern_length + svbool_t mask_v = svcmplt_n_s32(mask, v, pattern_length); + mask = svand_b_z(mask, mask, mask_v); + // h < text_length + svbool_t mask_h = svcmplt_n_s32(mask, h, text_length); + mask = svand_b_z(mask, mask, mask_h); + + // v++ and h++ + v = svadd_n_s32_z(mask, v, 4); + h = svadd_n_s32_z(mask, h, 4); + + svtest = svptest_any(svptrue_b32(), mask); + } + svst1_s32(original_mask, &offsets[k], sv_offsets); + } +#else + #pragma message("affine_wavefront_extend: SCALAR version") + // // Fetch m-wavefront + // affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + // if (mwavefront==NULL) return; + // // Extend diagonally each wavefront point + // awf_offset_t* const offsets = mwavefront->offsets; + int k; + for (k=mwavefront->lo;k<=mwavefront->hi;++k) { + // Exact extend + const awf_offset_t offset = offsets[k]; + int v = AFFINE_WAVEFRONT_V(k,offset); + int h = AFFINE_WAVEFRONT_H(k,offset); + while (pattern[v++]==text[h++]) { + ++(offsets[k]); + } + } +#endif +} + +/* + * Gap-Affine Wavefront exact extension + */ +void affine_wavefronts_extend_wavefront_packed( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int score) { + // Extend wavefront + affine_wavefronts_extend( + affine_wavefronts,pattern,pattern_length, + text,text_length,score); + // Reduce wavefront dynamically + if (affine_wavefronts->reduction.reduction_strategy == wavefronts_reduction_dynamic) { + affine_wavefronts_reduce_wavefronts( + affine_wavefronts,pattern_length, + text_length,score); + } +} diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_extend.h b/benchmarks/wfa/gap_affine/affine_wavefront_extend.h new file mode 100644 index 0000000..64d32c6 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_extend.h @@ -0,0 +1,53 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA extend exact-matches component + */ + +#ifndef AFFINE_WAVEFRONT_EXTEND_H_ +#define AFFINE_WAVEFRONT_EXTEND_H_ + +#include "gap_affine/affine_wavefront.h" + +/* + * Constants + */ +#define AFFINE_WAVEFRONT_PADDING 10 // (-AFFINE_WAVEFRONT_OFFSET_NULL) + +/* + * Gap-Affine Wavefront exact extension + */ +void affine_wavefronts_extend_wavefront_packed( + affine_wavefronts_t* const affine_wavefronts, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int score); + +#endif /* AFFINE_WAVEFRONT_EXTEND_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_penalties.c b/benchmarks/wfa/gap_affine/affine_wavefront_penalties.c new file mode 100644 index 0000000..b76ef05 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_penalties.c @@ -0,0 +1,125 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support functions for handling penalties scores + */ + +#include "gap_affine/affine_wavefront_penalties.h" + +/* + * Setup + */ +void affine_wavefronts_penalties_init( + affine_wavefronts_penalties_t* const wavefronts_penalties, + affine_penalties_t* const penalties, + const wavefronts_penalties_strategy penalties_strategy) { + wavefronts_penalties->base_penalties = *penalties; + wavefronts_penalties->penalties_strategy = + (penalties->match==0) ? wavefronts_penalties_match_zero : penalties_strategy; + switch (wavefronts_penalties->penalties_strategy) { + case wavefronts_penalties_match_zero: + case wavefronts_penalties_force_zero_match: + affine_penalties_mzero(penalties,&(wavefronts_penalties->wavefront_penalties)); + break; + case wavefronts_penalties_shifted_penalties: + affine_penalties_shift(penalties,&(wavefronts_penalties->wavefront_penalties),false); + break; + case wavefronts_penalties_odd_pair_penalties: + affine_penalties_shift(penalties,&(wavefronts_penalties->wavefront_penalties),true); + break; + default: + break; + } +} +/* + * Score Adjustment + */ +void affine_penalties_mzero( + affine_penalties_t* const base_penalties, + affine_penalties_t* const shifted_penalties) { + // Check base penalties + if (base_penalties->match > 0) { + fprintf(stderr,"Match score must be negative or zero (M=%d)\n",base_penalties->match); + exit(1); + } + if (base_penalties->mismatch <= 0 || + base_penalties->gap_opening <= 0 || + base_penalties->gap_extension <= 0) { + fprintf(stderr,"Mismatch/Gap scores must be strictly positive (X=%d,O=%d,E=%d)\n", + base_penalties->mismatch,base_penalties->gap_opening,base_penalties->gap_extension); + exit(1); + } + // Copy base penalties + *shifted_penalties = *base_penalties; + // Zero match score + shifted_penalties->match = 0; +} +void affine_penalties_shift( + affine_penalties_t* const base_penalties, + affine_penalties_t* const shifted_penalties, + const bool pair_odd_heuristic) { + // Check base penalties + if (base_penalties->match > 0) { + fprintf(stderr,"Match score must be negative (M=%d)\n",base_penalties->match); + exit(1); + } + if (base_penalties->mismatch <= 0 || + base_penalties->gap_opening <= 0 || + base_penalties->gap_extension <= 0) { + fprintf(stderr,"Mismatch/Gap scores must be strictly positive (X=%d,O=%d,E=%d)\n", + base_penalties->mismatch,base_penalties->gap_opening,base_penalties->gap_extension); + exit(1); + } + // Copy base penalties + *shifted_penalties = *base_penalties; + // Shift to zero match score + shifted_penalties->match = 0; + shifted_penalties->mismatch -= base_penalties->match; + shifted_penalties->gap_opening -= base_penalties->match; + shifted_penalties->gap_extension -= base_penalties->match; + // Odd/Pair shift heuristic + if (pair_odd_heuristic) { + const bool is_mismatch_pair = ((shifted_penalties->mismatch%2)==0); + const bool is_gap_opening_pair = ((shifted_penalties->gap_opening%2)==0); + const bool is_gap_extension_pair = ((shifted_penalties->gap_extension%2)==0); + const int total_odd = !is_mismatch_pair + !is_gap_opening_pair + !is_gap_extension_pair; + const int total_pair = is_mismatch_pair + is_gap_opening_pair + is_gap_extension_pair; + if (total_odd > total_pair) { + // Shift all to odd + if (is_mismatch_pair) ++(shifted_penalties->mismatch); + if (is_gap_opening_pair) ++(shifted_penalties->gap_opening); + if (is_gap_extension_pair) ++(shifted_penalties->gap_extension); + } else { + // Shift all to pair + if (!is_mismatch_pair) ++(shifted_penalties->mismatch); + if (!is_gap_opening_pair) ++(shifted_penalties->gap_opening); + if (!is_gap_extension_pair) ++(shifted_penalties->gap_extension); + } + } +} + diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_penalties.h b/benchmarks/wfa/gap_affine/affine_wavefront_penalties.h new file mode 100644 index 0000000..0864c00 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_penalties.h @@ -0,0 +1,86 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support functions for handling penalties scores + */ + +#ifndef AFFINE_WAVEFRONT_PENALTIES_H_ +#define AFFINE_WAVEFRONT_PENALTIES_H_ + +#include "utils/commons.h" + +/* + * Penalties + */ +typedef struct { + int match; // (Penalty representation; usually M <= 0) + int mismatch; // (Penalty representation; usually X > 0) + int gap_opening; // (Penalty representation; usually O > 0) + int gap_extension; // (Penalty representation; usually E > 0) +} affine_penalties_t; + +/* + * Wavefront Strategy + */ +typedef enum { + wavefronts_penalties_match_zero, + wavefronts_penalties_force_zero_match, + wavefronts_penalties_shifted_penalties, + wavefronts_penalties_odd_pair_penalties +} wavefronts_penalties_strategy; + +/* + * Wavefront Penalties + */ +typedef struct { + affine_penalties_t base_penalties; // Input base Gap-Affine penalties + affine_penalties_t wavefront_penalties; // Wavefront Gap-Affine penalties + wavefronts_penalties_strategy penalties_strategy; // Penalties adaptation strategy +} affine_wavefronts_penalties_t; + +/* + * Setup + */ +void affine_wavefronts_penalties_init( + affine_wavefronts_penalties_t* const wavefronts_penalties, + affine_penalties_t* const penalties, + const wavefronts_penalties_strategy penalties_strategy); + +/* + * Score Adjustment + */ +void affine_penalties_mzero( + affine_penalties_t* const base_penalties, + affine_penalties_t* const shifted_penalties); +void affine_penalties_shift( + affine_penalties_t* const base_penalties, + affine_penalties_t* const shifted_penalties, + const bool pair_odd_heuristic); + + +#endif /* AFFINE_WAVEFRONT_PENALTIES_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_reduction.c b/benchmarks/wfa/gap_affine/affine_wavefront_reduction.c new file mode 100644 index 0000000..c5d5ad1 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_reduction.c @@ -0,0 +1,49 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support functions for wavefront-reduction + * strategies (adaptive or banded strategies) + */ + +#include "gap_affine/affine_wavefront_reduction.h" + +/* + * Setup + */ +void affine_wavefronts_reduction_set_none( + affine_wavefronts_reduction_t* const wavefronts_reduction) { + wavefronts_reduction->reduction_strategy = wavefronts_reduction_none; +} +void affine_wavefronts_reduction_set_dynamic( + affine_wavefronts_reduction_t* const wavefronts_reduction, + const int min_wavefront_length, + const int max_distance_threshold) { + wavefronts_reduction->reduction_strategy = wavefronts_reduction_dynamic; + wavefronts_reduction->min_wavefront_length = min_wavefront_length; + wavefronts_reduction->max_distance_threshold = max_distance_threshold; +} diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_reduction.h b/benchmarks/wfa/gap_affine/affine_wavefront_reduction.h new file mode 100644 index 0000000..9a5cbf7 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_reduction.h @@ -0,0 +1,65 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support functions for wavefront-reduction + * strategies (adaptive or banded strategies) + */ + +#ifndef AFFINE_WAVEFRONT_REDUCTION_H_ +#define AFFINE_WAVEFRONT_REDUCTION_H_ + +#include "utils/commons.h" + +/* + * Wavefront Reduction + */ +typedef enum { + wavefronts_reduction_none, + wavefronts_reduction_dynamic, +} wavefront_reduction_type; + +/* + * Wavefront Penalties + */ +typedef struct { + wavefront_reduction_type reduction_strategy; // Reduction strategy + int min_wavefront_length; // Dynamic: Minimum wavefronts length to reduce + int max_distance_threshold; // Dynamic: Maximum distance between offsets allowed +} affine_wavefronts_reduction_t; + +/* + * Setup + */ +void affine_wavefronts_reduction_set_none( + affine_wavefronts_reduction_t* const wavefronts_reduction); +void affine_wavefronts_reduction_set_dynamic( + affine_wavefronts_reduction_t* const wavefronts_reduction, + const int min_wavefront_length, + const int max_distance_threshold); + +#endif /* AFFINE_WAVEFRONT_REDUCTION_H_ */ diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_utils.c b/benchmarks/wfa/gap_affine/affine_wavefront_utils.c new file mode 100644 index 0000000..63cd220 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_utils.c @@ -0,0 +1,102 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support utilities + */ + +#include "gap_affine/affine_wavefront_utils.h" + +/* + * Accessors + */ +affine_wavefront_t* affine_wavefronts_get_source_mwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score) { + return (score < 0 || affine_wavefronts->mwavefronts[score] == NULL) ? + &affine_wavefronts->wavefront_null : affine_wavefronts->mwavefronts[score]; +} +affine_wavefront_t* affine_wavefronts_get_source_iwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score) { + return (score < 0 || affine_wavefronts->iwavefronts[score] == NULL) ? + &affine_wavefronts->wavefront_null : affine_wavefronts->iwavefronts[score]; +} +affine_wavefront_t* affine_wavefronts_get_source_dwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score) { + return (score < 0 || affine_wavefronts->dwavefronts[score] == NULL) ? + &affine_wavefronts->wavefront_null : affine_wavefronts->dwavefronts[score]; +} +int affine_wavefronts_diagonal_length( + affine_wavefronts_t* const affine_wavefronts, + const int k) { + if (k >= 0) { + return MIN(affine_wavefronts->text_length-k,affine_wavefronts->pattern_length); + } else { + return MIN(affine_wavefronts->pattern_length+k,affine_wavefronts->text_length); + } +} +int affine_wavefronts_compute_distance( + const int pattern_length, + const int text_length, + const awf_offset_t offset, + const int k) { + const int v = AFFINE_WAVEFRONT_V(k,offset); + const int h = AFFINE_WAVEFRONT_H(k,offset); + const int left_v = pattern_length - v; + const int left_h = text_length - h; + return MAX(left_v,left_h); +} +/* + * Initial Conditions and finalization + */ +void affine_wavefront_initialize( + affine_wavefronts_t* const affine_wavefronts) { + affine_wavefronts->mwavefronts[0] = affine_wavefronts_allocate_wavefront(affine_wavefronts,0,0); + affine_wavefronts->mwavefronts[0]->offsets[0] = 0; +} +bool affine_wavefront_end_reached( + affine_wavefronts_t* const affine_wavefronts, + const int pattern_length, + const int text_length, + const int score) { + // Parameters + const int alignment_k = AFFINE_WAVEFRONT_DIAGONAL(text_length,pattern_length); + const int alignment_offset = AFFINE_WAVEFRONT_OFFSET(text_length,pattern_length); + // Fetch wavefront and check termination + affine_wavefront_t* const mwavefront = affine_wavefronts->mwavefronts[score]; + if (mwavefront!=NULL) { + awf_offset_t* const offsets = mwavefront->offsets; + if (mwavefront->lo <= alignment_k && + alignment_k <= mwavefront->hi && + offsets[alignment_k] >= alignment_offset) { + return true; + } + } + return false; +} diff --git a/benchmarks/wfa/gap_affine/affine_wavefront_utils.h b/benchmarks/wfa/gap_affine/affine_wavefront_utils.h new file mode 100644 index 0000000..3979113 --- /dev/null +++ b/benchmarks/wfa/gap_affine/affine_wavefront_utils.h @@ -0,0 +1,69 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: WFA support utilities + */ + +#ifndef AFFINE_WAVEFRONT_UTILS_H_ +#define AFFINE_WAVEFRONT_UTILS_H_ + +#include "gap_affine/affine_wavefront.h" + +/* + * Accessors + */ +affine_wavefront_t* affine_wavefronts_get_source_mwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score); +affine_wavefront_t* affine_wavefronts_get_source_iwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score); +affine_wavefront_t* affine_wavefronts_get_source_dwavefront( + affine_wavefronts_t* const affine_wavefronts, + const int score); +int affine_wavefronts_diagonal_length( + affine_wavefronts_t* const affine_wavefronts, + const int k); +int affine_wavefronts_compute_distance( + const int pattern_length, + const int text_length, + const awf_offset_t offset, + const int k); + +/* + * Initial Conditions and finalization + */ +void affine_wavefront_initialize( + affine_wavefronts_t* const affine_wavefronts); +bool affine_wavefront_end_reached( + affine_wavefronts_t* const affine_wavefronts, + const int pattern_length, + const int text_length, + const int score); + +#endif /* AFFINE_WAVEFRONT_UTILS_H_ */ diff --git a/benchmarks/wfa/gap_affine/edit_cigar.c b/benchmarks/wfa/gap_affine/edit_cigar.c new file mode 100644 index 0000000..69fd5a9 --- /dev/null +++ b/benchmarks/wfa/gap_affine/edit_cigar.c @@ -0,0 +1,282 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Edit cigar data-structure (match/mismatch/insertion/deletion) + */ + +#include "edit_cigar.h" + +/* + * Setup + */ +void edit_cigar_allocate( + edit_cigar_t* const edit_cigar, + const int pattern_length, + const int text_length, + mm_allocator_t* const mm_allocator) { + edit_cigar->max_operations = pattern_length+text_length; + edit_cigar->operations = mm_allocator_malloc(mm_allocator,edit_cigar->max_operations); + edit_cigar->begin_offset = edit_cigar->max_operations - 1; + edit_cigar->end_offset = edit_cigar->max_operations; + edit_cigar->score = INT32_MIN; +} +void edit_cigar_clear( + edit_cigar_t* const edit_cigar) { + edit_cigar->begin_offset = edit_cigar->max_operations - 1; + edit_cigar->end_offset = edit_cigar->max_operations; + edit_cigar->score = INT32_MIN; +} +void edit_cigar_free( + edit_cigar_t* const edit_cigar, + mm_allocator_t* const mm_allocator) { + mm_allocator_free(mm_allocator,edit_cigar->operations); +} +/* + * Score + */ +int edit_cigar_score_edit( + edit_cigar_t* const edit_cigar) { + int score = 0, i; + for (i=edit_cigar->begin_offset;iend_offset;++i) { + switch (edit_cigar->operations[i]) { + case 'M': break; + case 'X': + case 'D': + case 'I': ++score; break; + default: return INT_MIN; + } + } + return score; +} +int edit_cigar_score_gap_affine( + edit_cigar_t* const edit_cigar, + affine_penalties_t* const penalties) { + char last_op = '\0'; + int score = 0, i; + for (i=edit_cigar->begin_offset;iend_offset;++i) { + switch (edit_cigar->operations[i]) { + case 'M': + score -= penalties->match; + last_op = 'M'; + break; + case 'X': + score -= penalties->mismatch; + last_op = 'X'; + break; + case 'D': + score -= penalties->gap_extension + ((last_op=='D') ? 0 : penalties->gap_opening); + last_op = 'D'; + break; + case 'I': + score -= penalties->gap_extension + ((last_op=='I') ? 0 : penalties->gap_opening); + last_op = 'I'; + break; + default: + fprintf(stderr,"Computing CIGAR score: Unknown operation\n"); + exit(1); + } + } + return score; +} +/* + * Utils + */ +bool edit_cigar_check_alignment( + FILE* const stream, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + edit_cigar_t* const edit_cigar, + const bool verbose) { + // Parameters + char* const operations = edit_cigar->operations; + // Traverse CIGAR + int pattern_pos=0, text_pos=0, i; + for (i=edit_cigar->begin_offset;iend_offset;++i) { + switch (operations[i]) { + case 'M': + // Check match + if (pattern[pattern_pos] != text[text_pos]) { + if (verbose) { + fprintf(stream, + "Align Check. Alignment not matching (pattern[%d]=%c != text[%d]=%c)\n", + pattern_pos,pattern[pattern_pos],text_pos,text[text_pos]); + } + return false; + } + ++pattern_pos; + ++text_pos; + break; + case 'X': + // Check mismatch + if (pattern[pattern_pos] == text[text_pos]) { + if (verbose) { + fprintf(stream, + "Align Check. Alignment not mismatching (pattern[%d]=%c == text[%d]=%c)\n", + pattern_pos,pattern[pattern_pos],text_pos,text[text_pos]); + } + return false; + } + ++pattern_pos; + ++text_pos; + break; + case 'I': + ++text_pos; + break; + case 'D': + ++pattern_pos; + break; + default: + fprintf(stderr,"CIGAR check. Unknown edit operation '%c'\n",operations[i]); + exit(1); + break; + } + } + // Check alignment length + if (pattern_pos != pattern_length) { + if (verbose) { + fprintf(stream, + "Align Check. Alignment incorrect length (pattern-aligned=%d,pattern-length=%d)\n", + pattern_pos,pattern_length); + } + return false; + } + if (text_pos != text_length) { + if (verbose) { + fprintf(stream, + "Align Check. Alignment incorrect length (text-aligned=%d,text-length=%d)\n", + text_pos,text_length); + } + return false; + } + // OK + return true; +} +/* + * Display + */ +void edit_cigar_print( + FILE* const stream, + edit_cigar_t* const edit_cigar) { + char last_op = edit_cigar->operations[edit_cigar->begin_offset]; + int last_op_length = 1; + int i; + for (i=edit_cigar->begin_offset+1;iend_offset;++i) { + if (edit_cigar->operations[i]==last_op) { + ++last_op_length; + } else { + fprintf(stream,"%d%c",last_op_length,last_op); + last_op = edit_cigar->operations[i]; + last_op_length = 1; + } + } + fprintf(stream,"%d%c",last_op_length,last_op); +} +void edit_cigar_print_pretty( + FILE* const stream, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + edit_cigar_t* const edit_cigar, + mm_allocator_t* const mm_allocator) { + // Parameters + char* const operations = edit_cigar->operations; + // Allocate alignment buffers + const int max_buffer_length = text_length+pattern_length+1; + char* const pattern_alg = mm_allocator_calloc(mm_allocator,max_buffer_length,char,true); + char* const ops_alg = mm_allocator_calloc(mm_allocator,max_buffer_length,char,true); + char* const text_alg = mm_allocator_calloc(mm_allocator,max_buffer_length,char,true); + // Compute alignment buffers + int i, alg_pos = 0, pattern_pos = 0, text_pos = 0; + for (i=edit_cigar->begin_offset;iend_offset;++i) { + switch (operations[i]) { + case 'M': + if (pattern[pattern_pos] != text[text_pos]) { + pattern_alg[alg_pos] = pattern[pattern_pos]; + ops_alg[alg_pos] = 'X'; + text_alg[alg_pos++] = text[text_pos]; + } else { + pattern_alg[alg_pos] = pattern[pattern_pos]; + ops_alg[alg_pos] = '|'; + text_alg[alg_pos++] = text[text_pos]; + } + pattern_pos++; text_pos++; + break; + case 'X': + if (pattern[pattern_pos] != text[text_pos]) { + pattern_alg[alg_pos] = pattern[pattern_pos++]; + ops_alg[alg_pos] = ' '; + text_alg[alg_pos++] = text[text_pos++]; + } else { + pattern_alg[alg_pos] = pattern[pattern_pos++]; + ops_alg[alg_pos] = 'X'; + text_alg[alg_pos++] = text[text_pos++]; + } + break; + case 'I': + pattern_alg[alg_pos] = '-'; + ops_alg[alg_pos] = ' '; + text_alg[alg_pos++] = text[text_pos++]; + break; + case 'D': + pattern_alg[alg_pos] = pattern[pattern_pos++]; + ops_alg[alg_pos] = ' '; + text_alg[alg_pos++] = '-'; + break; + default: + break; + } + } + i=0; + while (pattern_pos < pattern_length) { + pattern_alg[alg_pos+i] = pattern[pattern_pos++]; + ops_alg[alg_pos+i] = '?'; + ++i; + } + i=0; + while (text_pos < text_length) { + text_alg[alg_pos+i] = text[text_pos++]; + ops_alg[alg_pos+i] = '?'; + ++i; + } + // Print alignment pretty + fprintf(stream," PRETTY.ALIGNMENT\t"); + edit_cigar_print(stderr,edit_cigar); + fprintf(stream,"\n"); + fprintf(stream," PATTERN %s\n",pattern_alg); + fprintf(stream," %s\n",ops_alg); + fprintf(stream," TEXT %s\n",text_alg); + // Free + mm_allocator_free(mm_allocator,pattern_alg); + mm_allocator_free(mm_allocator,ops_alg); + mm_allocator_free(mm_allocator,text_alg); +} + + diff --git a/benchmarks/wfa/gap_affine/edit_cigar.h b/benchmarks/wfa/gap_affine/edit_cigar.h new file mode 100644 index 0000000..d607958 --- /dev/null +++ b/benchmarks/wfa/gap_affine/edit_cigar.h @@ -0,0 +1,109 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Edit cigar data-structure (match/mismatch/insertion/deletion) + */ + +#ifndef EDIT_CIGAR_H_ +#define EDIT_CIGAR_H_ + +#include "utils/commons.h" +#include "utils/mm_allocator.h" +#include "gap_affine/affine_wavefront_penalties.h" + +/* + * CIGAR + */ +typedef struct { + char* operations; + int max_operations; + int begin_offset; + int end_offset; + int score; +} edit_cigar_t; + +/* + * Distance metrics + */ +typedef enum { + edit, + gap_lineal, + gap_affine +} distance_metric_t; + +/* + * Setup + */ +void edit_cigar_allocate( + edit_cigar_t* const edit_cigar, + const int pattern_length, + const int text_length, + mm_allocator_t* const mm_allocator); +void edit_cigar_clear( + edit_cigar_t* const edit_cigar); +void edit_cigar_free( + edit_cigar_t* const edit_cigar, + mm_allocator_t* const mm_allocator); + +/* + * Score + */ +int edit_cigar_score_edit( + edit_cigar_t* const edit_cigar); +int edit_cigar_score_gap_affine( + edit_cigar_t* const edit_cigar, + affine_penalties_t* const penalties); + +/* + * Utils + */ +bool edit_cigar_check_alignment( + FILE* const stream, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + edit_cigar_t* const edit_cigar, + const bool verbose); + +/* + * Display + */ +void edit_cigar_print( + FILE* const stream, + edit_cigar_t* const edit_cigar); +void edit_cigar_print_pretty( + FILE* const stream, + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + edit_cigar_t* const edit_cigar, + mm_allocator_t* const mm_allocator); + +#endif /* EDIT_CIGAR_H_ */ diff --git a/benchmarks/wfa/tools/Makefile b/benchmarks/wfa/tools/Makefile new file mode 100644 index 0000000..6b87f61 --- /dev/null +++ b/benchmarks/wfa/tools/Makefile @@ -0,0 +1,40 @@ +############################################################################### +# Definitions +############################################################################### +FOLDER_ROOT:=.. +FOLDER_BUILD_PATH:=$(FOLDER_ROOT)/$(FOLDER_BUILD) +FOLDER_BIN_PATH:=$(FOLDER_ROOT)/$(FOLDER_BIN) + +############################################################################### +# Tools +############################################################################### +TOOLS=generate_dataset align_benchmark +TOOLS_SRC=$(addsuffix .c, $(TOOLS)) + +############################################################################### +# Profiling +############################################################################### + +# Intel VTune Profiler +VTUNE_ANALYSIS=0 + +ifneq ($(VTUNE_HOME),) +VTUNE_ANALYSIS=1 +INCLUDES+= -I${VTUNE_HOME}/include +LIBS+=-L${VTUNE_HOME}/lib64 -littnotify +endif + +############################################################################### +# Rules +############################################################################### +OBJS=$(FOLDER_BUILD_PATH)/*.o + +all: LIBS+=$(LD_FLAGS) +all: FLAGS=$(CC_FLAGS) +all: $(TOOLS) + +align_benchmark: $(FOLDER_BUILD_PATH)/*.o align_benchmark.c + $(CC) $(FLAGS) -I$(FOLDER_ROOT) $(INCLUDES) -DVTUNE_ANALYSIS=$(VTUNE_ANALYSIS) align_benchmark.c $(OBJS) -o $(FOLDER_BIN_PATH)/align_benchmark $(LIBS) + +generate_dataset: generate_dataset.c + $(CC) $(CC_FLAGS) generate_dataset.c -o $(FOLDER_BIN_PATH)/generate_dataset $(LD_FLAGS) diff --git a/benchmarks/wfa/tools/align_benchmark.c b/benchmarks/wfa/tools/align_benchmark.c new file mode 100644 index 0000000..107feb0 --- /dev/null +++ b/benchmarks/wfa/tools/align_benchmark.c @@ -0,0 +1,445 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Wavefront Alignment benchmarking tool + */ + +#if VTUNE_ANALYSIS + #include +#endif + +#include "utils/commons.h" +#include "gap_affine/affine_wavefront.h" +#include "gap_affine/affine_wavefront_align.h" + +#include "omp.h" + +/* + * Parameters + */ +#define MAX_SEQUENCE_LENGTH 100000 + +/* + * Generic parameters + */ +typedef struct { + // I/O + char *input; + char *output; + // Penalties + affine_penalties_t affine_penalties; + int min_wavefront_length; + int max_distance_threshold; + // Misc + int nthreads; + int progress; + bool verbose; +} benchmark_args; +benchmark_args parameters = { + // Input + .input=NULL, + .output=NULL, + // Penalties + .affine_penalties = { + .match = 0, + .mismatch = 4, + .gap_opening = 6, + .gap_extension = 2, + }, + .min_wavefront_length = -1, // 10, + .max_distance_threshold = -1, //50, + // Misc + .nthreads = 1, + .progress = 10000, + .verbose = false +}; + +/* + * Parsing Input + */ + +typedef struct { + int id; + char* pattern; + int pattern_length; + char* text; + int text_length; +} input_pair_sequences_t; +input_pair_sequences_t** parse_input_sequences( + FILE* const input_file, + int** const total_sequences) { + + *total_sequences = malloc(parameters.nthreads * sizeof((*total_sequences)[0])); + input_pair_sequences_t **input_buffers = malloc(parameters.nthreads * sizeof(input_pair_sequences_t*)); + + char more_seqs = 1; + int id = 0; + #pragma omp parallel num_threads(parameters.nthreads) + { + // Init allocation + int thread_id = omp_get_thread_num(); + int total_allocated = 10000; + input_pair_sequences_t* input_buffer = malloc(total_allocated * sizeof(input_pair_sequences_t)); + + // Input reading loop + int total_parsed = 0; + while(more_seqs) { + // Avoid that some threads exit the while loop if another thread enters the + // critical section before the threads check the while condition. + #pragma omp barrier + #pragma omp for schedule(static, 1) + for (size_t i = 0; i < omp_get_num_threads(); ++i) { + #pragma omp critical + { + char *foo_line = NULL; + char *line1 = NULL; + char *line2 = NULL; + + size_t foo_line_allocated = 0; + size_t line1_allocated = 0; + size_t line2_allocated = 0; + + // Read BSW score (unused). + getline(&foo_line, &foo_line_allocated, input_file); + free(foo_line); + + // Read queries + int line1_length = getline(&line1, &line1_allocated, input_file); + int line2_length = getline(&line2, &line2_allocated, input_file); + + if (line1_length == -1 || line2_length == -1) { + free(line1); + free(line2); + + // OMP: Implicit flush of more_seqs at the end of the for loop. + more_seqs = 0; + } + else { + // Process input + char* const pattern = line1; + const int pattern_length = line1_length - 1; + pattern[pattern_length] = '\0'; + + char* const text = line2; + const int text_length = line2_length - 1; + text[text_length] = '\0'; + + // Allocate + if (total_parsed + 1 >= total_allocated) { + total_allocated += 10000; + input_buffer = realloc(input_buffer, total_allocated * sizeof(input_pair_sequences_t)); + } + + // Copy Pattern + input_buffer[total_parsed].id = id; + ++id; // OMP: Implicit flush of id on entry and exit of the critical section. + input_buffer[total_parsed].pattern = malloc(pattern_length + 1); + strncpy(input_buffer[total_parsed].pattern, pattern, pattern_length + 1); + input_buffer[total_parsed].pattern_length = pattern_length; + + // Copy Text + input_buffer[total_parsed].text = malloc(text_length + 1); + strncpy(input_buffer[total_parsed].text, text, text_length + 1); + input_buffer[total_parsed].text_length = text_length; + + free(line1); + free(line2); + + ++total_parsed; + } + } + } + } + input_buffers[thread_id] = input_buffer; + (*total_sequences)[thread_id] = total_parsed; + } + + return input_buffers; +} +/* + * Generic Menu + */ +void usage() { + fprintf(stderr, + "USE: ./align_benchmark -i [-o ] \n" + " Options:: \n" + " [I/O] \n" + " --input|i \n" + " --output|o \n" + " [Penalties] \n" + " --affine-penalties|g M,X,O,E \n" + " [Specifics] \n" + " --minimum-wavefront-length \n" + " --maximum-difference-distance \n" + " [Misc] \n" + " --nthreads|t [default=1] \n" + " --progress|P \n" + " --verbose|v \n" + " --help|h \n"); +} +void parse_arguments(int argc,char** argv) { + struct option long_options[] = { + /* I/O */ + { "input", required_argument, 0, 'i' }, + { "output", required_argument, 0, 'o' }, + /* Penalties */ + { "affine-penalties", required_argument, 0, 'p' }, + /* Specifics */ + { "minimum-wavefront-length", required_argument, 0, 1000 }, + { "maximum-difference-distance", required_argument, 0, 1001 }, + /* Misc */ + { "nthreads", required_argument, 0, 't' }, + { "progress", required_argument, 0, 'P' }, + { "verbose", no_argument, 0, 'v' }, + { "help", no_argument, 0, 'h' }, + { 0, 0, 0, 0 } }; + int c,option_index; + if (argc <= 1) { + usage(); + exit(0); + } + while (1) { + c=getopt_long(argc,argv,"i:o:p:t:P:vh",long_options,&option_index); + if (c==-1) break; + switch (c) { + /* + * I/O + */ + case 'i': + parameters.input = optarg; + break; + case 'o': + parameters.output = optarg; + break; + /* + * Penalties + */ + case 'p': { // --affine-penalties + char* sentinel = strtok(optarg,","); + parameters.affine_penalties.match = atoi(sentinel); + sentinel = strtok(NULL,","); + parameters.affine_penalties.mismatch = atoi(sentinel); + sentinel = strtok(NULL,","); + parameters.affine_penalties.gap_opening = atoi(sentinel); + sentinel = strtok(NULL,","); + parameters.affine_penalties.gap_extension = atoi(sentinel); + break; + } + /* + * Specific parameters + */ + case 1000: // --minimum-wavefront-length + parameters.min_wavefront_length = atoi(optarg); + break; + case 1001: // --maximum-difference-distance + parameters.max_distance_threshold = atoi(optarg); + break; + /* + * Misc + */ + case 't': + parameters.nthreads = atoi(optarg); + break; + case 'P': + parameters.progress = atoi(optarg); + break; + case 'v': + parameters.verbose = true; + break; + case 'h': + usage(); + exit(1); + // Other + case '?': default: + fprintf(stderr,"Option not recognized \n"); + exit(1); + } + } + // Checks + if (parameters.input == NULL) { + fprintf(stderr,"Option --input is required \n"); + exit(1); + } +} +int main(int argc,char* argv[]) { +#if VTUNE_ANALYSIS + __itt_pause(); +#endif + // Parsing command-line options + parse_arguments(argc,argv); + // Parameters + FILE *input_file = NULL; + FILE *output_file = NULL; + // Init I/O files + input_file = fopen(parameters.input, "r"); + if (input_file == NULL) { + fprintf(stderr,"Input file '%s' couldn't be opened\n",parameters.input); + exit(1); + } + if (parameters.output != NULL) { + output_file = fopen(parameters.output, "w"); + if (output_file == NULL) { + fprintf(stderr,"Output file '%s' couldn't be opened\n",parameters.output); + exit(1); + } + } + + struct timeval benchmark_start; + gettimeofday(&benchmark_start, NULL); + + // Parse input file + int *total_sequences; + input_pair_sequences_t** const input_buffers = + parse_input_sequences(input_file, &total_sequences); + + struct timeval alignment_start; + struct timeval alignment_end; + + int progress_mod = 0; + #pragma omp parallel num_threads(parameters.nthreads) + { + // Init MM-allocator + mm_allocator_t* const mm_allocator = mm_allocator_new(BUFFER_SIZE_8M); + + // Init wavefront + affine_wavefronts_t* affine_wavefronts; + if (parameters.min_wavefront_length < 0) { + affine_wavefronts = affine_wavefronts_new_complete( + MAX_SEQUENCE_LENGTH,MAX_SEQUENCE_LENGTH, + ¶meters.affine_penalties,mm_allocator); + } else { + affine_wavefronts = affine_wavefronts_new_reduced( + MAX_SEQUENCE_LENGTH,MAX_SEQUENCE_LENGTH, + ¶meters.affine_penalties,parameters.min_wavefront_length, + parameters.max_distance_threshold,mm_allocator); + } + + #pragma omp barrier + #pragma omp master + { +#if VTUNE_ANALYSIS + __itt_resume(); +#endif + gettimeofday(&alignment_start, NULL); + } + + // Pointer to thread private data. + int thread_id = omp_get_thread_num(); + input_pair_sequences_t *input_buffer = input_buffers[thread_id]; + int thread_total_sequences = total_sequences[thread_id]; + + edit_cigar_t *edit_cigars = malloc(thread_total_sequences * sizeof(edit_cigars[0])); + + // Read-align loop + for (int i = 0; i < thread_total_sequences; ++i) { + // Align + affine_wavefronts_clear(affine_wavefronts); + affine_wavefronts_align(affine_wavefronts, + input_buffer[i].pattern, input_buffer[i].pattern_length, + input_buffer[i].text, input_buffer[i].text_length); + + // Store output + // Alloc new operations array. + int noperations = + affine_wavefronts->edit_cigar.end_offset - + affine_wavefronts->edit_cigar.begin_offset; + edit_cigars[i].operations = mm_allocator_malloc(mm_allocator, noperations); + // Copy operations. + memcpy(edit_cigars[i].operations, + &affine_wavefronts->edit_cigar.operations[affine_wavefronts->edit_cigar.begin_offset], + noperations * sizeof(edit_cigars->operations[0])); + + // Copy the rest of the fields. + edit_cigars[i].max_operations = affine_wavefronts->edit_cigar.max_operations; + edit_cigars[i].score = affine_wavefronts->edit_cigar.score; + edit_cigars[i].begin_offset = 0; + edit_cigars[i].end_offset = noperations; + + // Update progress + if (parameters.verbose) { + #pragma omp critical + { + ++progress_mod; + if (progress_mod % parameters.progress == 0) { + fprintf(stderr,"...processed %d reads \n", progress_mod); + } + } + } + } + + #pragma omp barrier + #pragma omp master + { + gettimeofday(&alignment_end, NULL); +#if VTUNE_ANALYSIS + __itt_pause(); +#endif + } + + // Print the output. + #pragma omp critical + { + if (!parameters.verbose) { + progress_mod += thread_total_sequences; + } + if (output_file != NULL) { + for (int i = 0; i < thread_total_sequences; ++i) { + fprintf(output_file, "id=%d ", input_buffer[i].id); + edit_cigar_print(output_file, &edit_cigars[i]); + fprintf(output_file, "\n"); + } + } + } + + // Free + for (int i = 0; i < thread_total_sequences; ++i) { + mm_allocator_free(mm_allocator, edit_cigars[i].operations); + free(input_buffer[i].pattern); + free(input_buffer[i].text); + } + free(edit_cigars); + + free(input_buffer); + + affine_wavefronts_delete(affine_wavefronts); + mm_allocator_delete(mm_allocator); + } + free(total_sequences); + free(input_buffers); + fclose(input_file); + if (output_file != NULL) fclose(output_file); + + struct timeval benchmark_end; + gettimeofday(&benchmark_end, NULL); + + printf("Total.reads: %d\n", progress_mod); + printf("Time.Benchmark: %f s\n", (benchmark_end.tv_sec - benchmark_start.tv_sec) + + (benchmark_end.tv_usec - benchmark_start.tv_usec) * 1E-6); + printf("Time.Alignment: %f s\n", (alignment_end.tv_sec - alignment_start.tv_sec) + + (alignment_end.tv_usec - alignment_start.tv_usec) * 1E-6); +} diff --git a/benchmarks/wfa/tools/generate_dataset.c b/benchmarks/wfa/tools/generate_dataset.c new file mode 100644 index 0000000..3688fc6 --- /dev/null +++ b/benchmarks/wfa/tools/generate_dataset.c @@ -0,0 +1,245 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Sequence Generator for benchmarking pairwise algorithms + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +/* + * DNA Alphabet + */ +//#define ALPHABET_SIZE 26 +//char alphabet[] = { +// 'A','B','C','D','E', +// 'F','G','H','I','J', +// 'K','L','M','N','O', +// 'P','Q','R','S','T', +// 'U','V','W','X','Y', +// 'Z' +//}; +#define ALPHABET_SIZE 4 +char alphabet[] = { + 'A','C','G','T' +}; + +/* + * Random number generator + */ +uint64_t rand_iid(const uint64_t min,const uint64_t max) { + int n_rand = rand(); // [0, RAND_MAX] + const uint64_t range = max - min; + const uint64_t rem = RAND_MAX % range; + const uint64_t sample = RAND_MAX / range; + // Consider the small interval within remainder of RAND_MAX + if (n_rand < RAND_MAX - rem) { + return min + n_rand/sample; + } else { + return rand_iid(min,max); + } +} +/* + * Generate pattern + */ +void generate_pattern( + char* const pattern, + const uint64_t length) { + // Generate random characters + uint64_t i; + for (i=0;iposition;--i) { + candidate_text[i] = candidate_text[i-1]; + } + *candidate_length = new_candidate_length; + // Insert random character + candidate_text[position] = alphabet[rand_iid(0,ALPHABET_SIZE)]; +} +char* generate_candidate_text( + char* const pattern, + const uint64_t pattern_length, + const float error_degree) { + // Compute nominal number of errors + const uint64_t num_errors = ceil(pattern_length * error_degree); + // Allocate & init-by-copy candidate text + char* const candidate_text = malloc(pattern_length+num_errors); + uint64_t candidate_length = pattern_length; + memcpy(candidate_text,pattern,pattern_length); + // Generate random errors + int i; + for (i=0;i\n" + " --num-patterns|n \n" + " --length|l \n" + " --error|e \n" + " --help|h\n"); +} +void parse_arguments(int argc,char** argv) { + struct option long_options[] = { + { "num-patterns", required_argument, 0, 'n' }, + { "output", required_argument, 0, 'o' }, + { "length", required_argument, 0, 'l' }, + { "error", required_argument, 0, 'e' }, + { "help", no_argument, 0, 'h' }, + { 0, 0, 0, 0 } }; + int c,option_index; + if (argc <= 1) { + usage(); + exit(0); + } + while (1) { + c=getopt_long(argc,argv,"n:o:l:e:h",long_options,&option_index); + if (c==-1) break; + switch (c) { + case 'n': + parameters.num_reads = atoi(optarg); + break; + case 'o': + parameters.output = optarg; + break; + case 'l': + parameters.length = atoi(optarg); + break; + case 'e': + parameters.error_degree = atof(optarg); + break; + case 'h': + usage(); + exit(1); + case '?': default: + fprintf(stderr, "Option not recognized \n"); exit(1); + } + } +} +int main(int argc,char* argv[]) { + // Parsing command-line options + parse_arguments(argc,argv); + // Parameters + FILE *output_file; + // Open files + output_file = fopen(parameters.output,"w"); + // Allocate + char* const pattern = malloc(parameters.length+1); + const int pattern_length = parameters.length; + // Read-align loop + srand(time(0)); + int i; + for (i=0;i%s\n",pattern); + // Generate candidate-text + char* const candidate_text = generate_candidate_text( + pattern,pattern_length,parameters.error_degree); + // Print candidate-text + fprintf(output_file,"<%s\n",candidate_text); + free(candidate_text); + } + // Close files & free + fclose(output_file); + free(pattern); +} diff --git a/benchmarks/wfa/utils/Makefile b/benchmarks/wfa/utils/Makefile new file mode 100644 index 0000000..dd72a8c --- /dev/null +++ b/benchmarks/wfa/utils/Makefile @@ -0,0 +1,26 @@ +############################################################################### +# Definitions +############################################################################### +FOLDER_ROOT:=.. +FOLDER_BUILD_PATH:=$(FOLDER_ROOT)/$(FOLDER_BUILD) + +############################################################################### +# Modules +############################################################################### +MODULES=commons \ + mm_allocator \ + mm_stack \ + string_padded \ + vector + +SRCS=$(addsuffix .c, $(MODULES)) +OBJS=$(addprefix $(FOLDER_BUILD_PATH)/, $(SRCS:.c=.o)) + +############################################################################### +# Rules +############################################################################### +all: $(OBJS) + +# General building rule +$(FOLDER_BUILD_PATH)/%.o : %.c + $(CC) $(CC_FLAGS) -I$(FOLDER_ROOT) -c $< -o $@ diff --git a/benchmarks/wfa/utils/commons.c b/benchmarks/wfa/utils/commons.c new file mode 100644 index 0000000..832815d --- /dev/null +++ b/benchmarks/wfa/utils/commons.c @@ -0,0 +1,48 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Common functions/utilities and headers for C development + */ + +#include "commons.h" + +/* + * Random number generator + */ +uint64_t rand_iid(const uint64_t min,const uint64_t max) { + int n_rand = rand(); // [0, RAND_MAX] + const uint64_t range = max - min; + const uint64_t rem = RAND_MAX % range; + const uint64_t sample = RAND_MAX / range; + // Consider the small interval within remainder of RAND_MAX + if (n_rand < RAND_MAX - rem) { + return min + n_rand/sample; + } else { + return rand_iid(min,max); + } +} diff --git a/benchmarks/wfa/utils/commons.h b/benchmarks/wfa/utils/commons.h new file mode 100644 index 0000000..11293eb --- /dev/null +++ b/benchmarks/wfa/utils/commons.h @@ -0,0 +1,238 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Common functions/utilities and headers for C development + */ + +#ifndef COMMONS_H_ +#define COMMONS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +/* + * Macro Utils (Stringify) + */ +#define QUOTE(value) #value +#define SWAP(a,b) do {__typeof__(a) aux = a; a = b; b = aux;} while (0) + +/* + * Special Characters + */ +#define EOS '\0' +#define EOL '\n' +#define TAB '\t' +#define DOS_EOL '\r' +#define PLUS '+' +#define MINUS '-' +#define FORMAT '%' +#define SPACE ' ' +#define SLASH '/' +#define STAR '*' +#define DOT '.' +//#define EQUAL '=' +#define COMA ',' +#define SEMICOLON ';' +#define COLON ':' +#define HASH '#' +#define UNDERSCORE '_' + +/* + * Metric Factors + */ +#define METRIC_FACTOR_1K (1000ul) +#define METRIC_FACTOR_1M (1000000ul) +#define METRIC_FACTOR_1G (1000000000ul) + +/* + * Number of lines + */ +#define NUM_LINES_1K (1000ul) +#define NUM_LINES_2K (2000ul) +#define NUM_LINES_5K (5000ul) +#define NUM_LINES_10K (10000ul) +#define NUM_LINES_20K (20000ul) +#define NUM_LINES_50K (50000ul) +#define NUM_LINES_100K (100000ul) +#define NUM_LINES_200K (200000ul) +#define NUM_LINES_500K (500000ul) +#define NUM_LINES_1M (1000000ul) +#define NUM_LINES_2M (2000000ul) +#define NUM_LINES_5M (5000000ul) +#define NUM_LINES_10M (10000000ul) +#define NUM_LINES_20M (20000000ul) +#define NUM_LINES_50M (50000000ul) + +/* + * Buffer sizes + */ +#define BUFFER_SIZE_1K (1ul<<10) +#define BUFFER_SIZE_2K (1ul<<11) +#define BUFFER_SIZE_4K (1ul<<12) +#define BUFFER_SIZE_8K (1ul<<13) +#define BUFFER_SIZE_16K (1ul<<14) +#define BUFFER_SIZE_32K (1ul<<15) +#define BUFFER_SIZE_64K (1ul<<16) +#define BUFFER_SIZE_128K (1ul<<17) +#define BUFFER_SIZE_256K (1ul<<18) +#define BUFFER_SIZE_512K (1ul<<19) +#define BUFFER_SIZE_1M (1ul<<20) +#define BUFFER_SIZE_2M (1ul<<21) +#define BUFFER_SIZE_4M (1ul<<22) +#define BUFFER_SIZE_8M (1ul<<23) +#define BUFFER_SIZE_16M (1ul<<24) +#define BUFFER_SIZE_32M (1ul<<25) +#define BUFFER_SIZE_64M (1ul<<26) +#define BUFFER_SIZE_128M (1ul<<27) +#define BUFFER_SIZE_256M (1ul<<28) +#define BUFFER_SIZE_512M (1ul<<29) +#define BUFFER_SIZE_1G (1ul<<30) +#define BUFFER_SIZE_2G (1ul<<31) +#define BUFFER_SIZE_4G (1ul<<32) +#define BUFFER_SIZE_8G (1ul<<33) +#define BUFFER_SIZE_16G (1ul<<34) +#define BUFFER_SIZE_32G (1ul<<35) +#define BUFFER_SIZE_64G (1ul<<36) +#define BUFFER_SIZE_128G (1ul<<37) +#define BUFFER_SIZE_256G (1ul<<38) +// Conversion utils +#define CONVERT_B_TO_KB(number) ((number)/(1024)) +#define CONVERT_B_TO_MB(number) ((number)/(1024*1024)) +#define CONVERT_B_TO_GB(number) ((number)/(1024*1024*1024)) + +/* + * BM sizes + */ +#define UINT512_LENGTH 512 +#define UINT512_SIZE 64 +#define UINT256_LENGTH 256 +#define UINT256_SIZE 32 +#define UINT128_LENGTH 128 +#define UINT128_SIZE 16 +#define UINT64_LENGTH 64 +#define UINT64_SIZE 8 +#define UINT32_LENGTH 32 +#define UINT32_SIZE 4 +#define UINT16_LENGTH 16 +#define UINT16_SIZE 2 +#define UINT8_LENGTH 8 +#define UINT8_SIZE 1 + +/* + * Common Masks + */ +#define UINT64_ZEROS 0x0000000000000000ull +#define UINT64_ONES 0xFFFFFFFFFFFFFFFFull +#define UINT32_ZEROS 0x00000000ul +#define UINT32_ONES 0xFFFFFFFFul +// Extraction masks +#define UINT64_ONE_MASK 0x0000000000000001ull +#define UINT64_ZERO_MASK 0xFFFFFFFFFFFFFFFEull +#define UINT64_ONE_LAST_MASK 0x8000000000000000ull +#define UINT64_ZERO_LAST_MASK 0x7FFFFFFFFFFFFFFFull +#define UINT32_ONE_MASK 0x00000001ul +#define UINT32_ZERO_MASK 0xFFFFFFFEul +#define UINT32_ONE_LAST_MASK 0x80000000ul +#define UINT32_ZERO_LAST_MASK 0x7FFFFFFFul + +/* + * Common numerical data processing/formating + */ +#define MIN(a,b) (((a)<=(b))?(a):(b)) +#define MAX(a,b) (((a)>=(b))?(a):(b)) +#define ABS(a) (((a)>=0)?(a):-(a)) + +/* + * Pseudo-Random number generator + */ +#define rand_init() srand(time(0)) +#define rand_i(min,max) ( min + ( rand()%(max-min+1) ) ) +#define rand_f(min,max) ( min + ((double)rand()/(double)(RAND_MAX+1)) * (max-min+1) ) +uint64_t rand_iid(const uint64_t min,const uint64_t max); + +/* + * Parsing + */ +#define IS_NUMBER(character) ('0' <= (character) && (character) <= '9') +#define IS_DIGIT(character) IS_NUMBER(character) +#define IS_LETTER(character) (('a' <= (character) && (character) <= 'z') || ('A' <= (character) && (character) <= 'Z')) +#define IS_ALPHANUMERIC(character) (IS_NUMBER(character) || IS_LETTER(character)) +#define IS_BETWEEN(number,a,b) ((a)<=(number) && (number)<=(b)) + +#define IS_EOL(character) ((character)==EOL) +#define IS_ANY_EOL(character) ((character)==EOL || (character)==DOS_EOL) +#define IS_HEX_DIGIT(character) (IS_NUMBER(character) || ('a' <= (character) && (character) <= 'f') || ('A' <= (character) && (character) <= 'F')) + +#define IS_END_OF_RECORD(character) ( (character)==EOL || (character)==EOS ) +#define IS_END_OF_FIELD(character) ( IS_END_OF_RECORD(character) || (character)==SPACE || (character)==TAB ) + +#define GET_DIGIT(character) ((character) - '0') +#define GET_HEX_DIGIT(character) (IS_NUMBER(character) ? GET_DIGIT(character) : (toupper(character) - 'A' + 10)) + +/* + * Math + */ +#define BOUNDED_SUBTRACTION(minuend,subtrahend,limit) (((minuend)>((limit)+(subtrahend))) ? (minuend)-(subtrahend):(limit)) +#define BOUNDED_ADDITION(summand_A,summand_B,limit) ((((summand_A)+(summand_B))<(limit)) ? (summand_A)+(summand_B):(limit)) + +#define PERCENTAGE(AMOUNT,TOTAL) ((TOTAL)?100.0*(float)(AMOUNT)/(float)(TOTAL):0.0) +#define DIV_FLOOR(NUMERATOR,DENOMINATOR) ((NUMERATOR)/(DENOMINATOR)) +#define DIV_CEIL(NUMERATOR,DENOMINATOR) (((NUMERATOR)+((DENOMINATOR)-1))/(DENOMINATOR)) +#define DIVC_FLOOR(NUMERATOR,DENOMINATOR) ((DENOMINATOR) ? DIV_FLOOR(NUMERATOR,DENOMINATOR) :(0)) +#define DIVC_CEIL(NUMERATOR,DENOMINATOR) ((DENOMINATOR) ? DIV_CEIL(NUMERATOR,DENOMINATOR) :(0)) + +#endif /* COMMONS_H_ */ diff --git a/benchmarks/wfa/utils/mm_allocator.c b/benchmarks/wfa/utils/mm_allocator.c new file mode 100644 index 0000000..c765b5d --- /dev/null +++ b/benchmarks/wfa/utils/mm_allocator.c @@ -0,0 +1,571 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Simple managed-memory allocator that reduces the overhead + * of using malloc/calloc/free functions by allocating slabs of memory + * and dispatching memory segments in order. + */ + +#include "mm_allocator.h" + +/* + * Debug + */ +//#define MM_ALLOCATOR_FORCE_MALLOC + +/* + * Constants + */ +#define MM_ALLOCATOR_SEGMENT_INITIAL_REQUESTS 10000 +#define MM_ALLOCATOR_INITIAL_SEGMENTS 10 +#define MM_ALLOCATOR_INITIAL_MALLOC_REQUESTS 10 +#define MM_ALLOCATOR_INITIAL_STATES 10 + +/* + * Allocator Segments Freed Cond + */ +#define MM_ALLOCATOR_FREED_FLAG 0x80000000ul +#define MM_ALLOCATOR_REQUEST_IS_FREE(request) ((request)->size & MM_ALLOCATOR_FREED_FLAG) +#define MM_ALLOCATOR_REQUEST_SET_FREE(request) ((request)->size |= MM_ALLOCATOR_FREED_FLAG) +#define MM_ALLOCATOR_REQUEST_SIZE(request) ((request)->size & ~(MM_ALLOCATOR_FREED_FLAG)) + +/* + * Memory Request + */ +typedef struct { + // Request + uint32_t offset; + uint32_t size; + // Log +#ifdef MM_ALLOCATOR_LOG + uint64_t timestamp; + char* func_name; + uint64_t line_no; +#endif +} mm_allocator_request_t; +typedef struct { + // Request + void* mem; + uint64_t size; + // Log +#ifdef MM_ALLOCATOR_LOG + uint64_t timestamp; + char* func_name; + uint64_t line_no; +#endif +} mm_malloc_request_t; +/* + * Memory Segments + */ +typedef struct { + // Index (ID) + uint64_t idx; // Index in the segments vector + // Memory + uint64_t size; // Total memory available + void* memory; // Memory + uint64_t used; // Bytes used (offset to memory next free byte) + // Requests + vector_t* requests; // Memory requests (mm_allocator_request_t) +} mm_allocator_segment_t; +/* + * Reference (Header of every memory allocated) + */ +typedef struct { + uint32_t segment_idx; + uint32_t request_idx; +} mm_allocator_reference_t; + +/* + * Segments + */ +mm_allocator_segment_t* mm_allocator_segment_new( + mm_allocator_t* const mm_allocator) { + // Allocate handler + mm_allocator_segment_t* const segment = malloc(sizeof(mm_allocator_segment_t)); + // Index + const uint64_t segment_idx = vector_get_used(mm_allocator->segments); + segment->idx = segment_idx; + // Memory + segment->size = mm_allocator->segment_size; + segment->memory = malloc(mm_allocator->segment_size); + segment->used = 0; + // Requests + segment->requests = vector_new(MM_ALLOCATOR_SEGMENT_INITIAL_REQUESTS,mm_allocator_request_t); + // Add to segments + vector_insert(mm_allocator->segments,segment,mm_allocator_segment_t*); + // Return + return segment; +} +void mm_allocator_segment_clear( + mm_allocator_segment_t* const segment) { + segment->used = 0; + vector_clear(segment->requests); +} +void mm_allocator_segment_delete( + mm_allocator_segment_t* const segment) { + vector_delete(segment->requests); + free(segment->memory); + free(segment); +} +mm_allocator_request_t* mm_allocator_segment_get_request( + mm_allocator_segment_t* const segment, + const uint64_t request_idx) { + return vector_get_elm(segment->requests,request_idx,mm_allocator_request_t); +} +uint64_t mm_allocator_segment_get_num_requests( + mm_allocator_segment_t* const segment) { + return vector_get_used(segment->requests); +} +/* + * Setup + */ +mm_allocator_t* mm_allocator_new( + const uint64_t segment_size) { + // Allocate handler + mm_allocator_t* const mm_allocator = malloc(sizeof(mm_allocator_t)); + mm_allocator->request_ticker = 0; + // Segments + mm_allocator->segment_size = segment_size; + mm_allocator->segments = vector_new(MM_ALLOCATOR_INITIAL_SEGMENTS,mm_allocator_segment_t*); + mm_allocator->segments_free = vector_new(MM_ALLOCATOR_INITIAL_SEGMENTS,mm_allocator_segment_t*); + // Allocate an initial segment + mm_allocator_segment_new(mm_allocator); + mm_allocator->current_segment_idx = 0; + // Malloc Memory + mm_allocator->malloc_requests = vector_new(MM_ALLOCATOR_INITIAL_MALLOC_REQUESTS,mm_malloc_request_t); + mm_allocator->malloc_requests_freed = 0; + // Return + return mm_allocator; +} +void mm_allocator_clear( + mm_allocator_t* const mm_allocator) { + // Clear segments + vector_clear(mm_allocator->segments_free); + VECTOR_ITERATE(mm_allocator->segments,segment_ptr,p,mm_allocator_segment_t*) { + mm_allocator_segment_clear(*segment_ptr); // Clear segment + vector_insert(mm_allocator->segments_free,*segment_ptr,mm_allocator_segment_t*); // Add to free segments + } + mm_allocator->current_segment_idx = 0; + // Clear malloc memory + VECTOR_ITERATE(mm_allocator->malloc_requests,malloc_request,m,mm_malloc_request_t) { + if (malloc_request->size > 0) free(malloc_request->mem); // Free malloc requests + } + vector_clear(mm_allocator->malloc_requests); + mm_allocator->malloc_requests_freed = 0; +} +void mm_allocator_delete( + mm_allocator_t* const mm_allocator) { + // Free segments + VECTOR_ITERATE(mm_allocator->segments,segment_ptr,p,mm_allocator_segment_t*) { + mm_allocator_segment_delete(*segment_ptr); + } + vector_delete(mm_allocator->segments); + vector_delete(mm_allocator->segments_free); + // Free malloc memory + VECTOR_ITERATE(mm_allocator->malloc_requests,malloc_request,m,mm_malloc_request_t) { + if (malloc_request->size > 0) free(malloc_request->mem); // Free malloc requests + } + vector_delete(mm_allocator->malloc_requests); + // Free handler + free(mm_allocator); +} +/* + * Accessors + */ +mm_allocator_segment_t* mm_allocator_get_segment( + mm_allocator_t* const mm_allocator, + const uint64_t segment_idx) { + return *(vector_get_elm(mm_allocator->segments,segment_idx,mm_allocator_segment_t*)); +} +uint64_t mm_allocator_get_num_segments( + mm_allocator_t* const mm_allocator) { + return vector_get_used(mm_allocator->segments); +} +/* + * Allocate + */ +mm_allocator_segment_t* mm_allocator_fetch_segment( + mm_allocator_t* const mm_allocator, + const uint64_t num_bytes) { + // Fetch current segment + mm_allocator_segment_t* const curr_segment = + mm_allocator_get_segment(mm_allocator,mm_allocator->current_segment_idx); + // Check available segment size + if (curr_segment->used + num_bytes <= curr_segment->size) { + return curr_segment; + } + // Check overall segment size + if (num_bytes > curr_segment->size) { + return NULL; // Memory request over segment size + } + // Get free segment + const uint64_t free_segments = vector_get_used(mm_allocator->segments_free); + if (free_segments > 0) { + mm_allocator_segment_t* const segment = + mm_allocator_get_segment(mm_allocator,free_segments-1); + vector_dec_used(mm_allocator->segments_free); + mm_allocator->current_segment_idx = segment->idx; + return segment; + } + // Allocate new segment + mm_allocator_segment_t* const segment = mm_allocator_segment_new(mm_allocator); + mm_allocator->current_segment_idx = segment->idx; + return segment; +} +void* mm_allocator_allocate( + mm_allocator_t* const mm_allocator, + const uint64_t num_bytes, + const bool zero_mem, + const uint64_t align_bytes +#ifdef MM_ALLOCATOR_LOG + ,const char* func_name, + uint64_t line_no +#endif + ) { + // Zero check + if (num_bytes == 0) { + fprintf(stderr,"MMAllocator error. Zero bytes requested\n"); + exit(1); + } + // Add payload + const uint64_t num_bytes_allocated = num_bytes + sizeof(mm_allocator_reference_t) + align_bytes; + // Fetch segment +#ifdef MM_ALLOCATOR_FORCE_MALLOC + mm_allocator_segment_t* const segment = NULL; // Force malloc memory +#else + mm_allocator_segment_t* const segment = mm_allocator_fetch_segment(mm_allocator,num_bytes_allocated); +#endif + if (segment != NULL) { + // Allocate memory + void* const memory_base = segment->memory + segment->used; + if (zero_mem) memset(memory_base,0,num_bytes_allocated); // Set zero + // Compute aligned memory + void* memory_aligned = memory_base + sizeof(mm_allocator_reference_t) + align_bytes; + if (align_bytes > 0) { + memory_aligned = memory_aligned - ((uintptr_t)memory_aligned % align_bytes); + } + // Set mm_reference + mm_allocator_reference_t* const mm_reference = memory_aligned - sizeof(mm_allocator_reference_t); + mm_reference->segment_idx = segment->idx; + mm_reference->request_idx = mm_allocator_segment_get_num_requests(segment); + // Add request + mm_allocator_request_t* request; + vector_alloc_new(segment->requests,mm_allocator_request_t,request); + request->offset = segment->used; + request->size = num_bytes_allocated; +#ifdef MM_ALLOCATOR_LOG + request->timestamp = (mm_allocator->request_ticker)++; + request->func_name = (char*)func_name; + request->line_no = line_no; +#endif + // Update segment + segment->used += num_bytes_allocated; + // Return memory + return memory_aligned; + } else { + // Malloc memory + void* const memory_base = malloc(num_bytes_allocated); + if (zero_mem) memset(memory_base,0,num_bytes_allocated); // Set zero + // Compute aligned memory + void* memory_aligned = memory_base + sizeof(mm_allocator_reference_t) + align_bytes; + if (align_bytes > 0) { + memory_aligned = memory_aligned - ((uintptr_t)memory_aligned % align_bytes); + } + // Set reference + mm_allocator_reference_t* const mm_reference = memory_aligned - sizeof(mm_allocator_reference_t); + mm_reference->segment_idx = UINT32_MAX; + mm_reference->request_idx = vector_get_used(mm_allocator->malloc_requests); + // Add malloc-request + mm_malloc_request_t* request; + vector_alloc_new(mm_allocator->malloc_requests,mm_malloc_request_t,request); + request->mem = memory_base; + request->size = num_bytes_allocated; +#ifdef MM_ALLOCATOR_LOG + request->timestamp = (mm_allocator->request_ticker)++; + request->func_name = (char*)func_name; + request->line_no = line_no; +#endif + // Return memory + return memory_aligned; + } +} +/* + * Allocator Free + */ +void mm_allocator_free_malloc_request( + mm_allocator_t* const mm_allocator, + mm_allocator_reference_t* const mm_reference) { + // Fetch request + mm_malloc_request_t* const request = + vector_get_elm(mm_allocator->malloc_requests,mm_reference->request_idx,mm_malloc_request_t); + // Check double-free + if (request->size == 0) { + fprintf(stderr,"MMAllocator error: double free\n"); + exit(1); + } + // Free request + request->size = 0; + free(request->mem); + ++(mm_allocator->malloc_requests_freed); + // Check number of freed requests + if (mm_allocator->malloc_requests_freed >= 1000) { + // Remove freed requests + const uint64_t num_requests = vector_get_used(mm_allocator->malloc_requests); + mm_malloc_request_t* const requests = vector_get_mem(mm_allocator->malloc_requests,mm_malloc_request_t); + uint64_t i, busy_requests = 0; + for (i=0;i 0) { + requests[busy_requests] = requests[i]; + ++busy_requests; + } + } + vector_set_used(mm_allocator->malloc_requests,busy_requests); + } +} +void mm_allocator_free_allocator_request( + mm_allocator_t* const mm_allocator, + mm_allocator_reference_t* const mm_reference) { + // Fetch segment and request + mm_allocator_segment_t* const segment = + mm_allocator_get_segment(mm_allocator,mm_reference->segment_idx); + mm_allocator_request_t* const request = + mm_allocator_segment_get_request(segment,mm_reference->request_idx); + // Check double-free + if (MM_ALLOCATOR_REQUEST_IS_FREE(request)) { + fprintf(stderr,"MMAllocator error: double free\n"); + exit(1); + } + // Free request + MM_ALLOCATOR_REQUEST_SET_FREE(request); + // Free contiguous request(s) at the end of the segment + uint64_t num_requests = mm_allocator_segment_get_num_requests(segment); + if (mm_reference->request_idx == num_requests-1) { // Is the last request? + --num_requests; + mm_allocator_request_t* request = + vector_get_mem(segment->requests,mm_allocator_request_t) + (num_requests-1); + while (num_requests>0 && MM_ALLOCATOR_REQUEST_IS_FREE(request)) { + --num_requests; // Free request + --request; + } + // Update segment used + if (num_requests > 0) { + segment->used = request->offset + request->size; + vector_set_used(segment->requests,num_requests); + } else { + // Segment fully freed + mm_allocator_segment_clear(segment); // Clear + // Add to free segments (if it is not the current segment) + if (segment->idx != mm_allocator->current_segment_idx) { + vector_insert(mm_allocator->segments_free,segment,mm_allocator_segment_t*); + } + } + } +} +void mm_allocator_free( + mm_allocator_t* const mm_allocator, + void* const memory) { + // Get reference + void* const effective_memory = memory - sizeof(mm_allocator_reference_t); + mm_allocator_reference_t* const mm_reference = effective_memory; + if (mm_reference->segment_idx == UINT32_MAX) { + // Free malloc memory + mm_allocator_free_malloc_request(mm_allocator,mm_reference); + } else { + // Free allocator memory + mm_allocator_free_allocator_request(mm_allocator,mm_reference); + } +} +/* + * Utils + */ +void mm_allocator_get_occupation( + mm_allocator_t* const mm_allocator, + uint64_t* const bytes_used_malloc, + uint64_t* const bytes_used_allocator, + uint64_t* const bytes_free_available, + uint64_t* const bytes_free_fragmented) { + // Init + *bytes_used_malloc = 0; + *bytes_used_allocator = 0; + *bytes_free_available = 0; + *bytes_free_fragmented = 0; + // Check allocator memory + const uint64_t num_segments = mm_allocator_get_num_segments(mm_allocator); + int64_t segment_idx, request_idx; + for (segment_idx=0;segment_idx=0;--request_idx) { + mm_allocator_request_t* const request = mm_allocator_segment_get_request(segment,request_idx); + const uint64_t size = MM_ALLOCATOR_REQUEST_SIZE(request); + if (MM_ALLOCATOR_REQUEST_IS_FREE(request)) { + if (free_memory) { + *bytes_free_available += size; + } else { + *bytes_free_fragmented += size; + } + } else { + free_memory = false; + *bytes_used_allocator += size; + } + } + // Account for free space at the end of the segment + if (num_requests > 0) { + mm_allocator_request_t* const request = mm_allocator_segment_get_request(segment,num_requests-1); + *bytes_free_available += segment->used - (request->offset+request->size); + } + } + // Check malloc memory + const uint64_t num_requests = vector_get_used(mm_allocator->malloc_requests); + mm_malloc_request_t* const requests = vector_get_mem(mm_allocator->malloc_requests,mm_malloc_request_t); + uint64_t i; + for (i=0;ioffset, + (uint64_t)MM_ALLOCATOR_REQUEST_SIZE(request) +#ifdef MM_ALLOCATOR_LOG + ,request->func_name, + request->line_no, + request->timestamp +#endif + ); +} +void mm_allocator_print_malloc_request( + FILE* const stream, + mm_malloc_request_t* const request) { + fprintf(stream," [@%p" PRIu64 "\t(%" PRIu64 " Bytes)" +#ifdef MM_ALLOCATOR_LOG + "\t%s:%" PRIu64 "\t{ts=%" PRIu64 "}" +#endif + "\n", + request->mem, + request->size +#ifdef MM_ALLOCATOR_LOG + ,request->func_name, + request->line_no, + request->timestamp +#endif + ); +} +void mm_allocator_print_allocator_requests( + FILE* const stream, + mm_allocator_t* const mm_allocator, + const bool compact_free) { + // Print allocator memory + uint64_t segment_idx, request_idx; + uint64_t free_block = 0; + bool has_requests = false; + fprintf(stream," => MMAllocator.requests\n"); + const uint64_t num_segments = mm_allocator_get_num_segments(mm_allocator); + for (segment_idx=0;segment_idx 0) { + fprintf(stream," [n/a\tFree] \t(%" PRIu64 " Bytes)\n",free_block); + free_block = 0; + } + mm_allocator_print_allocator_request(stream,request,segment_idx,request_idx); + has_requests = true; + } + } else { + mm_allocator_print_allocator_request(stream,request,segment_idx,request_idx); + has_requests = true; + } + } + } + if (!has_requests) { + fprintf(stream," -- No requests --\n"); + } + // Print malloc memory + fprintf(stream," => MMMalloc.requests\n"); + const uint64_t num_requests = vector_get_used(mm_allocator->malloc_requests); + mm_malloc_request_t* const requests = vector_get_mem(mm_allocator->malloc_requests,mm_malloc_request_t); + uint64_t i; + for (i=0;i 0) { + mm_allocator_print_malloc_request(stream,requests+i); + } + } + if (num_requests == 0) { + fprintf(stream," -- No requests --\n"); + } +} +void mm_allocator_print( + FILE* const stream, + mm_allocator_t* const mm_allocator, + const bool display_requests) { + // Print header + fprintf(stream,"MMAllocator.report\n"); + // Print segment information + const uint64_t num_segments = mm_allocator_get_num_segments(mm_allocator); + const uint64_t segment_size = mm_allocator->segment_size; + fprintf(stream," => Segments.allocated %" PRIu64 "\n",num_segments); + fprintf(stream," => Segments.size %" PRIu64 " MB\n",segment_size/(1024*1024)); + fprintf(stream," => Memory.available %" PRIu64 " MB\n",num_segments*(segment_size/(1024*1024))); + // Print memory information + uint64_t bytes_used_malloc, bytes_used_allocator; + uint64_t bytes_free_available, bytes_free_fragmented; + mm_allocator_get_occupation(mm_allocator,&bytes_used_malloc,&bytes_used_allocator,&bytes_free_available,&bytes_free_fragmented); + fprintf(stream," => Memory.used %" PRIu64 "\n",bytes_used_allocator); + fprintf(stream," => Memory.free %" PRIu64 "\n",bytes_free_available+bytes_free_fragmented); + fprintf(stream," => Memory.free.available %" PRIu64 "\n",bytes_free_available); + fprintf(stream," => Memory.free.fragmented %" PRIu64 "\n",bytes_free_fragmented); + fprintf(stream," => Memory.malloc %" PRIu64 "\n",bytes_used_malloc); + // Print memory requests + if (display_requests) { + mm_allocator_print_allocator_requests(stream,mm_allocator,false); + } +} + + + diff --git a/benchmarks/wfa/utils/mm_allocator.h b/benchmarks/wfa/utils/mm_allocator.h new file mode 100644 index 0000000..47a307b --- /dev/null +++ b/benchmarks/wfa/utils/mm_allocator.h @@ -0,0 +1,131 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Simple managed-memory allocator that reduces the overhead + * of using malloc/calloc/free functions by allocating slabs of memory + * and dispatching memory segments in order. + */ + +#ifndef MM_ALLOCATOR_H_ +#define MM_ALLOCATOR_H_ + +#include "utils/vector.h" + +/* + * Configuration + */ +//#define MM_ALLOCATOR_LOG +#define MM_ALLOCATOR_ALIGNMENT 8 // 64bits + +/* + * MM-Allocator + */ +typedef struct { + // Metadata + uint64_t request_ticker; // Request ticker + // Memory segments + uint64_t segment_size; // Memory segment size (bytes) + vector_t* segments; // Memory segments (mm_allocator_segment_t*) + vector_t* segments_free; // Completely free segments (mm_allocator_segment_t*) + uint64_t current_segment_idx; // Current segment being used (serving memory) + // Malloc memory + vector_t* malloc_requests; // Malloc requests (mm_malloc_request_t) + uint64_t malloc_requests_freed; // Total malloc request freed and still in vector +} mm_allocator_t; + +/* + * Setup + */ +mm_allocator_t* mm_allocator_new( + const uint64_t segment_size); +void mm_allocator_clear( + mm_allocator_t* const mm_allocator); +void mm_allocator_delete( + mm_allocator_t* const mm_allocator); + +/* + * Allocator + */ +void* mm_allocator_allocate( + mm_allocator_t* const mm_allocator, + const uint64_t num_bytes, + const bool zero_mem, + const uint64_t align_bytes +#ifdef MM_ALLOCATOR_LOG + ,const char* func_name, + uint64_t line_no +#endif + ); + +#ifdef MM_ALLOCATOR_LOG +#define mm_allocator_alloc(mm_allocator,type) \ + ((type*)mm_allocator_allocate(mm_allocator,sizeof(type),false,MM_ALLOCATOR_ALIGNMENT,__func__,(uint64_t)__LINE__)) +#define mm_allocator_malloc(mm_allocator,num_bytes) \ + (mm_allocator_allocate(mm_allocator,num_bytes,false,MM_ALLOCATOR_ALIGNMENT,__func__,(uint64_t)__LINE__)) +#define mm_allocator_calloc(mm_allocator,num_elements,type,clear_mem) \ + ((type*)mm_allocator_allocate(mm_allocator,(num_elements)*sizeof(type),clear_mem,MM_ALLOCATOR_ALIGNMENT,__func__,(uint64_t)__LINE__)) +#else +#define mm_allocator_alloc(mm_allocator,type) \ + ((type*)mm_allocator_allocate(mm_allocator,sizeof(type),false,MM_ALLOCATOR_ALIGNMENT)) +#define mm_allocator_malloc(mm_allocator,num_bytes) \ + (mm_allocator_allocate(mm_allocator,num_bytes,false,MM_ALLOCATOR_ALIGNMENT)) +#define mm_allocator_calloc(mm_allocator,num_elements,type,clear_mem) \ + ((type*)mm_allocator_allocate(mm_allocator,(num_elements)*sizeof(type),clear_mem,MM_ALLOCATOR_ALIGNMENT)) +#endif + +#define mm_allocator_uint64(mm_allocator) mm_allocator_malloc(mm_allocator,sizeof(uint64_t)) +#define mm_allocator_uint32(mm_allocator) mm_allocator_malloc(mm_allocator,sizeof(uint32_t)) +#define mm_allocator_uint16(mm_allocator) mm_allocator_malloc(mm_allocator,sizeof(uint16_t)) +#define mm_allocator_uint8(mm_allocator) mm_allocator_malloc(mm_allocator,sizeof(uint8_t)) + +/* + * Free + */ +void mm_allocator_free( + mm_allocator_t* const mm_allocator, + void* const memory); + +/* + * Utils + */ +void mm_allocator_get_occupation( + mm_allocator_t* const mm_allocator, + uint64_t* const bytes_used_malloc, + uint64_t* const bytes_used_allocator, + uint64_t* const bytes_free_available, + uint64_t* const bytes_free_fragmented); + +/* + * Display + */ +void mm_allocator_print( + FILE* const stream, + mm_allocator_t* const mm_allocator, + const bool display_requests); + +#endif /* MM_ALLOCATOR_H_ */ diff --git a/benchmarks/wfa/utils/mm_stack.c b/benchmarks/wfa/utils/mm_stack.c new file mode 100644 index 0000000..b0b4c29 --- /dev/null +++ b/benchmarks/wfa/utils/mm_stack.c @@ -0,0 +1,259 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + */ + +#include "mm_stack.h" + +/* + * Debug + */ +//#define MM_STACK_FORCE_MALLOC + +/* + * Constants + */ +#define MM_STACK_INITIAL_SEGMENTS 10 +#define MM_STACK_INITIAL_MALLOC_REQUESTS 10 +#define MM_STACK_INITIAL_STATES 10 + +/* + * Stack state + */ +typedef struct { + uint64_t segment_idx; + uint64_t segment_used; + uint64_t num_malloc_requests; +} mm_stack_state_t; + +/* + * Memory Segments + */ +typedef struct { + uint64_t size; // Total memory available + void* memory; // Memory + uint64_t used; // Bytes used (offset to memory next free byte) +} mm_stack_segment_t; + +/* + * Segments + */ +mm_stack_segment_t* mm_stack_segment_new( + mm_stack_t* const mm_stack) { + // Allocate handler + mm_stack_segment_t* const segment = malloc(sizeof(mm_stack_segment_t)); + // Memory + segment->size = mm_stack->segment_size; + segment->memory = malloc(mm_stack->segment_size); + segment->used = 0; + // Add to segments + vector_insert(mm_stack->segments,segment,mm_stack_segment_t*); + // Return + return segment; +} +void mm_stack_segment_clear( + mm_stack_segment_t* const segment) { + segment->used = 0; +} +void mm_stack_segment_delete( + mm_stack_segment_t* const segment) { + free(segment->memory); + free(segment); +} +/* + * Setup + */ +mm_stack_t* mm_stack_new( + const uint64_t segment_size) { + // Allocate handler + mm_stack_t* const mm_stack = malloc(sizeof(mm_stack_t)); + // Memory segments + mm_stack->segments = vector_new(MM_STACK_INITIAL_SEGMENTS,mm_stack_segment_t*); + mm_stack->segment_size = segment_size; + mm_stack_segment_new(mm_stack); + mm_stack->current_segment_idx = 0; + // Malloc memory + mm_stack->malloc_requests = vector_new(MM_STACK_INITIAL_MALLOC_REQUESTS,void*); + // Stack states + mm_stack->states = vector_new(MM_STACK_INITIAL_STATES,mm_stack_state_t); + // Return + return mm_stack; +} +void mm_stack_clear( + mm_stack_t* const mm_stack) { + // Clear first memory segment and discard the rest + mm_stack_segment_t* const segment = *vector_get_elm(mm_stack->segments,0,mm_stack_segment_t*); + mm_stack_segment_clear(segment); + mm_stack->current_segment_idx = 0; + // Free malloc memory + VECTOR_ITERATE(mm_stack->malloc_requests,mem_ptr,m,void*) { + free(*mem_ptr); + } + vector_clear(mm_stack->malloc_requests); + // Clear states + vector_clear(mm_stack->states); +} +void mm_stack_delete( + mm_stack_t* const mm_stack) { + // Delete memory segments + VECTOR_ITERATE(mm_stack->segments,segment_ptr,p,mm_stack_segment_t*) { + mm_stack_segment_delete(*segment_ptr); + } + vector_delete(mm_stack->segments); + // Free malloc memory + VECTOR_ITERATE(mm_stack->malloc_requests,mem_ptr,m,void*) { + free(*mem_ptr); + } + vector_delete(mm_stack->malloc_requests); + // Clear states + vector_delete(mm_stack->states); + // Free handler + free(mm_stack); +} +/* + * Allocator + */ +mm_stack_segment_t* mm_stack_fetch_segment( + mm_stack_t* const mm_stack, + const uint64_t num_bytes) { + // Fetch current segment + mm_stack_segment_t* const curr_segment = + *vector_get_elm(mm_stack->segments,mm_stack->current_segment_idx,mm_stack_segment_t*); + // Check available segment size + if (curr_segment->used + num_bytes <= curr_segment->size) { + return curr_segment; + } + // Check overall segment size + if (num_bytes > curr_segment->size) { + return NULL; // Memory request over segment size + } + // Get free segment + const uint64_t num_segments = vector_get_used(mm_stack->segments); + ++(mm_stack->current_segment_idx); + if (mm_stack->current_segment_idx < num_segments) { + // Get next segment + mm_stack_segment_t* const segment = + *vector_get_elm(mm_stack->segments,mm_stack->current_segment_idx,mm_stack_segment_t*); + // Clear + mm_stack_segment_clear(segment); + // Return + return segment; + } + // Add new segment + return mm_stack_segment_new(mm_stack); +} +void* mm_stack_allocate( + mm_stack_t* const mm_stack, + const uint64_t num_bytes, + const bool zero_mem, + const uint64_t align_bytes) { + // Zero check + if (num_bytes == 0) { + fprintf(stderr,"MMStack error. Zero bytes requested\n"); + exit(1); + } + // Add payload + const uint64_t num_bytes_allocated = num_bytes + align_bytes; + // Fetch segment +#ifdef MM_STACK_FORCE_MALLOC + mm_stack_segment_t* const segment = NULL; // Force malloc memory +#else + mm_stack_segment_t* const segment = mm_stack_fetch_segment(mm_stack,num_bytes_allocated); +#endif + // Allocate memory + void* memory_base ; + if (segment != NULL) { + // Segment-memory + memory_base = segment->memory + segment->used; + if (zero_mem) memset(memory_base,0,num_bytes_allocated); // Set zero + segment->used += num_bytes_allocated; // Update segment + } else { + // Malloc-memory + memory_base = malloc(num_bytes_allocated); + if (zero_mem) memset(memory_base,0,num_bytes_allocated); // Set zero + // Add malloc-request + vector_insert(mm_stack->malloc_requests,memory_base,void*); + } + // Check alignment + if (align_bytes == 0) return memory_base; + // Align memory request + void* memory_aligned = memory_base + align_bytes; + memory_aligned = memory_aligned - ((uintptr_t)memory_aligned % align_bytes); + return memory_aligned; +} +/* + * Push/pop states + */ +void mm_stack_push( + mm_stack_t* const mm_stack) { + // Get new stack-state + mm_stack_state_t* stack_state; + vector_alloc_new(mm_stack->states,mm_stack_state_t,stack_state); + // Store current state + mm_stack_segment_t* const current_segment = + *vector_get_elm(mm_stack->segments,mm_stack->current_segment_idx,mm_stack_segment_t*); + stack_state->segment_idx = mm_stack->current_segment_idx; + stack_state->segment_used = current_segment->used; + stack_state->num_malloc_requests = vector_get_used(mm_stack->malloc_requests); +} +void mm_stack_pop( + mm_stack_t* const mm_stack) { + // Get last stack-state + mm_stack_state_t* const stack_state = vector_get_last_elm(mm_stack->states,mm_stack_state_t); + vector_dec_used(mm_stack->states); + // Restore segment-memory state + mm_stack->current_segment_idx = stack_state->segment_idx; + mm_stack_segment_t* const current_segment = + *(vector_get_elm(mm_stack->segments,stack_state->segment_idx,mm_stack_segment_t*)); + current_segment->used = stack_state->segment_used; + // Restore malloc-memory state (free requests) + const uint64_t total_malloc_requests = vector_get_used(mm_stack->malloc_requests); + void** const malloc_requests = vector_get_mem(mm_stack->malloc_requests,void*); + uint64_t i; + for (i=stack_state->num_malloc_requests;imalloc_requests,stack_state->num_malloc_requests); +} +/* + * Display + */ +void mm_stack_print( + FILE* const stream, + mm_stack_t* const mm_stack) { + // Print header + fprintf(stream,"MMStack.report\n"); + // Print segment information + const uint64_t num_segments = vector_get_used(mm_stack->segments); + const uint64_t segment_size = mm_stack->segment_size; + fprintf(stream," => Segments.allocated %" PRIu64 "\n",num_segments); + fprintf(stream," => Segments.size %" PRIu64 " MB\n",segment_size/(1024*1024)); + fprintf(stream," => Memory.available %" PRIu64 " MB\n",num_segments*(segment_size/(1024*1024))); +} + + diff --git a/benchmarks/wfa/utils/mm_stack.h b/benchmarks/wfa/utils/mm_stack.h new file mode 100644 index 0000000..b2c999a --- /dev/null +++ b/benchmarks/wfa/utils/mm_stack.h @@ -0,0 +1,101 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + */ + +#ifndef MM_STACK_H_ +#define MM_STACK_H_ + +#include "utils/vector.h" + +/* + * Configuration + */ +#define MM_STACK_ALIGNMENT 8 // 64bits + +/* + * MM-Allocator + */ +typedef struct { + // Memory segments + uint64_t segment_size; // Memory segment size (bytes) + vector_t* segments; // Memory segments (mm_stack_segment_t*) + uint64_t current_segment_idx; // Current segment being used (serving memory) + // Malloc memory + vector_t* malloc_requests; // Malloc requests (void*) + // Stack states + vector_t* states; // Stack saved states (mm_stack_state_t) +} mm_stack_t; + +/* + * Setup + */ +mm_stack_t* mm_stack_new( + const uint64_t segment_size); +void mm_stack_clear( + mm_stack_t* const mm_stack); +void mm_stack_delete( + mm_stack_t* const mm_stack); + +/* + * Allocator + */ +void* mm_stack_allocate( + mm_stack_t* const mm_stack, + const uint64_t num_bytes, + const bool zero_mem, + const uint64_t align_bytes); + +#define mm_stack_alloc(mm_stack,type) \ + ((type*)mm_stack_allocate(mm_stack,sizeof(type),false,MM_STACK_ALIGNMENT)) +#define mm_stack_malloc(mm_stack,num_bytes) \ + (mm_stack_allocate(mm_stack,num_bytes,false,MM_STACK_ALIGNMENT)) +#define mm_stack_calloc(mm_stack,num_elements,type,clear_mem) \ + ((type*)mm_stack_allocate(mm_stack,(num_elements)*sizeof(type),clear_mem,MM_STACK_ALIGNMENT)) + +#define mm_stack_uint64(mm_stack) mm_stack_malloc(mm_stack,sizeof(uint64_t)) +#define mm_stack_uint32(mm_stack) mm_stack_malloc(mm_stack,sizeof(uint32_t)) +#define mm_stack_uint16(mm_stack) mm_stack_malloc(mm_stack,sizeof(uint16_t)) +#define mm_stack_uint8(mm_stack) mm_stack_malloc(mm_stack,sizeof(uint8_t)) + +/* + * Push/pop states + */ +void mm_stack_push( + mm_stack_t* const mm_stack); +void mm_stack_pop( + mm_stack_t* const mm_stack); + +/* + * Display + */ +void mm_stack_print( + FILE* const stream, + mm_stack_t* const mm_stack); + +#endif /* MM_STACK_H_ */ diff --git a/benchmarks/wfa/utils/string_padded.c b/benchmarks/wfa/utils/string_padded.c new file mode 100644 index 0000000..7536fdc --- /dev/null +++ b/benchmarks/wfa/utils/string_padded.c @@ -0,0 +1,122 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Padded string module to avoid handling corner conditions + */ + +#include "utils/string_padded.h" +#include "utils/mm_allocator.h" + +/* + * Strings (text/pattern) padded + */ +void strings_padded_add_padding( + const char* const buffer, + const int buffer_length, + const int begin_padding_length, + const int end_padding_length, + const char padding_value, + char** const buffer_padded, + char** const buffer_padded_begin, + mm_allocator_t* const mm_allocator) { + // Allocate + const int buffer_padded_length = begin_padding_length + buffer_length + end_padding_length; + *buffer_padded = mm_allocator_malloc(mm_allocator,buffer_padded_length); + // Add begin padding + memset(*buffer_padded,padding_value,begin_padding_length); + // Copy buffer + *buffer_padded_begin = *buffer_padded + begin_padding_length; + memcpy(*buffer_padded_begin,buffer,buffer_length); + // Add end padding + memset(*buffer_padded_begin+buffer_length,padding_value,end_padding_length); +} +strings_padded_t* strings_padded_new( + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int padding_length, + mm_allocator_t* const mm_allocator) { + // Allocate + strings_padded_t* const strings_padded = + mm_allocator_alloc(mm_allocator,strings_padded_t); + strings_padded->mm_allocator = mm_allocator; + // Compute padding dimensions + const int pattern_begin_padding_length = 0; + const int pattern_end_padding_length = padding_length; + const int text_begin_padding_length = 0; + const int text_end_padding_length = padding_length; + // Add padding + strings_padded_add_padding( + pattern,pattern_length, + pattern_begin_padding_length,pattern_end_padding_length,'X', + &(strings_padded->pattern_padded_buffer), + &(strings_padded->pattern_padded),mm_allocator); + strings_padded_add_padding( + text,text_length, + text_begin_padding_length,text_end_padding_length,'Y', + &(strings_padded->text_padded_buffer), + &(strings_padded->text_padded),mm_allocator); + // Return + return strings_padded; +} +strings_padded_t* strings_padded_new_rhomb( + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int padding_length, + mm_allocator_t* const mm_allocator) { + // Allocate + strings_padded_t* const strings_padded = + mm_allocator_alloc(mm_allocator,strings_padded_t); + strings_padded->mm_allocator = mm_allocator; + // Compute padding dimensions + const int pattern_begin_padding_length = text_length + padding_length; + const int pattern_end_padding_length = pattern_length + text_length + padding_length; + const int text_begin_padding_length = padding_length; + const int text_end_padding_length = text_length + padding_length; + // Add padding + strings_padded_add_padding( + pattern,pattern_length, + pattern_begin_padding_length,pattern_end_padding_length,'X', + &(strings_padded->pattern_padded_buffer), + &(strings_padded->pattern_padded),mm_allocator); + strings_padded_add_padding( + text,text_length, + text_begin_padding_length,text_end_padding_length,'Y', + &(strings_padded->text_padded_buffer), + &(strings_padded->text_padded),mm_allocator); + // Return + return strings_padded; +} +void strings_padded_delete(strings_padded_t* const strings_padded) { + mm_allocator_free(strings_padded->mm_allocator,strings_padded->pattern_padded_buffer); + mm_allocator_free(strings_padded->mm_allocator,strings_padded->text_padded_buffer); + mm_allocator_free(strings_padded->mm_allocator,strings_padded); +} diff --git a/benchmarks/wfa/utils/string_padded.h b/benchmarks/wfa/utils/string_padded.h new file mode 100644 index 0000000..32c01a9 --- /dev/null +++ b/benchmarks/wfa/utils/string_padded.h @@ -0,0 +1,74 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Padded string module to avoid handling corner conditions + */ + +#ifndef STRING_PADDED_H +#define STRING_PADDED_H + +/* + * Includes + */ +#include "utils/commons.h" +#include "utils/mm_allocator.h" + +/* + * Strings Padded + */ +typedef struct { + // Strings + char* pattern_padded_buffer; + char* pattern_padded; + char* text_padded_buffer; + char* text_padded; + // MM + mm_allocator_t* mm_allocator; +} strings_padded_t; + +/* + * Strings (text/pattern) padded + */ +strings_padded_t* strings_padded_new( + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int padding_length, + mm_allocator_t* const mm_allocator); +strings_padded_t* strings_padded_new_rhomb( + const char* const pattern, + const int pattern_length, + const char* const text, + const int text_length, + const int padding_length, + mm_allocator_t* const mm_allocator); +void strings_padded_delete( + strings_padded_t* const strings_padded); + +#endif /* STRING_PADDED_H */ diff --git a/benchmarks/wfa/utils/vector.c b/benchmarks/wfa/utils/vector.c new file mode 100644 index 0000000..98299ba --- /dev/null +++ b/benchmarks/wfa/utils/vector.c @@ -0,0 +1,126 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Simple linear vector for generic type elements + */ + +#include "vector.h" + +/* + * Constants + */ +#define VECTOR_EXPAND_FACTOR (3.0/2.0) + +/* + * Setup + */ +vector_t* vector_new_(const uint64_t num_initial_elements,const uint64_t element_size) { + vector_t* const vector_buffer = malloc(sizeof(vector_t)); + vector_buffer->element_size = element_size; + vector_buffer->elements_allocated = num_initial_elements; + vector_buffer->memory = malloc(num_initial_elements*element_size); + if (!vector_buffer->memory) { + fprintf(stderr,"Could not create new vector (%"PRIu64" bytes requested)", + num_initial_elements*element_size); + exit(1); + } + vector_buffer->used = 0; + return vector_buffer; +} +void vector_reserve(vector_t* const vector,const uint64_t num_elements,const bool zero_mem) { + if (vector->elements_allocated < num_elements) { + const uint64_t proposed=(float)vector->elements_allocated*VECTOR_EXPAND_FACTOR; + vector->elements_allocated = num_elements>proposed?num_elements:proposed; + vector->memory = realloc(vector->memory,vector->elements_allocated*vector->element_size); + if (!vector->memory) { + fprintf(stderr,"Could not reserve vector (%"PRIu64" bytes requested)", + vector->elements_allocated*vector->element_size); + exit(1); + } + } + if (zero_mem) { + memset(vector->memory+vector->used*vector->element_size,0, + (vector->elements_allocated-vector->used)*vector->element_size); + } +} +void vector_resize__clear(vector_t* const vector,const uint64_t num_elements) { + if (vector->elements_allocated < num_elements) { + const uint64_t proposed = (float)vector->elements_allocated*VECTOR_EXPAND_FACTOR; + vector->elements_allocated = (num_elements>proposed)?num_elements:proposed; + // Free previous chunk (no need to pay the cost of reallocating memory) + free(vector->memory); + // Allocate new block of memory + vector->memory = malloc(vector->elements_allocated*vector->element_size); + if (!vector->memory) { + fprintf(stderr,"Could not reserve vector (%"PRIu64" bytes requested)", + vector->elements_allocated*vector->element_size); + exit(1); + } + } + vector->used=0; +} +void vector_cast__clear_(vector_t* const vector,const uint64_t element_size) { + vector->elements_allocated = (vector->elements_allocated*vector->element_size)/element_size; + vector->element_size = element_size; + vector->used = 0; +} +void vector_delete(vector_t* const vector) { + free(vector->memory); + free(vector); +} +/* + * Accessors + */ +#ifdef VECTOR_DEBUG +void* vector_get_mem_element(vector_t* const vector,const uint64_t position,const uint64_t element_size) { + if (position >= (vector)->used) { + fprintf(stderr,"Vector position out-of-range [0,%"PRIu64")",(vector)->used); + exit(1); + } + return vector->memory + (position*element_size); +} +#endif +/* + * Miscellaneous + */ +void vector_copy(vector_t* const vector_to,vector_t* const vector_from) { + // Prepare + vector_cast__clear_(vector_to,vector_from->element_size); + vector_reserve(vector_to,vector_from->used,false); + // Copy + vector_set_used(vector_to,vector_from->used); + memcpy(vector_to->memory,vector_from->memory,vector_from->used*vector_from->element_size); +} +vector_t* vector_dup(vector_t* const vector_src) { + vector_t* const vector_cpy = vector_new_(vector_src->used,vector_src->element_size); + // Copy + vector_set_used(vector_cpy,vector_src->used); + memcpy(vector_cpy->memory,vector_src->memory,vector_src->used*vector_src->element_size); + return vector_cpy; +} + diff --git a/benchmarks/wfa/utils/vector.h b/benchmarks/wfa/utils/vector.h new file mode 100644 index 0000000..0473bb5 --- /dev/null +++ b/benchmarks/wfa/utils/vector.h @@ -0,0 +1,141 @@ +/* + * The MIT License + * + * Wavefront Alignments Algorithms + * Copyright (c) 2017 by Santiago Marco-Sola + * + * This file is part of Wavefront Alignments Algorithms. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * PROJECT: Wavefront Alignments Algorithms + * AUTHOR(S): Santiago Marco-Sola + * DESCRIPTION: Simple linear vector for generic type elements + */ + +#ifndef VECTOR_H_ +#define VECTOR_H_ + +#include "commons.h" + +/* + * Checkers + */ +//#define VECTOR_DEBUG + +/* + * Data Structures + */ +typedef struct { + void* memory; + uint64_t used; + uint64_t element_size; + uint64_t elements_allocated; +} vector_t; + +/* + * Vector Setup (Initialization & Allocation) + */ +#define vector_new(num_initial_elements,type) vector_new_(num_initial_elements,sizeof(type)) +vector_t* vector_new_(const uint64_t num_initial_elements,const uint64_t element_size); +void vector_reserve(vector_t* const vector,const uint64_t num_elements,const bool zero_mem); +void vector_resize__clear(vector_t* const vector,const uint64_t num_elements); +#define vector_cast__clear(vector,type) vector_cast__clear_s(vector,sizeof(type)) +void vector_cast__clear_(vector_t* const vector,const uint64_t element_size); +#define vector_clear(vector) (vector)->used=0 +void vector_delete(vector_t* const vector); +#define vector_is_empty(vector) (vector_get_used(vector)==0) +#define vector_reserve_additional(vector,additional) vector_reserve(vector,vector_get_used(vector)+additional,false) +#define vector_prepare(vector,num_elements,type) \ + vector_cast__clear(vector,sizeof(type)); \ + vector_reserve(vector,num_elements,false); + +/* + * Element Getters/Setters + */ +#define vector_get_mem(vector,type) ((type*)((vector)->memory)) +#define vector_get_last_elm(vector,type) (vector_get_mem(vector,type)+(vector)->used-1) +#define vector_get_free_elm(vector,type) (vector_get_mem(vector,type)+(vector)->used) +#define vector_set_elm(vector,position,type,elm) *vector_get_elm(vector,position,type) = elm +#ifndef VECTOR_DEBUG + #define vector_get_elm(vector,position,type) (vector_get_mem(vector,type)+position) +#else + void* vector_get_mem_element(vector_t* const vector,const uint64_t position,const uint64_t element_size); + #define vector_get_elm(vector,position,type) ((type*)vector_get_mem_element(vector,position,sizeof(type))) +#endif + +/* + * Used elements Getters/Setters + */ +#define vector_get_used(vector) ((vector)->used) +#define vector_set_used(vector,total_used) (vector)->used=(total_used) +#define vector_inc_used(vector) (++((vector)->used)) +#define vector_dec_used(vector) (--((vector)->used)) +#define vector_add_used(vector,additional) vector_set_used(vector,vector_get_used(vector)+additional) +#define vector_update_used(vector,pointer_to_next_free_element) \ + (vector)->used = (pointer_to_next_free_element) - ((__typeof__(pointer_to_next_free_element))((vector)->memory)) + + +/* + * Vector Allocate/Insert (Get a new element or Add an element to the end of the vector) + */ +#define vector_alloc_new(vector,type,return_element_pointer) { \ + vector_reserve_additional(vector,1); \ + return_element_pointer = vector_get_free_elm(vector,type); \ + vector_inc_used(vector); \ +} +#define vector_insert(vector,element,type) { \ + vector_reserve_additional(vector,1); \ + *(vector_get_free_elm(vector,type))=element; \ + vector_inc_used(vector); \ +} + +/* + * Macro generic iterator + * VECTOR_ITERATE(vector_of_ints,elm_iterator,elm_counter,int) { + * ..code.. + * } + */ +#define VECTOR_ITERATE(vector,element,counter,type) \ + const uint64_t vector_##element##_used = vector_get_used(vector); \ + type* element = vector_get_mem(vector,type); \ + uint64_t counter; \ + for (counter=0;counter $INPUTS_DIR/nn-base/small/out-small.fastq echo "Running nn-variant" - python ../benchmarks/nn-variant/prediction.py --chkpnt_fn $INPUTS_DIR/nn-variant/model --sampleName chr20 --threads 1 --qual 100 --input_fn $INPUTS_DIR/nn-variant/small/prediction_input.h5 --output_fn $INPUTS_DIR/nn-variant/small/prediction_output.h5 + ../benchmarks/nn-variant/Clair3/callVar.sh --bam_fn="$INPUTS_DIR/nn-variant/HG002_GRCh38_ONT-UL_GIAB_20200122_chr20_0_10000000.phased.bam" --ref_fn="$INPUTS_DIR/nn-variant/hg38_chr20.fa" --threads=1 --platform="ont" --model_path="$INPUTS_DIR/nn-variant/models/r941_prom_hac_g360+g422" --bed_fn="$INPUTS_DIR/nn-variant/small/region.bed" --output="$INPUTS_DIR/nn-variant/output-small" echo "Running abea" ../benchmarks/abea/f5c eventalign -b $INPUTS_DIR/abea/small/1000reads.bam -g $INPUTS_DIR/abea/humangenome.fa -r $INPUTS_DIR/abea/1000reads.fastq -B 3.7M > $INPUTS_DIR/abea/small/events.tsv @@ -37,9 +37,9 @@ else echo "Running nn-base" python ../benchmarks/nn-base/bonito/basecall.py ../benchmarks/nn-base/models/bonito_dna_r941 $INPUTS_DIR/nn-base/large --device cuda:0 --fastq > $INPUTS_DIR/nn-base/large/out-large.fastq - + echo "Running nn-variant" - python ../benchmarks/nn-variant/prediction.py --chkpnt_fn $INPUTS_DIR/nn-variant/model --sampleName chr20 --threads 1 --qual 100 --input_fn $INPUTS_DIR/nn-variant/large/prediction_input.h5 --output_fn $INPUTS_DIR/nn-variant/large/prediction_output.h5 + ../benchmarks/nn-variant/Clair3/callVar.sh --bam_fn="$INPUTS_DIR/nn-variant/HG002_GRCh38_ONT-UL_GIAB_20200122_chr20_0_10000000.phased.bam" --ref_fn="$INPUTS_DIR/nn-variant/hg38_chr20.fa" --threads=1 --platform="ont" --model_path="$INPUTS_DIR/nn-variant/models/r941_prom_hac_g360+g422" --bed_fn="$INPUTS_DIR/nn-variant/large/region.bed" --output="$INPUTS_DIR/nn-variant/output-large" echo "Running abea" ../benchmarks/abea/f5c eventalign -b $INPUTS_DIR/abea/large/10000reads.bam -g $INPUTS_DIR/abea/humangenome.fa -r $INPUTS_DIR/abea/10000reads.fastq -B 3.7M > $INPUTS_DIR/abea/large/events.tsv diff --git a/scripts/vtune.pc.sh b/scripts/vtune.pc.sh index ba69a5f..3a56f42 100644 --- a/scripts/vtune.pc.sh +++ b/scripts/vtune.pc.sh @@ -42,6 +42,9 @@ vtune_pc $OUTPUTS_DIR/dbg_pc "../benchmarks/dbg/dbg $INPUTS_DIR/dbg/large/ERR194 echo "Running chain" vtune_pc $OUTPUTS_DIR/chain_pc "../benchmarks/chain/chain -i $INPUTS_DIR/chain/large/c_elegans_40x.10k.in -o $INPUTS_DIR/chain/large/c_elegans_40x.10k.out" +echo "Running fast-chain" +vtune_pc $OUTPUTS_DIR/fast-chain_pc "../benchmarks/fast-chain/chain -i $INPUTS_DIR/chain/large/c_elegans_40x.10k.in -o $INPUTS_DIR/fast-chain/large/c_elegans_40x.10k.out" + echo "Running poa" vtune_pc $OUTPUTS_DIR/poa_pc "../benchmarks/poa/poa -s $INPUTS_DIR/poa/large/input.fasta -t 1" @@ -55,3 +58,6 @@ echo "Running grm" export LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/2021.1.1/lib/intel64:/opt/intel/oneapi/compiler/2021.1.2/linux/compiler/lib/intel64_lin:$LD_LIBRARY_PATH vtune_pc $OUTPUTS_DIR/grm_pc "../benchmarks/grm/2.0/build_dynamic/plink2 --maf 0.01 --pgen $INPUTS_DIR/grm/large/chr1_phase3.pgen --pvar $INPUTS_DIR/grm/large/chr1_phase3.pvar --psam $INPUTS_DIR/grm/large/phase3_corrected.psam --make-grm-bin --out $INPUTS_DIR/grm/large/grm --threads 1" +echo "Running wfa" +vtune_pc $OUTPUTS_DIR/wfa_pc "../benchmarks/wfa/bin/align_benchmark -i $INPUTS_DIR/bsw/large/banded_SRR7733443_1m_input.txt -o checksum.file -t 1" + diff --git a/scripts/vtune.uarch.sh b/scripts/vtune.uarch.sh index 22af637..51e4b64 100644 --- a/scripts/vtune.uarch.sh +++ b/scripts/vtune.uarch.sh @@ -42,6 +42,9 @@ vtune_uarch $OUTPUTS_DIR/dbg_uarch "../benchmarks/dbg/dbg $INPUTS_DIR/dbg/large/ echo "Running chain" vtune_uarch $OUTPUTS_DIR/chain_uarch "../benchmarks/chain/chain -i $INPUTS_DIR/chain/large/c_elegans_40x.10k.in -o $INPUTS_DIR/chain/large/c_elegans_40x.10k.out" +echo "Running fast-chain" +vtune_uarch $OUTPUTS_DIR/fast-chain_uarch "../benchmarks/fast-chain/chain -i $INPUTS_DIR/chain/large/c_elegans_40x.10k.in -o $INPUTS_DIR/fast-chain/large/c_elegans_40x.10k.out" + echo "Running poa" vtune_uarch $OUTPUTS_DIR/poa_uarch "../benchmarks/poa/poa -s $INPUTS_DIR/poa/large/input.fasta -t 1" @@ -55,3 +58,6 @@ echo "Running grm" export LD_LIBRARY_PATH=/opt/intel/oneapi/mkl/2021.1.1/lib/intel64:/opt/intel/oneapi/compiler/2021.1.2/linux/compiler/lib/intel64_lin:$LD_LIBRARY_PATH vtune_uarch $OUTPUTS_DIR/grm_uarch "../benchmarks/grm/2.0/build_dynamic/plink2 --maf 0.01 --pgen $INPUTS_DIR/grm/large/chr1_phase3.pgen --pvar $INPUTS_DIR/grm/large/chr1_phase3.pvar --psam $INPUTS_DIR/grm/large/phase3_corrected.psam --make-grm-bin --out $INPUTS_DIR/grm/large/grm --threads 1" +echo "Running wfa" +vtune_uarch $OUTPUTS_DIR/wfa_uarch "../benchmarks/wfa/bin/align_benchmark -i $INPUTS_DIR/bsw/large/banded_SRR7733443_1m_input.txt -o checksum.file -t 1" +