-
Notifications
You must be signed in to change notification settings - Fork 0
Step by step pipeline
- Change
crop_and_align_frames.pytosplit_splotlight_behavior_video_to_frames.py
- Run
python src/poseforge/neuromechfly/scripts/copy_kinematic_recording.py- This script scans data from Aymanns et al. (2022) from the NeLy lab server (also publicly available on Harvard Dataverse: https://doi.org/10.7910/DVN/QQMNQK), extracts key kinematic data, and saves them as pickle files.
- Run
python src/poseforge/neuromechfly/scripts/run_simulation.py- This script selects non-resting segments from the recorded kinematics from Aymanns et al. Then, it simulates the selected segments using NeuroMechFly and saves the visual renderings.
- Because Aymanns et al. reports tethered fly behaviors, replaying them on flat terrain might result in failures (e.g. fly flipping upside down). This script includes code that filters out such periods and further splits each segment into several (though typically just one) subsegments.
- Run
python src/poseforge/spotlight/scripts/crop_and_align_frames.py- This script processes each Spotlight experimental trial by extracting, aligning, and cropping frames from the behavior video, and saving the processed frames as individual images in an output directory.
Note
The following step is only for training the flip detection model. It should not be used during production.
- Run
python src/poseforge/spotlight/scripts/train_flip_detection_model.py- This script trains a binary image classifier that detects whether the fly is flipped in the Spotlight arena. - Prerequisite: Manual labels of whether the fly is flipped must be supplied. This is done by creating amanual_label/subdirectory under the directory containing extracted frames from each Spotlight experimental trial, further creating amanual_label/flippedand amanual_label/not_flippedsubdirectories, and copying the extracted frames into the appropriate folder.
- Run
python src/poseforge/spotlight/scripts/detect_flipped_flies.py- This script generates a label file indicating whether the fly is flipped in each extracted Spotlight behavioral frame. Those in which the fly is flipped will be excluded in subsequent steps.
- Run
python src/poseforge/style_transfer/scripts/extract_dataset.py- This script randomly extracts subsets of NeuroMechFly rendering frames and Spotlight recording frames for training the style transfer model.
- Run
bash src/poseforge/style_transfer/scripts/train_cut_model_caller.sh- This script trains a Contrastive Unpaired Translation (CUT) model using a demonstrative set of hyperparameters.
- This shell script calls the CLI from
src/poseforge/style_transfer/scripts/train_cut_model.py. The advantage of having a shell script that calls a Python CLI is that we can change hyperparameters simply by passing them to the Python training script via the CLI from the shell script(s), as opposed to having to make multiple copies of the Python training script. This is handy for hyperparameter tuning. - Hyperparameters can be selected by training many models with different hyperparameters on a cluster (e.g. SCITAS). See
scripts_on_cluster/style_transfer_training/for an example pipeline to machine-generate a batch of*.runscripts that can be submitted to the Slurm scheduler on a cluster. - In the training procedure, we use Weights and Biases to simplify the task of monitoring the training runs and visualizing their results.
Note
The following steps are only for evaluating trained models and selecting the best one(s). They should not be used during inference time.
- Run
python src/poseforge/style_transfer/scripts/test_trained_models.py- This script runs inference on a manually selected, representative set of simulation data, using checkpoints from different training stages of each training run (e.g. once every 20 epochs).
- The user must manually specify a set of simulation data to use for testing and a set of model checkpoints to test. To do so, edit parameters in the
__main__section of the script.
- Run
python src/poseforge/style_transfer/scripts/visualize_inference_results.py- For each training run, this script merges videos its inference results at different stage of training into a single summary video for easier comparison. The original NeuroMechFly simulation rendering is also included in the summary video.
- Manually generate a
bulk_data/style_transfer/synthetic_output/summary_videos/quality_assessment/human_annotated_scores.csvfile with the following columns:-
run: Name of the training run, e.g. "ngf32_netGsmallstylegan2_batsize4_lambGAN0.1" -
best_epoch: Epoch number of the best model in this run -
score: Human-annotated score for the best model (1-5, higher is better) -
note: Optional note about the run
-
- Run
python src/poseforge/style_transfer/scripts/visualize_human_annotated_scores.py- This generates visualizations aimed to help refine model hyperparameters and iteratively retrain the models.
- Run
python src/poseforge/style_transfer/scripts/run_inference.py- This script uses a selected trained style transfer model to translate all NeuroMechFly rendering data into the domain of Spotlight behavior recordings.
- Pre-shuffle the synthetic (and experimental) dataset using
python src/poseforge/pose/contrast/scripts/preextract_atomic_batches.py. This will save small "atomic batches" of data that are always used together during training.- The Python file above is a CLI (run it with
-hto see the help message). An example call of the CLI is included in the__main__section of the script. Alternatively, one can import theextract_atomic_batchesfunction from this file and use it natively within Python (an example is included in the__main__section). - To run this on the SCITAS cluster (Jed), see
scripts_on_cluster/atomic_batch_extraction.
- The Python file above is a CLI (run it with
Note
The following step are only for pretraining the feature extractor with contrastive pretraining. It does not need to be rerun during production.
- Pretrain a ResNet18 feature extractor using
python src/poseforge/pose/contrast/scripts/run_contrastive_pretraining.py.- The Python file above is a CLI (run it with
-hto see the help message). An example call of the CLI is included in the__main__section of the script. Alternatively, invoke training natively within Python by uncommenting example code in the__main__section. - To train the model on the SCITAS cluster (Kuma), see
scripts_on_cluster/contrastive_pretraining_training
- The Python file above is a CLI (run it with
Note
The following steps are only for sanity-checking and visualizing the outcome of the constrastive pretraining step above. They do not need to be rerun during production. In inference time, the feature extractor will be used as a part of the pose estimation model.
- Run inference using
python src/poseforge/pose/contrast/scripts/run_feature_extractor_inference.py.- The Python file above is a CLI (run it with
-hto see the help message). An example call of the CLI is included in the__main__section of the script. Alternatively, invoke inference natively within Python by uncommenting example code in the__main__section. - To run inference on the SCITAS cluster (Kuma), see
scripts_on_cluster/contrastive_pretraining_inference. Note that running inference on all data will produce 200–300 GB of data. For quick inspection, it probably suffice to run inference only for one trial, one fly (e.g.fly5_trial005reserved for testing).
- The Python file above is a CLI (run it with
- Run
python src/poseforge/pose/contrast/scripts/visualize_latents.pyto generate videos showing the latent-space trajectories of selected behavior snippets.
Note
The following steps are only for training the model and visualizing its performance on synthetic data. They do not need to be rerun during production.
- Train 3D keypoint detection model using
python src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_training.py.- This is a CLI (run
python run_keypoints3d_training.py -hto see usage). However, the__main__section of this script also includes a commented-out example of how to run training directly within Python. - See
scripts_on_cluster/keypoints3d_training/for running on the SCITAS cluster.
- This is a CLI (run
- Visualize the performance of the model on synthetic data using
python src/poseforge/pose/keypoints3d/scripts/test_keypoints3d_models.py. Note that you must select a particular model checkpoint file, and it doesn't necessarily have to be final model after the last epoch (observe validation loss to help decide which epoch to use).
- Run inference on Spotlight data by running
python src/poseforge/pose/keypoints3d/scripts/run_keypoints3d_inference.py. This script actually runs prediction using the model state at the end of every other epoch. Combined with the next step, this is meant to help select the best checkpoint to use in production. - Optionally, if you wish to visualize the output of the 3D keypoint detection model, run
python src/poseforge/pose/keypoints3d/scripts/visualize_production_keypoints3d.py. Use the output to decide which checkpoint to use for production-time inference. - Run inverse kinematics by running
python src/poseforge/pose/keypoints3d/scripts/run_inverse_kinematics.py. After inferring joint angles via IK, this script also runs forward kinematics to determine a new set of body-size-constrained 3D keypoint positions.
Note
The following step is only for training the model; it does not have to be rerun in production.
- Train the model:
python src/poseforge/pose/bodyseg/scripts/run_bodyseg_training.py- See
scripts_on_cluster/bodyseg_trainingfor running on the SCITAS cluster.
- See
- Run inference using trained mode:
python src/poseforge/pose/bodyseg/scripts/run_bodyseg_inference.py- Note: you must first select a checkpoint to use (for example, by inspecting the logs). Specify the checkpoint in the
if __name__ == "__main__"section of this script.
- Note: you must first select a checkpoint to use (for example, by inspecting the logs). Specify the checkpoint in the
- Optionally, visualize the results using
python src/poseforge/pose/bodyseg/scripts/visualize_bodyseg_predictions.py. Similarly, you must specify a checkpoint to use. See the end of this script.
- Run
python -u src/poseforge/spotlight/scripts/map_segmentation_to_muscle.py. This will generate an output H5 file as well as debug plots and a time series plot showing kinematics and muscle activation.