diff --git a/detectron2/engine/train_loop.py b/detectron2/engine/train_loop.py index 738a69de94..06731b7565 100644 --- a/detectron2/engine/train_loop.py +++ b/detectron2/engine/train_loop.py @@ -469,9 +469,14 @@ def __init__( ) if grad_scaler is None: - from torch.cuda.amp import GradScaler + if torch.__version__ >= "2.4.0": + from torch.amp import GradScaler - grad_scaler = GradScaler() + grad_scaler = GradScaler("cuda") + else: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() self.grad_scaler = grad_scaler self.precision = precision self.log_grad_scaler = log_grad_scaler @@ -482,7 +487,10 @@ def run_step(self): """ assert self.model.training, "[AMPTrainer] model was changed to eval mode!" assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" - from torch.cuda.amp import autocast + if torch.__version__ >= "2.4.0": + from torch.amp import autocast + else: + from torch.cuda.amp import autocast start = time.perf_counter() data = next(self._data_loader_iter) @@ -490,13 +498,22 @@ def run_step(self): if self.zero_grad_before_forward: self.optimizer.zero_grad() - with autocast(dtype=self.precision): - loss_dict = self.model(data) - if isinstance(loss_dict, torch.Tensor): - losses = loss_dict - loss_dict = {"total_loss": loss_dict} - else: - losses = sum(loss_dict.values()) + if torch.__version__ >= "2.4.0": + with autocast("cuda", dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + else: + with autocast(dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) if not self.zero_grad_before_forward: self.optimizer.zero_grad()