Skip to content

Commit a9572c6

Browse files
committed
Implement type checker compatible Input() function
Convert Input from dataclass to function with overloads following Pydantic's pattern. Maintains exact same syntax for developers while providing full type checker compatibility.
1 parent 5a780b2 commit a9572c6

File tree

9 files changed

+1639
-8
lines changed

9 files changed

+1639
-8
lines changed

internal/tests/input_function_test.go

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
package tests
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
"net/http"
7+
"strings"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
13+
"github.com/replicate/cog-runtime/internal/server"
14+
)
15+
16+
func TestInputFunctionSchemaGeneration(t *testing.T) {
17+
t.Parallel()
18+
if *legacyCog {
19+
t.Skip("Input generation tests coglet specific implementations.")
20+
}
21+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
22+
procedureMode: false,
23+
explicitShutdown: false,
24+
uploadURL: "",
25+
module: "input_function",
26+
predictorClass: "Predictor",
27+
})
28+
29+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
30+
31+
resp, err := http.Get(runtimeServer.URL + "/openapi.json")
32+
require.NoError(t, err)
33+
defer resp.Body.Close()
34+
35+
body, err := io.ReadAll(resp.Body)
36+
require.NoError(t, err)
37+
38+
var schema map[string]any
39+
err = json.Unmarshal(body, &schema)
40+
require.NoError(t, err)
41+
42+
assert.Contains(t, schema, "components")
43+
44+
components := schema["components"].(map[string]any)
45+
assert.Contains(t, components, "schemas")
46+
47+
schemas := components["schemas"].(map[string]any)
48+
assert.Contains(t, schemas, "Input")
49+
50+
inputSchema := schemas["Input"].(map[string]any)
51+
assert.Equal(t, "object", inputSchema["type"])
52+
assert.Contains(t, inputSchema, "properties")
53+
assert.Contains(t, inputSchema, "required")
54+
55+
properties := inputSchema["properties"].(map[string]any)
56+
required := inputSchema["required"].([]any)
57+
58+
assert.Contains(t, properties, "message")
59+
assert.Contains(t, required, "message")
60+
messageField := properties["message"].(map[string]any)
61+
assert.Equal(t, "string", messageField["type"])
62+
assert.Equal(t, "Message to process", messageField["description"])
63+
64+
assert.Contains(t, properties, "repeat_count")
65+
assert.NotContains(t, required, "repeat_count")
66+
repeatField := properties["repeat_count"].(map[string]any)
67+
assert.Equal(t, "integer", repeatField["type"])
68+
assert.Equal(t, float64(1), repeatField["default"]) //nolint:testifylint // Checking absolute value not delta
69+
assert.Equal(t, float64(1), repeatField["minimum"]) //nolint:testifylint // Checking absolute value not delta
70+
assert.Equal(t, float64(10), repeatField["maximum"]) //nolint:testifylint // Checking absolute value not delta
71+
72+
assert.Contains(t, properties, "prefix")
73+
prefixField := properties["prefix"].(map[string]any)
74+
assert.Equal(t, "string", prefixField["type"])
75+
assert.Equal(t, "Result: ", prefixField["default"])
76+
assert.Equal(t, float64(1), prefixField["minLength"]) //nolint:testifylint // Checking absolute value not delta
77+
assert.Equal(t, float64(20), prefixField["maxLength"]) //nolint:testifylint // Checking absolute value not delta
78+
79+
assert.Contains(t, properties, "deprecated_option")
80+
deprecatedField := properties["deprecated_option"].(map[string]any)
81+
assert.Equal(t, true, deprecatedField["deprecated"])
82+
}
83+
84+
func TestInputFunctionBasicPrediction(t *testing.T) {
85+
t.Parallel()
86+
if *legacyCog {
87+
t.Skip("Input generation tests coglet specific implementations.")
88+
}
89+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
90+
procedureMode: false,
91+
explicitShutdown: false,
92+
uploadURL: "",
93+
module: "input_function",
94+
predictorClass: "Predictor",
95+
})
96+
97+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
98+
99+
input := map[string]any{"message": "hello world"}
100+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
101+
102+
resp, err := http.DefaultClient.Do(req)
103+
require.NoError(t, err)
104+
defer resp.Body.Close()
105+
assert.Equal(t, http.StatusOK, resp.StatusCode)
106+
107+
body, err := io.ReadAll(resp.Body)
108+
require.NoError(t, err)
109+
110+
var prediction server.PredictionResponse
111+
err = json.Unmarshal(body, &prediction)
112+
require.NoError(t, err)
113+
114+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
115+
assert.Equal(t, "Result: hello world", prediction.Output)
116+
}
117+
118+
func TestInputFunctionComplexPrediction(t *testing.T) {
119+
t.Parallel()
120+
if *legacyCog {
121+
t.Skip("Input generation tests coglet specific implementations.")
122+
}
123+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
124+
procedureMode: false,
125+
explicitShutdown: false,
126+
uploadURL: "",
127+
module: "input_function",
128+
predictorClass: "Predictor",
129+
})
130+
131+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
132+
133+
input := map[string]any{
134+
"message": "test message",
135+
"repeat_count": 2,
136+
"format_type": "uppercase",
137+
"prefix": "Output: ",
138+
"suffix": " [END]",
139+
"deprecated_option": "custom",
140+
}
141+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
142+
143+
resp, err := http.DefaultClient.Do(req)
144+
require.NoError(t, err)
145+
defer resp.Body.Close()
146+
assert.Equal(t, http.StatusOK, resp.StatusCode)
147+
148+
body, err := io.ReadAll(resp.Body)
149+
require.NoError(t, err)
150+
151+
var prediction server.PredictionResponse
152+
err = json.Unmarshal(body, &prediction)
153+
require.NoError(t, err)
154+
155+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
156+
assert.Equal(t, "Output: TEST MESSAGE TEST MESSAGE [END]", prediction.Output)
157+
}
158+
159+
func TestInputFunctionConstraintViolations(t *testing.T) {
160+
t.Parallel()
161+
if *legacyCog {
162+
t.Skip("Input generation tests coglet specific implementations.")
163+
}
164+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
165+
procedureMode: false,
166+
explicitShutdown: false,
167+
uploadURL: "",
168+
module: "input_function",
169+
predictorClass: "Predictor",
170+
})
171+
172+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
173+
174+
testCases := []struct {
175+
name string
176+
input map[string]any
177+
errorMsg string
178+
}{
179+
{
180+
name: "repeat_count too low",
181+
input: map[string]any{"message": "test", "repeat_count": 0},
182+
errorMsg: "fails constraint >= 1",
183+
},
184+
{
185+
name: "repeat_count too high",
186+
input: map[string]any{"message": "test", "repeat_count": 11},
187+
errorMsg: "fails constraint <= 10",
188+
},
189+
{
190+
name: "invalid format_type choice",
191+
input: map[string]any{"message": "test", "format_type": "invalid"},
192+
errorMsg: "does not match choices",
193+
},
194+
{
195+
name: "prefix too short",
196+
input: map[string]any{"message": "test", "prefix": ""},
197+
errorMsg: "fails constraint len() >= 1",
198+
},
199+
{
200+
name: "prefix too long",
201+
input: map[string]any{"message": "test", "prefix": strings.Repeat("x", 21)},
202+
errorMsg: "fails constraint len() <= 20",
203+
},
204+
}
205+
206+
for _, tc := range testCases {
207+
t.Run(tc.name, func(t *testing.T) {
208+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: tc.input})
209+
210+
resp, err := http.DefaultClient.Do(req)
211+
require.NoError(t, err)
212+
defer resp.Body.Close()
213+
214+
body, err := io.ReadAll(resp.Body)
215+
require.NoError(t, err)
216+
217+
var errorResp server.PredictionResponse
218+
err = json.Unmarshal(body, &errorResp)
219+
require.NoError(t, err)
220+
221+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
222+
assert.Contains(t, errorResp.Error, tc.errorMsg)
223+
})
224+
}
225+
}
226+
227+
func TestInputFunctionMissingRequired(t *testing.T) {
228+
t.Parallel()
229+
if *legacyCog {
230+
t.Skip("Input generation tests coglet specific implementations.")
231+
}
232+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
233+
procedureMode: false,
234+
explicitShutdown: false,
235+
uploadURL: "",
236+
module: "input_function",
237+
predictorClass: "Predictor",
238+
})
239+
240+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
241+
242+
input := map[string]any{"repeat_count": 2}
243+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
244+
245+
resp, err := http.DefaultClient.Do(req)
246+
require.NoError(t, err)
247+
defer resp.Body.Close()
248+
249+
body, err := io.ReadAll(resp.Body)
250+
require.NoError(t, err)
251+
252+
var errorResp server.PredictionResponse
253+
err = json.Unmarshal(body, &errorResp)
254+
require.NoError(t, err)
255+
256+
assert.Equal(t, server.PredictionFailed, errorResp.Status)
257+
assert.Contains(t, errorResp.Error, "missing required input field: message")
258+
}
259+
260+
func TestInputFunctionSimple(t *testing.T) {
261+
t.Parallel()
262+
if *legacyCog {
263+
t.Skip("Input generation tests coglet specific implementations.")
264+
}
265+
runtimeServer := setupCogRuntime(t, cogRuntimeServerConfig{
266+
procedureMode: false,
267+
explicitShutdown: false,
268+
uploadURL: "",
269+
module: "input_simple",
270+
predictorClass: "Predictor",
271+
})
272+
273+
waitForSetupComplete(t, runtimeServer, server.StatusReady, server.SetupSucceeded)
274+
275+
input := map[string]any{"message": "hello", "count": 3}
276+
req := httpPredictionRequest(t, runtimeServer, server.PredictionRequest{Input: input})
277+
278+
resp, err := http.DefaultClient.Do(req)
279+
require.NoError(t, err)
280+
defer resp.Body.Close()
281+
assert.Equal(t, http.StatusOK, resp.StatusCode)
282+
283+
body, err := io.ReadAll(resp.Body)
284+
require.NoError(t, err)
285+
286+
var prediction server.PredictionResponse
287+
err = json.Unmarshal(body, &prediction)
288+
require.NoError(t, err)
289+
290+
assert.Equal(t, server.PredictionSucceeded, prediction.Status)
291+
assert.Equal(t, "hellohellohello", prediction.Output)
292+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ classifiers = [
1111
'Programming Language :: Python :: 3.12',
1212
'Programming Language :: Python :: 3.13',
1313
]
14-
dependencies = []
14+
dependencies = ["typing_extensions>=4.15"]
1515

1616
[project.optional-dependencies]
1717
dev = [

0 commit comments

Comments
 (0)