Skip to content

添加注释 #1646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmengine/model/base_model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def train_step(self, data: Union[dict, tuple, list],
"""
# Enable automatic mixed precision training context.
with optim_wrapper.optim_context(self):
data = self.data_preprocessor(data, True)
data = self.data_preprocessor(data, True) # ! 数据前处理 (减均值除方差)
losses = self._run_forward(data, mode='loss') # type: ignore
parsed_losses, log_vars = self.parse_losses(losses) # type: ignore
optim_wrapper.update_params(parsed_losses)
Expand Down
2 changes: 1 addition & 1 deletion mmengine/runner/base_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, runner, dataloader: Union[DataLoader, Dict]) -> None:
# Determine whether or not different ranks use different seed.
diff_rank_seed = runner._randomness_cfg.get(
'diff_rank_seed', False)
self.dataloader = runner.build_dataloader(
self.dataloader = runner.build_dataloader( # 构建dataloader
dataloader, seed=runner.seed, diff_rank_seed=diff_rank_seed)
else:
self.dataloader = dataloader
Expand Down
2 changes: 1 addition & 1 deletion mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def run(self) -> torch.nn.Module:
and self._epoch >= self.val_begin
and (self._epoch % self.val_interval == 0
or self._epoch == self._max_epochs)):
self.runner.val_loop.run()
self.runner.val_loop.run() # ! 验证

self.runner.call_hook('after_train')
return self.runner.model
Expand Down
21 changes: 11 additions & 10 deletions mmengine/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1364,10 +1364,10 @@ def build_dataloader(dataloader: Union[DataLoader, Dict],

dataloader_cfg = copy.deepcopy(dataloader)

# build dataset
# build dataset 构建dataset
dataset_cfg = dataloader_cfg.pop('dataset')
if isinstance(dataset_cfg, dict):
dataset = DATASETS.build(dataset_cfg)
dataset = DATASETS.build(dataset_cfg) # 根据dataset_cfg的type构建对应的dataset类
if hasattr(dataset, 'full_init'):
dataset.full_init()
else:
Expand Down Expand Up @@ -1473,7 +1473,7 @@ def build_dataloader(dataloader: Union[DataLoader, Dict],
raise TypeError(
'collate_fn should be a dict or callable object, but got '
f'{collate_fn_cfg}')
data_loader = DataLoader(
data_loader = DataLoader( # 最终构建pytroch的DataLoader
dataset=dataset,
sampler=sampler if batch_sampler is None else None,
batch_sampler=batch_sampler,
Expand Down Expand Up @@ -1724,7 +1724,7 @@ def train(self) -> nn.Module:
'method. Please provide `train_dataloader`, `train_cfg`, '
'`optimizer` and `param_scheduler` arguments when '
'initializing runner.')

# ! 构建训练loop
self._train_loop = self.build_train_loop(
self._train_loop) # type: ignore

Expand All @@ -1742,10 +1742,10 @@ def train(self) -> nn.Module:
self._val_loop = self.build_val_loop(
self._val_loop) # type: ignore
# TODO: add a contextmanager to avoid calling `before_run` many times
self.call_hook('before_run')
self.call_hook('before_run') # 运行之前的hook

# initialize the model weights
self._init_model_weights()
self._init_model_weights() # 初始化模型

# try to enable activation_checkpointing feature
modules = self.cfg.get('activation_checkpointing', None)
Expand Down Expand Up @@ -1773,9 +1773,10 @@ def train(self) -> nn.Module:
# Maybe compile the model according to options in self.cfg.compile
# This must be called **AFTER** model has been wrapped.
self._maybe_compile('train_step')


# !开始训练模型
model = self.train_loop.run() # type: ignore
self.call_hook('after_run')
self.call_hook('after_run') # !运行之后的hook
return model

def val(self) -> dict:
Expand Down Expand Up @@ -1874,7 +1875,7 @@ def register_hook(
if 'priority' in hook:
_priority = hook.pop('priority')

hook_obj = HOOKS.build(hook)
hook_obj = HOOKS.build(hook) # 构建hook类
else:
hook_obj = hook

Expand Down Expand Up @@ -1963,7 +1964,7 @@ def register_default_hooks(
default_hooks[name] = hook

for hook in default_hooks.values():
self.register_hook(hook)
self.register_hook(hook) # 一个一个的注册

def register_custom_hooks(self, hooks: List[Union[Hook, Dict]]) -> None:
"""Register custom hooks into hook list.
Expand Down