11from os .path import basename
22
33import torch
4- from commode_utils .callback import UploadCheckpointCallback , PrintEpochResultCallback
4+ from commode_utils .callback import PrintEpochResultCallback , ModelCheckpointWithUpload
55from omegaconf import DictConfig
66from pytorch_lightning import seed_everything , Trainer , LightningModule , LightningDataModule
7- from pytorch_lightning .callbacks import ModelCheckpoint , EarlyStopping , LearningRateMonitor
7+ from pytorch_lightning .callbacks import EarlyStopping , LearningRateMonitor
88from pytorch_lightning .loggers import WandbLogger
99
1010
@@ -18,15 +18,14 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
1818 wandb_logger = WandbLogger (project = f"{ model_name } -- { dataset_name } " , log_model = False , offline = config .log_offline )
1919
2020 # define model checkpoint callback
21- checkpoint_callback = ModelCheckpoint (
21+ checkpoint_callback = ModelCheckpointWithUpload (
2222 dirpath = wandb_logger .experiment .dir ,
2323 filename = "{epoch:02d}-val_loss={val/loss:.4f}" ,
2424 monitor = "val/loss" ,
2525 every_n_epochs = params .save_every_epoch ,
2626 save_top_k = - 1 ,
2727 auto_insert_metric_name = False ,
2828 )
29- upload_checkpoint_callback = UploadCheckpointCallback (wandb_logger .experiment .dir )
3029 # define early stopping callback
3130 early_stopping_callback = EarlyStopping (patience = params .patience , monitor = "val/loss" , verbose = True , mode = "min" )
3231 # define callback for printing intermediate result
@@ -48,7 +47,6 @@ def train(model: LightningModule, data_module: LightningDataModule, config: Dict
4847 lr_logger ,
4948 early_stopping_callback ,
5049 checkpoint_callback ,
51- upload_checkpoint_callback ,
5250 print_epoch_result_callback ,
5351 ],
5452 resume_from_checkpoint = config .get ("checkpoint" , None ),
0 commit comments