A PyTorch-based GAN framework for bidirectional generation between neuroimaging data (fMRI, DTI) and behavioral measures, with a focus on robust training and reproducibility.
- Image โ Behavior: Generate behavioral predictions from neuroimaging data (single or multiple modalities, voxelwise or ROI-wise)
- Behavior โ Image: Generate synthetic neuroimaging data from behavioral scores or batteries of scores
- Optional self-attention mechanisms for capturing long-range dependencies
- GAN, lsGAN, and/or WGAN loss functions
- Flexible network architecture with optimizable hidden layers and other components
- Some support for conditional generation (cGAN is a work in progress)
- Nested cross-validation with Optuna (Bayesian)hyperparameter optimization
- Smart learning rate scheduling with warmup periods and plateau reduction
- Early stopping with customizable patience independently for generator and discriminator
- Gradient penalty and/or clipping for stability
- Multi-GPU support with parallel fold processing (but can default to CPU if no GPU is available)
pip install -r requirements.txtEdit gan_settings.py before running!
python gan_train.pyโโโ ๐ฏ Core Files
โ โโโ gan_settings.py # Global settings and hyperparameters (start here!)
โ โโโ gan_train.py # Main training script with nested CV
โ โโโ gan_arch.py # GAN architecture definitions used by gan_train.py
โ โโโ gan_eval.py # Standalone evaluation script (not used during training)
โ โโโ gan_trainer.py # Training orchestration and management
โ
โโโ ๐งฎ Data Processing
โ โโโ data_utils.py # Data loading and preprocessing
โ โโโ process_utils.py # GPU/device management and process orchestration
โ โโโ concat_mat_beh.py # MAT and behavioral data handling
โ
โโโ ๐ง Training Support
โ โโโ training_utils.py # Early stopping, LR scheduling
โ โโโ model_utils.py # Model initialization, checkpointing, and validation
โ โโโ metrics.py # Loss functions used by gan_trainer.py
โ โโโ visualization.py # Plotting and visualization tools
โ โโโ seed_manager.py # Random seed management
โ
โโโ ๐๏ธ Hyperparameter Optimization
โ โโโ hyperparameter_optimization.py # Optuna optimization used by gan_train.py
โ โโโ test_hyperparameter_optimization.py # Optimization tests
โ
โโโ ๐งช Testing
โ โโโ test_gan.py # Core GAN functionality tests
โ โโโ test_data_processing.py # Data pipeline tests
โ โโโ test_gan_dimensions.py # Input/output dimension tests
โ
โโโ ๐งน Utility
โ โโโ requirements.txt # Package dependencies
โ
โโโ ๐ outputs/ # Generated during training
โโโ cache/ # Intermediate results
โ โโโ data_cache.pkl # Preprocessed data
โ โโโ transform_info.pkl # Transform metadata
โ
โโโ checkpoints/ # Model states
โ โโโ outer_fold_{n}/
โ โ โโโ inner_fold_{m}/
โ โ โ โโโ generator_checkpoint.pt
โ โ โ โโโ discriminator_checkpoint.pt
โ โ โโโ best_model.pt
โ โโโ best_model.pt
โ
โโโ logs/ # Training records
โ โโโ random_seed.json
โ โโโ settings.json
โ โโโ training_log.txt
โ
โโโ plots/ # Visualizations
โ โโโ outer_fold_{n}/
โ โ โโโ loss_curves/
โ โ โโโ predictions/
โ โโโ final_results/
โ
โโโ results/ # Evaluation outputs
โโโ outer_fold_{n}/
โ โโโ predictions/
โ โ โโโ behavior_to_image/ # When MODE="behavior_to_image"
โ โ โ โโโ batch_{k}/
โ โ โ โ โโโ comparison_subject{i}.png
โ โ โ โ โโโ generated_subject{i}.npy
โ โ โ โโโ average_images/
โ โ โ โโโ average_comparison_sagittal.png
โ โ โ โโโ average_comparison_coronal.png
โ โ โ โโโ average_comparison_axial.png
โ โ โโโ image_to_behavior/ # When MODE="image_to_behavior"
โ โ โโโ predicted_scores.csv
โ โ โโโ score_comparisons/
โ โ โโโ scatter_plots.png
โ โ โโโ error_dist.png
โ โโโ true_data_fold{n}.npy
โ โโโ generated_data_fold{n}.npy
โโโ aggregate_results/
โโโ combined_true_data.npy
โโโ combined_generated_data.npy
- ๐๏ธ Generator & Discriminator Foundation
- Shared
BaseGANModulewith configurable dimensions - Smart normalization selection (batch/layer) and dynamicweight initialization strategiesdepending on loss function
- Dropout regularization (configurable rate)
- Other configurable components (e.g., depth, activation functions, GAN type)
- Shared
- ๐ Multi-Head Self-Attention
- 4-head scaled dot-product attention
- Learnable attention strength (ฮณ parameter)
- Strategic placement at network intervals
- Dimension-scaled transformations (dโปโฐยทโต)
- ๐ Early Stopping
Generator: Discriminator Validity โ (behavior_to_image) MSE โ (image_to_behavior) Discriminator: W-distance/accuracy โ
- Auto-checkpoint at peak performance
- Configurable minimum epochs and change thresholds
- Different validation metrics used based on mode:
- behavior_to_image: Uses discriminator's real/fake validity scores during training, MS-SSIM for final evaluation
- image_to_behavior: Uses MSE between predicted and actual behavioral scores
- ๐ Advanced Scheduling
- Warmup โ Plateau reduction
- Independent attention-specific optimization
- ๐ Training Safeguards
- WGAN gradient penalty
- Norm-based gradient clipping
- Hybrid metric scoring (MSE + MS-SSIM)
- Multiple generations per input for point-estimates of performance
- Nested cross-validation for hyperparameter optimization
- ๐ Performance Tracking
MSE, RMSE, MAE โ Lower is better MS-SSIM โ Higher is better- Multi-prediction correlation handling
- ๐ Analysis Pipeline
- Cross-validation aggregation
- Statistical analysis (correlation, MSE, MAE)
- Modality-specific evaluations
- automatic downsampling of brain images
- masking out empty voxels across the group of subjects analyzed
- reconstruction of outputs in original space
- ๐ฅ Input Processing
- Structured .mat parsing
- Automated normalization with robust scaling for behavioral scores
- Modality-based feature extraction
- ๐ฒ Randomization Control
- Global seed management
- Framework-wide determinism
- Consistent cross-validation