Skip to content

Commit 12ff796

Browse files
authored
feat: add a signature field to closures containing the function signature (#948)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 497aa1b commit 12ff796

9 files changed

+481
-245
lines changed

pdl-live-react/src/pdl_ast.d.ts

Lines changed: 378 additions & 49 deletions
Large diffs are not rendered by default.

src/pdl/pdl-schema.json

Lines changed: 43 additions & 186 deletions
Large diffs are not rendered by default.

src/pdl/pdl_ast.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BeforeValidator,
2020
ConfigDict,
2121
Field,
22+
Json,
2223
RootModel,
2324
TypeAdapter,
2425
)
@@ -236,15 +237,16 @@ class JsonSchemaTypePdlType(PdlType):
236237

237238

238239
class ObjPdlType(PdlType):
239-
"""Optional type."""
240+
"""Object type."""
240241

241242
obj: Optional[dict[str, "PdlTypeType"]]
242243

243244

244245
PdlTypeType = TypeAliasType(
245246
"PdlTypeType",
246247
Annotated[
247-
"Union[BasePdlType," # pyright: ignore
248+
"Union[None," # pyright: ignore
249+
" BasePdlType,"
248250
" EnumPdlType,"
249251
" StrPdlType,"
250252
" FloatPdlType,"
@@ -269,7 +271,7 @@ class Parser(BaseModel):
269271
description: Optional[str] = None
270272
"""Documentation associated to the parser.
271273
"""
272-
spec: Optional[PdlTypeType] = None
274+
spec: PdlTypeType = None
273275
"""Expected type of the parsed value.
274276
"""
275277

@@ -348,7 +350,7 @@ class Block(BaseModel):
348350
description: Optional[str] = None
349351
"""Documentation associated to the block.
350352
"""
351-
spec: Optional[PdlTypeType] = None
353+
spec: PdlTypeType = None
352354
"""Type specification of the result of the block.
353355
"""
354356
defs: dict[str, "BlockType"] = {}
@@ -416,8 +418,12 @@ class FunctionBlock(LeafBlock):
416418
"""Functions parameters with their types.
417419
"""
418420
returns: "BlockType" = Field(..., alias="return")
419-
"""Body of the function
421+
"""Body of the function.
422+
"""
423+
signature: Optional[Json] = None
424+
"""Function signature computed from the function definition.
420425
"""
426+
421427
# Field for internal use
422428
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(default=None, repr=False)
423429

src/pdl/pdl_dumper.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,10 @@ def expr_to_dict(expr: ExpressionType, json_compatible: bool):
317317

318318

319319
def type_to_dict(t: PdlTypeType):
320-
d: str | list | dict
320+
d: None | str | list | dict
321321
match t:
322+
case None:
323+
d = None
322324
case "null" | "bool" | "str" | "float" | "int" | "list" | "obj":
323325
d = t
324326
case EnumPdlType():

src/pdl/pdl_interpreter.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
101101
from .pdl_python_repl import PythonREPL # noqa: E402
102102
from .pdl_scheduler import yield_background, yield_result # noqa: E402
103+
from .pdl_schema_utils import get_json_schema # noqa: E402
103104
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
104105
from .pdl_utils import ( # noqa: E402
105106
GeneratorWrapper,
@@ -894,6 +895,16 @@ def process_block_body(
894895
if block.def_ is not None:
895896
scope = scope | {block.def_: closure}
896897
closure.pdl__scope = scope
898+
signature: dict[str, Any] = {"type": "function"}
899+
if block.def_ is not None:
900+
signature["name"] = block.def_
901+
if block.description is not None:
902+
signature["description"] = block.description
903+
if block.function is not None:
904+
signature["parameters"] = get_json_schema(block.function, False) or {}
905+
else:
906+
signature["parameters"] = {}
907+
closure.signature = signature
897908
result = PdlConst(closure)
898909
background = PdlList([])
899910
trace = closure.model_copy(update={})
@@ -976,6 +987,8 @@ def process_defs(
976987
state = state.with_iter(idx)
977988
state = state.with_yield_result(False)
978989
state = state.with_yield_background(False)
990+
if isinstance(block, FunctionBlock) and block.def_ is None:
991+
block = block.model_copy(update={"def_": x})
979992
result, _, _, block_trace = process_block(state, scope, block, newloc)
980993
scope = scope | PdlDict({x: result})
981994
defs_trace[x] = block_trace

src/pdl/pdl_llms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def generate_text_stream(
183183

184184

185185
def set_structured_decoding_parameters(
186-
spec: Optional[PdlTypeType],
186+
spec: PdlTypeType,
187187
parameters: Optional[dict[str, Any]],
188188
) -> dict[str, Any]:
189189
if parameters is None:

src/pdl/pdl_schema_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def convert_to_json_type(a_type):
4444

4545

4646
def pdltype_to_jsonschema(
47-
pdl_type: Optional[PdlTypeType], additional_properties: bool
47+
pdl_type: PdlTypeType, additional_properties: bool
4848
) -> dict[str, Any]:
4949
schema: dict[str, Any]
5050
match pdl_type:

src/pdl/pdl_schema_validator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
def type_check_args(
11-
args: Optional[dict[str, Any]], params: Optional[dict[str, Any]], loc
11+
args: Optional[dict[str, Any]],
12+
params: Optional[dict[str, PdlTypeType]],
13+
loc,
1214
) -> list[str]:
1315
if (args == {} or args is None) and (params is None or params == {}):
1416
return []
@@ -35,7 +37,7 @@ def type_check_args(
3537
return type_check(args_copy, schema, loc)
3638

3739

38-
def type_check_spec(result: Any, spec: Optional[PdlTypeType], loc) -> list[str]:
40+
def type_check_spec(result: Any, spec: PdlTypeType, loc) -> list[str]:
3941
schema = pdltype_to_jsonschema(spec, False)
4042
if schema is None:
4143
return ["Error obtaining a valid schema from spec"]

tests/test_function.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,17 @@ def test_function_call():
2020
assert text == "Hello world!"
2121

2222

23+
def test_hello_signature():
24+
result = exec_dict(hello_def, output="all")
25+
closure = result["scope"]["hello"]
26+
assert closure.signature == {
27+
"name": hello_def["def"],
28+
"description": hello_def["description"],
29+
"type": "function",
30+
"parameters": {},
31+
}
32+
33+
2334
hello_params = {
2435
"description": "Call hello",
2536
"text": [
@@ -39,6 +50,22 @@ def test_function_params():
3950
assert text == "Hello World!"
4051

4152

53+
def test_hello_params_signature():
54+
result = exec_dict(hello_params, output="all")
55+
closure = result["scope"]["hello"]
56+
assert closure.signature == {
57+
"name": hello_params["text"][0]["def"],
58+
"description": hello_params["text"][0]["description"],
59+
"type": "function",
60+
"parameters": {
61+
"type": "object",
62+
"properties": {"name": {"type": "string"}},
63+
"required": ["name"],
64+
"additionalProperties": False,
65+
},
66+
}
67+
68+
4269
hello_stutter = {
4370
"description": "Repeat the context",
4471
"text": [

0 commit comments

Comments
 (0)