8
8
# Mock the ONNX session before importing app
9
9
mock_session = MagicMock ()
10
10
mock_session .get_inputs .return_value = [MagicMock (name = "float_input" )]
11
- mock_session .run .return_value = [np .array ([[0.3 , 0.7 ]])]
11
+ # Return predictions for multiple inputs
12
+ mock_session .run .return_value = [np .array ([[0.3 , 0.7 ], [0.3 , 0.7 ]])]
12
13
13
14
with patch ("onnxruntime.InferenceSession" , return_value = mock_session ):
14
15
from app .main import app
@@ -46,7 +47,6 @@ def test_predict_single_valid():
46
47
assert "prediction" in data
47
48
assert "probability" in data
48
49
assert data ["smiles" ] == VALID_SMILES
49
- # Test the probability is between 0 and 1
50
50
assert 0 <= data ["probability" ] <= 1
51
51
52
52
@@ -57,20 +57,27 @@ def test_predict_single_invalid():
57
57
assert data ["error" ] is not None
58
58
59
59
60
- @pytest .mark .asyncio
61
- async def test_predict_batch ():
62
- csv_content = "smiles\n " + VALID_SMILES + "\n " + VALID_SMILES
63
- response = client .post ("/predict/batch" , files = {"file" : ("test.csv" , csv_content .encode (), "text/csv" )})
60
+ def test_predict_batch ():
61
+ csv_content = f"smiles\n { VALID_SMILES } \n { VALID_SMILES } "
62
+ files = {"file" : ("test.csv" , csv_content .encode (), "text/csv" )}
63
+
64
+ # Print request details for debugging
65
+ print (f"\n Test CSV Content:\n { csv_content } " )
66
+
67
+ response = client .post ("/predict/batch" , files = files )
64
68
assert response .status_code == 200
65
69
data = response .json ()
66
- assert len (data ) == 2
70
+
71
+ # Print response for debugging
72
+ print (f"\n Response data:\n { data } " )
73
+
74
+ assert len (data ) == 2 , f"Expected 2 predictions, got { len (data )} "
67
75
assert all ("prediction" in item for item in data )
68
76
assert all ("probability" in item for item in data )
69
77
assert all (0 <= item ["probability" ] <= 1 for item in data )
70
78
71
79
72
- @pytest .mark .asyncio
73
- async def test_predict_batch_invalid_file ():
80
+ def test_predict_batch_invalid_file ():
74
81
response = client .post ("/predict/batch" , files = {"file" : ("test.txt" , b"not a csv" , "text/plain" )})
75
82
assert response .status_code == 400
76
83
assert "Only CSV files are supported" in response .json ()["detail" ]
0 commit comments