Skip to content

Commit a8a2d45

Browse files
committed
[ENH]: Change attach_function API to be clearer
1 parent dae1739 commit a8a2d45

File tree

5 files changed

+44
-35
lines changed

5 files changed

+44
-35
lines changed

chromadb/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@
5353
# Reciprocal Rank Fusion for combining rankings
5454
Rrf,
5555
)
56+
57+
# Import attachable functions
58+
from chromadb.api.functions import (
59+
Function,
60+
StatisticsFunction,
61+
RecordCounterFunction,
62+
)
5663
from pathlib import Path
5764
import os
5865

@@ -97,6 +104,10 @@
97104
"IntInvertedIndexConfig",
98105
"FloatInvertedIndexConfig",
99106
"BoolInvertedIndexConfig",
107+
# Attachable Functions
108+
"Function",
109+
"StatisticsFunction",
110+
"RecordCounterFunction",
100111
]
101112

102113
from chromadb.types import CloudClientArg

chromadb/api/models/Collection.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Optional, Union, List, cast, Dict, Any
1+
from typing import TYPE_CHECKING, Optional, Union, List, cast
22

33
from chromadb.api.models.CollectionCommon import CollectionCommon
44
from chromadb.api.types import (
@@ -27,6 +27,7 @@
2727

2828
if TYPE_CHECKING:
2929
from chromadb.api.models.AttachedFunction import AttachedFunction
30+
from chromadb.api.functions import Function
3031

3132
logger = logging.getLogger(__name__)
3233

@@ -500,36 +501,34 @@ def delete(
500501

501502
def attach_function(
502503
self,
503-
function_id: str,
504+
function: "Function",
504505
name: str,
505506
output_collection: str,
506-
params: Optional[Dict[str, Any]] = None,
507507
) -> "AttachedFunction":
508508
"""Attach a function to this collection.
509509
510510
Args:
511-
function_id: Built-in function identifier (e.g., "record_counter")
511+
function: The function to attach (e.g. StatisticsFunction(), RecordCounterFunction())
512512
name: Unique name for this attached function
513513
output_collection: Name of the collection where function output will be stored
514-
params: Optional dictionary with function-specific parameters
515514
516515
Returns:
517516
AttachedFunction: Object representing the attached function
518517
519518
Example:
519+
>>> from chromadb.api.functions import StatisticsFunction
520520
>>> attached_fn = collection.attach_function(
521-
... function_id="record_counter",
521+
... function=StatisticsFunction(),
522522
... name="mycoll_stats_fn",
523-
... output_collection="mycoll_stats",
524-
... params={"threshold": 100}
523+
... output_collection="mycoll_stats"
525524
... )
526525
"""
527526
return self._client.attach_function(
528-
function_id=function_id,
527+
function_id=function.name,
529528
name=name,
530529
input_collection_id=self.id,
531530
output_collection=output_collection,
532-
params=params,
531+
params=function.params,
533532
tenant=self.tenant,
534533
database=self.database,
535534
)

chromadb/test/distributed/test_task_api.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import pytest
99
from chromadb.api.client import Client as ClientCreator
10+
from chromadb.api.functions import Function, RecordCounterFunction
1011
from chromadb.config import System
1112
from chromadb.errors import ChromaError, NotFoundError
1213
from chromadb.test.utils.wait_for_version_increase import (
@@ -29,10 +30,9 @@ def test_count_function_attach_and_detach(basic_http_client: System) -> None:
2930

3031
# Create a task that counts records in the collection
3132
attached_fn = collection.attach_function(
33+
function=RecordCounterFunction(),
3234
name="count_my_docs",
33-
function_id="record_counter", # Built-in operator that counts records
3435
output_collection="my_documents_counts",
35-
params=None,
3636
)
3737

3838
# Verify task creation succeeded
@@ -66,6 +66,14 @@ def test_count_function_attach_and_detach(basic_http_client: System) -> None:
6666
assert success is True
6767

6868

69+
class InvalidFunction(Function):
70+
"""A function with an invalid name for testing error handling."""
71+
72+
@property
73+
def name(self) -> str:
74+
return "nonexistent_function"
75+
76+
6977
def test_task_with_invalid_function(basic_http_client: System) -> None:
7078
"""Test that creating a task with an invalid function raises an error"""
7179
client = ClientCreator.from_system(basic_http_client)
@@ -77,10 +85,9 @@ def test_task_with_invalid_function(basic_http_client: System) -> None:
7785
# Attempt to create task with non-existent function should raise ChromaError
7886
with pytest.raises(ChromaError, match="function not found"):
7987
collection.attach_function(
88+
function=InvalidFunction(),
8089
name="invalid_task",
81-
function_id="nonexistent_function",
8290
output_collection="output_collection",
83-
params=None,
8491
)
8592

8693

@@ -94,10 +101,9 @@ def test_attach_function_returns_function_name(basic_http_client: System) -> Non
94101

95102
# Attach a function and verify function_name field in response
96103
attached_fn = collection.attach_function(
104+
function=RecordCounterFunction(),
97105
name="my_counter",
98-
function_id="record_counter",
99106
output_collection="output_collection",
100-
params=None,
101107
)
102108

103109
# Verify the attached function has function_name (not function_id UUID)
@@ -122,10 +128,9 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
122128
collection1.add(ids=["id1", "id2"], documents=["doc1", "doc2"])
123129

124130
attached_fn1 = collection1.attach_function(
131+
function=RecordCounterFunction(),
125132
name="task_1",
126-
function_id="record_counter",
127133
output_collection="output_1",
128-
params=None,
129134
)
130135

131136
assert attached_fn1 is not None
@@ -135,10 +140,9 @@ def test_function_multiple_collections(basic_http_client: System) -> None:
135140
collection2.add(ids=["id3", "id4"], documents=["doc3", "doc4"])
136141

137142
attached_fn2 = collection2.attach_function(
143+
function=RecordCounterFunction(),
138144
name="task_2",
139-
function_id="record_counter",
140145
output_collection="output_2",
141-
params=None,
142146
)
143147

144148
assert attached_fn2 is not None
@@ -170,10 +174,9 @@ def test_functions_one_attached_function_per_collection(
170174

171175
# Create first task on the collection
172176
attached_fn1 = collection.attach_function(
177+
function=RecordCounterFunction(),
173178
name="task_1",
174-
function_id="record_counter",
175179
output_collection="output_1",
176-
params=None,
177180
)
178181

179182
assert attached_fn1 is not None
@@ -182,19 +185,17 @@ def test_functions_one_attached_function_per_collection(
182185
# (only one attached function allowed per collection)
183186
with pytest.raises(ChromaError, match="already has an attached function"):
184187
collection.attach_function(
188+
function=RecordCounterFunction(),
185189
name="task_2",
186-
function_id="record_counter",
187190
output_collection="output_2",
188-
params=None,
189191
)
190192

191193
# Attempt to create a task with the same name but different params should also fail
192194
with pytest.raises(ChromaError, match="already exists"):
193195
collection.attach_function(
196+
function=RecordCounterFunction(),
194197
name="task_1",
195-
function_id="record_counter",
196198
output_collection="output_different", # Different output collection
197-
params=None,
198199
)
199200

200201
# Detach the first function
@@ -205,10 +206,9 @@ def test_functions_one_attached_function_per_collection(
205206

206207
# Now we should be able to attach a new function
207208
attached_fn2 = collection.attach_function(
209+
function=RecordCounterFunction(),
208210
name="task_2",
209-
function_id="record_counter",
210211
output_collection="output_2",
211-
params=None,
212212
)
213213

214214
assert attached_fn2 is not None
@@ -229,10 +229,9 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None:
229229
collection = client.create_collection(name="test_collection")
230230
collection.add(ids=["id1"], documents=["test"])
231231
attached_fn = collection.attach_function(
232+
function=RecordCounterFunction(),
232233
name="test_function",
233-
function_id="record_counter",
234234
output_collection="output_collection",
235-
params=None,
236235
)
237236

238237
collection.detach_function(attached_fn.name, delete_output_collection=True)

chromadb/utils/statistics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from collections import defaultdict
3131

3232
from chromadb.api.types import Where
33+
from chromadb.api.functions import StatisticsFunction
3334

3435
if TYPE_CHECKING:
3536
from chromadb.api.models.Collection import Collection
@@ -75,10 +76,9 @@ def attach_statistics_function(
7576
stats_collection_name = f"{collection.name}_statistics"
7677

7778
return collection.attach_function(
79+
function=StatisticsFunction(),
7880
name=get_statistics_fn_name(collection),
79-
function_id="statistics",
8081
output_collection=stats_collection_name,
81-
params=None,
8282
)
8383

8484

examples/task_api_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import chromadb
10+
from chromadb import RecordCounterFunction
1011
import time
1112

1213
# Connect to Chroma server
@@ -37,10 +38,9 @@
3738
# Attach a function that counts records in the collection
3839
# The 'record_counter' function processes each record and outputs {"count": N}
3940
attached_fn = collection.attach_function(
40-
function_id="record_counter", # Built-in function that counts records
41+
function=RecordCounterFunction(),
4142
name="count_my_docs",
42-
output_collection="my_documents_counts", # Auto-created
43-
params=None, # No additional parameters needed
43+
output_collection="my_documents_counts",
4444
)
4545

4646
print("✅ Function attached successfully!")

0 commit comments

Comments
 (0)