diff --git a/blech_common_avg_reference.py b/blech_common_avg_reference.py index aec7fe8f..a355fb39 100644 --- a/blech_common_avg_reference.py +++ b/blech_common_avg_reference.py @@ -667,6 +667,8 @@ def plot_car_correlation_comparison(pre_corr_mat, post_corr_mat, parser.add_argument('--cluster_algo', type=str, choices=['kmeans', 'bgmm'], help='Clustering algorithm to use for auto CAR, BGMM tends to allow more clusters', default='kmeans') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() # Get name of directory with the data files @@ -676,7 +678,9 @@ def plot_car_correlation_comparison(pre_corr_mat, post_corr_mat, # Get directory name from metadata handler dir_name = metadata_handler.dir_name # Now create pipeline check with the correct dir_name - this_pipeline_check = pipeline_graph_check(dir_name) + overwrite_dependencies = args.overwrite_dependencies + this_pipeline_check = pipeline_graph_check( + dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/blech_exp_info.py b/blech_exp_info.py index 9b4f6fd3..c2ce764f 100644 --- a/blech_exp_info.py +++ b/blech_exp_info.py @@ -120,6 +120,8 @@ def parse_arguments(): # Additional information parser.add_argument('--notes', help='Experiment notes') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') return parser.parse_args() @@ -722,7 +724,8 @@ def setup_experiment_info(): if not test_bool: script_path = os.path.abspath(__file__) - this_pipeline_check = pipeline_graph_check(dir_path) + this_pipeline_check = pipeline_graph_check( + dir_path, args.overwrite_dependencies) this_pipeline_check.write_to_log(script_path, 'attempted') # Set up cache diff --git a/blech_init.py b/blech_init.py index 41a8f54f..5630052a 100644 --- a/blech_init.py +++ b/blech_init.py @@ -187,8 +187,11 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram help='Directory name with data files') parser.add_argument('--force_run', action='store_true', help='Force run the script without asking user') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() force_run = args.force_run + overwrite_dependencies = args.overwrite_dependencies # Get name of directory with the data files metadata_handler = imp_metadata([[], args.dir_name]) @@ -199,7 +202,8 @@ def generate_processing_scripts(dir_name, blech_clust_dir, electrode_layout_fram dir_name = metadata_handler.dir_name # Now create pipeline check with the correct dir_name - this_pipeline_check = pipeline_graph_check(dir_name) + this_pipeline_check = pipeline_graph_check( + dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') # If info_dict present but execution log is not diff --git a/blech_make_arrays.py b/blech_make_arrays.py index 14fba2d1..6575ba4e 100644 --- a/blech_make_arrays.py +++ b/blech_make_arrays.py @@ -125,15 +125,32 @@ def create_emg_trials_for_digin( test_bool = False + # Parse command line arguments + import argparse + parser = argparse.ArgumentParser( + description='Create spike trains and EMG trials from HDF5 file') + parser.add_argument('dir_name', type=str, nargs='?', + help='Directory name with data files') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') + args = parser.parse_args() + if test_bool: data_dir = '/media/storage/NM_resorted_data/NM43/NM43_500ms_160510_125413' metadata_handler = imp_metadata([[], data_dir]) + overwrite_dependencies = False else: - metadata_handler = imp_metadata(sys.argv) + # Pass dir_name if provided, otherwise use sys.argv + if args.dir_name: + metadata_handler = imp_metadata([[], args.dir_name]) + else: + metadata_handler = imp_metadata(sys.argv[1:]) + overwrite_dependencies = args.overwrite_dependencies # Perform pipeline graph check script_path = os.path.realpath(__file__) - this_pipeline_check = pipeline_graph_check(metadata_handler.dir_name) + this_pipeline_check = pipeline_graph_check( + metadata_handler.dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/blech_post_process.py b/blech_post_process.py index 411597c4..25d1a2da 100644 --- a/blech_post_process.py +++ b/blech_post_process.py @@ -73,6 +73,8 @@ action='store_true') parser.add_argument('--delete-existing', help='Delete existing units', action='store_true') +parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() ############################################################ @@ -137,7 +139,8 @@ # Perform pipeline graph check script_path = os.path.realpath(__file__) -this_pipeline_check = pipeline_graph_check(dir_name) +this_pipeline_check = pipeline_graph_check( + dir_name, args.overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/blech_process.py b/blech_process.py index 8c80e897..befa0dc4 100644 --- a/blech_process.py +++ b/blech_process.py @@ -44,11 +44,14 @@ parser.add_argument('data_dir', type=str, help='Path to data directory') parser.add_argument('electrode_num', type=int, help='Electrode number to process') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() # Perform pipeline graph check script_path = os.path.realpath(__file__) - this_pipeline_check = pipeline_graph_check(args.data_dir) + this_pipeline_check = pipeline_graph_check( + args.data_dir, args.overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/blech_units_characteristics.py b/blech_units_characteristics.py index 28ab153d..d49f8901 100644 --- a/blech_units_characteristics.py +++ b/blech_units_characteristics.py @@ -11,6 +11,7 @@ - **Data Export**: Merges results into a single DataFrame and exports it to CSV and HDF5 formats. """ +import argparse import shutil from tqdm import tqdm import pingouin as pg @@ -38,17 +39,32 @@ # Get name of directory with the data files test_bool = False +# Parse command line arguments +parser = argparse.ArgumentParser( + description='Analyze unit characteristics') +parser.add_argument('dir_name', type=str, nargs='?', + help='Directory name with data files') +parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') +args = parser.parse_args() + if test_bool: # data_dir = '/media/storage/NM_resorted_data/NM43/NM43_500ms_160510_125413' data_dir = '/home/abuzarmahmood/projects/blech_clust/pipeline_testing/test_data_handling/test_data/KM45_5tastes_210620_113227_new' metadata_handler = imp_metadata([[], data_dir]) dir_name = metadata_handler.dir_name + overwrite_dependencies = False else: - metadata_handler = imp_metadata(sys.argv) + if args.dir_name: + metadata_handler = imp_metadata([[], args.dir_name]) + else: + metadata_handler = imp_metadata(sys.argv[1:]) dir_name = metadata_handler.dir_name + overwrite_dependencies = args.overwrite_dependencies # Perform pipeline graph check script_path = os.path.realpath(__file__) - this_pipeline_check = pipeline_graph_check(dir_name) + this_pipeline_check = pipeline_graph_check( + dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/blech_units_plot.py b/blech_units_plot.py index 837d86e0..2f495bbc 100644 --- a/blech_units_plot.py +++ b/blech_units_plot.py @@ -27,13 +27,27 @@ def setup_environment(args): Returns: tuple: metadata_handler, dir_name, params_dict, layout_frame, pipeline_check """ + # Parse command line arguments + import argparse + parser = argparse.ArgumentParser( + description='Plot unit waveforms and ISI histograms') + parser.add_argument('dir_name', type=str, nargs='?', + help='Directory name with data files') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') + parsed_args = parser.parse_args(args[1:]) + # Get name of directory with the data files - metadata_handler = imp_metadata(args) + if parsed_args.dir_name: + metadata_handler = imp_metadata([[], parsed_args.dir_name]) + else: + metadata_handler = imp_metadata(args) dir_name = metadata_handler.dir_name # Perform pipeline graph check script_path = os.path.realpath(__file__) - this_pipeline_check = pipeline_graph_check(dir_name) + this_pipeline_check = pipeline_graph_check( + dir_name, parsed_args.overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/tests/test_pipeline_graph_check.py b/tests/test_pipeline_graph_check.py new file mode 100644 index 00000000..ca72e2c8 --- /dev/null +++ b/tests/test_pipeline_graph_check.py @@ -0,0 +1,121 @@ + +""" +Tests for the pipeline_graph_check functionality in blech_utils.py. +Verifies that --overwrite_dependencies argument has been added to all scripts. +""" +import os +import sys + +# Add the parent directory to the path so we can import the module +sys.path.insert(0, os.path.abspath( + os.path.join(os.path.dirname(__file__), '..'))) + + +class TestOverwriteDependenciesArgument: + """Test that scripts accept --overwrite_dependencies argument""" + + def test_blech_init_has_argument(self): + """Test that blech_init.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_init.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_make_arrays_has_argument(self): + """Test that blech_make_arrays.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_make_arrays.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_process_has_argument(self): + """Test that blech_process.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_process.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_post_process_has_argument(self): + """Test that blech_post_process.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_post_process.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_common_avg_reference_has_argument(self): + """Test that blech_common_avg_reference.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_common_avg_reference.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_exp_info_has_argument(self): + """Test that blech_exp_info.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_exp_info.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_units_plot_has_argument(self): + """Test that blech_units_plot.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_units_plot.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_blech_units_characteristics_has_argument(self): + """Test that blech_units_characteristics.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/blech_units_characteristics.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_utils_infer_rnn_rates_has_argument(self): + """Test that utils/infer_rnn_rates.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/utils/infer_rnn_rates.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_utils_qa_utils_drift_check_has_argument(self): + """Test that utils/qa_utils/drift_check.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/utils/qa_utils/drift_check.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_utils_qa_utils_elbo_drift_has_argument(self): + """Test that utils/qa_utils/elbo_drift.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/utils/qa_utils/elbo_drift.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + def test_utils_qa_utils_unit_similarity_has_argument(self): + """Test that utils/qa_utils/unit_similarity.py accepts --overwrite_dependencies argument""" + with open('/workspace/project/blech_clust/utils/qa_utils/unit_similarity.py', 'r') as f: + content = f.read() + + assert '--overwrite_dependencies' in content + + +class TestPipelineGraphCheckClass: + """Test that pipeline_graph_check class has the new functionality""" + + def test_blech_utils_has_overwrite_dependencies_in_init(self): + """Test that blech_utils.py pipeline_graph_check __init__ has overwrite_dependencies parameter""" + with open('/workspace/project/blech_clust/utils/blech_utils.py', 'r') as f: + content = f.read() + + # Check that the __init__ method has the parameter + assert 'def __init__(self, data_dir, overwrite_dependencies=False):' in content + assert 'self.overwrite_dependencies = overwrite_dependencies' in content + + def test_blech_utils_check_previous_has_override_logic(self): + """Test that check_previous method has the override logic""" + with open('/workspace/project/blech_clust/utils/blech_utils.py', 'r') as f: + content = f.read() + + # Check that the check_previous method has the override logic + assert 'if self.overwrite_dependencies:' in content + assert "input('Continue anyway? ([y]/n) :: ')" in content diff --git a/utils/blech_utils.py b/utils/blech_utils.py index f0a71c93..58e4bdf5 100644 --- a/utils/blech_utils.py +++ b/utils/blech_utils.py @@ -116,8 +116,9 @@ class pipeline_graph_check(): 4) If prior exeuction is not present or failed, generate warning, give user option to override, else exit """ - def __init__(self, data_dir): + def __init__(self, data_dir, overwrite_dependencies=False): self.data_dir = data_dir + self.overwrite_dependencies = overwrite_dependencies self.tee = Tee(data_dir) self.load_graph() self.get_git_info() @@ -208,6 +209,8 @@ def check_graph(self): def check_previous(self, script_path): """ Check that previous run script is present and executed successfully + If overwrite_dependencies is True, skip the check + If check fails and overwrite_dependencies is False, ask user interactively """ # Check that script_path is present in flat_graph if script_path in self.flat_graph.keys(): @@ -224,8 +227,45 @@ def check_previous(self, script_path): if any([x in log_dict['completed'].keys() for x in parent_script]): return True else: - raise ValueError( - f'Parent script [{parent_script}] not found in log') + # Parent script not found in log + if self.overwrite_dependencies: + print( + f'WARNING: Parent script [{parent_script}] not found in log') + print( + 'Overwriting dependencies as --overwrite-dependencies flag was provided') + return True + else: + # Ask user interactively whether to continue + print( + f'WARNING: Parent script [{parent_script}] not found in log') + response = input('Continue anyway? ([y]/n) :: ') + if response.lower() in ['', 'y', 'yes']: + print('Continuing with overwrite...') + return True + else: + print('Exiting...') + raise ValueError( + f'Parent script [{parent_script}] not found in log. User chose to exit.') + else: + # Log file doesn't exist + if self.overwrite_dependencies: + print( + f'WARNING: Execution log not found at {self.log_path}') + print( + 'Overwriting dependencies as --overwrite-dependencies flag was provided') + return True + else: + # Ask user interactively whether to continue + print( + f'WARNING: Execution log not found at {self.log_path}') + response = input('Continue anyway? ([y]/n) :: ') + if response.lower() in ['', 'y', 'yes']: + print('Continuing with overwrite...') + return True + else: + print('Exiting...') + raise ValueError( + f'Execution log not found at {self.log_path}. User chose to exit.') else: raise ValueError( f'Script path [{script_path}] not found in flat graph') diff --git a/utils/infer_rnn_rates.py b/utils/infer_rnn_rates.py index 03cd33db..910c3b8f 100644 --- a/utils/infer_rnn_rates.py +++ b/utils/infer_rnn_rates.py @@ -63,6 +63,8 @@ help='Time to forecast into the future (default: %(default)s)') parser.add_argument('--separate_tastes', action='store_true', help='Fit RNNs for each taste separately (default: %(default)s)') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() @@ -196,7 +198,8 @@ def parse_group_by(spikes_xr, group_by_list): if not test_mode: metadata_handler = imp_metadata([[], args.data_dir]) # Perform pipeline graph check - this_pipeline_check = pipeline_graph_check(args.data_dir) + this_pipeline_check = pipeline_graph_check( + args.data_dir, args.overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/utils/qa_utils/drift_check.py b/utils/qa_utils/drift_check.py index 30387848..d18855e3 100644 --- a/utils/qa_utils/drift_check.py +++ b/utils/qa_utils/drift_check.py @@ -12,6 +12,7 @@ """ # Import stuff! +import argparse import numpy as np import tables import sys @@ -75,12 +76,25 @@ def array_to_df(array, dim_names): ############################################################ # Initialize ############################################################ +# Parse command line arguments +parser = argparse.ArgumentParser( + description='Check for drift in firing rates') +parser.add_argument('dir_name', type=str, nargs='?', + help='Directory name with data files') +parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') +args = parser.parse_args() + # Get name of directory with the data files -metadata_handler = imp_metadata(sys.argv) +if args.dir_name: + metadata_handler = imp_metadata([[], args.dir_name]) +else: + metadata_handler = imp_metadata(sys.argv) dir_name = metadata_handler.dir_name +overwrite_dependencies = args.overwrite_dependencies # Perform pipeline graph check -this_pipeline_check = pipeline_graph_check(dir_name) +this_pipeline_check = pipeline_graph_check(dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/utils/qa_utils/elbo_drift.py b/utils/qa_utils/elbo_drift.py index ad086735..13130d9e 100644 --- a/utils/qa_utils/elbo_drift.py +++ b/utils/qa_utils/elbo_drift.py @@ -27,6 +27,8 @@ parser = argparse.ArgumentParser() parser.add_argument('dir_name', type=str, help='Directory containing data') parser.add_argument('--force', action='store_true', help='Force re-fitting') +parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') args = parser.parse_args() @@ -133,7 +135,8 @@ def ridge_plot( dir_name = metadata_handler.dir_name # Perform pipeline graph check -this_pipeline_check = pipeline_graph_check(dir_name) +this_pipeline_check = pipeline_graph_check( + dir_name, args.overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted') diff --git a/utils/qa_utils/unit_similarity.py b/utils/qa_utils/unit_similarity.py index bf906f33..38afdaf8 100644 --- a/utils/qa_utils/unit_similarity.py +++ b/utils/qa_utils/unit_similarity.py @@ -200,12 +200,27 @@ def write_out_similarties(unique_pairs, unique_pairs_collisions, waveform_counts if __name__ == '__main__': + # Parse command line arguments + import argparse + parser = argparse.ArgumentParser( + description='Check unit similarity') + parser.add_argument('dir_name', type=str, nargs='?', + help='Directory name with data files') + parser.add_argument('--overwrite_dependencies', action='store_true', + help='Overwrite dependency check and continue even if previous script was not run') + args = parser.parse_args() + # Get name of directory with the data files - metadata_handler = imp_metadata(sys.argv) + if args.dir_name: + metadata_handler = imp_metadata([[], args.dir_name]) + else: + metadata_handler = imp_metadata(sys.argv) dir_name = metadata_handler.dir_name + overwrite_dependencies = args.overwrite_dependencies # Perform pipeline graph check - this_pipeline_check = pipeline_graph_check(dir_name) + this_pipeline_check = pipeline_graph_check( + dir_name, overwrite_dependencies) this_pipeline_check.check_previous(script_path) this_pipeline_check.write_to_log(script_path, 'attempted')