From 419494fae34f77018e1b3268904db7d2d31b1be2 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 5 Sep 2025 20:42:34 +0000 Subject: [PATCH] Revert "Fix/onnx serde error workaround (#10422)" This reverts commit dcc4a7aad2e91a76db37b49e9000912846ca9eb4. --- CHANGELOG.md | 1 - test/nn/models/test_basic_gnn.py | 20 +- test/test_onnx.py | 409 ------------------------------- torch_geometric/__init__.py | 3 +- torch_geometric/_onnx.py | 214 ---------------- 5 files changed, 5 insertions(+), 642 deletions(-) delete mode 100644 test/test_onnx.py diff --git a/CHANGELOG.md b/CHANGELOG.md index a590afe5e190..756bfcf6fc30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed `ogbn_train_cugraph` example for distributed cuGraph ([#10439](https://github.com/pyg-team/pytorch_geometric/pull/10439)) -- Added `safe_onnx_export` function with workarounds for `onnx_ir.serde.SerdeError` issues in ONNX export ([#10422](https://github.com/pyg-team/pytorch_geometric/pull/10422)) - Fixed importing PyTorch Lightning in `torch_geometric.graphgym` and `torch_geometric.data.lightning` when using `lightning` instead of `pytorch-lightning` ([#10404](https://github.com/pyg-team/pytorch_geometric/pull/10404), [#10417](https://github.com/pyg-team/pytorch_geometric/pull/10417))) - Fixed `detach()` warnings in example scripts involving tensor conversions ([#10357](https://github.com/pyg-team/pytorch_geometric/pull/10357)) - Fixed non-tuple indexing to resolve PyTorch deprecation warning ([#10389](https://github.com/pyg-team/pytorch_geometric/pull/10389)) diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index b7430e95994e..90f055bd7566 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -251,12 +251,10 @@ def test_packaging(): @onlyLinux @withPackage('torch>=2.6.0') @withPackage('onnx', 'onnxruntime', 'onnxscript') -def test_onnx(tmp_path: str) -> None: +def test_onnx(tmp_path): import onnx import onnxruntime as ort - from torch_geometric import safe_onnx_export - warnings.filterwarnings('ignore', '.*tensor to a Python boolean.*') warnings.filterwarnings('ignore', '.*shape inference of prim::Constant.*') @@ -278,27 +276,17 @@ def forward(self, x, edge_index): assert expected.size() == (3, 16) path = osp.join(tmp_path, 'model.onnx') - success = safe_onnx_export( + torch.onnx.export( model, (x, edge_index), path, input_names=('x', 'edge_index'), opset_version=18, dynamo=True, # False is deprecated by PyTorch - skip_on_error=True, # Skip gracefully in CI if upstream issue occurs ) - if not success: - # ONNX export was skipped due to known upstream issue - # This allows CI to pass while the upstream bug exists - warnings.warn( - "ONNX export test skipped due to known upstream onnx_ir issue. " - "This is expected and does not indicate a problem with PyTorch " - "Geometric.", UserWarning, stacklevel=2) - return - - onnx_model = onnx.load(path) - onnx.checker.check_model(onnx_model) + model = onnx.load(path) + onnx.checker.check_model(model) providers = ['CPUExecutionProvider'] ort_session = ort.InferenceSession(path, providers=providers) diff --git a/test/test_onnx.py b/test/test_onnx.py deleted file mode 100644 index 0c219e51ed52..000000000000 --- a/test/test_onnx.py +++ /dev/null @@ -1,409 +0,0 @@ -import os -import tempfile -import warnings -from typing import Any -from unittest.mock import patch - -import pytest -import torch - -from torch_geometric import is_in_onnx_export, safe_onnx_export - -# Global mock to prevent ANY real ONNX calls in tests -# This ensures no deprecation warnings or real ONNX issues -pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning") - - -class SimpleModel(torch.nn.Module): - """Simple model for testing ONNX export.""" - def __init__(self) -> None: - super().__init__() - self.linear = torch.nn.Linear(4, 2) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.linear(x) - - -def test_is_in_onnx_export() -> None: - """Test is_in_onnx_export function.""" - assert not is_in_onnx_export() - - -def test_safe_onnx_export_ci_resilient() -> None: - """Test safe_onnx_export handles CI environment issues gracefully.""" - model = SimpleModel() - x = torch.randn(3, 4) - - # Use mocking to prevent real ONNX calls and deprecation warnings - with patch('torch.onnx.export', return_value=None) as mock_export: - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - # Test with skip_on_error=True - should never fail - result = safe_onnx_export(model, (x, ), f.name, - skip_on_error=True) - # Should always succeed with mocking - assert result is True - - # Verify the mock was called correctly - mock_export.assert_called_once() - call_args = mock_export.call_args[0] - assert call_args[0] is model - assert isinstance(call_args[1], tuple) - assert call_args[2] == f.name - - finally: - if os.path.exists(f.name): - try: - os.unlink(f.name) - except (PermissionError, OSError): - pass # Ignore file lock issues - - -def test_safe_onnx_export_success() -> None: - """Test successful ONNX export with pure mocking.""" - model = SimpleModel() - x = torch.randn(3, 4) - - # Use comprehensive mocking to avoid any real ONNX calls - with patch('torch.onnx.export', return_value=None) as mock_export: - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - # Test with tuple args - should succeed with mock - result = safe_onnx_export(model, (x, ), f.name) - assert result is True - - # Verify torch.onnx.export was called with correct args - mock_export.assert_called() - call_args = mock_export.call_args[0] - assert call_args[0] is model # model - assert isinstance(call_args[1], tuple) # args as tuple - assert call_args[2] == f.name # file path - - # Reset mock for second test - mock_export.reset_mock() - - # Test with single tensor (should be converted to tuple) - result = safe_onnx_export(model, x, f.name) - assert result is True - - # Verify single tensor was converted to tuple - call_args = mock_export.call_args[0] - assert isinstance(call_args[1], tuple) - - finally: - if os.path.exists(f.name): - try: - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - except (PermissionError, OSError): - pass - - -def test_safe_onnx_export_with_skip_on_error() -> None: - """Test safe_onnx_export with skip_on_error=True.""" - model = SimpleModel() - x = torch.randn(3, 4) - - # Mock torch.onnx.export to raise SerdeError - with patch('torch.onnx.export') as mock_export: - mock_export.side_effect = Exception( - "onnx_ir.serde.SerdeError: allowzero") - - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - # Should return False instead of raising - result = safe_onnx_export(model, (x, ), f.name, - skip_on_error=True) - assert result is False - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -def test_serde_error_patterns() -> None: - """Test detection of various SerdeError patterns.""" - model = SimpleModel() - x = torch.randn(3, 4) - - error_patterns = [ - "onnx_ir.serde.SerdeError: allowzero attribute", - "ValueError: Value out of range: 1", "serialize_model_into failed", - "serialize_attribute_into failed" - ] - - for error_msg in error_patterns: - # Use multiple patch targets to ensure comprehensive mocking - with patch('torch.onnx.export') as mock_export, \ - patch('torch_geometric._onnx.torch.onnx.export') as mock_export2: - - mock_export.side_effect = Exception(error_msg) - mock_export2.side_effect = Exception(error_msg) - - with tempfile.NamedTemporaryFile(suffix='.onnx', - delete=False) as f: - try: - result = safe_onnx_export(model, (x, ), f.name, - skip_on_error=True) - assert result is False - finally: - if os.path.exists(f.name): - try: - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - except (PermissionError, OSError): - pass # Ignore file lock issues - - -def test_non_serde_error_reraise() -> None: - """Test that non-SerdeError exceptions are re-raised.""" - model = SimpleModel() - x = torch.randn(3, 4) - - # Use comprehensive mocking to prevent real ONNX calls - with patch('torch.onnx.export') as mock_export: - mock_export.side_effect = ValueError("Some other error") - - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - with pytest.raises(ValueError, match="Some other error"): - safe_onnx_export(model, (x, ), f.name) - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -def test_dynamo_fallback() -> None: - """Test dynamo=False fallback strategy.""" - model = SimpleModel() - x = torch.randn(3, 4) - - call_count = 0 - - def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - # First call fails - raise Exception("onnx_ir.serde.SerdeError: allowzero") - elif call_count == 2 and not kwargs.get('dynamo', True): - # Second call succeeds with dynamo=False - return None - else: - raise Exception("Unexpected call") - - with patch('torch.onnx.export', side_effect=mock_export_side_effect): - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - result = safe_onnx_export(model, (x, ), f.name, dynamo=True) - assert result is True - assert call_count == 2 - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -def test_opset_fallback() -> None: - """Test opset version fallback strategy.""" - model = SimpleModel() - x = torch.randn(3, 4) - - call_count = 0 - - def mock_export_side_effect(*_args: Any, **kwargs: Any) -> None: - nonlocal call_count - call_count += 1 - # Fail until we get to opset_version=17 - if kwargs.get('opset_version') == 17: - # This call succeeds - return None - else: - # All other calls fail - raise Exception("onnx_ir.serde.SerdeError: allowzero") - - with patch('torch.onnx.export', side_effect=mock_export_side_effect): - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - result = safe_onnx_export(model, (x, ), f.name, - opset_version=18) - # Should succeed when opset_version=17 is tried - assert result is True - finally: - if os.path.exists(f.name): - try: - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - except (PermissionError, OSError): - pass - - -def test_all_strategies_fail() -> None: - """Test when all workaround strategies fail.""" - model = SimpleModel() - x = torch.randn(3, 4) - - with patch('torch.onnx.export') as mock_export: - mock_export.side_effect = Exception( - "onnx_ir.serde.SerdeError: allowzero") - - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - # Should raise RuntimeError when skip_on_error=False - with pytest.raises(RuntimeError, - match="Failed to export model to ONNX"): - safe_onnx_export(model, (x, ), f.name, skip_on_error=False) - - # Should return False when skip_on_error=True - result = safe_onnx_export(model, (x, ), f.name, - skip_on_error=True) - assert result is False - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -def test_pytest_environment_detection() -> None: - """Test pytest environment detection for better error messages.""" - model = SimpleModel() - x = torch.randn(3, 4) - - with patch('torch.onnx.export') as mock_export: - mock_export.side_effect = Exception( - "onnx_ir.serde.SerdeError: allowzero") - - # Set pytest environment variable - with patch.dict(os.environ, {'PYTEST_CURRENT_TEST': 'test_something'}): - with tempfile.NamedTemporaryFile(suffix='.onnx', - delete=False) as f: - try: - with pytest.raises(RuntimeError) as exc_info: - safe_onnx_export(model, (x, ), f.name, - skip_on_error=False) - - # Should contain pytest-specific guidance - assert "pytest environments" in str(exc_info.value) - assert "torch.jit.script()" in str(exc_info.value) - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -def test_warnings_emitted() -> None: - """Test that appropriate warnings are emitted during workarounds.""" - model = SimpleModel() - x = torch.randn(3, 4) - - call_count = 0 - - def mock_export_side_effect(*_args: Any, **_kwargs: Any) -> None: - nonlocal call_count - call_count += 1 - if call_count == 1: - raise Exception("onnx_ir.serde.SerdeError: allowzero") - elif call_count == 2: - return None # Success on dynamo fallback - else: - raise Exception("Unexpected call") - - with patch('torch.onnx.export', side_effect=mock_export_side_effect): - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - result = safe_onnx_export(model, (x, ), f.name, - dynamo=True) - - assert result is True - assert len(w) >= 2 # Initial error + dynamo fallback - assert any("allowzero boolean attribute bug" in str( - warning.message) for warning in w) - assert any( - "dynamo=False as workaround" in str(warning.message) - for warning in w) - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass - - -@pytest.mark.parametrize( - "args_input", - [ - torch.randn(3, 4), # Single tensor - (torch.randn(3, 4), ), # Tuple with one tensor - (torch.randn(3, 4), torch.randn(3, 2)), # Tuple with multiple tensors - ]) -def test_args_conversion(args_input: Any) -> None: - """Test that args are properly converted to tuple format.""" - model = SimpleModel() - - with patch('torch.onnx.export') as mock_export: - mock_export.return_value = None - - with tempfile.NamedTemporaryFile(suffix='.onnx', delete=False) as f: - try: - result = safe_onnx_export(model, args_input, f.name) - assert result is True - - # Check that torch.onnx.export was called with tuple args - mock_export.assert_called_once() - call_args = mock_export.call_args[0] - assert isinstance(call_args[1], tuple) # args should be tuple - finally: - if os.path.exists(f.name): - try: - - os.unlink(f.name) - - except (PermissionError, OSError): - - pass diff --git a/torch_geometric/__init__.py b/torch_geometric/__init__.py index 0fd6ece95216..40aadb5a2268 100644 --- a/torch_geometric/__init__.py +++ b/torch_geometric/__init__.py @@ -4,7 +4,7 @@ import torch_geometric.typing from ._compile import compile, is_compiling -from ._onnx import is_in_onnx_export, safe_onnx_export +from ._onnx import is_in_onnx_export from .index import Index from .edge_index import EdgeIndex from .hash_tensor import HashTensor @@ -43,7 +43,6 @@ 'compile', 'is_compiling', 'is_in_onnx_export', - 'safe_onnx_export', 'is_mps_available', 'is_xpu_available', 'device', diff --git a/torch_geometric/_onnx.py b/torch_geometric/_onnx.py index 5560342976d3..fdfde0a0eac1 100644 --- a/torch_geometric/_onnx.py +++ b/torch_geometric/_onnx.py @@ -1,7 +1,3 @@ -import warnings -from os import PathLike -from typing import Any, Union - import torch from torch_geometric import is_compiling @@ -16,213 +12,3 @@ def is_in_onnx_export() -> bool: if torch.jit.is_scripting(): return False return torch.onnx.is_in_onnx_export() - - -def safe_onnx_export( - model: torch.nn.Module, - args: Union[torch.Tensor, tuple[Any, ...]], - f: Union[str, PathLike[Any], None], - skip_on_error: bool = False, - **kwargs: Any, -) -> bool: - r"""A safe wrapper around :meth:`torch.onnx.export` that handles known - ONNX serialization issues in PyTorch Geometric. - - This function provides workarounds for the ``onnx_ir.serde.SerdeError`` - with boolean ``allowzero`` attributes that occurs in certain environments. - - Args: - model (torch.nn.Module): The model to export. - args (torch.Tensor or tuple): The input arguments for the model. - f (str or PathLike): The file path to save the model. - skip_on_error (bool): If True, return False instead of raising when - workarounds fail. Useful for CI environments. - **kwargs: Additional arguments passed to :meth:`torch.onnx.export`. - - Returns: - bool: True if export succeeded, False if skipped due to known issues - (only when skip_on_error=True). - - Example: - >>> from torch_geometric.nn import SAGEConv - >>> from torch_geometric import safe_onnx_export - >>> - >>> class MyModel(torch.nn.Module): - ... def __init__(self): - ... super().__init__() - ... self.conv = SAGEConv(8, 16) - ... def forward(self, x, edge_index): - ... return self.conv(x, edge_index) - >>> - >>> model = MyModel() - >>> x = torch.randn(3, 8) - >>> edge_index = torch.tensor([[0, 1, 2], [1, 0, 2]]) - >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx') - >>> - >>> # For CI environments: - >>> success = safe_onnx_export(model, (x, edge_index), 'model.onnx', - ... skip_on_error=True) - >>> if not success: - ... print("ONNX export skipped due to known upstream issue") - """ - # Convert single tensor to tuple for torch.onnx.export compatibility - if isinstance(args, torch.Tensor): - args = (args, ) - - try: - # First attempt: standard ONNX export - torch.onnx.export(model, args, f, **kwargs) - return True - - except Exception as e: - error_str = str(e) - error_type = type(e).__name__ - - # Check for the specific onnx_ir.serde.SerdeError patterns - is_allowzero_error = (('onnx_ir.serde.SerdeError' in error_str - and 'allowzero' in error_str) or - 'ValueError: Value out of range: 1' in error_str - or 'serialize_model_into' in error_str - or 'serialize_attribute_into' in error_str) - - if is_allowzero_error: - warnings.warn( - f"Encountered known ONNX serialization issue ({error_type}). " - "This is likely the allowzero boolean attribute bug. " - "Attempting workaround...", UserWarning, stacklevel=2) - - # Apply workaround strategies - return _apply_onnx_allowzero_workaround(model, args, f, - skip_on_error, **kwargs) - - else: - # Re-raise other errors - raise - - -def _apply_onnx_allowzero_workaround( - model: torch.nn.Module, - args: tuple[Any, ...], - f: Union[str, PathLike[Any], None], - skip_on_error: bool = False, - **kwargs: Any, -) -> bool: - r"""Apply workaround strategies for onnx_ir.serde.SerdeError with allowzero - attributes. - - Returns: - bool: True if export succeeded, False if skipped (when - skip_on_error=True). - """ - # Strategy 1: Try without dynamo if it was enabled - if kwargs.get('dynamo', False): - try: - kwargs_no_dynamo = kwargs.copy() - kwargs_no_dynamo['dynamo'] = False - - warnings.warn( - "Retrying ONNX export with dynamo=False as workaround", - UserWarning, stacklevel=3) - - torch.onnx.export(model, args, f, **kwargs_no_dynamo) - return True - - except Exception: - pass - - # Strategy 2: Try with different opset versions - original_opset = kwargs.get('opset_version', 18) - for opset_version in [17, 16, 15, 14, 13, 11]: - if opset_version != original_opset: - try: - kwargs_opset = kwargs.copy() - kwargs_opset['opset_version'] = opset_version - - warnings.warn( - f"Retrying ONNX export with opset_version={opset_version}", - UserWarning, stacklevel=3) - - torch.onnx.export(model, args, f, **kwargs_opset) - return True - - except Exception: - continue - - # Strategy 3: Try legacy export (non-dynamo with older opset) - try: - kwargs_legacy = kwargs.copy() - kwargs_legacy['dynamo'] = False - kwargs_legacy['opset_version'] = 11 - - warnings.warn( - "Retrying ONNX export with legacy settings " - "(dynamo=False, opset_version=11)", UserWarning, stacklevel=3) - - torch.onnx.export(model, args, f, **kwargs_legacy) - return True - - except Exception: - pass - - # Strategy 4: Try with minimal settings - try: - minimal_kwargs: dict[str, Any] = { - 'opset_version': 11, - 'dynamo': False, - } - # Add optional parameters if they exist - if kwargs.get('input_names') is not None: - minimal_kwargs['input_names'] = kwargs.get('input_names') - if kwargs.get('output_names') is not None: - minimal_kwargs['output_names'] = kwargs.get('output_names') - - warnings.warn( - "Retrying ONNX export with minimal settings as last resort", - UserWarning, stacklevel=3) - - torch.onnx.export(model, args, f, **minimal_kwargs) - return True - - except Exception: - pass - - # If all strategies fail, handle based on skip_on_error flag - import os - pytest_detected = 'PYTEST_CURRENT_TEST' in os.environ or 'pytest' in str(f) - - if skip_on_error: - # For CI environments: skip gracefully instead of failing - warnings.warn( - "ONNX export skipped due to known upstream issue " - "(onnx_ir.serde.SerdeError). " - "This is caused by a bug in the onnx_ir package where boolean " - "allowzero attributes cannot be serialized. All workarounds " - "failed. Consider updating packages: pip install --upgrade onnx " - "onnxscript " - "onnx_ir", UserWarning, stacklevel=3) - return False - - # For regular usage: provide detailed error message - error_msg = ( - "Failed to export model to ONNX due to known serialization issue. " - "This is caused by a bug in the onnx_ir package where boolean " - "allowzero attributes cannot be serialized. " - "Workarounds attempted: dynamo=False, multiple opset versions, " - "and legacy export. ") - - if pytest_detected: - error_msg += ( - "\n\nThis error commonly occurs in pytest environments. " - "Try one of these solutions:\n" - "1. Run the export outside of pytest (in a regular Python " - "script)\n" - "2. Update packages: pip install --upgrade onnx onnxscript " - "onnx_ir\n" - "3. Use torch.jit.script() instead of ONNX export for testing\n" - "4. Use safe_onnx_export(..., skip_on_error=True) to skip " - "gracefully in CI") - else: - error_msg += ("\n\nTry updating packages: pip install --upgrade onnx " - "onnxscript onnx_ir") - - raise RuntimeError(error_msg)