Skip to content

Commit 294eced

Browse files
committed
Make mock session return predictions for multiple inputs
1 parent 28e00a0 commit 294eced

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

backend/tests/test_api.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
# Mock the ONNX session before importing app
99
mock_session = MagicMock()
1010
mock_session.get_inputs.return_value = [MagicMock(name="float_input")]
11-
mock_session.run.return_value = [np.array([[0.3, 0.7]])]
11+
# Return predictions for multiple inputs
12+
mock_session.run.return_value = [np.array([[0.3, 0.7], [0.3, 0.7]])]
1213

1314
with patch("onnxruntime.InferenceSession", return_value=mock_session):
1415
from app.main import app
@@ -46,7 +47,6 @@ def test_predict_single_valid():
4647
assert "prediction" in data
4748
assert "probability" in data
4849
assert data["smiles"] == VALID_SMILES
49-
# Test the probability is between 0 and 1
5050
assert 0 <= data["probability"] <= 1
5151

5252

@@ -57,20 +57,27 @@ def test_predict_single_invalid():
5757
assert data["error"] is not None
5858

5959

60-
@pytest.mark.asyncio
61-
async def test_predict_batch():
62-
csv_content = "smiles\n" + VALID_SMILES + "\n" + VALID_SMILES
63-
response = client.post("/predict/batch", files={"file": ("test.csv", csv_content.encode(), "text/csv")})
60+
def test_predict_batch():
61+
csv_content = f"smiles\n{VALID_SMILES}\n{VALID_SMILES}"
62+
files = {"file": ("test.csv", csv_content.encode(), "text/csv")}
63+
64+
# Print request details for debugging
65+
print(f"\nTest CSV Content:\n{csv_content}")
66+
67+
response = client.post("/predict/batch", files=files)
6468
assert response.status_code == 200
6569
data = response.json()
66-
assert len(data) == 2
70+
71+
# Print response for debugging
72+
print(f"\nResponse data:\n{data}")
73+
74+
assert len(data) == 2, f"Expected 2 predictions, got {len(data)}"
6775
assert all("prediction" in item for item in data)
6876
assert all("probability" in item for item in data)
6977
assert all(0 <= item["probability"] <= 1 for item in data)
7078

7179

72-
@pytest.mark.asyncio
73-
async def test_predict_batch_invalid_file():
80+
def test_predict_batch_invalid_file():
7481
response = client.post("/predict/batch", files={"file": ("test.txt", b"not a csv", "text/plain")})
7582
assert response.status_code == 400
7683
assert "Only CSV files are supported" in response.json()["detail"]

0 commit comments

Comments
 (0)