- Environment
- Download, extract and Generate metadata for datasets
- Reproducing Paper Results
- Additional Support/Issues?
- Citation
We use Miniconda to manage the environment. Our Python version is 3.11.5. To create the environment, run the following command:
conda env create -f environment.yml -n mtl-group-robustness-env
To activate the environment, run the following command:
conda activate mtl-group-robustness-env
To downloads, extracts and formats the datasets as per the code, run the following script. This will store the data and metadata in the data folder. It already contains the civilcomments-small dataset.
python3 ./src/setup_datasets.py dataset_name --download --data_path dataThe ./src/hparams.yaml file includes the optimal hyperparameters for each method across all five datasets. To get started, execute the following command to generate Python scripts for training with the best hyperparameters.
python3 ./src/generate_hyper_search_scripts.py --dataset waterbirds --method erm_mt_l1This will create a txt file in the hparams_files folder, containing the Python script for five seeds. It will also generate an executable bash file in the scripts folder. To start training run the following command:
sbatch ./scripts/train_waterbirds_erm_mt_l1_hp.shThis will store the best results as a json file for each run in the models_params folder.
If you face any issues in our code / reporducing our results raise a Github issue or contact Atharva Kulkarni ([email protected])
@article{
kulkarni2024multitask,
title={Multitask Learning Can Improve Worst-Group Outcomes},
author={Atharva Kulkarni and Lucio M. Dery and Amrith Setlur and Aditi Raghunathan and Ameet Talwalkar and Graham Neubig},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2024},
url={https://openreview.net/forum?id=sPlhAIp6mk},
note={}
}