2525from memory_profiler import memory_usage
2626from numpy .typing import NDArray
2727import logging
28-
28+ from itertools import islice
29+ import pickle
2930from . import __version__ , __description__
3031
3132from .feature_extraction import (
3738from .nearest_neighbors import NNDescent_ava
3839from . import global_variables
3940from .custom_logging import logger , add_log_file
41+ from .align import (
42+ Seq ,
43+ get_overlap_candidates ,
44+ run_multiprocess_alignment_optimized ,
45+ cWeightedSemiglobalAligner ,
46+ AlignmentResult
47+ )
4048
4149
4250logger .setLevel (logging .DEBUG )
@@ -108,6 +116,7 @@ def parse_command_line_arguments():
108116 help = "Minimum allowed frequency of a k-mer in all reads." ,
109117 )
110118 parser .add_argument (
119+ "-t" ,
111120 "--threads" ,
112121 type = int ,
113122 required = False ,
@@ -118,6 +127,7 @@ def parse_command_line_arguments():
118127 type = int ,
119128 required = False ,
120129 default = 1000 ,
130+ help = "Number of reads to process in each chunk when generating the feature matrix." ,
121131 )
122132
123133 parser .add_argument (
@@ -150,7 +160,7 @@ def parse_command_line_arguments():
150160 "--save-feature-matrix" ,
151161 action = "store_true" ,
152162 default = False ,
153- help = "Save the feature matrix to a file." ,
163+ help = "Save the embedding feature matrix to a file." ,
154164 )
155165 parser .add_argument (
156166 "--keep-intermediates" ,
@@ -182,6 +192,7 @@ def get_neighbors_ava(
182192 logger .info (
183193 f"Using { global_variables .threads } threads"
184194 )
195+
185196 neighbor_indices , distances = NNDescent_ava ().get_neighbors (
186197 embedding_matrix ,
187198 metric = "cosine" ,
@@ -259,43 +270,6 @@ def get_metadata_table(
259270 metadata_df = pd .DataFrame (metadata )
260271 return metadata_df
261272
262- def get_output_dataframe (
263- neighbor_matrix : NDArray ,
264- read_names : List [str ],
265- strands : list [int ],
266- ) -> pd .DataFrame :
267- query_names = []
268- target_names = []
269- ranks = []
270- query_orientations = []
271- target_orientations = []
272-
273- for query_index in range (0 , neighbor_matrix .shape [0 ]):
274- query_name = read_names [query_index ]
275- neighbors = neighbor_matrix [query_index ]
276- query_orientation = ["+" , "-" ][strands [query_index ]]
277- for rank , target_index in enumerate (neighbors ):
278- if target_index == query_index :
279- continue
280- target_name = read_names [target_index ]
281- target_orientation = ["+" , "-" ][strands [target_index ]]
282- query_names .append (query_name )
283- query_orientations .append (query_orientation )
284- target_names .append (target_name )
285- target_orientations .append (target_orientation )
286- ranks .append (rank )
287-
288- columns = {
289- "query_name" : query_names ,
290- "query_orientation" : query_orientations ,
291- "target_name" : target_names ,
292- "target_orientation" : target_orientations ,
293- "neighbor_rank" : ranks ,
294- }
295- df = pd .DataFrame (columns )
296- logger .debug (f"Output DataFrame shape: { df .shape } " )
297- return df
298-
299273def run_fedrann_pipeline (
300274 * ,
301275 input_path : str ,
@@ -321,7 +295,6 @@ def run_fedrann_pipeline(
321295 sample_fraction = kmer_sample_fraction ,
322296 min_multiplicity = kmer_min_multiplicity
323297 )
324- logger .debug (f"kmer_searcher n_features: { n_features } " )
325298
326299 logger .info ("--- 2. Generate dimension reduction and IDF matrix ---" )
327300 fwd_kmer_library_path = join (global_variables .temp_dir , "fwd_kmer_library.fasta" )
@@ -330,56 +303,60 @@ def run_fedrann_pipeline(
330303 counter_file = fwd_kmer_library_path ,
331304 n_features = n_features
332305 )
333- logger .debug (f"get_precompute_matrix n_features: { n_features } " )
334-
306+
335307 logger .info ("--- 3. Generate feature matrix ---" )
336308 embedding_matrix = get_feature_matrix (
337309 ks_file = kmer_searcher_output_path ,
310+ fasta_file = input_path ,
338311 precompute_matrix = precompute_mat ,
339312 kmer_count = n_features ,
340313 read_count = read_count ,
341314 chunk_size = chunk_size
342- )
343-
344- # # Save metadata
345- # metadata_output_file = join(output_dir, "metadata.tsv")
346- # logger.info(f"Saved metadata table to {metadata_output_file}")
347- # metadata_df = get_metadata_table(
348- # read_names=read_names,
349- # strands=strands,
350- # )
351- # metadata_df.to_csv(metadata_output_file, sep="\t", index=False)
352- # del read_names, strands
353- # gc.collect()
315+ )
316+ encoded_reads = get_metadata (kmer_searcher_output_path ,input_path ,n_features )
354317
355-
318+ if save_feature_matrix :
319+ feature_matrix_file = join (output_dir , "embedding_feature_matrix.npy" )
320+ logger .debug (f"Saving feature matrix to { feature_matrix_file } " )
321+ np .save (feature_matrix_file , embedding_matrix )
322+
356323 # Nearest neighbors search
357324 logger .info ("--- 4. Nearest Neighbors Search ---" )
358325 neighbor_matrix , distances = get_neighbors_ava (
359326 embedding_matrix ,
360327 nndescent_n_trees = nndescent_n_trees ,
361328 nndescent_n_neighbors = nndescent_n_neighbors ,
362329 )
330+
363331 del embedding_matrix
364332 gc .collect ()
365333
366- # Save output
367- nbr_output_file = join ( output_dir , "overlaps.tsv" )
368- logger . debug ( "Saving overlap table to %s" , nbr_output_file )
369-
370- read_names , strands = get_metadata (
371- ks_file = kmer_searcher_output_path ,
372- kmer_count = n_features ,
373- )
334+ logger . info ( "--- 5. Align candidates ---" )
335+
336+ overlap_candidates = get_overlap_candidates ( neighbor_matrix , nndescent_n_neighbors )
337+
338+ ##for test
339+ overlap_candidates_file = join ( output_dir , "overlaps_candidates.pkl" )
340+ with open ( overlap_candidates_file , "wb" ) as f :
341+ pickle . dump ( overlap_candidates , f )
374342
375- df = get_output_dataframe (
376- neighbor_matrix = neighbor_matrix ,
377- read_names = read_names ,
378- strands = strands
343+ ####
344+ nbr_output_file = join (output_dir , "overlaps.paf" )
345+ run_multiprocess_alignment_optimized (
346+ overlap_candidates ,
347+ encoded_reads ,
348+ marker_weights = None ,
349+ kmer_size = kmer_size ,
350+ aligner = cWeightedSemiglobalAligner ,
351+ processes = global_variables .threads ,
352+ batch_size = 100 ,
353+ output_path = nbr_output_file ,
354+ max_total_wait_seconds = 600 ,
379355 )
356+ # Save output
357+
358+ logger .debug ("Saving overlap table to %s" , nbr_output_file )
380359
381- df .to_csv (nbr_output_file , sep = "\t " , index = False )
382-
383360 if not keep_intermediates :
384361 logger .debug ("Removing intermediate files" )
385362 rmtree (global_variables .temp_dir )
@@ -425,34 +402,32 @@ def main():
425402 save_feature_matrix = args .save_feature_matrix ,
426403 chunk_size = args .chunk_size
427404 )
405+
428406 if args .mprof :
429- logger .debug ("Attention: Memory profiling enabled. Running with memory profiler." )
430407 mprof_dir = join (output_dir , "mprof" )
431408 os .makedirs (mprof_dir , exist_ok = True )
432409 mprof_output_path = join (mprof_dir , "memory_profile.dat" )
433410
434- # 确保函数有足够的执行时间
435- @memory_usage (
436- backend = "psutil" ,
437- interval = 1 ,
438- multiprocess = True ,
439- include_children = True ,
440- timestamps = True ,
441- max_usage = False ,
442- stream = open (mprof_output_path , "wt" ) # 直接传入文件流
443- )
444- def profiled_function ():
445- return f ()
446-
447- # 执行并确保文件关闭
448- try :
449- profiled_function ()
450- finally :
451- # 确保文件正确关闭
452- if 'profiled_function' in locals ():
453- # 获取stream并关闭
454- pass
411+ with open (mprof_output_path , "wt" ) as f_stream :
412+ logger .debug (f"Profiling to { mprof_output_path } " )
413+
414+ # 1. 运行并获取返回的结果列表
415+ mem_result = memory_usage (
416+ f ,
417+ backend = "psutil" ,
418+ interval = 1 ,
419+ multiprocess = True ,
420+ include_children = True ,
421+ timestamps = True
422+ )
423+
424+ # 2. 手动将结果写入文件 (模拟 mprof 的格式)
425+ f_stream .write ("MT 1.0\n " ) # mprof 标识符
426+ for mem , ts in mem_result :
427+ f_stream .write (f"MEM { mem :.6f} { ts :.6f} \n " )
428+ f_stream .flush () # 强制刷入
455429 else :
430+ # 正常执行
456431 f ()
457432
458433
0 commit comments