diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 96d726f9..2e4a8b8c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,6 +12,7 @@ name: "CI" env: UV_FROZEN: "1" + TEST__LOCAL_DB: "1" # signals test fixtures to not use testcontainers jobs: rebase-checker: diff --git a/pyproject.toml b/pyproject.toml index 53d38a6d..dd8d7ac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -169,7 +169,7 @@ max-doc-length = 79 convention = "numpy" [tool.pytest.ini_options] -asyncio_mode = "strict" +asyncio_mode = "auto" asyncio_default_fixture_loop_scope="function" # The python_files setting is not for test detection (pytest will pick up any # test files named *_test.py without this setting) but to enable special diff --git a/src/lsst/cmservice/common/jsonpatch.py b/src/lsst/cmservice/common/jsonpatch.py new file mode 100644 index 00000000..b57ce69a --- /dev/null +++ b/src/lsst/cmservice/common/jsonpatch.py @@ -0,0 +1,224 @@ +"""Module implementing functions to support json-patch operations on Python +objects based on RFC6902. +""" + +import operator +from collections.abc import MutableMapping, MutableSequence +from functools import reduce +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import AliasChoices, BaseModel, Field + +type AnyMutable = MutableMapping | MutableSequence + + +class JSONPatchError(Exception): + """Exception raised when a JSON patch operation cannot be completed.""" + + pass + + +class JSONPatch(BaseModel): + """Model representing a PATCH operation using RFC6902. + + This model will generally be accepted as a ``Sequence[JSONPatch]``. + """ + + op: Literal["add", "remove", "replace", "move", "copy", "test"] + path: str = Field( + description="An RFC6901 JSON Pointer", pattern=r"^\/(metadata|spec|configuration|metadata_|data)\/.*$" + ) + value: Any | None = None + from_: str | None = Field( + default=None, + pattern=r"^\/(metadata|spec|configuration|metadata_|data)\/.*$", + validation_alias=AliasChoices("from", "from_"), + ) + + +def apply_json_patch[T: MutableMapping](op: JSONPatch, o: T) -> T: + """Applies a jsonpatch to an object, returning the modified object. + + Modifications are made in-place (i.e., the input object is not copied). + + Notes + ----- + While this JSON Patch operation nominally implements RFC6902, there are + some edge cases inappropriate to the application that are supported by the + RFC but disallowed through lack of support: + + - Unsupported: JSON pointer values that refer to object/dict keys that are + numeric, e.g., {"1": "first", "2": "second"} + - Unsupported: JSON pointer values that refer to an entire object, e.g., + "" -- the JSON Patch must have a root element ("/") per the model. + - Unsupported: JSON pointer values taht refer to a nameless object, e.g., + "/" -- JSON allows object keys to be the empty string ("") but this is + disallowed by the application. + """ + # The JSON Pointer root value is discarded as the rest of the pointer is + # split into parts + op_path = op.path.split("/")[1:] + + # The terminal path part is either the name of a key or an index in a list + # FIXME this assumes that an "integer-string" in the path is always refers + # to a list index, although it could just as well be a key in a dict + # like ``{"1": "first, "2": "second"}`` which is complicated by the + # fact that Python dict keys can be either ints or strs but this is + # not allowed in JSON (i.e., object keys MUST be strings) + # FIXME this doesn't support, e.g., nested lists with multiple index values + # in the path, e.g., ``[["a", "A"], ["b", "B"]]`` + target_key_or_index: str | None = op_path.pop() + if target_key_or_index is None: + raise JSONPatchError("JSON Patch operations on empty keys not allowed.") + + reference_token: int | str + # the reference token is referring to a an array index if the token is + # numeric or is the single character "-" + if target_key_or_index == "-": + reference_token = target_key_or_index + elif target_key_or_index.isnumeric(): + reference_token = int(target_key_or_index) + else: + reference_token = str(target_key_or_index) + + # The remaining parts of the path are a pointer to the object needing + # modification, which should reduce to either a dict or a list + try: + op_target: AnyMutable = reduce(operator.getitem, op_path, o) + except KeyError: + raise JSONPatchError(f"Path {op.path} not found in object") + + match op: + case JSONPatch(op="add", value=new_value): + if reference_token == "-" and isinstance(op_target, MutableSequence): + # The "-" reference token is unique to the add operation and + # means the next element beyond the end of the current list + op_target.append(new_value) + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + op_target.insert(reference_token, new_value) + elif isinstance(reference_token, str) and isinstance(op_target, MutableMapping): + op_target[reference_token] = new_value + + case JSONPatch(op="replace", value=new_value): + # The main difference between replace and add is that replace will + # not create new properties or elements in the target + if reference_token == "-": + raise JSONPatchError("Cannot use reference token `-` with replace operation.") + elif isinstance(op_target, MutableMapping): + try: + assert reference_token in op_target.keys() + except AssertionError: + raise JSONPatchError(f"Cannot replace missing key {reference_token} in object") + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + try: + assert reference_token < len(op_target) + except AssertionError: + raise JSONPatchError(f"Cannot replace missing index {reference_token} in object") + + if TYPE_CHECKING: + assert isinstance(op_target, MutableMapping) + op_target[reference_token] = new_value + + case JSONPatch(op="remove"): + if isinstance(reference_token, str) and isinstance(op_target, MutableMapping): + if reference_token == "-": + raise JSONPatchError("Removal operations not allowed on `-` reference token") + _ = op_target.pop(reference_token, None) + elif isinstance(reference_token, int): + try: + _ = op_target.pop(reference_token) + except IndexError: + # The index we are meant to remove does not exist, but that + # is not an error (idempotence) + pass + else: + # This should be unreachable + raise ValueError("Reference token in JSON Patch must be int | str") + + case JSONPatch(op="move", from_=from_location): + # the move operation is equivalent to a remove(from) + add(target) + if TYPE_CHECKING: + assert from_location is not None + + # Handle the from_location with the same logic as the op.path + from_path = from_location.split("/")[1:] + + # Is the last element of the from_path an index or a key? + from_target: str | int = from_path.pop() + try: + from_target = int(from_target) + except ValueError: + pass + + try: + from_object = reduce(operator.getitem, from_path, o) + value = from_object[from_target] + except (KeyError, IndexError): + raise JSONPatchError(f"Path {from_location} not found in object") + + # add the value to the new location + op_target[reference_token] = value # type: ignore[index] + # and remove it from the old + _ = from_object.pop(from_target) + + case JSONPatch(op="copy", from_=from_location): + # The copy op is the same as the move op except the original is not + # removed + if TYPE_CHECKING: + assert from_location is not None + + # Handle the from_location with the same logic as the op.path + from_path = from_location.split("/")[1:] + + # Is the last element of the from_path an index or a key? + from_target = from_path.pop() + try: + from_target = int(from_target) + except ValueError: + pass + + try: + from_object = reduce(operator.getitem, from_path, o) + value = from_object[from_target] + except (KeyError, IndexError): + raise JSONPatchError(f"Path {from_location} not found in object") + + # add the value to the new location + op_target[reference_token] = value # type: ignore[index] + + case JSONPatch(op="test", value=assert_value): + # assert that the patch value is present at the patch path + # The main difference between test and replace is that test does + # not make any modifications after its assertions + if reference_token == "-": + raise JSONPatchError("Cannot use reference token `-` with test operation.") + elif isinstance(op_target, MutableMapping): + try: + assert reference_token in op_target.keys() + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: Key {reference_token} does not exist at {op.path}" + ) + elif isinstance(reference_token, int) and isinstance(op_target, MutableSequence): + try: + assert reference_token < len(op_target) + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: " + f"Index {reference_token} does not exist at {op.path}" + ) + + if TYPE_CHECKING: + assert isinstance(op_target, MutableMapping) + try: + assert op_target[reference_token] == assert_value + except AssertionError: + raise JSONPatchError( + f"Test operation assertion failed: {op.path} does not match value {assert_value}" + ) + + case _: + # Model validation should prevent this from ever happening + raise JSONPatchError(f"Unknown JSON Patch operation: {op.op}") + + return o diff --git a/src/lsst/cmservice/common/types.py b/src/lsst/cmservice/common/types.py index 169a52d4..840d3bfa 100644 --- a/src/lsst/cmservice/common/types.py +++ b/src/lsst/cmservice/common/types.py @@ -1,8 +1,12 @@ +from typing import Annotated + from sqlalchemy.ext.asyncio import AsyncSession as AsyncSessionSA from sqlalchemy.ext.asyncio import async_scoped_session from sqlmodel.ext.asyncio.session import AsyncSession from .. import models +from ..models.serde import EnumSerializer, ManifestKindEnumValidator, StatusEnumValidator +from .enums import ManifestKind, StatusEnum type AnyAsyncSession = AsyncSession | AsyncSessionSA | async_scoped_session """A type union of async database sessions the application may use""" @@ -10,3 +14,15 @@ type AnyCampaignElement = models.Group | models.Campaign | models.Step | models.Job """A type union of Campaign elements""" + + +type StatusField = Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] +"""A type for fields representing a Status with a custom validator tuned for +enums operations. +""" + + +type KindField = Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] +"""A type for fields representing a Kind with a custom validator tuned for +enums operations. +""" diff --git a/src/lsst/cmservice/config.py b/src/lsst/cmservice/config.py index bfb045c1..16487d65 100644 --- a/src/lsst/cmservice/config.py +++ b/src/lsst/cmservice/config.py @@ -8,6 +8,7 @@ AliasChoices, BaseModel, Field, + SecretStr, computed_field, field_serializer, field_validator, @@ -456,6 +457,11 @@ class AsgiConfiguration(BaseModel): default="/cm-service", ) + enable_frontend: bool = Field( + description="Whether to run the frontend web app", + default=True, + ) + frontend_prefix: str = Field( description="The URL prefix for the frontend web app", default="/web_app", @@ -535,13 +541,13 @@ class DatabaseConfiguration(BaseModel): description="The URL for the cm-service database", ) - password: str | None = Field( + password: SecretStr | None = Field( default=None, description="The password for the cm-service database", ) - table_schema: str | None = Field( - default=None, + table_schema: str = Field( + default="public", description="Schema to use for cm-service database", ) @@ -550,6 +556,31 @@ class DatabaseConfiguration(BaseModel): description="SQLAlchemy engine echo setting for the cm-service database", ) + max_overflow: int = Field( + default=10, + description="Maximum connection overflow allowed for QueuePool.", + ) + + pool_size: int = Field( + default=5, + description="Number of open connections kept in the QueuePool", + ) + + pool_recycle: int = Field( + default=-1, + description="Timeout in seconds before connections are recycled", + ) + + pool_timeout: int = Field( + default=30, + description="Wait timeout for acquiring a connection from the pool", + ) + + pool_fields: set[str] = Field( + default={"max_overflow", "pool_size", "pool_recycle", "pool_timeout"}, + description="Set of fields used for connection pool configuration", + ) + class Configuration(BaseSettings): """Configuration for cm-service. diff --git a/src/lsst/cmservice/db/__init__.py b/src/lsst/cmservice/db/__init__.py index 11f61d4f..eeec4db4 100644 --- a/src/lsst/cmservice/db/__init__.py +++ b/src/lsst/cmservice/db/__init__.py @@ -1,5 +1,6 @@ """Database table definitions and utility functions""" +from . import campaigns_v2 from .base import Base from .campaign import Campaign from .element import ElementMixin diff --git a/src/lsst/cmservice/db/campaigns_v2.py b/src/lsst/cmservice/db/campaigns_v2.py index aa7802fc..655fa911 100644 --- a/src/lsst/cmservice/db/campaigns_v2.py +++ b/src/lsst/cmservice/db/campaigns_v2.py @@ -1,19 +1,23 @@ from datetime import datetime -from typing import Annotated, Any +from typing import Any from uuid import NAMESPACE_DNS, UUID, uuid5 -from pydantic import AliasChoices, PlainSerializer, PlainValidator, ValidationInfo, model_validator +from pydantic import AliasChoices, ValidationInfo, model_validator from sqlalchemy.dialects import postgresql from sqlalchemy.ext.mutable import MutableDict, MutableList from sqlalchemy.types import PickleType -from sqlmodel import Column, Enum, Field, SQLModel, String +from sqlmodel import Column, Enum, Field, MetaData, SQLModel, String from ..common.enums import ManifestKind, StatusEnum +from ..common.types import KindField, StatusField from ..config import config _default_campaign_namespace = uuid5(namespace=NAMESPACE_DNS, name="io.lsst.cmservice") """Default UUID5 namespace for campaigns""" +metadata: MetaData = MetaData(schema=config.db.table_schema) +"""SQLModel metadata for table models""" + def jsonb_column(name: str, aliases: list[str] | None = None) -> Any: """Constructor for a Field based on a JSONB database column. @@ -47,30 +51,10 @@ def jsonb_column(name: str, aliases: list[str] | None = None) -> Any: # 3. the model of the manifest when updating an object # 4. a response model for APIs related to the object -EnumSerializer = PlainSerializer( - lambda x: x.name, - return_type="str", - when_used="always", -) -"""A serializer for enums that produces its name, not the value.""" - - -StatusEnumValidator = PlainValidator(lambda x: StatusEnum[x] if isinstance(x, str) else StatusEnum(x)) -"""A validator for the StatusEnum that can parse the enum from either a name -or a value. -""" - - -ManifestKindEnumValidator = PlainValidator( - lambda x: ManifestKind[x] if isinstance(x, str) else ManifestKind(x) -) -"""A validator for the ManifestKindEnum that can parse the enum from a name -or a value. -""" - class BaseSQLModel(SQLModel): __table_args__ = {"schema": config.db.table_schema} + metadata = metadata class CampaignBase(BaseSQLModel): @@ -80,7 +64,7 @@ class CampaignBase(BaseSQLModel): name: str namespace: UUID owner: str | None = Field(default=None) - status: Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] | None = Field( + status: StatusField | None = Field( default=StatusEnum.waiting, sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), ) @@ -112,7 +96,16 @@ class Campaign(CampaignModel, table=True): __tablename__: str = "campaigns_v2" # type: ignore[misc] - machine: UUID | None + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") + + +class CampaignUpdate(BaseSQLModel): + """Model representing updatable fields for a PATCH operation on a Campaign + using RFC7396. + """ + + owner: str | None = None + status: StatusField | None = None class NodeBase(BaseSQLModel): @@ -122,11 +115,11 @@ class NodeBase(BaseSQLModel): name: str namespace: UUID version: int - kind: Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] = Field( + kind: KindField = Field( default=ManifestKind.other, sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)), ) - status: Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] | None = Field( + status: StatusField | None = Field( default=StatusEnum.waiting, sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), ) @@ -155,7 +148,7 @@ def custom_model_validator(cls, data: Any, info: ValidationInfo) -> Any: class Node(NodeModel, table=True): __tablename__: str = "nodes_v2" # type: ignore[misc] - machine: UUID | None + machine: UUID | None = Field(foreign_key="machines_v2.id", default=None, ondelete="CASCADE") class EdgeBase(BaseSQLModel): @@ -202,14 +195,20 @@ class MachineBase(BaseSQLModel): state: Any | None = Field(sa_column=Column("state", PickleType)) +class Machine(MachineBase, table=True): + """machines_v2 db table.""" + + __tablename__: str = "machines_v2" # type: ignore[misc] + + class ManifestBase(BaseSQLModel): """manifests_v2 db table""" id: UUID = Field(primary_key=True) name: str version: int - namespace: UUID - kind: Annotated[ManifestKind, EnumSerializer] = Field( + namespace: UUID = Field(foreign_key="campaigns_v2.id") + kind: KindField = Field( default=ManifestKind.other, sa_column=Column("kind", Enum(ManifestKind, length=20, native_enum=False, create_constraint=False)), ) @@ -239,14 +238,14 @@ class Manifest(ManifestBase, table=True): __tablename__: str = "manifests_v2" # type: ignore[misc] -class Task(SQLModel, table=True): +class Task(BaseSQLModel, table=True): """tasks_v2 db table""" __tablename__: str = "tasks_v2" # type: ignore[misc] id: UUID = Field(primary_key=True) - namespace: UUID - node: UUID + namespace: UUID = Field(foreign_key="campaigns_v2.id") + node: UUID = Field(foreign_key="nodes_v2.id") priority: int created_at: datetime last_processed_at: datetime @@ -255,10 +254,10 @@ class Task(SQLModel, table=True): site_affinity: list[str] = Field( sa_column=Column("site_affinity", MutableList.as_mutable(postgresql.ARRAY(String()))) ) - status: Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] = Field( + status: StatusField = Field( sa_column=Column("status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False)), ) - previous_status: Annotated[StatusEnum, StatusEnumValidator, EnumSerializer] = Field( + previous_status: StatusField = Field( sa_column=Column( "previous_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) ), @@ -267,11 +266,19 @@ class Task(SQLModel, table=True): class ActivityLogBase(BaseSQLModel): id: UUID = Field(primary_key=True) - namespace: UUID - node: UUID + namespace: UUID = Field(foreign_key="campaigns_v2.id") + node: UUID = Field(foreign_key="nodes_v2.id") operator: str - from_status: Annotated[StatusEnum, EnumSerializer] - to_status: Annotated[StatusEnum, EnumSerializer] + to_status: StatusField = Field( + sa_column=Column( + "to_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) + ), + ) + from_status: StatusField = Field( + sa_column=Column( + "from_status", Enum(StatusEnum, length=20, native_enum=False, create_constraint=False) + ), + ) detail: dict = jsonb_column("detail") diff --git a/src/lsst/cmservice/db/manifests_v2.py b/src/lsst/cmservice/db/manifests_v2.py index 90cef6d2..0b080ba4 100644 --- a/src/lsst/cmservice/db/manifests_v2.py +++ b/src/lsst/cmservice/db/manifests_v2.py @@ -1,34 +1,101 @@ -from typing import Annotated +"""Module for models representing generic CM Service manifests. -from pydantic import AliasChoices -from sqlmodel import Field, SQLModel +These manifests are used in APIs, especially when creating resources. They do +not necessarily represent the object's database or ORM model. +""" -from ..common.enums import ManifestKind -from .campaigns_v2 import EnumSerializer, ManifestKindEnumValidator +from typing import Self +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, ValidationInfo, model_validator -# this can probably be a BaseModel since this is not a db relation, but the -# distinction probably doesn't matter -class ManifestWrapper(SQLModel): - """a model for an object's Manifest wrapper, used by APIs where the `spec` - should be the kind's table model, more or less. +from ..common.enums import DEFAULT_NAMESPACE, ManifestKind +from ..common.types import KindField + + +class Manifest[MetadataT, SpecT](BaseModel): + """A parameterized model for an object's Manifest, used by APIs where the + `spec` should be the kind's table model, more or less. """ apiversion: str = Field(default="io.lsst.cmservice/v1") - kind: Annotated[ManifestKind, ManifestKindEnumValidator, EnumSerializer] = Field( - default=ManifestKind.other, - ) - metadata_: dict = Field( - default_factory=dict, - schema_extra={ - "validation_alias": AliasChoices("metadata", "metadata_"), - "serialization_alias": "metadata", - }, + kind: KindField = Field(default=ManifestKind.other) + metadata_: MetadataT = Field( + validation_alias=AliasChoices("metadata", "metadata_"), + serialization_alias="metadata", ) - spec: dict = Field( - default_factory=dict, - schema_extra={ - "validation_alias": AliasChoices("spec", "configuration", "data"), - "serialization_alias": "spec", - }, + spec: SpecT = Field( + validation_alias=AliasChoices("spec", "configuration", "data"), + serialization_alias="spec", ) + + +class ManifestMetadata(BaseModel): + """Generic metadata model for Manifests. + + Conventionally denormalized fields are excluded from the model_dump when + serialized for ORM use. + """ + + name: str + namespace: str + + +class ManifestSpec(BaseModel): + """Generic spec model for Manifests. + + Notes + ----- + Any spec body is allowed via config, but any fields that aren't first-class + fields won't be subject to validation or available as model attributes + except in the ``__pydantic_extra__`` dictionary. The full spec will be + expressed via ``model_dump()``. + """ + + model_config = ConfigDict(extra="allow") + + +class VersionedMetadata(ManifestMetadata): + """Metadata model for versioned Manifests.""" + + version: int = 0 + + +class ManifestModelMetadata(VersionedMetadata): + """Manifest model for general Manifests. These manifests are versioned but + a namespace is optional. + """ + + namespace: str = Field(default=str(DEFAULT_NAMESPACE)) + + +class ManifestModel(Manifest[ManifestModelMetadata, ManifestSpec]): + """Manifest model for generic Manifest handling.""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate an Campaign Manifest after a model has been created.""" + if self.kind in [ManifestKind.campaign, ManifestKind.node, ManifestKind.edge]: + raise ValueError(f"Manifests may not be a {self.kind.name} kind.") + + return self + + +class CampaignMetadata(BaseModel): + """Metadata model for a Campaign Manifest. + + Campaign metadata does not require a namespace field. + """ + + name: str + + +class CampaignManifest(Manifest[CampaignMetadata, ManifestSpec]): + """validating model for campaigns""" + + @model_validator(mode="after") + def custom_model_validator(self, info: ValidationInfo) -> Self: + """Validate an Campaign Manifest after a model has been created.""" + if self.kind is not ManifestKind.campaign: + raise ValueError("Campaigns may only be created from a manifest") + + return self diff --git a/src/lsst/cmservice/db/session.py b/src/lsst/cmservice/db/session.py index 3824fb5b..9a1f12c0 100644 --- a/src/lsst/cmservice/db/session.py +++ b/src/lsst/cmservice/db/session.py @@ -2,9 +2,9 @@ from collections.abc import AsyncGenerator -# from pydantic import SecretStr #noqa: ERA001 from sqlalchemy import URL, make_url from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlalchemy.pool import AsyncAdaptedQueuePool, Pool from sqlmodel.ext.asyncio.session import AsyncSession from ..config import config @@ -21,6 +21,7 @@ class DatabaseSessionDependency: engine: AsyncEngine | None sessionmaker: async_sessionmaker[AsyncSession] | None url: URL + pool_class: type[Pool] = AsyncAdaptedQueuePool def __init__(self) -> None: self.engine = None @@ -39,21 +40,23 @@ async def initialize( If true (default), the database drivername will be forced to an async form. """ + await self.aclose() if isinstance(config.db.url, str): self.url = make_url(config.db.url) if use_async and self.url.drivername == "postgresql": self.url = self.url.set(drivername="postgresql+asyncpg") - # FIXME use SecretStr for password - # if isinstance(config.db.password, SecretStr): - # password = config.db.password.get_secret_value() #noqa: ERA001 if config.db.password is not None: - self.url = self.url.set(password=config.db.password) - if self.engine: - await self.engine.dispose() + self.url = self.url.set(password=config.db.password.get_secret_value()) + pool_kwargs = ( + config.db.model_dump(include=config.db.pool_fields) + if self.pool_class is AsyncAdaptedQueuePool + else {} + ) self.engine = create_async_engine( url=self.url, echo=config.db.echo, - # TODO add pool-level configs + poolclass=self.pool_class, + **pool_kwargs, ) self.sessionmaker = async_sessionmaker(self.engine, class_=AsyncSession, expire_on_commit=False) @@ -76,7 +79,7 @@ async def aclose(self) -> None: if self.engine: self.sessionmaker = None await self.engine.dispose() - self._engine = None + self.engine = None db_session_dependency = DatabaseSessionDependency() diff --git a/src/lsst/cmservice/main.py b/src/lsst/cmservice/main.py index b1d4f846..db3133e0 100644 --- a/src/lsst/cmservice/main.py +++ b/src/lsst/cmservice/main.py @@ -13,83 +13,15 @@ from .routers import ( healthz, index, + tags_metadata, v1, + v2, ) from .web_app import web_app configure_uvicorn_logging(config.logging.level) configure_logging(profile=config.logging.profile, log_level=config.logging.level, name=config.asgi.title) -tags_metadata = [ - { - "name": "Loaders", - "description": "Operations that load Objects in to the DB.", - }, - { - "name": "Actions", - "description": "Operations perform actions on existing Objects in to the DB." - "In many cases this will result in the creating of new objects in the DB.", - }, - { - "name": "Campaigns", - "description": "Operations with `campaign`s. A `campaign` consists of several processing `step`s " - "which are run sequentially. A `campaign` also holds configuration such as a URL for a butler repo " - "and a production area. `campaign`s must be uniquely named withing a given `production`.", - }, - { - "name": "Steps", - "description": "Operations with `step`s. A `step` consists of several processing `group`s which " - "may be run in parallel. `step`s must be uniquely named within a give `campaign`.", - }, - { - "name": "Groups", - "description": "Operations with `groups`. A `group` can be processed in a single `workflow`, " - "but we also need to account for possible failures. `group`s must be uniquely named within a " - "given `step`.", - }, - { - "name": "Scripts", - "description": "Operations with `scripts`. A `script` does a single operation, either something" - "that is done asynchronously, such as making new collections in the Butler, or creating" - "new objects in the DB, such as new `steps` and `groups`.", - }, - { - "name": "Jobs", - "description": "Operations with `jobs`. A `job` runs a single `workflow`: keeps a count" - "of the results data products and keeps track of associated errors.", - }, - { - "name": "Pipetask Error Types", - "description": "Operations with `pipetask_error_type` table.", - }, - { - "name": "Pipetask Errors", - "description": "Operations with `pipetask_error` table.", - }, - { - "name": "Product Sets", - "description": "Operations with `product_set` table.", - }, - { - "name": "Task Sets", - "description": "Operations with `task_set` table.", - }, - { - "name": "Script Dependencies", - "description": "Operations with `script_dependency` table.", - }, - { - "name": "Step Dependencies", - "description": "Operations with `step_dependency` table.", - }, - { - "name": "Wms Task Reports", - "description": "Operations with `wms_task_report` table.", - }, - {"name": "Specifications", "description": "Operations with `specification` table."}, - {"name": "SpecBlocks", "description": "Operations with `spec_block` table."}, -] - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: @@ -122,9 +54,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: app.include_router(healthz.health_router, prefix="") app.include_router(index.router, prefix="") app.include_router(v1.router, prefix=config.asgi.prefix) +app.include_router(v2.router, prefix=config.asgi.prefix) # Start the frontend web application. -app.mount(config.asgi.frontend_prefix, web_app) +if config.asgi.enable_frontend: + app.mount(config.asgi.frontend_prefix, web_app) if __name__ == "__main__": diff --git a/src/lsst/cmservice/models/serde.py b/src/lsst/cmservice/models/serde.py new file mode 100644 index 00000000..02ac5c46 --- /dev/null +++ b/src/lsst/cmservice/models/serde.py @@ -0,0 +1,44 @@ +"""Module for serialization and deserialization support for pydantic and +other derivative models. +""" + +from enum import EnumType +from functools import partial +from typing import Any + +from pydantic import PlainSerializer, PlainValidator + +from ..common.enums import ManifestKind, StatusEnum + + +def EnumValidator[T: EnumType](value: Any, enum_: T) -> T: + """Create an enum from the input value. The input can be either the + enum name or its value. + + Used as a Validator for a pydantic field. + """ + try: + new_enum: T = enum_[value] if value in enum_.__members__ else enum_(value) + except (KeyError, ValueError): + raise ValueError(f"Value must be a member of {enum_.__qualname__}") + return new_enum + + +EnumSerializer = PlainSerializer( + lambda x: x.name, + return_type="str", + when_used="always", +) +"""A serializer for enums that produces its name, not the value.""" + + +StatusEnumValidator = PlainValidator(partial(EnumValidator, enum_=StatusEnum)) +"""A validator for the StatusEnum that can parse the enum from either a name +or a value. +""" + + +ManifestKindEnumValidator = PlainValidator(partial(EnumValidator, enum_=ManifestKind)) +"""A validator for the ManifestKindEnum that can parse the enum from a name +or a value. +""" diff --git a/src/lsst/cmservice/routers/__init__.py b/src/lsst/cmservice/routers/__init__.py index ecda6158..3ae25ad8 100644 --- a/src/lsst/cmservice/routers/__init__.py +++ b/src/lsst/cmservice/routers/__init__.py @@ -40,3 +40,101 @@ wms_task_reports, wrappers, ) + +tags_metadata = [ + { + "name": "actions", + "description": "Operations perform actions on existing Objects in to the DB." + "In many cases this will result in the creating of new objects in the DB.", + }, + { + "name": "campaigns", + "description": "Operations with `campaign`s. A `campaign` consists of several processing `step`s " + "which are run sequentially. A `campaign` also holds configuration such as a URL for a butler repo " + "and a production area. `campaign`s must be uniquely named withing a given `production`.", + }, + { + "name": "edges", + "description": "Operations with `edge`s within a `campaign` graph.", + }, + { + "name": "groups", + "description": "Operations with `groups`. A `group` can be processed in a single `workflow`, " + "but we also need to account for possible failures. `group`s must be uniquely named within a " + "given `step`.", + }, + { + "name": "health", + "description": "Operations that check or report on the health of the application", + }, + { + "name": "internal", + "description": "Operations that are used by processes that are internal to the application", + }, + { + "name": "jobs", + "description": "Operations with `jobs`. A `job` runs a single `workflow`: keeps a count" + "of the results data products and keeps track of associated errors.", + }, + { + "name": "loaders", + "description": "Operations that load Objects in to the DB.", + }, + { + "name": "manifests", + "description": "Operations on manifests.", + }, + { + "name": "pipetask error types", + "description": "Operations with `pipetask_error_type` table.", + }, + { + "name": "pipetask errors", + "description": "Operations with `pipetask_error` table.", + }, + { + "name": "product sets", + "description": "Operations with `product_set` table.", + }, + { + "name": "scripts", + "description": "Operations with `scripts`. A `script` does a single operation, either something" + "that is done asynchronously, such as making new collections in the Butler, or creating" + "new objects in the DB, such as new `steps` and `groups`.", + }, + { + "name": "script dependencies", + "description": "Operations with `script_dependency` table.", + }, + { + "name": "script errors", + "description": "Operations with `script_errors` table.", + }, + {"name": "spec blocks", "description": "Operations with `spec_block` table."}, + {"name": "specifications", "description": "Operations with `specification` table."}, + { + "name": "steps", + "description": "Operations with `step`s. A `step` consists of several processing `group`s which " + "may be run in parallel. `step`s must be uniquely named within a give `campaign`.", + }, + { + "name": "step dependencies", + "description": "Operations with `step_dependency` table.", + }, + { + "name": "task sets", + "description": "Operations with `task_set` table.", + }, + { + "name": "v1", + "description": "Operations associated with the v1 legacy application", + }, + { + "name": "v2", + "description": "Operations associated with the v2 application", + }, + { + "name": "wms task reports", + "description": "Operations with `wms_task_report` table.", + }, +] diff --git a/src/lsst/cmservice/routers/actions.py b/src/lsst/cmservice/routers/actions.py index a199b323..41342c32 100644 --- a/src/lsst/cmservice/routers/actions.py +++ b/src/lsst/cmservice/routers/actions.py @@ -10,7 +10,7 @@ router = APIRouter( prefix="/actions", - tags=["Actions"], + tags=["actions"], ) diff --git a/src/lsst/cmservice/routers/campaigns.py b/src/lsst/cmservice/routers/campaigns.py index 0956aa49..3df4ee16 100644 --- a/src/lsst/cmservice/routers/campaigns.py +++ b/src/lsst/cmservice/routers/campaigns.py @@ -29,14 +29,12 @@ UpdateModelClass = models.CampaignUpdate # Specify the associated database table DbClass = db.Campaign -# Specify the tag in the router documentation -TAG_STRING = "Campaigns" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["campaigns"], ) # Attach functions to the router diff --git a/src/lsst/cmservice/routers/groups.py b/src/lsst/cmservice/routers/groups.py index ddada514..212c9cc1 100644 --- a/src/lsst/cmservice/routers/groups.py +++ b/src/lsst/cmservice/routers/groups.py @@ -18,14 +18,12 @@ UpdateModelClass = models.GroupUpdate # Specify the associated database table DbClass = db.Group -# Specify the tag in the router documentation -TAG_STRING = "Groups" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["groups"], ) diff --git a/src/lsst/cmservice/routers/jobs.py b/src/lsst/cmservice/routers/jobs.py index 9e81b09d..08be3723 100644 --- a/src/lsst/cmservice/routers/jobs.py +++ b/src/lsst/cmservice/routers/jobs.py @@ -24,14 +24,12 @@ UpdateModelClass = models.JobUpdate # Specify the associated database table DbClass = db.Job -# Specify the tag in the router documentation -TAG_STRING = "Jobs" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["jobs"], ) diff --git a/src/lsst/cmservice/routers/loaders.py b/src/lsst/cmservice/routers/loaders.py index 225c7880..bb293837 100644 --- a/src/lsst/cmservice/routers/loaders.py +++ b/src/lsst/cmservice/routers/loaders.py @@ -9,7 +9,7 @@ router = APIRouter( prefix="/load", - tags=["Loaders"], + tags=["loaders"], ) diff --git a/src/lsst/cmservice/routers/pipetask_error_types.py b/src/lsst/cmservice/routers/pipetask_error_types.py index 7139199e..37b40db1 100644 --- a/src/lsst/cmservice/routers/pipetask_error_types.py +++ b/src/lsst/cmservice/routers/pipetask_error_types.py @@ -14,14 +14,12 @@ UpdateModelClass = models.PipetaskErrorTypeUpdate # Specify the associated database table DbClass = db.PipetaskErrorType -# Specify the tag in the router documentation -TAG_STRING = "Pipetask Error Types" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["pipetask error types"], ) diff --git a/src/lsst/cmservice/routers/pipetask_errors.py b/src/lsst/cmservice/routers/pipetask_errors.py index 352ef24f..07dd8bf8 100644 --- a/src/lsst/cmservice/routers/pipetask_errors.py +++ b/src/lsst/cmservice/routers/pipetask_errors.py @@ -14,14 +14,12 @@ UpdateModelClass = models.PipetaskErrorUpdate # Specify the associated database table DbClass = db.PipetaskError -# Specify the tag in the router documentation -TAG_STRING = "Pipetask Errors" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["pipetask errors"], ) diff --git a/src/lsst/cmservice/routers/product_sets.py b/src/lsst/cmservice/routers/product_sets.py index 34875b8f..29683dac 100644 --- a/src/lsst/cmservice/routers/product_sets.py +++ b/src/lsst/cmservice/routers/product_sets.py @@ -14,14 +14,12 @@ UpdateModelClass = models.ProductSetUpdate # Specify the associated database table DbClass = db.ProductSet -# Specify the tag in the router documentation -TAG_STRING = "Product Sets" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["product sets"], ) diff --git a/src/lsst/cmservice/routers/queues.py b/src/lsst/cmservice/routers/queues.py index 02c52c50..92b1f924 100644 --- a/src/lsst/cmservice/routers/queues.py +++ b/src/lsst/cmservice/routers/queues.py @@ -22,14 +22,12 @@ UpdateModelClass = models.QueueUpdate # Specify the associated database table DbClass = db.Queue -# Specify the tag in the router documentation -TAG_STRING = "Queues" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["queues"], ) diff --git a/src/lsst/cmservice/routers/script_dependencies.py b/src/lsst/cmservice/routers/script_dependencies.py index 9bc9d720..25866e19 100644 --- a/src/lsst/cmservice/routers/script_dependencies.py +++ b/src/lsst/cmservice/routers/script_dependencies.py @@ -12,14 +12,12 @@ CreateModelClass = models.DependencyCreate # Specify the associated database table DbClass = db.ScriptDependency -# Specify the tag in the router documentation -TAG_STRING = "Script Dependencies" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["script dependencies"], ) diff --git a/src/lsst/cmservice/routers/script_errors.py b/src/lsst/cmservice/routers/script_errors.py index 4b71d0b1..35a7cec4 100644 --- a/src/lsst/cmservice/routers/script_errors.py +++ b/src/lsst/cmservice/routers/script_errors.py @@ -14,14 +14,12 @@ UpdateModelClass = models.ScriptErrorUpdate # Specify the associated database table DbClass = db.ScriptError -# Specify the tag in the router documentation -TAG_STRING = "ScriptErrors" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["script errors"], ) diff --git a/src/lsst/cmservice/routers/scripts.py b/src/lsst/cmservice/routers/scripts.py index ed6ad47e..903bd634 100644 --- a/src/lsst/cmservice/routers/scripts.py +++ b/src/lsst/cmservice/routers/scripts.py @@ -21,14 +21,12 @@ UpdateModelClass = models.ScriptUpdate # Specify the associated database table DbClass = db.Script -# Specify the tag in the router documentation -TAG_STRING = "Scripts" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["scripts"], ) diff --git a/src/lsst/cmservice/routers/spec_blocks.py b/src/lsst/cmservice/routers/spec_blocks.py index ab469c4d..9e3bb297 100644 --- a/src/lsst/cmservice/routers/spec_blocks.py +++ b/src/lsst/cmservice/routers/spec_blocks.py @@ -14,14 +14,12 @@ UpdateModelClass = models.SpecBlockUpdate # Specify the associated database table DbClass = db.SpecBlock -# Specify the tag in the router documentation -TAG_STRING = "SpecBlocks" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["specblocks"], ) diff --git a/src/lsst/cmservice/routers/specifications.py b/src/lsst/cmservice/routers/specifications.py index 47af01b1..85c5ad16 100644 --- a/src/lsst/cmservice/routers/specifications.py +++ b/src/lsst/cmservice/routers/specifications.py @@ -14,14 +14,12 @@ UpdateModelClass = models.SpecificationUpdate # Specify the associated database table DbClass = db.Specification -# Specify the tag in the router documentation -TAG_STRING = "Specifications" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["specifications"], ) diff --git a/src/lsst/cmservice/routers/step_dependencies.py b/src/lsst/cmservice/routers/step_dependencies.py index fc8eeb5f..642f7e57 100644 --- a/src/lsst/cmservice/routers/step_dependencies.py +++ b/src/lsst/cmservice/routers/step_dependencies.py @@ -12,14 +12,12 @@ CreateModelClass = models.DependencyCreate # Specify the associated database table DbClass = db.StepDependency -# Specify the tag in the router documentation -TAG_STRING = "Step Dependencies" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["step dependencies"], ) diff --git a/src/lsst/cmservice/routers/steps.py b/src/lsst/cmservice/routers/steps.py index 707faae3..c3a92679 100644 --- a/src/lsst/cmservice/routers/steps.py +++ b/src/lsst/cmservice/routers/steps.py @@ -14,14 +14,12 @@ UpdateModelClass = models.StepUpdate # Specify the associated database table DbClass = db.Step -# Specify the tag in the router documentation -TAG_STRING = "Steps" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["steps"], ) diff --git a/src/lsst/cmservice/routers/task_sets.py b/src/lsst/cmservice/routers/task_sets.py index 07f223f3..e459c7b2 100644 --- a/src/lsst/cmservice/routers/task_sets.py +++ b/src/lsst/cmservice/routers/task_sets.py @@ -14,14 +14,12 @@ UpdateModelClass = models.TaskSetUpdate # Specify the associated database table DbClass = db.TaskSet -# Specify the tag in the router documentation -TAG_STRING = "Task Sets" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["task sets"], ) diff --git a/src/lsst/cmservice/routers/v2/__init__.py b/src/lsst/cmservice/routers/v2/__init__.py new file mode 100644 index 00000000..94d0c1a8 --- /dev/null +++ b/src/lsst/cmservice/routers/v2/__init__.py @@ -0,0 +1,13 @@ +from fastapi import APIRouter + +from . import ( + campaigns, + manifests, +) + +router = APIRouter( + prefix="/v2", +) + +router.include_router(campaigns.router) +router.include_router(manifests.router) diff --git a/src/lsst/cmservice/routers/v2/campaigns.py b/src/lsst/cmservice/routers/v2/campaigns.py new file mode 100644 index 00000000..dfcb8033 --- /dev/null +++ b/src/lsst/cmservice/routers/v2/campaigns.py @@ -0,0 +1,335 @@ +"""http routers for managing Campaign tables. + +The /campaigns endpoint supports a collection resource and single resources +representing campaign objects within CM-Service. +""" + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Annotated +from uuid import UUID, uuid5 + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from sqlalchemy.orm import aliased +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from ...common.logging import LOGGER +from ...db.campaigns_v2 import Campaign, CampaignUpdate, Edge, Node +from ...db.manifests_v2 import CampaignManifest +from ...db.session import db_session_dependency + +# TODO should probably bind a logger to the fastapi app or something +logger = LOGGER.bind(module=__name__) + + +# Build the router +router = APIRouter( + prefix="/campaigns", + tags=["campaigns", "v2"], +) + + +@router.get( + "/", + summary="Get a list of campaigns", +) +async def read_campaign_collection( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + limit: Annotated[int, Query(le=100)] = 10, + offset: Annotated[int, Query()] = 0, +) -> Sequence[Campaign]: + """...""" + try: + campaigns = await session.exec(select(Campaign).offset(offset).limit(limit)) + + response.headers["Next"] = str( + request.url_for("read_campaign_collection").include_query_params( + offset=(offset + limit), limit=limit + ) + ) + if offset > 0: + response.headers["Previous"] = str( + request.url_for("read_campaign_collection").include_query_params( + offset=(offset - limit), limit=limit + ) + ) + return campaigns.all() + except Exception as msg: + logger.exception() + raise HTTPException(status_code=500, detail=f"{str(msg)}") from msg + + +@router.get( + "/{campaign_name}", + summary="Get campaign detail", +) +async def read_campaign_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, +) -> Campaign: + """Fetch a single campaign from the database given either the campaign id + or its name. + """ + s = select(Campaign) + # The input could be a campaign UUID or it could be a literal name. + try: + if campaign_id := UUID(campaign_name): + s = s.where(Campaign.id == campaign_id) + except ValueError: + s = s.where(Campaign.name == campaign_name) + + campaign = (await session.exec(s)).one_or_none() + # set the response headers + if campaign is not None: + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) + response.headers["Nodes"] = str( + request.url_for("read_campaign_node_collection", campaign_name=campaign.id) + ) + response.headers["Edges"] = str( + request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + ) + return campaign + else: + raise HTTPException(status_code=404) + + +@router.patch( + "/{campaign_name}", + summary="Update campaign detail", + status_code=202, +) +async def update_campaign_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, + patch_data: CampaignUpdate, +) -> Campaign: + """Partial update method for campaigns. + + Should primarily be used to set the status of a campaign, e.g., from + waiting->ready, in order to trigger any validation rules contained in that + transition. + + Another common use case would be to set status to "paused". + + This could be used to update a campaign's metadata, but otherwise the + status is the only field available for modification, and even then there is + not an imperative "change the status" command, rather a request to evolve + the state of a campaign from A to B, which may or may not be successful. + + Rather than manipulating the campaign's record, a change to status should + instead create a work item for the task processing queue for an executor + to discover and attempt to act upon. Barring that, the work should be + delegated to a Background Task. This is why the method returns a 202; the + user needs to check back "later" to see if the requested state change has + occurred. + """ + use_rfc7396 = False + use_rfc6902 = False + mutable_fields = [] + if request.headers["Content-Type"] == "application/merge-patch+json": + use_rfc7396 = True + mutable_fields.extend(["owner", "status"]) + elif request.headers["Content-Type"] == "application/json-patch+json": + use_rfc6902 = True + mutable_fields.extend(["configuration", "metadata_"]) + raise HTTPException(status_code=501, detail="Not yet implemented.") + else: + raise HTTPException(status_code=406, detail="Unsupported Content-Type") + + if TYPE_CHECKING: + assert use_rfc7396 + assert not use_rfc6902 + s = select(Campaign) + # The input could be a campaign UUID or it could be a literal name. + try: + if campaign_id := UUID(campaign_name): + s = s.where(Campaign.id == campaign_id) + except ValueError: + s = s.where(Campaign.name == campaign_name) + + campaign = (await session.exec(s)).one_or_none() + if campaign is None: + raise HTTPException(status_code=404, detail="No such campaign") + + # update the campaign with the patch data + update_data = patch_data.model_dump(exclude_unset=True) + campaign.sqlmodel_update(update_data) + session.add(campaign) + await session.commit() + await session.refresh(campaign) + # set the response headers + if campaign is not None: + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) + response.headers["Nodes"] = str( + request.url_for("read_campaign_node_collection", campaign_name=campaign.id) + ) + response.headers["Edges"] = str( + request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + ) + return campaign + + +@router.get( + "/{campaign_name}/nodes", + summary="Get campaign Nodes", +) +async def read_campaign_node_collection( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, + limit: Annotated[int, Query(le=100)] = 10, + offset: Annotated[int, Query()] = 0, +) -> Sequence[Node]: + # This is a convenience api that could also be `/nodes?campaign=... + + # The input could be a campaign UUID or it could be a literal name. + # TODO this could just as well be a campaign query with a join to nodes + s = select(Node) + try: + if campaign_id := UUID(campaign_name): + s = s.where(Node.namespace == campaign_id) + except ValueError: + # FIXME get an id from a name + raise HTTPException(status_code=422, detail="campaign_name must be a uuid") + s = s.offset(offset).limit(limit) + nodes = await session.exec(s) + response.headers["Next"] = str( + request.url_for( + "read_campaign_node_collection", + campaign_name=campaign_name, + ).include_query_params(offset=(offset + limit), limit=limit), + ) + # TODO Previous + return nodes.all() + + +@router.get( + "/{campaign_name}/edges", + summary="Get campaign Edges", +) +async def read_campaign_edge_collection( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, + *, + resolve_names: bool = False, +) -> Sequence[Edge]: + # This is a convenience api that could also be `/edges?campaign=... + + # The input could be a campaign UUID or it could be a literal name. + # This is why raw SQL is better than ORMs + # This is probably better off as two queries instead of a "complicated" + # join. + if resolve_names: + source_nodes = aliased(Node, name="source") + target_nodes = aliased(Node, name="target") + s = ( + select( + col(Edge.id).label("id"), + col(Edge.name).label("name"), + col(Edge.namespace).label("namespace"), + col(source_nodes.name).label("source"), + col(target_nodes.name).label("target"), + col(Edge.configuration).label("configuration"), + ) # type: ignore + .join_from(Edge, source_nodes, Edge.source == source_nodes.id) + .join_from(Edge, target_nodes, Edge.target == target_nodes.id) + ) + else: + s = select(Edge) + try: + if campaign_id := UUID(campaign_name): + s = s.where(Edge.namespace == campaign_id) + except ValueError: + # FIXME get an id from a name + raise HTTPException(status_code=422, detail="campaign_name must be a uuid") + edges = await session.exec(s) + return edges.all() + + +@router.delete( + "/{campaign_name}/edges/{edge_name}", + summary="Delete campaign edge", + status_code=204, +) +async def delete_campaign_edge_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + campaign_name: str, + edge_name: str, +) -> None: + """Delete an edge resource from the campaign, using either name or id.""" + # If the campaign name is not a uuid, find the appropriate id + try: + campaign_id = UUID(campaign_name) + except ValueError: + # FIXME get an id from a name + raise HTTPException(status_code=422, detail="campaign_name must be a uuid") + + try: + edge_id = UUID(edge_name) + except ValueError: + edge_id = uuid5(campaign_id, edge_name) + + s = select(Edge).where(Edge.id == edge_id) + edge_to_delete = (await session.exec(s)).one_or_none() + + if edge_to_delete is not None: + await session.delete(edge_to_delete) + await session.commit() + else: + raise HTTPException(status_code=404, detail="No such edge.") + return None + + +@router.post( + "/", + summary="Add a campaign resource", +) +async def create_campaign_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + manifest: CampaignManifest, +) -> Campaign: + # Create a campaign spec from the manifest, delegating the creation of new + # dynamic fields to the model validation method, -OR- create new dynamic + # fields here. + campaign = Campaign.model_validate( + dict( + name=manifest.metadata_.name, + metadata_=manifest.metadata_.model_dump(), + # owner = ... # TODO Get username from gafaelfawr # noqa: ERA001 + ) + ) + + # A new campaign comes with a START and END node + start_node = Node.model_validate(dict(name="START", namespace=campaign.id)) + end_node = Node.model_validate(dict(name="END", namespace=campaign.id)) + + # Put the campaign in the database + session.add(campaign) + session.add(start_node) + session.add(end_node) + await session.commit() + await session.refresh(campaign) + + # set the response headers + response.headers["Self"] = str(request.url_for("read_campaign_resource", campaign_name=campaign.id)) + response.headers["Nodes"] = str( + request.url_for("read_campaign_node_collection", campaign_name=campaign.id) + ) + response.headers["Edges"] = str( + request.url_for("read_campaign_edge_collection", campaign_name=campaign.id) + ) + + return campaign diff --git a/src/lsst/cmservice/routers/v2/manifests.py b/src/lsst/cmservice/routers/v2/manifests.py new file mode 100644 index 00000000..308f4a33 --- /dev/null +++ b/src/lsst/cmservice/routers/v2/manifests.py @@ -0,0 +1,248 @@ +"""http routers for managing Manifest tables. + +The /manifests endpoint supports a collection resource and single resources +representing manifest objects within CM-Service. +""" + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Annotated +from uuid import UUID, uuid5 + +from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response +from sqlmodel import col, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from ...common.jsonpatch import JSONPatch, JSONPatchError, apply_json_patch +from ...common.logging import LOGGER +from ...db.campaigns_v2 import Campaign, Manifest, _default_campaign_namespace +from ...db.manifests_v2 import ManifestModel +from ...db.session import db_session_dependency + +# TODO should probably bind a logger to the fastapi app or something +logger = LOGGER.bind(module=__name__) + + +# Build the router +router = APIRouter( + prefix="/manifests", + tags=["manifests", "v2"], +) + + +@router.get( + "/", + summary="Get a list of manifests", +) +async def read_manifest_collection( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + offset: Annotated[int, Query()] = 0, + limit: Annotated[int, Query(le=100)] = 10, +) -> Sequence[Manifest]: + """Gets all manifests""" + response.headers["Next"] = str( + request.url_for("read_manifest_collection").include_query_params(offset=(offset + limit), limit=limit) + ) + if offset > 0: + response.headers["Previous"] = str( + request.url_for("read_manifest_collection").include_query_params( + offset=(offset - limit), limit=limit + ) + ) + try: + nodes = await session.exec(select(Manifest).offset(offset).limit(limit)) + return nodes.all() + except Exception as msg: + logger.exception() + raise HTTPException(status_code=500, detail=f"{str(msg)}") from msg + + +@router.get( + "/{manifest_name_or_id}", + summary="Get manifest detail", +) +async def read_single_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + manifest_name_or_id: str, + manifest_version: Annotated[int | None, Query(ge=0, alias="version")] = None, +) -> Manifest: + """Fetch a single manifest from the database given either an id or name. + + When available, only the most recent version of the Manifest is returned, + unless the version is provided as part of the query string. + """ + s = select(Manifest) + # The input could be a UUID or it could be a literal name. + try: + if _id := UUID(manifest_name_or_id): + s = s.where(Manifest.id == _id) + except ValueError: + s = s.where(Manifest.name == manifest_name_or_id) + + if manifest_version is None: + s = s.order_by(col(Manifest.version).desc()).limit(1) + else: + s = s.where(Manifest.version == manifest_version) + + manifest = (await session.exec(s)).one_or_none() + if manifest is not None: + response.headers["Self"] = str( + request.url_for("read_single_resource", manifest_name_or_id=manifest.id) + ) + return manifest + else: + raise HTTPException(status_code=404) + + +@router.post( + "/", + summary="Add a manifest resource", + status_code=204, +) +async def create_one_or_more_manifests( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + manifests: ManifestModel | list[ManifestModel], +) -> None: + # We could be given a single manifest or a list of them. In the singleton + # case, wrap it in a list so we can treat everything equally + if not isinstance(manifests, list): + manifests = [manifests] + + for manifest in manifests: + _name = manifest.metadata_.name + + # A manifest must exist in the namespace of an existing campaign + # or the default namespace + _namespace: str | None = manifest.metadata_.namespace + if _namespace is None: + _namespace_uuid = _default_campaign_namespace + else: + try: + _namespace_uuid = UUID(_namespace) + except ValueError: + # get the campaign ID by its name to use as a namespace + # it is an error if the namespace/campaign does not exist + # FIXME but this could also be handled by FK constraints + if ( + _campaign_id := ( + await session.exec(select(Campaign.id).where(Campaign.name == _namespace)) + ).one_or_none() + ) is None: + raise HTTPException(status_code=422, detail="Requested namespace does not exist.") + _namespace_uuid = _campaign_id + + # A manifest must be a new version if name+namespace already exists + # check db for manifest as name+namespace, get version and increment + + s = ( + select(Manifest) + .where(Manifest.name == _name) + .where(Manifest.namespace == _namespace_uuid) + .order_by(col(Manifest.version).desc()) + .limit(1) + ) + + _previous = (await session.exec(s)).one_or_none() + _version = _previous.version if _previous else manifest.metadata_.version + _version += 1 + _manifest = Manifest( + id=uuid5(_namespace_uuid, f"{_name}.{_version}"), + name=_name, + namespace=_namespace_uuid, + kind=manifest.kind, + version=_version, + metadata_=manifest.metadata_.model_dump(), + spec=manifest.spec.model_dump(), + ) + + # Put the node in the database + session.add(_manifest) + + await session.commit() + return None + + +@router.patch( + "/{manifest_name_or_id}", + summary="Update manifest detail", + status_code=202, +) +async def update_manifest_resource( + request: Request, + response: Response, + session: Annotated[AsyncSession, Depends(db_session_dependency)], + manifest_name_or_id: str, + patch_data: Sequence[JSONPatch], +) -> Manifest: + """Partial update method for manifests. + + A Manifest's spec or metadata may be updated with this PATCH operation. All + updates to a Manifest creates a new version of the Manifest instead of + updating an existing record in-place. This preserves history and keeps + previous manifest versions available. + + A Manifest's name, id, kind, or namespace may not be modified by this + method, and attempts to do so will produce a 4XX client error. + + This PATCH endpoint supports only RFC6902 json-patch requests. + + Notes + ----- + - This API always targets the latest version of a manifest when applying + a patch. This requires and maintains a "linear" sequence of versions; + it is not permissible to "patch" a previous version and create a "tree"- + like history of manifests. For example, every manifest may be diffed + against any previous version without having to consider branches. + """ + use_rfc6902 = False + if request.headers["Content-Type"] == "application/json-patch+json": + use_rfc6902 = True + else: + raise HTTPException(status_code=406, detail="Unsupported Content-Type") + + if TYPE_CHECKING: + assert use_rfc6902 + + s = select(Manifest) + # The input could be a UUID or it could be a literal name. + try: + if _id := UUID(manifest_name_or_id): + s = s.where(Manifest.id == _id) + except ValueError: + s = s.where(Manifest.name == manifest_name_or_id) + + # we want to order and sort by version, in descending order, so we always + # fetch only the most recent version of manifest + # FIXME this implies that when a manifest ID is provided, it should be an + # error if it is not the most recent version. + s = s.order_by(col(Manifest.version).desc()).limit(1) + + old_manifest = (await session.exec(s)).one_or_none() + if old_manifest is None: + raise HTTPException(status_code=404, detail="No such campaign") + + new_manifest = old_manifest.model_dump(by_alias=True) + new_manifest["version"] += 1 + new_manifest["id"] = uuid5(new_manifest["namespace"], f"{new_manifest['name']}.{new_manifest['version']}") + + for patch in patch_data: + try: + apply_json_patch(patch, new_manifest) + except JSONPatchError as e: + raise HTTPException( + status_code=422, + detail=f"Unable to process one or more patch operations: {e}", + ) + + # create Manifest from new_manifest, add to session, and commit + new_manifest_db = Manifest.model_validate(new_manifest) + session.add(new_manifest_db) + await session.commit() + + # TODO response headers + return new_manifest_db diff --git a/src/lsst/cmservice/routers/wms_task_reports.py b/src/lsst/cmservice/routers/wms_task_reports.py index 9355aa33..04e63a04 100644 --- a/src/lsst/cmservice/routers/wms_task_reports.py +++ b/src/lsst/cmservice/routers/wms_task_reports.py @@ -14,14 +14,12 @@ UpdateModelClass = models.WmsTaskReportUpdate # Specify the associated database table DbClass = db.WmsTaskReport -# Specify the tag in the router documentation -TAG_STRING = "Wms Task Reports" # Build the router router = APIRouter( prefix=f"/{DbClass.class_string}", - tags=[TAG_STRING], + tags=["wms task reports"], ) diff --git a/tests/common/test_jsonpatch.py b/tests/common/test_jsonpatch.py new file mode 100644 index 00000000..73e82391 --- /dev/null +++ b/tests/common/test_jsonpatch.py @@ -0,0 +1,150 @@ +from typing import Any + +import pytest + +from lsst.cmservice.common.jsonpatch import JSONPatch, JSONPatchError, apply_json_patch + + +@pytest.fixture +def target_object() -> dict[str, Any]: + return { + "apiVersion": "io.lsst.cmservice/v1", + "spec": { + "one": 1, + "two": 2, + "three": 4, + "a_list": ["a", "b", "c", "e"], + "tag_list": ["yes", "yeah", "yep"], + }, + "metadata": { + "owner": "bob_loblaw", + }, + } + + +def test_jsonpatch_add(target_object: dict[str, Any]) -> None: + """Tests the use of an add operation with a JSON Patch.""" + + # Fail to add a value to an element that does not exist + op = JSONPatch(op="add", path="/spec/b_list/0", value="a") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # Fix the missing "four" property in the spec + op = JSONPatch(op="add", path="/spec/four", value=4) + target_object = apply_json_patch(op, target_object) + assert target_object["spec"].get("four") == 4 + + # Insert the missing "d" value in the spec's a_list property + op = JSONPatch(op="add", path="/spec/a_list/3", value="d") + target_object = apply_json_patch(op, target_object) + assert target_object["spec"].get("a_list")[3] == "d" + assert target_object["spec"].get("a_list")[4] == "e" + + # Append to an existing list using "-" + op = JSONPatch(op="add", path="/spec/a_list/-", value="f") + target_object = apply_json_patch(op, target_object) + assert len(target_object["spec"]["a_list"]) == 6 + assert target_object["spec"]["a_list"][-1] == "f" + + +def test_jsonpatch_replace(target_object: dict[str, Any]) -> None: + """Tests the use of a replace operation with a JSON Patch.""" + + # Fail to replace a value for a missing key + op = JSONPatch(op="replace", path="/spec/five", value=5) + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # Fail to replace a value for a missing index + op = JSONPatch(op="replace", path="/spec/a_list/4", value="e") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # Fix the incorrect "three" property in the spec + op = JSONPatch(op="replace", path="/spec/three", value=3) + target_object = apply_json_patch(op, target_object) + assert target_object["spec"]["three"] == 3 + + +def test_jsonpatch_remove(target_object: dict[str, Any]) -> None: + """Tests the use of a remove operation with a JSON Patch.""" + + # Remove the first element ("a") of the "a_list" property in the spec + op = JSONPatch(op="remove", path="/spec/a_list/0") + target_object = apply_json_patch(op, target_object) + assert target_object["spec"]["a_list"][0] == "b" + + # Remove the a non-existent index from the same list (not an error) + op = JSONPatch(op="remove", path="/spec/a_list/8") + target_object = apply_json_patch(op, target_object) + assert len(target_object["spec"]["a_list"]) == 3 + + # Remove the previously added key "four" element in the spec + op = JSONPatch(op="remove", path="/spec/four") + target_object = apply_json_patch(op, target_object) + assert "four" not in target_object["spec"].keys() + + # Repeat the previous removal (not an error) + op = JSONPatch(op="remove", path="/spec/four") + target_object = apply_json_patch(op, target_object) + + +def test_jsonpatch_move(target_object: dict[str, Any]) -> None: + """Tests the use of a move operation with a JSON Patch.""" + + # move the tags list from spec to metadata + op = JSONPatch(op="move", path="/metadata/tag_list", from_="/spec/tag_list") + target_object = apply_json_patch(op, target_object) + assert "tag_list" not in target_object["spec"].keys() + assert "tag_list" in target_object["metadata"].keys() + + # Fail to move a nonexistent object + op = JSONPatch(op="move", path="/spec/yes_such_list", from_="/spec/no_such_list") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + +def test_jsonpatch_copy(target_object: dict[str, Any]) -> None: + """Tests the use of a copy operation with a JSON Patch.""" + + # copy the owner from metadata to spec as the name "pilot" + op = JSONPatch(op="copy", path="/spec/pilot", from_="/metadata/owner") + target_object = apply_json_patch(op, target_object) + assert target_object["spec"]["pilot"] == target_object["metadata"]["owner"] + + # Fail to copy a nonexistent object + op = JSONPatch(op="copy", path="/spec/yes_such_list", from_="/spec/no_such_list") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + +def test_jsonpatch_test(target_object: dict[str, Any]) -> None: + """Tests the use of a test/assert operation with a JSON Patch.""" + + # test successful assertion + op = JSONPatch(op="test", path="/metadata/owner", value="bob_loblaw") + _ = apply_json_patch(op, target_object) + + op = JSONPatch(op="test", path="/spec/a_list/0", value="a") + _ = apply_json_patch(op, target_object) + + # test value mismatch + op = JSONPatch(op="test", path="/metadata/owner", value="bob_alice") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # test missing key + op = JSONPatch(op="test", path="/metadata/pilot", value="bob_alice") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # test missing index + op = JSONPatch(op="test", path="/spec/a_list/8", value="bob_alice") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) + + # test bad reference token + op = JSONPatch(op="test", path="/spec/a_list/-", value="bob_alice") + with pytest.raises(JSONPatchError): + _ = apply_json_patch(op, target_object) diff --git a/tests/conftest.py b/tests/conftest.py index c7d18291..00997a3d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import importlib from collections.abc import AsyncIterator, Iterator from typing import Any @@ -12,7 +13,7 @@ from safir.testing.uvicorn import UvicornProcess, spawn_uvicorn from sqlalchemy.ext.asyncio import AsyncEngine -from lsst.cmservice import db, main +from lsst.cmservice import db from lsst.cmservice.common.enums import ScriptMethodEnum from lsst.cmservice.config import config as config_ @@ -23,6 +24,8 @@ def set_app_config(monkeypatch: Any) -> None: config_.script_handler = ScriptMethodEnum.bash +# FIXME this fixture should be refactored not to use unnecessary safir helper +# functions, as all this functionality already exists without it. @pytest_asyncio.fixture(name="engine") async def engine_fixture() -> AsyncIterator[AsyncEngine]: """Return a SQLAlchemy AsyncEngine configured to talk to the app db.""" @@ -40,8 +43,10 @@ async def app_fixture() -> AsyncIterator[FastAPI]: Wraps the application in a lifespan manager so that startup and shutdown events are sent during test execution. """ - async with LifespanManager(main.app): - yield main.app + main_ = importlib.import_module("lsst.cmservice.main") + app: FastAPI = getattr(main_, "app") + async with LifespanManager(app): + yield app @pytest_asyncio.fixture(name="client") @@ -51,6 +56,8 @@ async def client_fixture(app: FastAPI) -> AsyncIterator[AsyncClient]: yield the_client +# FIXME this fixture should be replaced by patching the CLIRunner's httpx +# client (see client_fixture) @pytest_asyncio.fixture(name="uvicorn") async def uvicorn_fixture(tmp_path_factory: TempPathFactory) -> AsyncIterator[UvicornProcess]: """Spawn and return a uvicorn process hosting the test app.""" diff --git a/tests/models/test_serde.py b/tests/models/test_serde.py new file mode 100644 index 00000000..ab450eeb --- /dev/null +++ b/tests/models/test_serde.py @@ -0,0 +1,33 @@ +import pytest +from pydantic import BaseModel, ValidationError + +from lsst.cmservice.common.enums import ManifestKind, StatusEnum +from lsst.cmservice.common.types import KindField, StatusField + + +class TestModel(BaseModel): + status: StatusField + kind: KindField + + +def test_validators() -> None: + """Test model field enum validators.""" + # test enum validation by name and value + x = TestModel(status=0, kind="campaign") + assert x.status is StatusEnum.waiting + assert x.kind is ManifestKind.campaign + + # test bad input (wrong name) + with pytest.raises(ValidationError): + x = TestModel(status="bad", kind="edge") + + # test bad input (bad value) + with pytest.raises(ValidationError): + x = TestModel(status="waiting", kind=99) + + +def test_serializers() -> None: + x = TestModel(status="accepted", kind="node") + y = x.model_dump() + assert y["status"] == "accepted" + assert y["kind"] == "node" diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/v2/conftest.py b/tests/v2/conftest.py new file mode 100644 index 00000000..ccd19505 --- /dev/null +++ b/tests/v2/conftest.py @@ -0,0 +1,152 @@ +"""Shared conftest module for v2 unit and functional tests.""" + +import importlib +import os +from collections.abc import AsyncGenerator, Generator +from typing import TYPE_CHECKING +from uuid import NAMESPACE_DNS, uuid4 + +import pytest +import pytest_asyncio +from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient +from sqlalchemy import insert +from sqlalchemy.pool import NullPool +from sqlalchemy.schema import CreateSchema, DropSchema +from testcontainers.postgres import PostgresContainer + +from lsst.cmservice.common.types import AnyAsyncSession +from lsst.cmservice.config import config +from lsst.cmservice.db.campaigns_v2 import metadata +from lsst.cmservice.db.session import DatabaseSessionDependency, db_session_dependency + +if TYPE_CHECKING: + from fastapi import FastAPI + +POSTGRES_CONTAINER_IMAGE = "postgres:16" + + +@pytest.fixture(scope="module") +def monkeypatch_module() -> Generator[pytest.MonkeyPatch]: + with pytest.MonkeyPatch.context() as mp: + yield mp + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def rawdb(monkeypatch_module: pytest.MonkeyPatch) -> AsyncGenerator[DatabaseSessionDependency]: + """Test fixture for a postgres container. + + A scoped ephemeral container will be created for the test if the env var + `TEST__LOCAL_DB` is not set; otherwise the fixture will assume that the + correct database is available to the test environment through ordinary + configuration parameters. + + The tests are performed within a random temporary schema that is created + and dropped along with the tables. + """ + + monkeypatch_module.setattr(target=config.asgi, name="enable_frontend", value=False) + monkeypatch_module.setattr(target=config.db, name="table_schema", value=uuid4().hex[:8]) + + if os.getenv("TEST__LOCAL_DB") is not None: + db_session_dependency.pool_class = NullPool + await db_session_dependency.initialize() + assert db_session_dependency.engine is not None + yield db_session_dependency + await db_session_dependency.aclose() + else: + with PostgresContainer( + image=POSTGRES_CONTAINER_IMAGE, + username="cm-service", + password="INSECURE-PASSWORD", + dbname="cm-service", + driver="asyncpg", + ) as postgres: + psql_url = postgres.get_connection_url() + monkeypatch_module.setattr(target=config.db, name="url", value=psql_url) + monkeypatch_module.setattr(target=config.db, name="password", value=postgres.password) + monkeypatch_module.setattr(target=config.db, name="echo", value=True) + db_session_dependency.pool_class = NullPool + await db_session_dependency.initialize() + assert db_session_dependency.engine is not None + yield db_session_dependency + await db_session_dependency.aclose() + + +@pytest_asyncio.fixture(scope="module", loop_scope="module") +async def testdb(rawdb: DatabaseSessionDependency) -> AsyncGenerator[DatabaseSessionDependency]: + """Test fixture for a migrated postgres container. + + This fixture creates all the database objects defined for the ORM metadata + and drops them at the end of the fixture scope. + """ + # v1 objects are created from the legacy DeclarativeBase class and + # v2 objects are created from the SQLModel metadata. + assert rawdb.engine is not None + async with rawdb.engine.begin() as aconn: + await aconn.run_sync(metadata.drop_all) + await aconn.execute(CreateSchema(config.db.table_schema, if_not_exists=True)) + await aconn.run_sync(metadata.create_all) + await aconn.execute( + insert(metadata.tables[f"{metadata.schema}.campaigns_v2"]).values( + id="dda54a0c-6878-5c95-ac4f-007f6808049e", + namespace=str(NAMESPACE_DNS), + name="DEFAULT", + owner="root", + ) + ) + await aconn.commit() + yield rawdb + async with rawdb.engine.begin() as aconn: + await aconn.run_sync(metadata.drop_all) + await aconn.execute(DropSchema(config.db.table_schema, if_exists=True)) + await aconn.commit() + + +@pytest_asyncio.fixture(name="session", scope="module", loop_scope="module") +async def session_fixture(testdb: DatabaseSessionDependency) -> AsyncGenerator[AnyAsyncSession]: + """Test fixture for an async database session""" + assert testdb.engine is not None + assert testdb.sessionmaker is not None + async with testdb.sessionmaker() as session: + try: + yield session + finally: + await session.close() + await testdb.engine.dispose() + + +def client_fixture(session: AnyAsyncSession) -> Generator[TestClient]: + """Test fixture for a FastAPI test client with dependency injection + overriden. + """ + + def get_session_override() -> AnyAsyncSession: + return session + + main_ = importlib.import_module("lsst.cmservice.main") + app: FastAPI = getattr(main_, "app") + + app.dependency_overrides[db_session_dependency] = get_session_override + client = TestClient(app) + yield client + app.dependency_overrides.clear() + + +@pytest_asyncio.fixture(name="aclient", scope="module", loop_scope="module") +async def async_client_fixture(session: AnyAsyncSession) -> AsyncGenerator[AsyncClient]: + """Test fixture for an HTTPX async test client with dependency injection + overriden. + """ + main_ = importlib.import_module("lsst.cmservice.main") + app: FastAPI = getattr(main_, "app") + + def get_session_override() -> AnyAsyncSession: + return session + + app.dependency_overrides[db_session_dependency] = get_session_override + async with AsyncClient( + follow_redirects=True, transport=ASGITransport(app), base_url="http://test" + ) as aclient: + yield aclient + app.dependency_overrides.clear() diff --git a/tests/v2/test_campaign_routes.py b/tests/v2/test_campaign_routes.py new file mode 100644 index 00000000..75a51405 --- /dev/null +++ b/tests/v2/test_campaign_routes.py @@ -0,0 +1,183 @@ +"""Tests v2 fastapi campaign routes""" + +from uuid import NAMESPACE_DNS, uuid4, uuid5 + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio(loop_scope="module") +"""All tests in this module will run in the same event loop.""" + + +async def test_list_campaigns(aclient: AsyncClient) -> None: + """Tests listing the set of all campaigns.""" + # initially, only the default campaign should be available. + x = await aclient.get("/cm-service/v2/campaigns") + assert len(x.json()) == 1 + + +async def test_list_campaign(aclient: AsyncClient) -> None: + """Tests lookup of a single campaign by name and by ID""" + campaign_name = uuid4().hex[-8:] + + x = await aclient.get( + f"/cm-service/v2/campaigns/{campaign_name}", + ) + assert x.status_code == 404 + + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata_": {"name": campaign_name}, + "spec": {}, + }, + ) + assert x.is_success + campaign_id = x.json()["id"] + + x = await aclient.get( + f"/cm-service/v2/campaigns/{campaign_name}", + ) + assert x.is_success + + x = await aclient.get( + f"/cm-service/v2/campaigns/{campaign_id}", + ) + assert x.is_success + + +async def test_negative_campaign(aclient: AsyncClient) -> None: + """Tests campaign api negative results.""" + # Test failure to create campaign with invalid manifest (wrong kind) + campaign_name = uuid4().hex[-8:] + + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "node", + "metadata": { + "name": campaign_name, + }, + "spec": {}, + }, + ) + assert x.is_client_error + + # Test failure to create campaign with incomplete manifest (missing name) + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata": {}, + "spec": {}, + }, + ) + assert x.is_client_error + + +async def test_create_campaign(aclient: AsyncClient) -> None: + campaign_name = uuid4().hex[-8:] + + # Test successful campaign creation + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata": {"name": campaign_name}, + "spec": {}, + }, + ) + assert x.is_success + + campaign = x.json() + # - the campaign should exist in the namespace of the default campaign + assert campaign["namespace"] == str(uuid5(NAMESPACE_DNS, "io.lsst.cmservice")) + # - the returned campaign's ID should be the name of the campaign within + # the namespace of the default campaign + assert campaign["id"] == str(uuid5(uuid5(NAMESPACE_DNS, "io.lsst.cmservice"), campaign_name)) + assert campaign["status"] == "waiting" + + # - the response headers should have pointers to other API endpoints for + # the campaign + assert set(["self", "nodes", "edges"]) <= x.headers.keys() + + # - the provided links should be valid + y = await aclient.get(x.headers["self"]) + assert y.is_success + + y = await aclient.get(x.headers["nodes"]) + assert y.is_success + # A new empty campaign should have a START and END node + nodes = y.json() + assert len(nodes) == 2 + y = await aclient.get(x.headers["edges"]) + assert y.is_success + # A new empty campaign should not have any edges + edges = y.json() + assert len(edges) == 0 + + +async def test_patch_campaign(aclient: AsyncClient) -> None: + # Create a new campaign with spec data + campaign_name = uuid4().hex[-8:] + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "apiVersion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata_": {"name": campaign_name}, + "spec": { + "collections": ["a", "b", "c"], + "enable_notifications": True, + }, + }, + ) + + assert x.is_success + # The response header 'self' has the canonical url for the new campaign + campaign_url = x.headers["self"] + + # Try an unsupported content-type for patch + y = await aclient.patch( + campaign_url, + json={"status": "ready", "owner": "bob_loblaw"}, + headers={"Content-Type": "application/not-supported+json"}, + ) + assert y.status_code == 406 + + # Update the campaign using RFC6902 + y = await aclient.patch( + campaign_url, + json={"status": "ready", "owner": "bob_loblaw"}, + headers={"Content-Type": "application/json-patch+json"}, + ) + # RFC6902 not implemented + assert y.status_code == 501 + + # Update the campaign using RFC7396 and campaign id + y = await aclient.patch( + campaign_url, + json={"status": "ready", "owner": "bob_loblaw"}, + headers={"Content-Type": "application/merge-patch+json"}, + ) + assert y.is_success + updated_campaign = y.json() + assert updated_campaign["owner"] == "bob_loblaw" + assert updated_campaign["status"] == "ready" + + # Update the campaign again using RFC7396, ensuring only a single field + # is patched, using campaign name + y = await aclient.patch( + f"/cm-service/v2/campaigns/{campaign_name}", + json={"owner": "alice_bob"}, + headers={"Content-Type": "application/merge-patch+json"}, + ) + assert y.is_success + updated_campaign = y.json() + assert updated_campaign["owner"] == "alice_bob" + assert updated_campaign["status"] == "ready" diff --git a/tests/v2/test_db.py b/tests/v2/test_db.py new file mode 100644 index 00000000..7995f52f --- /dev/null +++ b/tests/v2/test_db.py @@ -0,0 +1,77 @@ +"""Tests v2 database operations""" + +from uuid import uuid4, uuid5 + +import pytest +from sqlmodel import select + +from lsst.cmservice.db.campaigns_v2 import Campaign, Machine, _default_campaign_namespace +from lsst.cmservice.db.session import DatabaseSessionDependency + + +@pytest.mark.asyncio +async def test_create_campaigns_v2(testdb: DatabaseSessionDependency) -> None: + """Tests the campaigns_v2 table by creating and updating a Campaign.""" + + assert testdb.sessionmaker is not None + + campaign_name = "test_campaign" + campaign = Campaign( + id=uuid5(_default_campaign_namespace, campaign_name), + name=campaign_name, + namespace=_default_campaign_namespace, + owner="test", + metadata_={"mtime": 0, "crtime": 0}, + configuration={"mtime": 0, "crtime": 0}, + ) + async with testdb.sessionmaker() as session: + session.add(campaign) + await session.commit() + + del campaign + + async with testdb.sessionmaker() as session: + statement = select(Campaign).where(Campaign.name == "test_campaign") + results = await session.exec(statement) + campaign = results.one() + campaign.name = "a_new_name" + campaign.configuration["mtime"] = 1750107719 + campaign.metadata_["crtime"] = 1750107719 + await session.commit() + + del campaign + + async with testdb.sessionmaker() as session: + statement = select(Campaign).where(Campaign.name == "a_new_name") + results = await session.exec(statement) + campaign = results.one() + assert campaign.name == "a_new_name" + assert "mtime" in campaign.configuration + assert "crtime" in campaign.metadata_ + assert campaign.configuration["mtime"] == 1750107719 + assert campaign.configuration["crtime"] == 0 + assert campaign.metadata_["crtime"] == 1750107719 + assert campaign.metadata_["mtime"] == 0 + + +@pytest.mark.asyncio +async def test_create_machines_v2(testdb: DatabaseSessionDependency) -> None: + """Tests the machines_v2 table by storing + retrieving a pickled object.""" + + assert testdb.sessionmaker is not None + + # the machines table is a PickleType so it doesn't really matter for this + # test what kind of object is being pickled. + o = {"a": [1, 2, 3, 4, {"aa": [[0, 1], [2, 3]]}]} + + machine_id = uuid4() + machine = Machine(id=machine_id, state=o) + async with testdb.sessionmaker() as session: + session.add(machine) + await session.commit() + + async with testdb.sessionmaker() as session: + s = select(Machine).where(Machine.id == machine_id).limit(1) + unpickled = (await session.exec(s)).one() + + assert unpickled.state == o diff --git a/tests/v2/test_manifest_routes.py b/tests/v2/test_manifest_routes.py new file mode 100644 index 00000000..0c4144e1 --- /dev/null +++ b/tests/v2/test_manifest_routes.py @@ -0,0 +1,264 @@ +"""Tests v2 fastapi manifest routes""" + +from uuid import uuid4 + +import pytest +from httpx import AsyncClient + +pytestmark = pytest.mark.asyncio(loop_scope="module") +"""All tests in this module will run in the same event loop.""" + + +async def test_list_manifests(aclient: AsyncClient) -> None: + x = await aclient.get("/cm-service/v2/manifests") + assert x.is_success + assert len(x.json()) == 0 + assert True + + +async def test_load_manifests(aclient: AsyncClient) -> None: + campaign_name = uuid4().hex[-8:] + + # Try to create a campaign manifest + x = await aclient.post( + "/cm-service/v2/manifests", + json={ + "apiversion": "io.lsst.cmservice/v1", + "kind": "campaign", + "metadata_": {}, + "spec": {}, + }, + ) + assert x.is_client_error + + # Try to create a manifest without a name + x = await aclient.post( + "/cm-service/v2/manifests", + json={ + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata_": {}, + "spec": {}, + }, + ) + assert x.is_client_error + + # Try to create a manifest with an unknown namespace + x = await aclient.post( + "/cm-service/v2/manifests", + json={ + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata_": { + "name": uuid4().hex[-8:], + "namespace": campaign_name, + }, + "spec": {}, + }, + ) + assert x.is_client_error + + # Create a manifest in the default namespace + x = await aclient.post( + "/cm-service/v2/manifests", + json={ + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata_": { + "name": uuid4().hex[-8:], + }, + "spec": { + "one": 1, + "two": 2, + "three": 4, + }, + }, + ) + assert x.is_success + + # Create a campaign for additional manifests + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "kind": "campaign", + "metadata": {"name": campaign_name}, + "spec": {}, + }, + ) + assert x.is_success + campaign_id = x.json()["id"] + + # Create multiple manifests in the campaign namespace by both its name and + # its id. The first uses the most "proper" manifest field names "metadata" + # and "spec"; the second uses alternate "metadata_" and "data" aliases. + x = await aclient.post( + "/cm-service/v2/manifests", + json=[ + { + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata": { + "name": uuid4().hex[-8:], + "namespace": campaign_id, + }, + "spec": { + "one": 1, + "two": 2, + "three": 4, + }, + }, + { + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata_": { + "name": uuid4().hex[-8:], + "namespace": campaign_name, + }, + "data": { + "one": 1, + "two": 2, + "three": 4, + }, + }, + ], + ) + assert x.is_success + + # Get all the loaded manifests + x = await aclient.get("/cm-service/v2/manifests") + assert x.is_success + manifests = x.json() + assert len(manifests) == 3 + assert manifests[-1]["spec"]["one"] == 1 + + +async def test_patch_manifest(aclient: AsyncClient) -> None: + """Tests partial update of manifests and single resource retrieval.""" + campaign_name = uuid4().hex[-8:] + manifest_name = uuid4().hex[-8:] + + # Create a campaign and a manifest + x = await aclient.post( + "/cm-service/v2/campaigns", + json={ + "kind": "campaign", + "metadata": {"name": campaign_name}, + "spec": {}, + }, + ) + assert x.is_success + + x = await aclient.post( + "/cm-service/v2/manifests", + json={ + "apiversion": "io.lsst.cmservice/v1", + "kind": "other", + "metadata": { + "name": manifest_name, + "namespace": campaign_name, + }, + "spec": { + "one": 1, + "two": 2, + "three": 4, + "a_list": ["a", "b", "c", "e"], + }, + }, + ) + + # "Correct" the spec of the loaded manifest + x = await aclient.patch( + f"/cm-service/v2/manifests/{manifest_name}", + headers={"Content-Type": "application/json-patch+json"}, + json=[ + { + "op": "replace", + "path": "/spec/three", + "value": 3, + }, + { + "op": "replace", + "path": "/spec/a_list/3", + "value": "d", + }, + { + "op": "add", + "path": "/spec/four", + "value": 4, + }, + { + "op": "add", + "path": "/spec/owner", + "value": "bob_loblaw", + }, + ], + ) + assert x.is_success + patched_manifest = x.json() + assert patched_manifest["version"] == 2 + assert patched_manifest["spec"]["three"] == 3 + assert patched_manifest["spec"]["a_list"][3] == "d" + assert patched_manifest["spec"]["four"] == 4 + assert patched_manifest["spec"]["owner"] == "bob_loblaw" + + # In the previous test we added the "owner" to the wrong path, so now we + # want to it from spec->metadata + x = await aclient.patch( + f"/cm-service/v2/manifests/{manifest_name}", + headers={"Content-Type": "application/json-patch+json"}, + json=[ + { + "op": "move", + "path": "/metadata/owner", + "from": "/spec/owner", + } + ], + ) + assert x.is_success + patched_manifest = x.json() + assert patched_manifest["version"] == 3 + assert patched_manifest["metadata"]["owner"] == "bob_loblaw" + assert "owner" not in patched_manifest["spec"] + + # Using the "test" operator as a gating function, try but fail to update + # the previously moved owner field + x = await aclient.patch( + f"/cm-service/v2/manifests/{manifest_name}", + headers={"Content-Type": "application/json-patch+json"}, + json=[ + { + "op": "test", + "path": "/spec/owner", + "value": "bob_loblaw", + }, + { + "op": "replace", + "path": "/spec/owner", + "value": "lob_boblaw", + }, + { + "op": "add", + "path": "/metadata/scope", + "value": "drp", + }, + ], + ) + assert x.is_client_error + + # Get the manifest with multiple versions + # First, make sure when not indicated, the most recent version is returned + # Note: the previous patch with a failing test op must not have created any + # new version. + x = await aclient.get(f"/cm-service/v2/manifests/{manifest_name}") + assert x.is_success + assert x.json()["version"] == 3 + + # RFC6902 prescribes an all-or-nothing patch operation, so the previous op + # with a failing test assertion must not have otherwise completed, e.g., + # the addition of a "scope" key to the manifest's metadata + assert "scope" not in x.json().get("metadata") + + # Next, get a specific version of the manifest + x = await aclient.get(f"/cm-service/v2/manifests/{manifest_name}?version=2") + assert x.is_success + assert x.json()["version"] == 2