Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed `ogbn_train_cugraph` example for distributed cuGraph ([#10439](https://github.com/pyg-team/pytorch_geometric/pull/10439))
- Added `safe_onnx_export` function with workarounds for `onnx_ir.serde.SerdeError` issues in ONNX export ([#10422](https://github.com/pyg-team/pytorch_geometric/pull/10422))
- Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404), [#10417](https://github.com/pyg-team/pytorch_geometric/pull/10417)))
- Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357))
- Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389))
Expand Down
20 changes: 4 additions & 16 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,10 @@ def test_packaging():
@onlyLinux
@withPackage('torch>=2.6.0')
@withPackage('onnx', 'onnxruntime', 'onnxscript')
def test_onnx(tmp_path: str) -> None:
def test_onnx(tmp_path):
import onnx
import onnxruntime as ort

from torch_geometric import safe_onnx_export

warnings.filterwarnings('ignore', '.*tensor to a Python boolean.*')
warnings.filterwarnings('ignore', '.*shape inference of prim::Constant.*')

Expand All @@ -278,27 +276,17 @@ def forward(self, x, edge_index):
assert expected.size() == (3, 16)

path = osp.join(tmp_path, 'model.onnx')
success = safe_onnx_export(
torch.onnx.export(
model,
(x, edge_index),
path,
input_names=('x', 'edge_index'),
opset_version=18,
dynamo=True, # False is deprecated by PyTorch
skip_on_error=True, # Skip gracefully in CI if upstream issue occurs
)

if not success:
# ONNX export was skipped due to known upstream issue
# This allows CI to pass while the upstream bug exists
warnings.warn(
"ONNX export test skipped due to known upstream onnx_ir issue. "
"This is expected and does not indicate a problem with PyTorch "
"Geometric.", UserWarning, stacklevel=2)
return

onnx_model = onnx.load(path)
onnx.checker.check_model(onnx_model)
model = onnx.load(path)
onnx.checker.check_model(model)

providers = ['CPUExecutionProvider']
ort_session = ort.InferenceSession(path, providers=providers)
Expand Down
Loading
Loading