Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, model, criterion, metric_ftns, optimizer, config):

self.checkpoint_dir = config.save_dir

# setup visualization writer instance
# setup visualization writer instance
self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard'])

if config.resume is not None:
Expand Down
2 changes: 1 addition & 1 deletion logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .logger import *
from .visualization import *
from .visualization import *
4 changes: 2 additions & 2 deletions new_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


# This script initializes new pytorch project with the template files.
# Run `python3 new_project.py ../MyNewProject` then new project named
# Run `python3 new_project.py ../MyNewProject` then new project named
# MyNewProject will be made
current_dir = Path()
assert (current_dir / 'new_project.py').is_file(), 'Script should be executed in the pytorch-template directory'
Expand All @@ -15,4 +15,4 @@

ignore = [".git", "data", "saved", "new_project.py", "LICENSE", ".flake8", "README.md", "__pycache__"]
copytree(current_dir, target_dir, ignore=ignore_patterns(*ignore))
print('New project initialized at', target_dir.absolute().resolve())
print('New project initialized at', target_dir.absolute().resolve())
24 changes: 16 additions & 8 deletions parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
class ConfigParser:
def __init__(self, config, resume=None, modification=None, run_id=None):
"""
class to parse configuration json file. Handles hyperparameters for training, initializations of modules, checkpoint saving
and logging module.
:param config: Dict containing configurations, hyperparameters for training. contents of `config.json` file for example.
class to parse configuration json file. Handles hyperparameters for training, initializations of modules,
checkpoint saving and logging module.
:param config: Dict containing configurations, hyperparameters for training.
contents of `config.json` file for example.
:param resume: String, path to the checkpoint being loaded.
:param modification: Dict keychain:value, specifying position values to be replaced from config dict.
:param run_id: Unique Identifier for training processes. Used to save checkpoints and training log. Timestamp is being used as default
:param run_id: Unique Identifier for training processes. Used to save checkpoints and training log.
Timestamp is being used as default
"""
# load config file and apply modification
self._config = _update_config(config, modification)
Expand All @@ -26,7 +28,7 @@ class to parse configuration json file. Handles hyperparameters for training, in
save_dir = Path(self.config['trainer']['save_dir'])

exper_name = self.config['name']
if run_id is None: # use timestamp as default run-id
if run_id is None: # use timestamp as default run-id
run_id = datetime.now().strftime(r'%m%d_%H%M%S')
self._save_dir = save_dir / 'models' / exper_name / run_id
self._log_dir = save_dir / 'log' / exper_name / run_id
Expand Down Expand Up @@ -67,14 +69,14 @@ def from_args(cls, args, options=''):
assert args.config is not None, msg_no_cfg
resume = None
cfg_fname = Path(args.config)

config = read_json(cfg_fname)
if args.config and resume:
# update new config for fine-tuning
config.update(read_json(args.config))

# parse custom cli options into dictionary
modification = {opt.target : getattr(args, _get_opt_name(opt.flags)) for opt in options}
modification = {opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options}
return cls(config, resume, modification)

def init_obj(self, name, module, *args, **kwargs):
Expand Down Expand Up @@ -112,7 +114,8 @@ def __getitem__(self, name):
return self.config[name]

def get_logger(self, name, verbosity=2):
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(verbosity, self.log_levels.keys())
msg_verbosity = 'verbosity option {} is invalid. Valid options are {}.'.format(
verbosity, self.log_levels.keys())
assert verbosity in self.log_levels, msg_verbosity
logger = logging.getLogger(name)
logger.setLevel(self.log_levels[verbosity])
Expand All @@ -132,6 +135,8 @@ def log_dir(self):
return self._log_dir

# helper functions to update config dict with custom cli options


def _update_config(config, modification):
if modification is None:
return config
Expand All @@ -141,17 +146,20 @@ def _update_config(config, modification):
_set_by_path(config, k, v)
return config


def _get_opt_name(flags):
for flg in flags:
if flg.startswith('--'):
return flg.replace('--', '')
return flags[0].replace('--', '')


def _set_by_path(tree, keys, value):
"""Set a value in a nested object in tree by sequence of keys."""
keys = keys.split(';')
_get_by_path(tree, keys[:-1])[keys[-1]] = value


def _get_by_path(tree, keys):
"""Access a nested object in tree by sequence of keys."""
return reduce(getitem, keys, tree)
1 change: 1 addition & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)


def main(config):
logger = config.get_logger('train')

Expand Down
8 changes: 6 additions & 2 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Trainer(BaseTrainer):
"""
Trainer class
"""

def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
super().__init__(model, criterion, metric_ftns, optimizer, config)
Expand Down Expand Up @@ -66,10 +67,13 @@ def _train_epoch(self, epoch):

if self.do_validation:
val_log = self._valid_epoch(epoch)
log.update(**{'val_'+k : v for k, v in val_log.items()})
log.update(**{'val_' + k: v for k, v in val_log.items()})

if self.lr_scheduler is not None:
self.lr_scheduler.step()
if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.lr_scheduler.step(log[self.config['lr_scheduler']['metric']])
else:
self.lr_scheduler.step()
return log

def _valid_epoch(self, epoch):
Expand Down
5 changes: 5 additions & 0 deletions utils/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,25 @@ def ensure_dir(dirname):
if not dirname.is_dir():
dirname.mkdir(parents=True, exist_ok=False)


def read_json(fname):
fname = Path(fname)
with fname.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)


def write_json(content, fname):
fname = Path(fname)
with fname.open('wt') as handle:
json.dump(content, handle, indent=4, sort_keys=False)


def inf_loop(data_loader):
''' wrapper function for endless data loader. '''
for loader in repeat(data_loader):
yield from loader


def prepare_device(n_gpu_use):
"""
setup GPU device if available. get gpu device indices which are used for DataParallel
Expand All @@ -43,6 +47,7 @@ def prepare_device(n_gpu_use):
list_ids = list(range(n_gpu_use))
return device, list_ids


class MetricTracker:
def __init__(self, *keys, writer=None):
self.writer = writer
Expand Down