diff --git a/chromadb/test/distributed/test_task_api.py b/chromadb/test/distributed/test_task_api.py index 663e21753f5..7ad86e3766a 100644 --- a/chromadb/test/distributed/test_task_api.py +++ b/chromadb/test/distributed/test_task_api.py @@ -227,3 +227,71 @@ def test_function_remove_nonexistent(basic_http_client: System) -> None: # Trying to detach this function again should raise NotFoundError with pytest.raises(NotFoundError, match="does not exist"): attached_fn.detach(delete_output_collection=True) + + +def test_count_function_attach_and_detach_attach_attach(basic_http_client: System) -> None: + """Test creating and removing a function with the record_counter operator""" + client = ClientCreator.from_system(basic_http_client) + client.reset() + + # Create a collection + collection = client.get_or_create_collection( + name="my_document", + metadata={"description": "Sample documents for task processing"}, + ) + + # Create a task that counts records in the collection + attached_fn = collection.attach_function( + name="count_my_docs", + function_id="record_counter", # Built-in operator that counts records + output_collection="my_documents_counts", + params=None, + ) + + # Verify task creation succeeded + assert attached_fn is not None + initial_version = get_collection_version(client, collection.name) + + # Add documents + collection.add( + ids=["doc_{}".format(i) for i in range(0, 300)], + documents=["test document"] * 300, + ) + + # Verify documents were added + assert collection.count() == 300 + + wait_for_version_increase(client, collection.name, initial_version) + # Give some time to invalidate the frontend query cache + sleep(60) + + result = client.get_collection("my_documents_counts").get("function_output") + assert result["metadatas"] is not None + assert result["metadatas"][0]["total_count"] == 300 + + # Remove the task + success = attached_fn.detach( + delete_output_collection=True, + ) + + # Verify task removal succeeded + assert success is True + + # Create a task that counts records in the collection + attached_fn = collection.attach_function( + name="count_my_docs", + function_id="record_counter", # Built-in operator that counts records + output_collection="my_documents_counts", + params=None, + ) + assert attached_fn is not None + + # Create a task that counts records in the collection + attached_fn = collection.attach_function( + name="count_my_docs", + function_id="record_counter", # Built-in operator that counts records + output_collection="my_documents_counts", + params=None, + ) + assert attached_fn is not None + diff --git a/go/pkg/sysdb/coordinator/create_task_test.go b/go/pkg/sysdb/coordinator/create_task_test.go index d3698576c5d..380b857862d 100644 --- a/go/pkg/sysdb/coordinator/create_task_test.go +++ b/go/pkg/sysdb/coordinator/create_task_test.go @@ -69,7 +69,7 @@ func (suite *AttachFunctionTestSuite) setupAttachFunctionMocks(ctx context.Conte // Phase 1: Create attached function in transaction // Check if any attached function exists for this collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{}, nil).Once() suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() @@ -167,7 +167,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_SuccessfulCreation() { // Setup mocks that will be called within the transaction (using mock.Anything for context) // Check if any attached function exists for this collection suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{}, nil).Once() // Look up database @@ -285,7 +285,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Alrea // Inside transaction: check for existing attached functions suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{existingAttachedFunction}, nil).Once() // Validate function by ID @@ -364,7 +364,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { // Phase 1: Create attached function in transaction suite.mockMetaDomain.On("AttachedFunctionDb", mock.Anything).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{}, nil).Once() suite.mockMetaDomain.On("DatabaseDb", mock.Anything).Return(suite.mockDatabaseDb).Once() @@ -420,7 +420,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_RecoveryFlow() { // Inside transaction: check for existing attached functions suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{incompleteAttachedFunction}, nil).Once() // Validate function matches @@ -511,7 +511,7 @@ func (suite *AttachFunctionTestSuite) TestAttachFunction_IdempotentRequest_Param // Inside transaction: check for existing attached functions suite.mockMetaDomain.On("AttachedFunctionDb", txCtx).Return(suite.mockAttachedFunctionDb).Once() - suite.mockAttachedFunctionDb.On("GetByCollectionID", inputCollectionID). + suite.mockAttachedFunctionDb.On("GetReadyOrNotReadyByCollectionID", inputCollectionID). Return([]*dbmodel.AttachedFunction{existingAttachedFunction}, nil).Once() // Validate function - returns DIFFERENT function name diff --git a/go/pkg/sysdb/coordinator/task.go b/go/pkg/sysdb/coordinator/task.go index d6b387b52d7..737e0c594b1 100644 --- a/go/pkg/sysdb/coordinator/task.go +++ b/go/pkg/sysdb/coordinator/task.go @@ -92,7 +92,7 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att err := s.catalog.txImpl.Transaction(ctx, func(txCtx context.Context) error { // Check if there's any active (ready, non-deleted) attached function for this collection // We only allow one active attached function per collection - existingAttachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetByCollectionID(req.InputCollectionId) + existingAttachedFunctions, err := s.catalog.metaDomain.AttachedFunctionDb(txCtx).GetReadyOrNotReadyByCollectionID(req.InputCollectionId) if err != nil { log.Error("AttachFunction: failed to check for existing attached function", zap.Error(err)) return err @@ -194,7 +194,9 @@ func (s *Coordinator) AttachFunction(ctx context.Context, req *coordinatorpb.Att } err = s.catalog.metaDomain.AttachedFunctionDb(txCtx).Insert(attachedFunction) - if err != nil { + if err == common.ErrAttachedFunctionAlreadyExists { + // idempotent fall through + } else if err != nil { log.Error("AttachFunction: failed to insert attached function", zap.Error(err)) return err } diff --git a/go/pkg/sysdb/metastore/db/dao/task.go b/go/pkg/sysdb/metastore/db/dao/task.go index bb80ab7b947..d4d18b58942 100644 --- a/go/pkg/sysdb/metastore/db/dao/task.go +++ b/go/pkg/sysdb/metastore/db/dao/task.go @@ -153,6 +153,23 @@ func (s *attachedFunctionDb) GetByCollectionID(inputCollectionID string) ([]*dbm return attachedFunctions, nil } +// Returns the non-deleted functions, without regard for `is_ready`. Deleted functions will still +// be excluded. +func (s *attachedFunctionDb) GetReadyOrNotReadyByCollectionID(inputCollectionID string) ([]*dbmodel.AttachedFunction, error) { + var attachedFunctions []*dbmodel.AttachedFunction + err := s.db. + Where("input_collection_id = ?", inputCollectionID). + Where("is_deleted = ?", false). + Find(&attachedFunctions).Error + + if err != nil { + log.Error("GetReadyOrNotReadyByCollectionID failed", zap.Error(err), zap.String("input_collection_id", inputCollectionID)) + return nil, err + } + + return attachedFunctions, nil +} + func (s *attachedFunctionDb) SoftDelete(inputCollectionID string, name string) error { // Update name and is_deleted in a single query // Format: _deleted__ diff --git a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IAttachedFunctionDb.go b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IAttachedFunctionDb.go index a43da87d507..84ccd0be373 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/mocks/IAttachedFunctionDb.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/mocks/IAttachedFunctionDb.go @@ -142,6 +142,36 @@ func (_m *IAttachedFunctionDb) GetByCollectionID(inputCollectionID string) ([]*d return r0, r1 } +// GetReadyOrNotReadyByCollectionID provides a mock function with given fields: inputCollectionID +func (_m *IAttachedFunctionDb) GetReadyOrNotReadyByCollectionID(inputCollectionID string) ([]*dbmodel.AttachedFunction, error) { + ret := _m.Called(inputCollectionID) + + if len(ret) == 0 { + panic("no return value specified for GetReadyOrNotReadyByCollectionID") + } + + var r0 []*dbmodel.AttachedFunction + var r1 error + if rf, ok := ret.Get(0).(func(string) ([]*dbmodel.AttachedFunction, error)); ok { + return rf(inputCollectionID) + } + if rf, ok := ret.Get(0).(func(string) []*dbmodel.AttachedFunction); ok { + r0 = rf(inputCollectionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*dbmodel.AttachedFunction) + } + } + + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(inputCollectionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // GetByName provides a mock function with given fields: inputCollectionID, attachedFunctionName func (_m *IAttachedFunctionDb) GetByName(inputCollectionID string, attachedFunctionName string) (*dbmodel.AttachedFunction, error) { ret := _m.Called(inputCollectionID, attachedFunctionName) diff --git a/go/pkg/sysdb/metastore/db/dbmodel/task.go b/go/pkg/sysdb/metastore/db/dbmodel/task.go index 1a10d2671aa..f95e0f08970 100644 --- a/go/pkg/sysdb/metastore/db/dbmodel/task.go +++ b/go/pkg/sysdb/metastore/db/dbmodel/task.go @@ -42,6 +42,7 @@ type IAttachedFunctionDb interface { GetByID(id uuid.UUID) (*AttachedFunction, error) GetAnyByID(id uuid.UUID) (*AttachedFunction, error) // TODO(tanujnay112): Consolidate all the getters. GetByCollectionID(inputCollectionID string) ([]*AttachedFunction, error) + GetReadyOrNotReadyByCollectionID(inputCollectionID string) ([]*AttachedFunction, error) Update(attachedFunction *AttachedFunction) error Finish(id uuid.UUID) error SoftDelete(inputCollectionID string, name string) error