Skip to content
30 changes: 24 additions & 6 deletions test/graphgym/test_graphgym.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
set_run_dir,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import LoggerCallback, set_printing
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.models.gnn import FeatureEncoder, GNNStackStage
from torch_geometric.graphgym.models.head import GNNNodeHead
Expand Down Expand Up @@ -194,12 +194,30 @@ def test_train(destroy_process_group, tmp_path, capfd):
loaders = create_loader()
model = create_model()
cfg.params = params_count(model)

# --- minimal logger callback that collects logs ---
class LoggerCallback(pl.Callback):
def __init__(self):
super().__init__()
self.logged = []

def on_train_batch_end(self, trainer, pl_module, outputs, batch,
batch_idx):
self.logged.append({"type": "train", "step": trainer.global_step})

def on_validation_batch_end(self, trainer, pl_module, outputs, batch,
batch_idx, dataloader_idx=0):
self.logged.append({"type": "val", "step": trainer.global_step})

logger = LoggerCallback()
trainer = pl.Trainer(max_epochs=1, max_steps=4, callbacks=logger,
log_every_n_steps=1)
trainer = pl.Trainer(max_epochs=1, max_steps=4, callbacks=[logger],
log_every_n_steps=1, enable_progress_bar=False)
train_loader, val_loader = loaders[0], loaders[1]
trainer.fit(model, train_loader, val_loader)

out, err = capfd.readouterr()
assert 'Sanity Checking' in out
assert 'Epoch 0:' in out
assert not trainer.sanity_checking
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand this diff correctly, this assert not trainer.sanity_checking tests nothing. Previously, it tested that the code ran sanity checking, but with this PR, it doesn't. sanity_checking is only True when the trainer is running a few validation steps at the start of fit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just remove this assert? Testing whether it performed sanity checking doens't make much sense as a test anyway.

Copy link
Contributor Author

@drivanov drivanov Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was suggested by ChatGPT. The following explanation is adapted from its reasoning:

     assert not trainer.sanity_checking

could fail (i.e., trainer.sanity_checking == True after trainer.fit() finishes).

What would cause it

If training never moved past the sanity check

  • Example: an error in the training loop caused Lightning to stop during sanity checking.
  • Then trainer.fit() would return early, and trainer.sanity_checking could remain True.

If sanity checking was the only phase run

  • If max_steps=0 or max_epochs=0, training effectively does nothing after the sanity check.
  • Depending on version, trainer.sanity_checking might stay True.

Misconfigured validation loop

  • If validation loader is empty or raises errors, Lightning can get stuck in a state where it ends sanity checking improperly.
  • Edge cases in earlier Lightning versions sometimes left the flag not reset.

Bug in Lightning

  • If the sanity_checking flag isn’t toggled back (it’s set True at the start of sanity check, False afterward).
  • Rare, but regressions could happen across versions.

Normal expectation

  • After a successful trainer.fit(), training always goes past sanity check, so trainer.sanity_checking must be False.
  • Therefore, this assertion is a robust guarantee that training ran at least one epoch/step beyond the pre-validation check.

So in short:

It would fail only if trainer.fit() exited abnormally (error, misconfig, or 0 training steps) or Lightning had a bug in resetting the flag.

As we all know, ChatGPT isn't always accurate. Unfortunately, in this case, I am unable to assess the accuracy of its claims. Please advise.

assert trainer.current_epoch >= 0
# ensure both train and val batches were seen
types = {entry["type"] for entry in logger.logged}
assert "val" in types, "Validation did not run"
assert "train" in types, "Training did not run"
Loading