Skip to content
This repository was archived by the owner on May 6, 2025. It is now read-only.

Commit 5bb274c

Browse files
committed
Bring back the default 10-cases per test setting for NT tests.
PiperOrigin-RevId: 440451840
1 parent 25903a6 commit 5bb274c

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ We happily welcome contributions!
3535

3636

3737

38+
3839
## Contents
3940
* [Colab Notebooks](#colab-notebooks)
4041
* [Installation](#installation)

tests/test_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,34 @@
1414

1515
"""Utilities for testing."""
1616

17-
from typing import Any, Sequence, Dict
1817
import dataclasses
1918
import logging
19+
import os
20+
from typing import Dict, Sequence
2021

2122
from absl import flags
2223
from absl.testing import parameterized
23-
2424
import jax
2525
from jax import config
26+
from jax import dtypes as _dtypes
2627
from jax import jit
2728
from jax import vmap
2829
import jax.numpy as np
2930
import numpy as onp
30-
from jax import dtypes as _dtypes
3131

32-
flags.DEFINE_integer('num_generated_cases', 10000,
33-
'The maximum number of test cases in combinatorial tests.')
34-
flags.DEFINE_string('jax_test_dut', None,
35-
'')
32+
33+
flags.DEFINE_string(
34+
'jax_test_dut',
35+
'',
36+
help=
37+
'Describes the device under test in case special consideration is required.'
38+
)
39+
40+
flags.DEFINE_integer(
41+
'num_generated_cases',
42+
int(os.getenv('JAX_NUM_GENERATED_CASES', '10')),
43+
help='Number of generated cases to test'
44+
)
3645

3746
FLAGS = flags.FLAGS
3847

0 commit comments

Comments
 (0)