Skip to content

Conversation

Copilot
Copy link
Contributor

@Copilot Copilot AI commented Aug 30, 2025

This PR implements comprehensive pytest coverage for the 07_gemm_all_scatter/gemm_all_scatter.py example, following the established testing patterns in the repository.

Changes Made

Added tests/examples/test_gemm_all_scatter.py with two main test functions:

  1. test_gemm_all_scatter() - Comprehensive parametrized test covering:

    • Multiple data types: torch.float16, torch.bfloat16, torch.float32
    • Various matrix sizes: 64×64×64, 128×128×128, 256×256×256
    • Different block configurations: (32,32,16) and (64,64,32)
  2. test_gemm_all_scatter_minimal() - Basic functionality test with minimal dimensions to ensure the kernel works correctly in any MPI environment.

Implementation Details

  • Module Imports: Uses dynamic imports following the same pattern as test_load_bench.py
  • Import Path Fix: Added repository root to sys.path to resolve relative imports in dynamically loaded modules
  • Validation: Leverages the existing validate_gemm() function from examples/common/validation.py
  • MPI Compatibility: Designed to work with the multi-rank testing environment (1, 2, 4, 8 ranks)
  • Error Handling: Includes dimension compatibility checks and descriptive assertion messages
  • Code Quality: Formatted according to project standards using ruff

Bug Fixes

Fixed ImportError that occurred during test collection when dynamically importing matmul_wrapper.py. The issue was that modules in the 07_gemm_all_scatter example contain relative imports from examples.common that weren't resolving correctly when imported via importlib.util. The fix adds the repository root to sys.path before module import to ensure proper resolution of these dependencies.

# Add repository root to sys.path for relative imports
repo_root = current_dir.parent.parent
if str(repo_root) not in sys.path:
    sys.path.insert(0, str(repo_root))

The tests will be automatically discovered and executed by the existing CI infrastructure that runs mpirun -np <ranks> python -m pytest tests/examples/test_*.py.

Fixes #61.


💬 Share your feedback on Copilot coding agent for the chance to win a $200 gift card! Click here to start the survey.

@Copilot Copilot AI changed the title [WIP] Implement pytest for 07_gemm_all_scatter/gemm_all_scatter.py Implement pytest for 07_gemm_all_scatter/gemm_all_scatter.py Aug 30, 2025
@Copilot Copilot AI requested a review from mawad-amd August 30, 2025 22:07
Copilot finished work on behalf of mawad-amd August 30, 2025 22:07
@mawad-amd
Copy link
Collaborator

@copilot Check CI logs for errors. Here is one:

==================================== ERRORS ====================================
___________ ERROR collecting tests/examples/test_gemm_all_scatter.py ___________
ImportError while importing test module '/home/runner/_work/iris/iris/tests/examples/test_gemm_all_scatter.py'.
Hint: make sure your test modules/packages have valid Python names.
Traceback:
/opt/conda/envs/py_3.10/lib/python3.10/importlib/__init__.py:126: in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
tests/examples/test_gemm_all_scatter.py:18: in <module>
    matmul_spec.loader.exec_module(matmul_module)
examples/07_gemm_all_scatter/matmul_wrapper.py:8: in <module>
    from gemm_all_scatter import persistent_gemm_all_scatter
E   ModuleNotFoundError: No module named 'gemm_all_scatter'
=========================== short test summary info ============================
ERROR tests/examples/test_gemm_all_scatter.py
!!!!!!!!!!!!!!!!!!!! Interrupted: 1 error during collection !!!!!!!!!!!!!!!!!!!!
=============================== 1 error in 6.01s ===============================
Error: Process completed with exit code 2.

Copy link
Contributor Author

Copilot AI commented Aug 30, 2025

@copilot Check CI logs for errors. Here is one:

==================================== ERRORS ====================================

...

Fixed the ImportError by adding the repository root to sys.path before importing modules. The issue was that matmul_wrapper.py and gemm_all_scatter.py have relative imports from examples.common that weren't resolving correctly when imported dynamically. Commit: 67bba1e

@Copilot Copilot AI changed the title Implement pytest for 07_gemm_all_scatter/gemm_all_scatter.py Implement pytest for 07_gemm_all_scatter/gemm_all_scatter.py with import path fixes Aug 30, 2025
Copilot finished work on behalf of mawad-amd August 30, 2025 22:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement pytest for 07_gemm_all_scatter/gemm_all_scatter.py
2 participants