From dfea44c88eb74662b6b2f0fe5cdd9614cb8ce00e Mon Sep 17 00:00:00 2001 From: Adam Lerer Date: Tue, 9 Jun 2020 23:39:29 -0400 Subject: [PATCH 1/3] Weight decay (latest) --- test/test_functional.py | 2 ++ torchbiggraph/config.py | 11 +++++++++++ torchbiggraph/model.py | 14 ++++++++++++++ torchbiggraph/train_cpu.py | 8 +++++--- 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 1219ce5..131922f 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -351,6 +351,8 @@ def test_default(self): edge_paths=[], # filled in later checkpoint_path=self.checkpoint_path.name, workers=2, + wd=0.01, + wd_interval=2, ) dataset = generate_dataset(base_config, num_entities=100, fractions=[0.4, 0.2]) self.addCleanup(dataset.cleanup) diff --git a/torchbiggraph/config.py b/torchbiggraph/config.py index 0feee42..229d3e2 100644 --- a/torchbiggraph/config.py +++ b/torchbiggraph/config.py @@ -228,6 +228,17 @@ class ConfigSchema(Schema): regularizer: str = attr.ib( default="N3", metadata={"help": "Type of regularization to be applied."} ) + + wd : float = attr.ib( + default=0, + validator=non_negative, + metadata={"help": "Simple (unweighted) weight decay"}, + ) + wd_interval : int = attr.ib( + default=100, + validator=non_negative, + metadata={"help": "Interval to amortize weight decay"}, + ) # data config diff --git a/torchbiggraph/model.py b/torchbiggraph/model.py index 4b30afa..5be14bd 100644 --- a/torchbiggraph/model.py +++ b/torchbiggraph/model.py @@ -404,6 +404,8 @@ def __init__( ], comparator: AbstractComparator, regularizer: AbstractRegularizer, + wd: float, + wd_interval: int, global_emb: bool = False, max_norm: Optional[float] = None, num_dynamic_rels: int = 0, @@ -444,6 +446,8 @@ def __init__( self.max_norm: Optional[float] = max_norm self.half_precision = half_precision self.regularizer: Optional[AbstractRegularizer] = regularizer + self.wd = wd + self.wd_interval = wd_interval def set_embeddings(self, entity: str, side: Side, weights: nn.Parameter) -> None: if self.entities[entity].featurized: @@ -762,6 +766,14 @@ def forward(self, edges: EdgeList) -> Scores: reg, ) + + def l2_norm(self): + ret = 0 + for e in set(self.lhs_embs.values()) | set(self.rhs_embs.values()): + ret += e.weight.pow(2).sum() + return ret + + def forward_direction_agnostic( self, src: EntityList, @@ -921,6 +933,8 @@ def make_model(config: ConfigSchema) -> MultiRelationEmbedder: rhs_operators=rhs_operators, comparator=comparator, regularizer=regularizer, + wd=config.wd, + wd_interval=config.wd_interval, global_emb=config.global_emb, max_norm=config.max_norm, num_dynamic_rels=num_dynamic_rels, diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index b678964..31d8074 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -9,6 +9,7 @@ import logging import math import time +import random from collections import defaultdict from functools import partial from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple @@ -100,9 +101,10 @@ def _process_one_batch( count=len(batch_edges), ) if reg is not None: - (loss + reg).backward() - else: - loss.backward() + loss = loss + reg + if model.wd > 0 and random.random() < 1. / model.wd_interval: + loss = loss * model.wd * model.wd_interval * model.l2_norm() + loss.backward() self.model_optimizer.step(closure=None) for optimizer in self.unpartitioned_optimizers.values(): optimizer.step(closure=None) From 36ea50405b336ab88b9e76b76e6be0953a12e248 Mon Sep 17 00:00:00 2001 From: Adam Lerer Date: Wed, 10 Jun 2020 13:56:51 -0700 Subject: [PATCH 2/3] Avoid persistent dense gradients from wd, that slow down training --- torchbiggraph/train_cpu.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index 31d8074..9ac131b 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -87,7 +87,13 @@ def __init__( def _process_one_batch( self, model: MultiRelationEmbedder, batch_edges: EdgeList ) -> Stats: - model.zero_grad() + # Tricky: this isbasically like calling `model.zero_grad()` except + # that `zero_grad` calls `p.grad.zero_()`. When we perform infrequent + # global L2 regularization, it converts the embedding gradients to dense, + # and then they can never convert back to sparse gradients unless we set + # them to `None` again here. + for p in model.parameters(): + p.grad = None scores, reg = model(batch_edges) @@ -103,7 +109,7 @@ def _process_one_batch( if reg is not None: loss = loss + reg if model.wd > 0 and random.random() < 1. / model.wd_interval: - loss = loss * model.wd * model.wd_interval * model.l2_norm() + loss = loss + model.wd * model.wd_interval * model.l2_norm() loss.backward() self.model_optimizer.step(closure=None) for optimizer in self.unpartitioned_optimizers.values(): From 8fc291326076590cd9ddd631bc48d8a47e8d27b0 Mon Sep 17 00:00:00 2001 From: Adam Lerer Date: Wed, 10 Jun 2020 16:42:59 -0700 Subject: [PATCH 3/3] early stopping --- torchbiggraph/config.py | 6 +++ torchbiggraph/train_cpu.py | 83 +++++++++++++++++++++----------------- 2 files changed, 53 insertions(+), 36 deletions(-) diff --git a/torchbiggraph/config.py b/torchbiggraph/config.py index 229d3e2..198a9ac 100644 --- a/torchbiggraph/config.py +++ b/torchbiggraph/config.py @@ -396,6 +396,12 @@ class ConfigSchema(Schema): "after each training step." }, ) + early_stopping: bool = attr.ib( + default=False, + metadata={ + "help": "Stop training when validation loss increases." + } + ) # expert options diff --git a/torchbiggraph/train_cpu.py b/torchbiggraph/train_cpu.py index 9ac131b..1b2d076 100644 --- a/torchbiggraph/train_cpu.py +++ b/torchbiggraph/train_cpu.py @@ -578,6 +578,7 @@ def train(self) -> None: eval_stats_chunk_avg, ) + last_chunk_loss = float("inf") for epoch_idx, edge_path_idx, edge_chunk_idx in iteration_manager: logger.info( f"Starting epoch {epoch_idx + 1} / {iteration_manager.num_epochs}, " @@ -729,10 +730,17 @@ def train(self) -> None: current_index = (iteration_manager.iteration_idx + 1) * total_buckets - 1 - self._maybe_write_checkpoint( + all_stats_dicts = self._maybe_write_checkpoint( epoch_idx, edge_path_idx, edge_chunk_idx, current_index ) + if config.early_stopping: + assert iteration_manager.num_edge_paths == 1 + chunk_loss = all_stats_dicts[-1]["eval_stats_chunk_avg"]["metrics"]["loss"] + if chunk_loss > last_chunk_loss: + break + last_chunk_loss = chunk_loss + # now we're sure that all partition files exist, # so be strict about loading them self.strict = True @@ -922,7 +930,7 @@ def _maybe_write_checkpoint( edge_path_idx: int, edge_chunk_idx: int, current_index: int, - ) -> None: + ) -> List[Dict[str, Any]]: config = self.config @@ -963,42 +971,43 @@ def _maybe_write_checkpoint( state_dict, self.trainer.model_optimizer.state_dict() ) - logger.info("Writing the training stats") - all_stats_dicts: List[Dict[str, Any]] = [] - bucket_eval_stats_list = [] - chunk_stats_dict = { - "epoch_idx": epoch_idx, - "edge_path_idx": edge_path_idx, - "edge_chunk_idx": edge_chunk_idx, + all_stats_dicts: List[Dict[str, Any]] = [] + bucket_eval_stats_list = [] + chunk_stats_dict = { + "epoch_idx": epoch_idx, + "edge_path_idx": edge_path_idx, + "edge_chunk_idx": edge_chunk_idx, + } + for stats in self.bucket_scheduler.get_stats_for_pass(): + stats_dict = { + "lhs_partition": stats.lhs_partition, + "rhs_partition": stats.rhs_partition, + "index": stats.index, + "stats": stats.train.to_dict(), } - for stats in self.bucket_scheduler.get_stats_for_pass(): - stats_dict = { - "lhs_partition": stats.lhs_partition, - "rhs_partition": stats.rhs_partition, - "index": stats.index, - "stats": stats.train.to_dict(), - } - if stats.eval_before is not None: - stats_dict["eval_stats_before"] = stats.eval_before.to_dict() - bucket_eval_stats_list.append(stats.eval_before) - - if stats.eval_after is not None: - stats_dict["eval_stats_after"] = stats.eval_after.to_dict() - - stats_dict.update(chunk_stats_dict) - all_stats_dicts.append(stats_dict) - - if len(bucket_eval_stats_list) != 0: - eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list) - self.stats_handler.on_stats( - index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg - ) - chunk_stats_dict["index"] = current_index - chunk_stats_dict[ - "eval_stats_chunk_avg" - ] = eval_stats_chunk_avg.to_dict() - all_stats_dicts.append(chunk_stats_dict) + if stats.eval_before is not None: + stats_dict["eval_stats_before"] = stats.eval_before.to_dict() + bucket_eval_stats_list.append(stats.eval_after) + + if stats.eval_after is not None: + stats_dict["eval_stats_after"] = stats.eval_after.to_dict() + stats_dict.update(chunk_stats_dict) + all_stats_dicts.append(stats_dict) + + if len(bucket_eval_stats_list) != 0: + eval_stats_chunk_avg = Stats.average_list(bucket_eval_stats_list) + chunk_stats_dict["index"] = current_index + chunk_stats_dict[ + "eval_stats_chunk_avg" + ] = eval_stats_chunk_avg.to_dict() + all_stats_dicts.append(chunk_stats_dict) + + if self.rank == 0: + logger.info("Writing the training stats") + self.stats_handler.on_stats( + index=current_index, eval_stats_chunk_avg=eval_stats_chunk_avg + ) self.checkpoint_manager.append_stats(all_stats_dicts) logger.info("Writing the checkpoint") @@ -1029,3 +1038,5 @@ def _maybe_write_checkpoint( self.checkpoint_manager.preserve_current_version(config, epoch_idx + 1) if not preserve_old_checkpoint: self.checkpoint_manager.remove_old_version(config) + + return all_stats_dicts