Skip to content

Generate neuroimaging scans from behavioral data or behavioral data from neuroimaging scans with generative deep learning

License

Notifications You must be signed in to change notification settings

alexteghipco/brainGANs

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

7 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

๐Ÿง  brainGAN

A PyTorch-based GAN framework for bidirectional generation between neuroimaging data (fMRI, DTI) and behavioral measures, with a focus on robust training and reproducibility.

โœจ Key Features

๐Ÿ”„ Bidirectional Generation

  • Image โ†’ Behavior: Generate behavioral predictions from neuroimaging data (single or multiple modalities, voxelwise or ROI-wise)
  • Behavior โ†’ Image: Generate synthetic neuroimaging data from behavioral scores or batteries of scores

๐Ÿ—๏ธ Architecture

  • Optional self-attention mechanisms for capturing long-range dependencies
  • GAN, lsGAN, and/or WGAN loss functions
  • Flexible network architecture with optimizable hidden layers and other components
  • Some support for conditional generation (cGAN is a work in progress)

๐Ÿš‚ Robust Training Pipeline

  • Nested cross-validation with Optuna (Bayesian)hyperparameter optimization
  • Smart learning rate scheduling with warmup periods and plateau reduction
  • Early stopping with customizable patience independently for generator and discriminator
  • Gradient penalty and/or clipping for stability
  • Multi-GPU support with parallel fold processing (but can default to CPU if no GPU is available)

๐Ÿ› ๏ธ Quick Start

Installation

pip install -r requirements.txt

โš ๏ธ Important

Edit gan_settings.py before running!

Usage

python gan_train.py

๐Ÿ“ Project Structure

โ”œโ”€โ”€ ๐ŸŽฏ Core Files
โ”‚   โ”œโ”€โ”€ gan_settings.py     # Global settings and hyperparameters (start here!)
โ”‚   โ”œโ”€โ”€ gan_train.py        # Main training script with nested CV
โ”‚   โ”œโ”€โ”€ gan_arch.py         # GAN architecture definitions used by gan_train.py
โ”‚   โ”œโ”€โ”€ gan_eval.py         # Standalone evaluation script (not used during training)
โ”‚   โ””โ”€โ”€ gan_trainer.py      # Training orchestration and management
โ”‚
โ”œโ”€โ”€ ๐Ÿงฎ Data Processing
โ”‚   โ”œโ”€โ”€ data_utils.py       # Data loading and preprocessing
โ”‚   โ”œโ”€โ”€ process_utils.py    # GPU/device management and process orchestration
โ”‚   โ””โ”€โ”€ concat_mat_beh.py   # MAT and behavioral data handling
โ”‚
โ”œโ”€โ”€ ๐Ÿ”ง Training Support
โ”‚   โ”œโ”€โ”€ training_utils.py   # Early stopping, LR scheduling
โ”‚   โ”œโ”€โ”€ model_utils.py      # Model initialization, checkpointing, and validation
โ”‚   โ”œโ”€โ”€ metrics.py          # Loss functions used by gan_trainer.py
โ”‚   โ”œโ”€โ”€ visualization.py    # Plotting and visualization tools
โ”‚   โ””โ”€โ”€ seed_manager.py     # Random seed management
โ”‚
โ”œโ”€โ”€ ๐ŸŽ›๏ธ Hyperparameter Optimization
โ”‚   โ”œโ”€โ”€ hyperparameter_optimization.py  # Optuna optimization used by gan_train.py
โ”‚   โ””โ”€โ”€ test_hyperparameter_optimization.py  # Optimization tests
โ”‚
โ”œโ”€โ”€ ๐Ÿงช Testing
โ”‚   โ”œโ”€โ”€ test_gan.py              # Core GAN functionality tests
โ”‚   โ”œโ”€โ”€ test_data_processing.py  # Data pipeline tests
โ”‚   โ””โ”€โ”€ test_gan_dimensions.py   # Input/output dimension tests
โ”‚
โ”œโ”€โ”€ ๐Ÿงน Utility
โ”‚   โ””โ”€โ”€ requirements.txt    # Package dependencies
โ”‚
โ””โ”€โ”€ ๐Ÿ“‚ outputs/             # Generated during training
    โ”œโ”€โ”€ cache/             # Intermediate results
    โ”‚   โ”œโ”€โ”€ data_cache.pkl     # Preprocessed data
    โ”‚   โ””โ”€โ”€ transform_info.pkl # Transform metadata
    โ”‚
    โ”œโ”€โ”€ checkpoints/      # Model states
    โ”‚   โ”œโ”€โ”€ outer_fold_{n}/
    โ”‚   โ”‚   โ”œโ”€โ”€ inner_fold_{m}/
    โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ generator_checkpoint.pt
    โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ discriminator_checkpoint.pt
    โ”‚   โ”‚   โ””โ”€โ”€ best_model.pt
    โ”‚   โ””โ”€โ”€ best_model.pt
    โ”‚
    โ”œโ”€โ”€ logs/            # Training records
    โ”‚   โ”œโ”€โ”€ random_seed.json
    โ”‚   โ”œโ”€โ”€ settings.json
    โ”‚   โ””โ”€โ”€ training_log.txt
    โ”‚
    โ”œโ”€โ”€ plots/          # Visualizations
    โ”‚   โ”œโ”€โ”€ outer_fold_{n}/
    โ”‚   โ”‚   โ”œโ”€โ”€ loss_curves/
    โ”‚   โ”‚   โ””โ”€โ”€ predictions/
    โ”‚   โ””โ”€โ”€ final_results/
    โ”‚
    โ””โ”€โ”€ results/        # Evaluation outputs
        โ”œโ”€โ”€ outer_fold_{n}/
        โ”‚   โ”œโ”€โ”€ predictions/
        โ”‚   โ”‚   โ”œโ”€โ”€ behavior_to_image/    # When MODE="behavior_to_image"
        โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ batch_{k}/
        โ”‚   โ”‚   โ”‚   โ”‚   โ”œโ”€โ”€ comparison_subject{i}.png
        โ”‚   โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ generated_subject{i}.npy
        โ”‚   โ”‚   โ”‚   โ””โ”€โ”€ average_images/
        โ”‚   โ”‚   โ”‚       โ”œโ”€โ”€ average_comparison_sagittal.png
        โ”‚   โ”‚   โ”‚       โ”œโ”€โ”€ average_comparison_coronal.png
        โ”‚   โ”‚   โ”‚       โ””โ”€โ”€ average_comparison_axial.png
        โ”‚   โ”‚   โ””โ”€โ”€ image_to_behavior/    # When MODE="image_to_behavior"
        โ”‚   โ”‚       โ”œโ”€โ”€ predicted_scores.csv
        โ”‚   โ”‚       โ””โ”€โ”€ score_comparisons/
        โ”‚   โ”‚           โ”œโ”€โ”€ scatter_plots.png
        โ”‚   โ”‚           โ””โ”€โ”€ error_dist.png
        โ”‚   โ”œโ”€โ”€ true_data_fold{n}.npy
        โ”‚   โ””โ”€โ”€ generated_data_fold{n}.npy
        โ””โ”€โ”€ aggregate_results/
            โ”œโ”€โ”€ combined_true_data.npy
            โ””โ”€โ”€ combined_generated_data.npy

1๏ธโƒฃ GAN Architecture (gan_arch.py)

Base Architecture
  • ๐Ÿ—๏ธ Generator & Discriminator Foundation
    • Shared BaseGANModule with configurable dimensions
    • Smart normalization selection (batch/layer) and dynamicweight initialization strategiesdepending on loss function
    • Dropout regularization (configurable rate)
    • Other configurable components (e.g., depth, activation functions, GAN type)
Attention Mechanism
  • ๐Ÿ” Multi-Head Self-Attention
    • 4-head scaled dot-product attention
    • Learnable attention strength (ฮณ parameter)
    • Strategic placement at network intervals
    • Dimension-scaled transformations (dโปโฐยทโต)

2๏ธโƒฃ Training Pipeline (training_utils.py)

Training Management
  • ๐Ÿ›‘ Early Stopping
    Generator:     Discriminator Validity โ†‘ (behavior_to_image)
                   MSE โ†“ (image_to_behavior)
    Discriminator: W-distance/accuracy โ†‘
    • Auto-checkpoint at peak performance
    • Configurable minimum epochs and change thresholds
    • Different validation metrics used based on mode:
      • behavior_to_image: Uses discriminator's real/fake validity scores during training, MS-SSIM for final evaluation
      • image_to_behavior: Uses MSE between predicted and actual behavioral scores
Learning Rate Control
  • ๐Ÿ“ˆ Advanced Scheduling
    • Warmup โ†’ Plateau reduction
    • Independent attention-specific optimization
Stability Measures
  • ๐Ÿ”’ Training Safeguards
    • WGAN gradient penalty
    • Norm-based gradient clipping
    • Hybrid metric scoring (MSE + MS-SSIM)
    • Multiple generations per input for point-estimates of performance
    • Nested cross-validation for hyperparameter optimization

3๏ธโƒฃ Evaluation System (gan_eval.py)

Metrics Suite
  • ๐Ÿ“Š Performance Tracking
    MSE, RMSE, MAE    โ†’ Lower is better
    MS-SSIM           โ†’ Higher is better
    
    • Multi-prediction correlation handling
Results Processing
  • ๐Ÿ“ˆ Analysis Pipeline
    • Cross-validation aggregation
    • Statistical analysis (correlation, MSE, MAE)
    • Modality-specific evaluations

๐Ÿ” Data Processing

  • automatic downsampling of brain images
  • masking out empty voxels across the group of subjects analyzed
  • reconstruction of outputs in original space

1๏ธโƒฃ Data Integration

  • ๐Ÿ“ฅ Input Processing
    • Structured .mat parsing
    • Automated normalization with robust scaling for behavioral scores
    • Modality-based feature extraction

2๏ธโƒฃ Reproducibility

  • ๐ŸŽฒ Randomization Control
    • Global seed management
    • Framework-wide determinism
    • Consistent cross-validation

๐Ÿ“ซ Contact

[email protected]

About

Generate neuroimaging scans from behavioral data or behavioral data from neuroimaging scans with generative deep learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages