9
9
import json
10
10
import functools
11
11
import logging
12
+ import re
12
13
from functools import wraps
13
14
from abc import ABC
14
15
from typing import List , Dict , Any , Union , Type , Literal , Sequence , Tuple , get_args
19
20
from .models import SimpleInputModel , SimpleToolResponseModel
20
21
from .schema import NoTitleDescriptionJsonSchema
21
22
from .errors import SimpleToolError , ValidationError
23
+ import sys
24
+ from pathlib import Path
22
25
23
26
24
27
def get_valid_content_types () -> Tuple [Type , ...]:
@@ -59,8 +62,17 @@ async def wrapper(*args, **kwargs):
59
62
60
63
class SimpleTool (ABC ):
61
64
"""Base class for all simple tools. """
62
- name : str = Field (..., description = "Name of the tool" )
63
- description : str = Field ("This tool does not have a description" , description = "Description of the tool" )
65
+ name : str = Field (
66
+ ...,
67
+ description = "Name of the tool" ,
68
+ pattern = "^[a-zA-Z0-9_-]+$" ,
69
+ max_length = 64
70
+ )
71
+ description : str = Field (
72
+ ...,
73
+ description = "Description of the tool's functionality" ,
74
+ max_length = 1024
75
+ )
64
76
input_model : ClassVar [Type [SimpleInputModel ]] # Class-level input model
65
77
66
78
# Add default timeout configuration
@@ -70,6 +82,24 @@ def __init__(self):
70
82
"""
71
83
Initialize SimpleTool.
72
84
"""
85
+ # Validate name and description
86
+ if not hasattr (self , 'name' ) or not isinstance (self .name , str ):
87
+ raise ValidationError ("name" , "Tool must have a name attribute of type str" )
88
+
89
+ if not hasattr (self , 'description' ) or not isinstance (self .description , str ):
90
+ raise ValidationError ("description" , "Tool must have a description attribute of type str" )
91
+
92
+ # Validate name pattern and length
93
+
94
+ if not re .match ("^[a-zA-Z0-9_-]+$" , self .name ):
95
+ raise ValidationError ("name" , "Tool name must contain only alphanumeric characters, underscores, and hyphens" )
96
+ if len (self .name ) > 64 :
97
+ raise ValidationError ("name" , "Tool name cannot exceed 64 characters" )
98
+
99
+ # Validate description length
100
+ if len (self .description ) > 1024 :
101
+ raise ValidationError ("description" , "Tool description cannot exceed 1024 characters" )
102
+
73
103
# Validate input_model is defined at the class level
74
104
if not hasattr (self .__class__ , 'input_model' ) or not issubclass (self .__class__ .input_model , SimpleInputModel ):
75
105
raise ValidationError ("input_model" , f"Subclass { self .__class__ .__name__ } must define a class-level 'input_model' as a subclass of SimpleInputModel" )
@@ -196,8 +226,8 @@ def __str__(self) -> str:
196
226
"""Return a one-line JSON string representation of the tool."""
197
227
sorted_input_schema = self ._sort_input_schema (self .input_schema )
198
228
return json .dumps ({
199
- "name" : self .name ,
200
- "description" : self .description ,
229
+ "name" : str ( self .name ) ,
230
+ "description" : str ( self .description ) ,
201
231
"input_schema" : sorted_input_schema
202
232
}).encode ("utf-8" ).decode ("unicode_escape" )
203
233
@@ -434,3 +464,29 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
434
464
435
465
# Propagate any original exceptions
436
466
return False
467
+
468
+ def __reduce__ (self ):
469
+ """Make SimpleTool picklable by only serializing essential attributes."""
470
+ # Get module name from the module path
471
+ module = sys .modules .get (self .__class__ .__module__ )
472
+ module_file = getattr (module , '__file__' , None ) if module else None
473
+ if module_file and isinstance (module_file , str ):
474
+ # Use the actual module file name without .py extension
475
+ module_name = Path (module_file ).stem
476
+ self .__class__ .__module__ = module_name
477
+ else :
478
+ # Fallback to tool name only if we really have to
479
+ self .__class__ .__module__ = self .name
480
+
481
+ return (self .__class__ , (), {
482
+ 'name' : self .name ,
483
+ 'description' : self .description ,
484
+ 'input_schema' : getattr (self , 'input_schema' , None ),
485
+ 'output_schema' : getattr (self , 'output_schema' , None ),
486
+ '_timeout' : getattr (self , '_timeout' , self .DEFAULT_TIMEOUT )
487
+ })
488
+
489
+ def __setstate__ (self , state ):
490
+ """Restore state after unpickling."""
491
+ for key , value in state .items ():
492
+ setattr (self , key , value )
0 commit comments