Skip to content

Commit aaf3cc9

Browse files
committed
Implement type checker compatible Input() function
Convert Input from dataclass to function with overloads following Pydantic's pattern. Maintains exact same syntax for developers while providing full type checker compatibility. Key changes: - Add Input() function with @overload decorators returning Any for type checkers - Create FieldInfo dataclass to store field metadata at runtime - Add api.pyi stub with ParamSpec for BasePredictor signature flexibility - Update inspector.py to use FieldInfo instead of Input dataclass - Add comprehensive test suite (31 unit/functional tests + Go integration) - Fix BaseModel inheritance tests with proper type annotations and getattr usage
1 parent 5a780b2 commit aaf3cc9

File tree

10 files changed

+1753
-34
lines changed

10 files changed

+1753
-34
lines changed

internal/tests/input_function_test.go

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,274 @@
1+
package tests
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/replicate/cog-runtime/internal/server"
14+
)
15+
16+
func TestInputFunctionSchemaGeneration(t *testing.T) {
17+
t.Parallel()
18+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
19+
procedureMode: false,
20+
explicitShutdown: false,
21+
uploadURL: "",
22+
module: "input_function_test",
23+
predictorClass: "Predictor",
24+
})
25+
26+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
27+
28+
resp, err := http.Get(runtimeServer.URL + "/openapi.json")
29+
require.NoError(t, err)
30+
defer resp.Body.Close()
31+
32+
body, err := io.ReadAll(resp.Body)
33+
require.NoError(t, err)
34+
35+
var schema map[string]any
36+
err = json.Unmarshal(body, &schema)
37+
require.NoError(t, err)
38+
39+
assert.Contains(t, schema, "components")
40+
41+
components := schema["components"].(map[string]any)
42+
assert.Contains(t, components, "schemas")
43+
44+
schemas := components["schemas"].(map[string]any)
45+
assert.Contains(t, schemas, "Input")
46+
47+
inputSchema := schemas["Input"].(map[string]any)
48+
assert.Equal(t, "object", inputSchema["type"])
49+
assert.Contains(t, inputSchema, "properties")
50+
assert.Contains(t, inputSchema, "required")
51+
52+
properties := inputSchema["properties"].(map[string]any)
53+
required := inputSchema["required"].([]any)
54+
55+
assert.Contains(t, properties, "message")
56+
assert.Contains(t, required, "message")
57+
messageField := properties["message"].(map[string]any)
58+
assert.Equal(t, "string", messageField["type"])
59+
assert.Equal(t, "Message to process", messageField["description"])
60+
61+
assert.Contains(t, properties, "repeat_count")
62+
assert.NotContains(t, required, "repeat_count")
63+
repeatField := properties["repeat_count"].(map[string]any)
64+
assert.Equal(t, "integer", repeatField["type"])
65+
assert.Equal(t, float64(1), repeatField["default"]) //nolint:testifylint // Checking absolute value not delta
66+
assert.Equal(t, float64(1), repeatField["minimum"]) //nolint:testifylint // Checking absolute value not delta
67+
assert.Equal(t, float64(10), repeatField["maximum"]) //nolint:testifylint // Checking absolute value not delta
68+
69+
assert.Contains(t, properties, "prefix")
70+
prefixField := properties["prefix"].(map[string]any)
71+
assert.Equal(t, "string", prefixField["type"])
72+
assert.Equal(t, "Result: ", prefixField["default"])
73+
assert.Equal(t, float64(1), prefixField["minLength"]) //nolint:testifylint // Checking absolute value not delta
74+
assert.Equal(t, float64(20), prefixField["maxLength"]) //nolint:testifylint // Checking absolute value not delta
75+
76+
assert.Contains(t, properties, "deprecated_option")
77+
deprecatedField := properties["deprecated_option"].(map[string]any)
78+
assert.Equal(t, true, deprecatedField["deprecated"])
79+
}
80+
81+
func TestInputFunctionBasicPrediction(t *testing.T) {
82+
t.Parallel()
83+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
84+
procedureMode: false,
85+
explicitShutdown: false,
86+
uploadURL: "",
87+
module: "input_function_test",
88+
predictorClass: "Predictor",
89+
})
90+
91+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
92+
93+
input := map[string]any{"message": "hello world"}
94+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
95+
96+
resp, err := http.DefaultClient.Do(req)
97+
require.NoError(t, err)
98+
defer resp.Body.Close()
99+
assert.Equal(t, http.StatusOK, resp.StatusCode)
100+
101+
body, err := io.ReadAll(resp.Body)
102+
require.NoError(t, err)
103+
104+
var prediction server.PredictionResponse
105+
err = json.Unmarshal(body, &prediction)
106+
require.NoError(t, err)
107+
108+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
109+
assert.Equal(t, "Result: hello world", prediction.Output)
110+
}
111+
112+
func TestInputFunctionComplexPrediction(t *testing.T) {
113+
t.Parallel()
114+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
115+
procedureMode: false,
116+
explicitShutdown: false,
117+
uploadURL: "",
118+
module: "input_function_test",
119+
predictorClass: "Predictor",
120+
})
121+
122+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
123+
124+
input := map[string]any{
125+
"message": "test message",
126+
"repeat_count": 2,
127+
"format_type": "uppercase",
128+
"prefix": "Output: ",
129+
"suffix": " [END]",
130+
"deprecated_option": "custom",
131+
}
132+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
133+
134+
resp, err := http.DefaultClient.Do(req)
135+
require.NoError(t, err)
136+
defer resp.Body.Close()
137+
assert.Equal(t, http.StatusOK, resp.StatusCode)
138+
139+
body, err := io.ReadAll(resp.Body)
140+
require.NoError(t, err)
141+
142+
var prediction server.PredictionResponse
143+
err = json.Unmarshal(body, &prediction)
144+
require.NoError(t, err)
145+
146+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
147+
assert.Equal(t, "Output: TEST MESSAGE TEST MESSAGE [END]", prediction.Output)
148+
}
149+
150+
func TestInputFunctionConstraintViolations(t *testing.T) {
151+
t.Parallel()
152+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
153+
procedureMode: false,
154+
explicitShutdown: false,
155+
uploadURL: "",
156+
module: "input_function_test",
157+
predictorClass: "Predictor",
158+
})
159+
160+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
161+
162+
testCases := []struct {
163+
name string
164+
input map[string]any
165+
errorMsg string
166+
}{
167+
{
168+
name: "repeat_count too low",
169+
input: map[string]any{"message": "test", "repeat_count": 0},
170+
errorMsg: "fails constraint >= 1",
171+
},
172+
{
173+
name: "repeat_count too high",
174+
input: map[string]any{"message": "test", "repeat_count": 11},
175+
errorMsg: "fails constraint <= 10",
176+
},
177+
{
178+
name: "invalid format_type choice",
179+
input: map[string]any{"message": "test", "format_type": "invalid"},
180+
errorMsg: "does not match choices",
181+
},
182+
{
183+
name: "prefix too short",
184+
input: map[string]any{"message": "test", "prefix": ""},
185+
errorMsg: "fails constraint len() >= 1",
186+
},
187+
{
188+
name: "prefix too long",
189+
input: map[string]any{"message": "test", "prefix": strings.Repeat("x", 21)},
190+
errorMsg: "fails constraint len() <= 20",
191+
},
192+
}
193+
194+
for _, tc := range testCases {
195+
t.Run(tc.name, func(t *testing.T) {
196+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: tc.input})
197+
198+
resp, err := http.DefaultClient.Do(req)
199+
require.NoError(t, err)
200+
defer resp.Body.Close()
201+
202+
body, err := io.ReadAll(resp.Body)
203+
require.NoError(t, err)
204+
205+
var errorResp server.PredictionResponse
206+
err = json.Unmarshal(body, &errorResp)
207+
require.NoError(t, err)
208+
209+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
210+
assert.Contains(t, errorResp.Error, tc.errorMsg)
211+
})
212+
}
213+
}
214+
215+
func TestInputFunctionMissingRequired(t *testing.T) {
216+
t.Parallel()
217+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
218+
procedureMode: false,
219+
explicitShutdown: false,
220+
uploadURL: "",
221+
module: "input_function_test",
222+
predictorClass: "Predictor",
223+
})
224+
225+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
226+
227+
input := map[string]any{"repeat_count": 2}
228+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
229+
230+
resp, err := http.DefaultClient.Do(req)
231+
require.NoError(t, err)
232+
defer resp.Body.Close()
233+
234+
body, err := io.ReadAll(resp.Body)
235+
require.NoError(t, err)
236+
237+
var errorResp server.PredictionResponse
238+
err = json.Unmarshal(body, &errorResp)
239+
require.NoError(t, err)
240+
241+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
242+
assert.Contains(t, errorResp.Error, "missing required input field: message")
243+
}
244+
245+
func TestInputFunctionSimple(t *testing.T) {
246+
t.Parallel()
247+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
248+
procedureMode: false,
249+
explicitShutdown: false,
250+
uploadURL: "",
251+
module: "input_simple_test",
252+
predictorClass: "Predictor",
253+
})
254+
255+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
256+
257+
input := map[string]any{"message": "hello", "count": 3}
258+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
259+
260+
resp, err := http.DefaultClient.Do(req)
261+
require.NoError(t, err)
262+
defer resp.Body.Close()
263+
assert.Equal(t, http.StatusOK, resp.StatusCode)
264+
265+
body, err := io.ReadAll(resp.Body)
266+
require.NoError(t, err)
267+
268+
var prediction server.PredictionResponse
269+
err = json.Unmarshal(body, &prediction)
270+
require.NoError(t, err)
271+
272+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
273+
assert.Equal(t, "hellohellohello", prediction.Output)
274+
}

0 commit comments

Comments
 (0)