From 9757b53c7ce50a6fd90caab7b21cbe783c7614f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Mon, 18 Aug 2025 17:42:17 +0200 Subject: [PATCH 1/3] WIP graphsage support --- graphdatascience/model/v2/graphsage_model.py | 54 ++++ graphdatascience/model/v2/model.py | 173 ++++++++++ graphdatascience/model/v2/model_info.py | 21 ++ .../api/graphsage_endpoints.py | 301 ++++++++++++++++++ .../cypher/graphsage_cypher_endpoints.py | 185 +++++++++++ .../cypher/test_graphsage_cypher_endpoints.py | 126 ++++++++ 6 files changed, 860 insertions(+) create mode 100644 graphdatascience/model/v2/graphsage_model.py create mode 100644 graphdatascience/model/v2/model.py create mode 100644 graphdatascience/model/v2/model_info.py create mode 100644 graphdatascience/procedure_surface/api/graphsage_endpoints.py create mode 100644 graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py diff --git a/graphdatascience/model/v2/graphsage_model.py b/graphdatascience/model/v2/graphsage_model.py new file mode 100644 index 000000000..776468048 --- /dev/null +++ b/graphdatascience/model/v2/graphsage_model.py @@ -0,0 +1,54 @@ +from typing import Any + +from pandas import Series + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...graph.graph_type_check import graph_type_check +from ..model import Model + + +class GraphSageModelV2(Model): + """ + Represents a GraphSAGE model in the model catalog. + Construct this using :func:`gds.graphSage.train()`. + """ + + def _endpoint_prefix(self) -> str: + return "gds.beta.graphSage." + + @graph_type_check + def predict_write(self, G: Graph, **config: Any) -> "Series[Any]": + """ + Generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + **config: The config for the prediction. + + Returns: + The result of the write operation. + + """ + endpoint = self._endpoint_prefix() + "write" + config["modelName"] = self.name() + params = CallParameters(graph_name=G.name(), config=config) + + return self._query_runner.call_procedure( # type: ignore + endpoint=endpoint, params=params, logging=True + ).squeeze() + + @graph_type_check + def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]": + """ + Estimate the memory needed to generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + **config: The config for the prediction. + + Returns: + The memory needed to generate embeddings for the given graph and write the results to the database. + + """ + return self._estimate_predict("write", G.name(), config) diff --git a/graphdatascience/model/v2/model.py b/graphdatascience/model/v2/model.py new file mode 100644 index 000000000..b0e952ddb --- /dev/null +++ b/graphdatascience/model/v2/model.py @@ -0,0 +1,173 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any + +from pandas import DataFrame, Series + +from graphdatascience.model.v2.model_info import ModelInfo + +from ..call_parameters import CallParameters +from ..graph.graph_object import Graph +from ..graph.graph_type_check import graph_type_check +from ..query_runner.query_runner import QueryRunner +from ..server_version.compatible_with import compatible_with +from ..server_version.server_version import ServerVersion + + +class InfoProvider(ABC): + @abstractmethod + def fetch(self, model_name: str) -> ModelInfo: + """Return the task with progress for the given job_id.""" + pass + + +class Model(ABC): + def __init__(self, name: str, info_provider: InfoProvider): + self._name = name + self._info_provider = info_provider + + # TODO estimate mode, predict modes on here? + # implement Cypher and Arrow info_provider and stuff + + def name(self) -> str: + """ + Get the name of the model. + + Returns: + The name of the model. + + """ + return self._name + + def type(self) -> str: + """ + Get the type of the model. + + Returns: + The type of the model. + + """ + return self._info_provider.fetch(self._name).type + + def train_config(self) -> Series[Any]: + """ + Get the train config of the model. + + Returns: + The train config of the model. + + """ + return self._info_provider.fetch(self._name).train_config + + def graph_schema(self) -> Series[Any]: + """ + Get the graph schema of the model. + + Returns: + The graph schema of the model. + + """ + return self._info_provider.fetch(self._name).graph_schema + + def loaded(self) -> bool: + """ + Check whether the model is loaded in memory. + + Returns: + True if the model is loaded in memory, False otherwise. + + """ + return self._info_provider.fetch(self._name).loaded + + def stored(self) -> bool: + """ + Check whether the model is stored on disk. + + Returns: + True if the model is stored on disk, False otherwise. + + """ + return self._info_provider.fetch(self._name).stored + + def creation_time(self) -> datetime.datetime: + """ + Get the creation time of the model. + + Returns: + The creation time of the model. + + """ + return self._info_provider.fetch(self._name).creation_time + + def shared(self) -> bool: + """ + Check whether the model is shared. + + Returns: + True if the model is shared, False otherwise. + + """ + return self._info_provider.fetch(self._name).shared + + def published(self) -> bool: + """ + Check whether the model is published. + + Returns: + True if the model is published, False otherwise. + + """ + return self._info_provider.fetch(self._name).published + + def model_info(self) -> dict[str, Any]: + """ + Get the model info of the model. + + Returns: + The model info of the model. + + """ + return self._info_provider.fetch(self._name).model_info + + def exists(self) -> bool: + """ + Check whether the model exists. + + Returns: + True if the model exists, False otherwise. + + """ + raise NotImplementedError() + + def drop(self, failIfMissing: bool = False) -> Series[Any]: + """ + Drop the model. + + Args: + failIfMissing: If True, an error is thrown if the model does not exist. If False, no error is thrown. + + Returns: + The result of the drop operation. + + """ + raise NotImplementedError() + + def metrics(self) -> Series[Any]: + """ + Get the metrics of the model. + + Returns: + The metrics of the model. + + """ + model_info = self._info_provider.fetch(self._name).model_info + metrics: Series[Any] = Series(model_info["metrics"]) + return metrics + + def __str__(self) -> str: + return f"{self.__class__.__name__}(name={self.name()}, type={self.type()})" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._info_provider.fetch(self._name).to_dict()})" diff --git a/graphdatascience/model/v2/model_info.py b/graphdatascience/model/v2/model_info.py new file mode 100644 index 000000000..e3345bc58 --- /dev/null +++ b/graphdatascience/model/v2/model_info.py @@ -0,0 +1,21 @@ +import datetime +from typing import Any +from pydantic import BaseModel +from abc import ABC, abstractmethod +from pydantic.alias_generators import to_camel + + +class ModelInfo(BaseModel, alias_generator=to_camel): + name: str + type: str + train_config: dict[str, Any] + graph_schema: dict[str, Any] + loaded: bool + stored: bool + shared: bool + published: bool + model_info: dict[str, Any] # TODO better typing in actual model? + creation_time: datetime.datetime # TODO correct type? / conversion needed + + def __getitem__(self, item: str) -> Any: + return getattr(self, item) diff --git a/graphdatascience/procedure_surface/api/graphsage_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_endpoints.py new file mode 100644 index 000000000..8d7852d68 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graphsage_endpoints.py @@ -0,0 +1,301 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, List, Optional + +from pandas import DataFrame +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel + +from graphdatascience.model.v2.graphsage_model import GraphSageModel + +from ...graph.graph_object import Graph + + +class GraphSageEndpoints(ABC): + """ + Abstract base class defining the API for the GraphSage algorithm. + """ + + @abstractmethod + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> GraphSageModel: + """ + Trains a GraphSage model on the given graph. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + model_name : str + Name under which the model will be stored + feature_properties : List[str] + The names of the node properties to use as input features + activation_function : Optional[Any], default=None + The activation function to apply after each layer + negative_sample_weight : Optional[int], default=None + Weight of negative samples in the loss function + embedding_dimension : Optional[int], default=None + The dimension of the generated embeddings + tolerance : Optional[float], default=None + Tolerance for early stopping based on loss improvement + learning_rate : Optional[float], default=None + Learning rate for the training optimization + max_iterations : Optional[int], default=None + Maximum number of training iterations + sample_sizes : Optional[List[int]], default=None + Number of neighbors to sample at each layer + aggregator : Optional[Any], default=None + The aggregator function for neighborhood aggregation + penalty_l2 : Optional[float], default=None + L2 regularization penalty + search_depth : Optional[int], default=None + Maximum search depth for neighbor sampling + epochs : Optional[int], default=None + Number of training epochs + projected_feature_dimension : Optional[int], default=None + Dimension to project input features to before training + batch_sampling_ratio : Optional[float], default=None + Ratio of nodes to sample for each training batch + store_model_to_disk : Optional[bool], default=None + Whether to persist the model to disk + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run + username : Optional[str] = None + The username to attribute the procedure run to + log_progress : Optional[bool], default=None + Whether to log progress + sudo : Optional[bool], default=None + Override memory estimation limits + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + batch_size : Optional[int], default=None + Batch size for training + relationship_weight_property : Optional[str], default=None + The property name that contains weight + random_seed : Optional[Any], default=None + Random seed for reproducible results + + Returns + ------- + GraphSageTrainResult + Training metrics and model information + """ + + @abstractmethod + def mutate( + self, + G: Graph, + model_name: str, + mutate_property: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + ) -> GraphSageMutateResult: + """ + Executes the GraphSage algorithm using a trained model and writes the results back to the graph as a node property. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + model_name : str + Name of the trained GraphSage model to use + mutate_property : str + The name of the node property to store the embeddings + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run + username : Optional[str] = None + The username to attribute the procedure run to + log_progress : Optional[bool], default=None + Whether to log progress + sudo : Optional[bool], default=None + Override memory estimation limits + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + batch_size : Optional[int], default=None + Batch size for inference + + Returns + ------- + GraphSageMutateResult + Algorithm metrics and statistics + """ + + @abstractmethod + def stream( + self, + G: Graph, + model_name: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + ) -> DataFrame: + """ + Executes the GraphSage algorithm using a trained model and returns the results as a stream. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + model_name : str + Name of the trained GraphSage model to use + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run + username : Optional[str] = None + The username to attribute the procedure run to + log_progress : Optional[bool], default=None + Whether to log progress + sudo : Optional[bool], default=None + Override memory estimation limits + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + batch_size : Optional[int], default=None + Batch size for inference + + Returns + ------- + DataFrame + Embeddings as a stream with columns nodeId and embedding + """ + + @abstractmethod + def write( + self, + G: Graph, + model_name: str, + write_property: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + write_concurrency: Optional[Any] = None, + ) -> GraphSageWriteResult: + """ + Executes the GraphSage algorithm using a trained model and writes the results back to the database. + + Parameters + ---------- + G : Graph + The graph to run the algorithm on + model_name : str + Name of the trained GraphSage model to use + write_property : str + The name of the node property to write the embeddings to + relationship_types : Optional[List[str]], default=None + The relationship types used to select relationships for this algorithm run + node_labels : Optional[List[str]], default=None + The node labels used to select nodes for this algorithm run + username : Optional[str] = None + The username to attribute the procedure run to + log_progress : Optional[bool], default=None + Whether to log progress + sudo : Optional[bool], default=None + Override memory estimation limits + concurrency : Optional[Any], default=None + The number of concurrent threads + job_id : Optional[Any], default=None + An identifier for the job + batch_size : Optional[int], default=None + Batch size for inference + write_concurrency : Optional[Any], default=None + The number of concurrent threads used for writing result + + Returns + ------- + GraphSageWriteResult + Algorithm metrics and statistics + """ + + +class GraphSageTrainResult(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + model_info: dict[str, Any] + configuration: dict[str, Any] + train_millis: int + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] + + +class GraphSageMutateResult(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + mutate_millis: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] + + +class GraphSageWriteResult(BaseModel): + model_config = ConfigDict(alias_generator=to_camel) + + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + write_millis: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] diff --git a/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py new file mode 100644 index 000000000..5ba3193db --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py @@ -0,0 +1,185 @@ +from typing import Any, List, Optional + +from pandas import DataFrame + +from ...call_parameters import CallParameters +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner +from ..api.graphsage_endpoints import ( + GraphSageEndpoints, + GraphSageMutateResult, + GraphSageTrainResult, + GraphSageWriteResult, +) +from ..utils.config_converter import ConfigConverter + + +class GraphSageCypherEndpoints(GraphSageEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> GraphSageTrainResult: + config = ConfigConverter.convert_to_gds_config( + model_name=model_name, + feature_properties=feature_properties, + activation_function=activation_function, + negative_sample_weight=negative_sample_weight, + embedding_dimension=embedding_dimension, + tolerance=tolerance, + learning_rate=learning_rate, + max_iterations=max_iterations, + sample_sizes=sample_sizes, + aggregator=aggregator, + penalty_l2=penalty_l2, + search_depth=search_depth, + epochs=epochs, + projected_feature_dimension=projected_feature_dimension, + batch_sampling_ratio=batch_sampling_ratio, + store_model_to_disk=store_model_to_disk, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + relationship_weight_property=relationship_weight_property, + random_seed=random_seed, + ) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.train", params=params).squeeze() + + return GraphSageTrainResult(**result.to_dict()) + + def mutate( + self, + G: Graph, + model_name: str, + mutate_property: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + ) -> GraphSageMutateResult: + config = ConfigConverter.convert_to_gds_config( + model_name=model_name, + mutate_property=mutate_property, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + ) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.mutate", params=params).squeeze() + + return GraphSageMutateResult(**result.to_dict()) + + def stream( + self, + G: Graph, + model_name: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + ) -> DataFrame: + config = ConfigConverter.convert_to_gds_config( + model_name=model_name, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + ) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream", params=params) + + def write( + self, + G: Graph, + model_name: str, + write_property: str, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + write_concurrency: Optional[Any] = None, + ) -> GraphSageWriteResult: + config = ConfigConverter.convert_to_gds_config( + model_name=model_name, + write_property=write_property, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + write_concurrency=write_concurrency, + ) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.write", params=params).squeeze() + + return GraphSageWriteResult(**result.to_dict()) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py new file mode 100644 index 000000000..22c175d08 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py @@ -0,0 +1,126 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.graphsage_cypher_endpoints import GraphSageCypherEndpoints + + +@pytest.fixture +def sample_graph_with_features(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {nodeProperties: 'feature'}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def graphsage_endpoints(query_runner: QueryRunner) -> Generator[GraphSageCypherEndpoints, None, None]: + yield GraphSageCypherEndpoints(query_runner) + + +def test_graphsage_train(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: + """Test GraphSage train operation.""" + result = graphsage_endpoints.train( + G=sample_graph_with_features, + model_name="testModel", + feature_properties=["feature"], + embedding_dimension=64, + ) + + assert result.train_millis >= 0 + assert result.model_info is not None + assert result.configuration is not None + assert "testModel" in str(result.model_info) + + +def test_graphsage_mutate(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: + """Test GraphSage mutate operation.""" + # First train a model + graphsage_endpoints.train( + G=sample_graph_with_features, + model_name="testMutateModel", + feature_properties=["feature"], + embedding_dimension=64, + ) + + # Then use it for mutate + result = graphsage_endpoints.mutate( + G=sample_graph_with_features, + model_name="testMutateModel", + mutate_property="graphsage_embedding", + ) + + assert result.node_count == 3 + assert result.node_properties_written == 3 + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.mutate_millis >= 0 + assert result.configuration is not None + + +def test_graphsage_stream(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: + """Test GraphSage stream operation.""" + # First train a model + graphsage_endpoints.train( + G=sample_graph_with_features, + model_name="testStreamModel", + feature_properties=["feature"], + embedding_dimension=64, + ) + + # Then use it for stream + result = graphsage_endpoints.stream( + G=sample_graph_with_features, + model_name="testStreamModel", + ) + + assert len(result) == 3 # We have 3 nodes + + # Check that we have the expected result structure + # For Cypher endpoints, this returns a DataFrame with string columns + assert "nodeId" in result.columns + assert "embedding" in result.columns + + +def test_graphsage_write(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: + """Test GraphSage write operation.""" + # First train a model + graphsage_endpoints.train( + G=sample_graph_with_features, + model_name="testWriteModel", + feature_properties=["feature"], + embedding_dimension=64, + ) + + # Then use it for write + result = graphsage_endpoints.write( + G=sample_graph_with_features, + model_name="testWriteModel", + write_property="graphsage_embedding", + ) + + assert result.node_count == 3 + assert result.node_properties_written == 3 + assert result.pre_processing_millis >= 0 + assert result.compute_millis >= 0 + assert result.write_millis >= 0 + assert result.configuration is not None From 5f0fdf17401d49f649fe722f5380334549c647c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 19 Aug 2025 16:24:30 +0200 Subject: [PATCH 2/3] Support GraphSage train + Model catalog ops in v2 endpoints --- graphdatascience/model/v2/graphsage_model.py | 39 +++- graphdatascience/model/v2/model.py | 142 ++----------- graphdatascience/model/v2/model_api.py | 51 +++++ graphdatascience/model/v2/model_info.py | 13 +- .../api/graphsage_endpoints.py | 195 +----------------- .../arrow/graphsage_arrow_endpoints.py | 86 ++++++++ .../arrow/model_api_arrow.py | 57 +++++ .../cypher/graphsage_cypher_endpoints.py | 111 +--------- .../cypher/model_api_cypher.py | 49 +++++ .../arrow/test_graphsage_arrow_endpoints.py | 56 +++++ .../arrow/test_model_api_arrow.py | 85 ++++++++ .../cypher/test_graphsage_cypher_endpoints.py | 84 +------- .../cypher/test_model_api_cypher.py | 83 ++++++++ 13 files changed, 533 insertions(+), 518 deletions(-) create mode 100644 graphdatascience/model/v2/model_api.py create mode 100644 graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py create mode 100644 graphdatascience/procedure_surface/arrow/model_api_arrow.py create mode 100644 graphdatascience/procedure_surface/cypher/model_api_cypher.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py diff --git a/graphdatascience/model/v2/graphsage_model.py b/graphdatascience/model/v2/graphsage_model.py index 776468048..022e400fa 100644 --- a/graphdatascience/model/v2/graphsage_model.py +++ b/graphdatascience/model/v2/graphsage_model.py @@ -1,11 +1,12 @@ from typing import Any from pandas import Series +from pydantic import BaseModel +from pydantic.alias_generators import to_camel -from ...call_parameters import CallParameters from ...graph.graph_object import Graph from ...graph.graph_type_check import graph_type_check -from ..model import Model +from .model import Model class GraphSageModelV2(Model): @@ -30,13 +31,7 @@ def predict_write(self, G: Graph, **config: Any) -> "Series[Any]": The result of the write operation. """ - endpoint = self._endpoint_prefix() + "write" - config["modelName"] = self.name() - params = CallParameters(graph_name=G.name(), config=config) - - return self._query_runner.call_procedure( # type: ignore - endpoint=endpoint, params=params, logging=True - ).squeeze() + raise ValueError @graph_type_check def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]": @@ -51,4 +46,28 @@ def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]": The memory needed to generate embeddings for the given graph and write the results to the database. """ - return self._estimate_predict("write", G.name(), config) + raise ValueError + + +class GraphSageMutateResult(BaseModel, alias_generator=to_camel): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + mutate_millis: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] + + +class GraphSageWriteResult(BaseModel, alias_generator=to_camel): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + write_millis: int + configuration: dict[str, Any] + + def __getitem__(self, item: str) -> Any: + return self.__dict__[item] diff --git a/graphdatascience/model/v2/model.py b/graphdatascience/model/v2/model.py index b0e952ddb..18ff2500c 100644 --- a/graphdatascience/model/v2/model.py +++ b/graphdatascience/model/v2/model.py @@ -1,32 +1,17 @@ from __future__ import annotations -from abc import ABC, abstractmethod -from datetime import datetime -from typing import Any +from abc import ABC +from typing import Optional -from pandas import DataFrame, Series - -from graphdatascience.model.v2.model_info import ModelInfo - -from ..call_parameters import CallParameters -from ..graph.graph_object import Graph -from ..graph.graph_type_check import graph_type_check -from ..query_runner.query_runner import QueryRunner -from ..server_version.compatible_with import compatible_with -from ..server_version.server_version import ServerVersion - - -class InfoProvider(ABC): - @abstractmethod - def fetch(self, model_name: str) -> ModelInfo: - """Return the task with progress for the given job_id.""" - pass +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_info import ModelDetails +# Compared to v1 Model offering typed parameters for predict endpoints class Model(ABC): - def __init__(self, name: str, info_provider: InfoProvider): + def __init__(self, name: str, model_api: ModelApi): self._name = name - self._info_provider = info_provider + self._model_api = model_api # TODO estimate mode, predict modes on here? # implement Cypher and Arrow info_provider and stuff @@ -41,95 +26,8 @@ def name(self) -> str: """ return self._name - def type(self) -> str: - """ - Get the type of the model. - - Returns: - The type of the model. - - """ - return self._info_provider.fetch(self._name).type - - def train_config(self) -> Series[Any]: - """ - Get the train config of the model. - - Returns: - The train config of the model. - - """ - return self._info_provider.fetch(self._name).train_config - - def graph_schema(self) -> Series[Any]: - """ - Get the graph schema of the model. - - Returns: - The graph schema of the model. - - """ - return self._info_provider.fetch(self._name).graph_schema - - def loaded(self) -> bool: - """ - Check whether the model is loaded in memory. - - Returns: - True if the model is loaded in memory, False otherwise. - - """ - return self._info_provider.fetch(self._name).loaded - - def stored(self) -> bool: - """ - Check whether the model is stored on disk. - - Returns: - True if the model is stored on disk, False otherwise. - - """ - return self._info_provider.fetch(self._name).stored - - def creation_time(self) -> datetime.datetime: - """ - Get the creation time of the model. - - Returns: - The creation time of the model. - - """ - return self._info_provider.fetch(self._name).creation_time - - def shared(self) -> bool: - """ - Check whether the model is shared. - - Returns: - True if the model is shared, False otherwise. - - """ - return self._info_provider.fetch(self._name).shared - - def published(self) -> bool: - """ - Check whether the model is published. - - Returns: - True if the model is published, False otherwise. - - """ - return self._info_provider.fetch(self._name).published - - def model_info(self) -> dict[str, Any]: - """ - Get the model info of the model. - - Returns: - The model info of the model. - - """ - return self._info_provider.fetch(self._name).model_info + def details(self) -> ModelDetails: + return self._model_api.get(self._name) def exists(self) -> bool: """ @@ -139,9 +37,9 @@ def exists(self) -> bool: True if the model exists, False otherwise. """ - raise NotImplementedError() + return self._model_api.exists(self._name) - def drop(self, failIfMissing: bool = False) -> Series[Any]: + def drop(self, failIfMissing: bool = False) -> Optional[ModelDetails]: """ Drop the model. @@ -152,22 +50,10 @@ def drop(self, failIfMissing: bool = False) -> Series[Any]: The result of the drop operation. """ - raise NotImplementedError() - - def metrics(self) -> Series[Any]: - """ - Get the metrics of the model. - - Returns: - The metrics of the model. - - """ - model_info = self._info_provider.fetch(self._name).model_info - metrics: Series[Any] = Series(model_info["metrics"]) - return metrics + return self._model_api.drop(self._name, failIfMissing) def __str__(self) -> str: - return f"{self.__class__.__name__}(name={self.name()}, type={self.type()})" + return f"{self.__class__.__name__}(name={self.name()}, type={self.details().type})" def __repr__(self) -> str: - return f"{self.__class__.__name__}({self._info_provider.fetch(self._name).to_dict()})" + return f"{self.__class__.__name__}({self.details().model_dump()})" diff --git a/graphdatascience/model/v2/model_api.py b/graphdatascience/model/v2/model_api.py new file mode 100644 index 000000000..a6c867814 --- /dev/null +++ b/graphdatascience/model/v2/model_api.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from graphdatascience.model.v2.model_info import ModelDetails + + +class ModelApi(ABC): + """ + Abstract base class defining the API for model operations. + This class is intended to be subclassed by specific model implementations. + """ + + @abstractmethod + def exists(self, model: str) -> bool: + """ + Check if a specific model exists. + + Args: + model: The name of the model. + + Returns: + True if the model exists, False otherwise. + """ + pass + + @abstractmethod + def get(self, model: str) -> ModelDetails: + """ + Get the details of a specific model. + + Args: + model: The name of the model. + + Returns: + The details of the model. + """ + pass + + @abstractmethod + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + """ + Drop a specific model. + + Args: + model: The name of the model. + fail_if_missing: If True, an error is thrown if the model does not exist. If False, no error is thrown. + + Returns: + The result of the drop operation. + """ + pass diff --git a/graphdatascience/model/v2/model_info.py b/graphdatascience/model/v2/model_info.py index e3345bc58..356fb13c3 100644 --- a/graphdatascience/model/v2/model_info.py +++ b/graphdatascience/model/v2/model_info.py @@ -1,21 +1,20 @@ import datetime from typing import Any -from pydantic import BaseModel -from abc import ABC, abstractmethod + +from pydantic import BaseModel, Field from pydantic.alias_generators import to_camel -class ModelInfo(BaseModel, alias_generator=to_camel): - name: str - type: str +class ModelDetails(BaseModel, alias_generator=to_camel): + name: str = Field(alias="modelName") + type: str = Field(alias="modelType") train_config: dict[str, Any] graph_schema: dict[str, Any] loaded: bool stored: bool - shared: bool published: bool model_info: dict[str, Any] # TODO better typing in actual model? - creation_time: datetime.datetime # TODO correct type? / conversion needed + creation_time: datetime.datetime def __getitem__(self, item: str) -> Any: return getattr(self, item) diff --git a/graphdatascience/procedure_surface/api/graphsage_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_endpoints.py index 8d7852d68..c2c6adcb2 100644 --- a/graphdatascience/procedure_surface/api/graphsage_endpoints.py +++ b/graphdatascience/procedure_surface/api/graphsage_endpoints.py @@ -3,11 +3,8 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional -from pandas import DataFrame -from pydantic import BaseModel, ConfigDict -from pydantic.alias_generators import to_camel - -from graphdatascience.model.v2.graphsage_model import GraphSageModel +from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.api.base_result import BaseResult from ...graph.graph_object import Graph @@ -47,7 +44,7 @@ def train( batch_size: Optional[int] = None, relationship_weight_property: Optional[str] = None, random_seed: Optional[Any] = None, - ) -> GraphSageModel: + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: """ Trains a GraphSage model on the given graph. @@ -110,192 +107,12 @@ def train( Returns ------- - GraphSageTrainResult - Training metrics and model information - """ - - @abstractmethod - def mutate( - self, - G: Graph, - model_name: str, - mutate_property: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - ) -> GraphSageMutateResult: - """ - Executes the GraphSage algorithm using a trained model and writes the results back to the graph as a node property. - - Parameters - ---------- - G : Graph - The graph to run the algorithm on - model_name : str - Name of the trained GraphSage model to use - mutate_property : str - The name of the node property to store the embeddings - relationship_types : Optional[List[str]], default=None - The relationship types used to select relationships for this algorithm run - node_labels : Optional[List[str]], default=None - The node labels used to select nodes for this algorithm run - username : Optional[str] = None - The username to attribute the procedure run to - log_progress : Optional[bool], default=None - Whether to log progress - sudo : Optional[bool], default=None - Override memory estimation limits - concurrency : Optional[Any], default=None - The number of concurrent threads - job_id : Optional[Any], default=None - An identifier for the job - batch_size : Optional[int], default=None - Batch size for inference - - Returns - ------- - GraphSageMutateResult - Algorithm metrics and statistics - """ - - @abstractmethod - def stream( - self, - G: Graph, - model_name: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - ) -> DataFrame: - """ - Executes the GraphSage algorithm using a trained model and returns the results as a stream. - - Parameters - ---------- - G : Graph - The graph to run the algorithm on - model_name : str - Name of the trained GraphSage model to use - relationship_types : Optional[List[str]], default=None - The relationship types used to select relationships for this algorithm run - node_labels : Optional[List[str]], default=None - The node labels used to select nodes for this algorithm run - username : Optional[str] = None - The username to attribute the procedure run to - log_progress : Optional[bool], default=None - Whether to log progress - sudo : Optional[bool], default=None - Override memory estimation limits - concurrency : Optional[Any], default=None - The number of concurrent threads - job_id : Optional[Any], default=None - An identifier for the job - batch_size : Optional[int], default=None - Batch size for inference - - Returns - ------- - DataFrame - Embeddings as a stream with columns nodeId and embedding + GraphSageModelV2 + Trained model """ - @abstractmethod - def write( - self, - G: Graph, - model_name: str, - write_property: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - write_concurrency: Optional[Any] = None, - ) -> GraphSageWriteResult: - """ - Executes the GraphSage algorithm using a trained model and writes the results back to the database. - - Parameters - ---------- - G : Graph - The graph to run the algorithm on - model_name : str - Name of the trained GraphSage model to use - write_property : str - The name of the node property to write the embeddings to - relationship_types : Optional[List[str]], default=None - The relationship types used to select relationships for this algorithm run - node_labels : Optional[List[str]], default=None - The node labels used to select nodes for this algorithm run - username : Optional[str] = None - The username to attribute the procedure run to - log_progress : Optional[bool], default=None - Whether to log progress - sudo : Optional[bool], default=None - Override memory estimation limits - concurrency : Optional[Any], default=None - The number of concurrent threads - job_id : Optional[Any], default=None - An identifier for the job - batch_size : Optional[int], default=None - Batch size for inference - write_concurrency : Optional[Any], default=None - The number of concurrent threads used for writing result - - Returns - ------- - GraphSageWriteResult - Algorithm metrics and statistics - """ - - -class GraphSageTrainResult(BaseModel): - model_config = ConfigDict(alias_generator=to_camel) +class GraphSageTrainResult(BaseResult): model_info: dict[str, Any] configuration: dict[str, Any] train_millis: int - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] - - -class GraphSageMutateResult(BaseModel): - model_config = ConfigDict(alias_generator=to_camel) - - node_count: int - node_properties_written: int - pre_processing_millis: int - compute_millis: int - mutate_millis: int - configuration: dict[str, Any] - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] - - -class GraphSageWriteResult(BaseModel): - model_config = ConfigDict(alias_generator=to_camel) - - node_count: int - node_properties_written: int - pre_processing_millis: int - compute_millis: int - write_millis: int - configuration: dict[str, Any] - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] diff --git a/graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py new file mode 100644 index 000000000..478579c64 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py @@ -0,0 +1,86 @@ +from typing import Any, List, Optional + +from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 + +from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from ...graph.graph_object import Graph +from ..api.graphsage_endpoints import ( + GraphSageEndpoints, + GraphSageTrainResult, +) +from .model_api_arrow import ModelApiArrow +from .node_property_endpoints import NodePropertyEndpoints + + +class GraphSageArrowEndpoints(GraphSageEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + self._node_property_endpoints = NodePropertyEndpoints(arrow_client) + self._model_api = ModelApiArrow(arrow_client) + + def train( + self, + G: Graph, + model_name: str, + feature_properties: List[str], + activation_function: Optional[Any] = None, + negative_sample_weight: Optional[int] = None, + embedding_dimension: Optional[int] = None, + tolerance: Optional[float] = None, + learning_rate: Optional[float] = None, + max_iterations: Optional[int] = None, + sample_sizes: Optional[List[int]] = None, + aggregator: Optional[Any] = None, + penalty_l2: Optional[float] = None, + search_depth: Optional[int] = None, + epochs: Optional[int] = None, + projected_feature_dimension: Optional[int] = None, + batch_sampling_ratio: Optional[float] = None, + store_model_to_disk: Optional[bool] = None, + relationship_types: Optional[List[str]] = None, + node_labels: Optional[List[str]] = None, + username: Optional[str] = None, + log_progress: Optional[bool] = None, + sudo: Optional[bool] = None, + concurrency: Optional[Any] = None, + job_id: Optional[Any] = None, + batch_size: Optional[int] = None, + relationship_weight_property: Optional[str] = None, + random_seed: Optional[Any] = None, + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: + config = self._node_property_endpoints.create_base_config( + G, + model_name=model_name, + feature_properties=feature_properties, + activation_function=activation_function, + negative_sample_weight=negative_sample_weight, + embedding_dimension=embedding_dimension, + tolerance=tolerance, + learning_rate=learning_rate, + max_iterations=max_iterations, + sample_sizes=sample_sizes, + aggregator=aggregator, + penalty_l2=penalty_l2, + search_depth=search_depth, + epochs=epochs, + projected_feature_dimension=projected_feature_dimension, + batch_sampling_ratio=batch_sampling_ratio, + store_model_to_disk=store_model_to_disk, + relationship_types=relationship_types, + node_labels=node_labels, + username=username, + log_progress=log_progress, + sudo=sudo, + concurrency=concurrency, + job_id=job_id, + batch_size=batch_size, + relationship_weight_property=relationship_weight_property, + random_seed=random_seed, + ) + + result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config) + + model = GraphSageModelV2(model_name, self._model_api) + train_result = GraphSageTrainResult(**result) + + return model, train_result diff --git a/graphdatascience/procedure_surface/arrow/model_api_arrow.py b/graphdatascience/procedure_surface/arrow/model_api_arrow.py new file mode 100644 index 000000000..b550bd466 --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/model_api_arrow.py @@ -0,0 +1,57 @@ +import datetime +import json +import re +from typing import Any, Optional + +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_info import ModelDetails + + +class ModelApiArrow(ModelApi): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client: AuthenticatedArrowClient = arrow_client + super().__init__() + + def exists(self, model: str) -> bool: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.exists", payload=json.dumps({"modelName": model}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + return False + + return True + + def get(self, model: str) -> ModelDetails: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.get", payload=json.dumps({"modelName": model}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + raise ValueError(f"There is no '{model}' in the model catalog") + + return self._parse_model_details(result[0]) + + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + raw_result = self._arrow_client.do_action_with_retry( + "v2/model.drop", payload=json.dumps({"modelName": model, "failIfMissing": fail_if_missing}).encode("utf-8") + ) + result = deserialize(raw_result) + + if not result: + return None + + return self._parse_model_details(result[0]) + + def _parse_model_details(self, input: dict[str, Any]) -> ModelDetails: + creation_time = input.pop("creationTime") + if creation_time and isinstance(creation_time, str): + # Trim microseconds from 9 digits to 6 digits + trimmed = re.sub(r"\.(\d{6})\d+", r".\1", creation_time) + input["creationTime"] = datetime.datetime.strptime(trimmed, "%Y-%m-%dT%H:%M:%S.%fZ[%Z]") + + return ModelDetails(**input) diff --git a/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py index 5ba3193db..95ac7de03 100644 --- a/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py @@ -1,15 +1,14 @@ from typing import Any, List, Optional -from pandas import DataFrame +from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.cypher.model_api_cypher import ModelApiCypher from ...call_parameters import CallParameters from ...graph.graph_object import Graph from ...query_runner.query_runner import QueryRunner from ..api.graphsage_endpoints import ( GraphSageEndpoints, - GraphSageMutateResult, GraphSageTrainResult, - GraphSageWriteResult, ) from ..utils.config_converter import ConfigConverter @@ -47,7 +46,7 @@ def train( batch_size: Optional[int] = None, relationship_weight_property: Optional[str] = None, random_seed: Optional[Any] = None, - ) -> GraphSageTrainResult: + ) -> tuple[GraphSageModelV2, GraphSageTrainResult]: config = ConfigConverter.convert_to_gds_config( model_name=model_name, feature_properties=feature_properties, @@ -80,106 +79,8 @@ def train( params = CallParameters(graph_name=G.name(), config=config) params.ensure_job_id_in_config() - result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.train", params=params).squeeze() + result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.train", params=params).iloc[0] - return GraphSageTrainResult(**result.to_dict()) - - def mutate( - self, - G: Graph, - model_name: str, - mutate_property: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - ) -> GraphSageMutateResult: - config = ConfigConverter.convert_to_gds_config( - model_name=model_name, - mutate_property=mutate_property, - relationship_types=relationship_types, - node_labels=node_labels, - username=username, - log_progress=log_progress, - sudo=sudo, - concurrency=concurrency, - job_id=job_id, - batch_size=batch_size, - ) - - params = CallParameters(graph_name=G.name(), config=config) - params.ensure_job_id_in_config() - - result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.mutate", params=params).squeeze() - - return GraphSageMutateResult(**result.to_dict()) - - def stream( - self, - G: Graph, - model_name: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - ) -> DataFrame: - config = ConfigConverter.convert_to_gds_config( - model_name=model_name, - relationship_types=relationship_types, - node_labels=node_labels, - username=username, - log_progress=log_progress, - sudo=sudo, - concurrency=concurrency, - job_id=job_id, - batch_size=batch_size, + return GraphSageModelV2(name=model_name, model_api=ModelApiCypher(self._query_runner)), GraphSageTrainResult( + **result.to_dict() ) - - params = CallParameters(graph_name=G.name(), config=config) - params.ensure_job_id_in_config() - - return self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream", params=params) - - def write( - self, - G: Graph, - model_name: str, - write_property: str, - relationship_types: Optional[List[str]] = None, - node_labels: Optional[List[str]] = None, - username: Optional[str] = None, - log_progress: Optional[bool] = None, - sudo: Optional[bool] = None, - concurrency: Optional[Any] = None, - job_id: Optional[Any] = None, - batch_size: Optional[int] = None, - write_concurrency: Optional[Any] = None, - ) -> GraphSageWriteResult: - config = ConfigConverter.convert_to_gds_config( - model_name=model_name, - write_property=write_property, - relationship_types=relationship_types, - node_labels=node_labels, - username=username, - log_progress=log_progress, - sudo=sudo, - concurrency=concurrency, - job_id=job_id, - batch_size=batch_size, - write_concurrency=write_concurrency, - ) - - params = CallParameters(graph_name=G.name(), config=config) - params.ensure_job_id_in_config() - - result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.write", params=params).squeeze() - - return GraphSageWriteResult(**result.to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/model_api_cypher.py b/graphdatascience/procedure_surface/cypher/model_api_cypher.py new file mode 100644 index 000000000..dbe33e680 --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/model_api_cypher.py @@ -0,0 +1,49 @@ +from typing import Any, Optional + +import neo4j + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.model.v2.model_info import ModelDetails +from graphdatascience.query_runner.query_runner import QueryRunner + + +class ModelApiCypher(ModelApi): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + super().__init__() + + def exists(self, model: str) -> bool: + params = CallParameters(name=model) + + result = self._query_runner.call_procedure("gds.model.exists", params=params, custom_error=False) + if result.empty: + return False + + return result.iloc[0]["exists"] # type: ignore + + def get(self, model: str) -> ModelDetails: + params = CallParameters(name=model) + + result = self._query_runner.call_procedure("gds.model.list", params=params, custom_error=False) + if result.empty: + raise ValueError(f"There is no '{model}' in the model catalog") + + return self._to_model_details(result.iloc[0].to_dict()) + + def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]: + params = CallParameters(model_name=model, fail_if_missing=fail_if_missing) + + result = self._query_runner.call_procedure("gds.model.drop", params=params, custom_error=False) + + if result.empty: + return None + + return self._to_model_details(result.iloc[0].to_dict()) + + def _to_model_details(self, result: dict[str, Any]) -> ModelDetails: + creation_time = result.get("creationTime", None) + if creation_time and isinstance(creation_time, neo4j.time.DateTime): + result["creationTime"] = creation_time.to_native() + + return ModelDetails(**result) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py new file mode 100644 index 000000000..4f955a224 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py @@ -0,0 +1,56 @@ +import json +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graphsage_arrow_endpoints import GraphSageArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (d: Node {feature: 4.0}), + (a)-[:REL]->(b), + (b)-[:REL]->(c), + (c)-[:REL]->(d), + (d)-[:REL]->(a) + """ + + yield create_graph(arrow_client, "g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8")) + + +@pytest.fixture +def graphsage_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[GraphSageArrowEndpoints, None, None]: + yield GraphSageArrowEndpoints(arrow_client) + + +def test_graphsage_train(graphsage_endpoints: GraphSageArrowEndpoints, sample_graph: Graph) -> None: + """Test GraphSage train operation.""" + model, result = graphsage_endpoints.train( + G=sample_graph, + model_name="testGraphSageModel", + feature_properties=["feature"], + embedding_dimension=1, + epochs=1, # Use minimal epochs for faster testing + max_iterations=1, # Use minimal iterations for faster testing + ) + + # Check the result + assert result.train_millis >= 0 + assert result.configuration is not None + assert result.model_info is not None + + # Check the model + assert model.name() == "testGraphSageModel" + assert model.exists() + + # Clean up the model + model.drop() diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py new file mode 100644 index 000000000..ef2e0c9fe --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py @@ -0,0 +1,85 @@ +import json +from typing import Generator + +import pytest +from pyarrow.flight import FlightServerError + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.arrow.graphsage_arrow_endpoints import GraphSageArrowEndpoints +from graphdatascience.procedure_surface.arrow.model_api_arrow import ModelApiArrow +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a: Node {age: 1}) + (b: Node {age: 2}) + (c: Node {age: 3}) + (d: Node {age: 4}) + (e: Node {age: 5}) + (f: Node {age: 6}) + (a)-[:REL]->(b) + (b)-[:REL]->(c) + (c)-[:REL]->(a) + (d)-[:REL]->(e) + (e)-[:REL]->(f) + (f)-[:REL]->(d) + """ + + yield create_graph(arrow_client, "model_api_g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "model_api_g"}).encode("utf-8")) + + +@pytest.fixture +def gs_model(arrow_client: AuthenticatedArrowClient, sample_graph: Graph) -> Generator[str, None, None]: + model, _ = GraphSageArrowEndpoints(arrow_client).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["age"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model.name() + + arrow_client.do_action_with_retry("v2/model.drop", json.dumps({"modelName": model.name()}).encode("utf-8")) + + +@pytest.fixture +def model_api(arrow_client: AuthenticatedArrowClient) -> Generator[ModelApiArrow, None, None]: + yield ModelApiArrow(arrow_client) + + +def test_model_get(gs_model: str, model_api: ModelApiArrow) -> None: + model = model_api.get(gs_model) + + assert model.name == gs_model + assert model.type == "graphSage" + + with pytest.raises(ValueError, match="There is no 'nonexistent-model' in the model catalog"): + model_api.get("nonexistent-model") + + +def test_model_exists(gs_model: str, model_api: ModelApiArrow) -> None: + assert model_api.exists(gs_model) + assert not model_api.exists("nonexistent-model") + + +def test_model_delete(gs_model: str, model_api: ModelApiArrow) -> None: + model_details = model_api.drop(gs_model, fail_if_missing=False) + + assert model_details is not None + assert model_details.name == gs_model + + # Check that the model no longer exists + assert not model_api.exists(gs_model) + + # Attempt to drop a non-existing model + assert model_api.drop("nonexistent-model", fail_if_missing=False) is None + + with pytest.raises(FlightServerError, match="Model with name `nonexistent-model` does not exist"): + model_api.drop("nonexistent-model", fail_if_missing=True) diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py index 22c175d08..f144c70a5 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py @@ -39,88 +39,14 @@ def graphsage_endpoints(query_runner: QueryRunner) -> Generator[GraphSageCypherE def test_graphsage_train(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: """Test GraphSage train operation.""" - result = graphsage_endpoints.train( + model, train_result = graphsage_endpoints.train( G=sample_graph_with_features, model_name="testModel", feature_properties=["feature"], embedding_dimension=64, ) - assert result.train_millis >= 0 - assert result.model_info is not None - assert result.configuration is not None - assert "testModel" in str(result.model_info) - - -def test_graphsage_mutate(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: - """Test GraphSage mutate operation.""" - # First train a model - graphsage_endpoints.train( - G=sample_graph_with_features, - model_name="testMutateModel", - feature_properties=["feature"], - embedding_dimension=64, - ) - - # Then use it for mutate - result = graphsage_endpoints.mutate( - G=sample_graph_with_features, - model_name="testMutateModel", - mutate_property="graphsage_embedding", - ) - - assert result.node_count == 3 - assert result.node_properties_written == 3 - assert result.pre_processing_millis >= 0 - assert result.compute_millis >= 0 - assert result.mutate_millis >= 0 - assert result.configuration is not None - - -def test_graphsage_stream(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: - """Test GraphSage stream operation.""" - # First train a model - graphsage_endpoints.train( - G=sample_graph_with_features, - model_name="testStreamModel", - feature_properties=["feature"], - embedding_dimension=64, - ) - - # Then use it for stream - result = graphsage_endpoints.stream( - G=sample_graph_with_features, - model_name="testStreamModel", - ) - - assert len(result) == 3 # We have 3 nodes - - # Check that we have the expected result structure - # For Cypher endpoints, this returns a DataFrame with string columns - assert "nodeId" in result.columns - assert "embedding" in result.columns - - -def test_graphsage_write(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: - """Test GraphSage write operation.""" - # First train a model - graphsage_endpoints.train( - G=sample_graph_with_features, - model_name="testWriteModel", - feature_properties=["feature"], - embedding_dimension=64, - ) - - # Then use it for write - result = graphsage_endpoints.write( - G=sample_graph_with_features, - model_name="testWriteModel", - write_property="graphsage_embedding", - ) - - assert result.node_count == 3 - assert result.node_properties_written == 3 - assert result.pre_processing_millis >= 0 - assert result.compute_millis >= 0 - assert result.write_millis >= 0 - assert result.configuration is not None + assert train_result.train_millis >= 0 + assert train_result.model_info is not None + assert train_result.configuration is not None + assert model.name() == "testModel" diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py new file mode 100644 index 000000000..428b0d24f --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_model_api_cypher.py @@ -0,0 +1,83 @@ +from typing import Generator + +import pytest +from neo4j.exceptions import Neo4jError + +from graphdatascience import Graph, QueryRunner +from graphdatascience.procedure_surface.cypher.model_api_cypher import ModelApiCypher + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {age: 1}), + (b: Node {age: 2}), + (c: Node {age: 3}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def gs_model(query_runner: QueryRunner, sample_graph: Graph) -> Generator[str, None, None]: + train_result = query_runner.run_cypher( + "CALL gds.beta.graphSage.train($graph, {modelName: 'gs-model', featureProperties:['age'], embeddingDimension: 1, sampleSizes: [1], maxIterations: 1, searchDepth: 1})", + {"graph": sample_graph.name()}, + ) + + model_name = train_result.iloc[0]["modelInfo"]["modelName"] + + yield model_name # type: ignore + + query_runner.run_cypher("CALL gds.model.drop($name, false)", {"name": model_name}) + + +@pytest.fixture +def model_api(query_runner: QueryRunner) -> Generator[ModelApiCypher, None, None]: + yield ModelApiCypher(query_runner) + + +def test_model_get(gs_model: str, model_api: ModelApiCypher) -> None: + model = model_api.get(gs_model) + + assert model.name == gs_model + assert model.type == "graphSage" + + with pytest.raises(ValueError, match="There is no 'nonexistent-model' in the model catalog"): + model_api.get("nonexistent-model") + + +def test_model_exists(gs_model: str, model_api: ModelApiCypher) -> None: + assert model_api.exists(gs_model) + assert not model_api.exists("nonexistent-model") + + +def test_model_delete(gs_model: str, model_api: ModelApiCypher) -> None: + model_details = model_api.drop(gs_model, fail_if_missing=False) + + assert model_details is not None + assert model_details.name == gs_model + + # Check that the model no longer exists + assert not model_api.exists(gs_model) + + # Attempt to drop a non-existing model + assert model_api.drop("nonexistent-model", fail_if_missing=False) is None + + with pytest.raises(Neo4jError, match="Model with name `nonexistent-model` does not exist"): + model_api.drop("nonexistent-model", fail_if_missing=True) From a99546f6325ac0c2446770278deba8a24692af46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florentin=20D=C3=B6rre?= Date: Tue, 19 Aug 2025 17:44:51 +0200 Subject: [PATCH 3/3] Implement gs predict endpoints --- graphdatascience/model/v2/graphsage_model.py | 73 ------ graphdatascience/model/v2/model.py | 2 +- graphdatascience/model/v2/model_api.py | 2 +- .../v2/{model_info.py => model_details.py} | 0 .../api/graphsage_predict_endpoints.py | 51 +++++ ...points.py => graphsage_train_endpoints.py} | 4 +- .../api/model/graphsage_model.py | 214 ++++++++++++++++++ .../graphsage_predict_arrow_endpoints.py | 63 ++++++ ....py => graphsage_train_arrow_endpoints.py} | 13 +- .../arrow/model_api_arrow.py | 2 +- .../graphsage_predict_cypher_endpoints.py | 58 +++++ ...py => graphsage_train_cypher_endpoints.py} | 17 +- .../cypher/model_api_cypher.py | 2 +- .../arrow/test_graphsage_arrow_endpoints.py | 8 +- .../test_graphsage_predict_arrow_endpoints.py | 74 ++++++ .../arrow/test_model_api_arrow.py | 4 +- .../cypher/test_graphsage_cypher_endpoints.py | 14 +- ...test_graphsage_predict_cypher_endpoints.py | 82 +++++++ 18 files changed, 580 insertions(+), 103 deletions(-) delete mode 100644 graphdatascience/model/v2/graphsage_model.py rename graphdatascience/model/v2/{model_info.py => model_details.py} (100%) create mode 100644 graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py rename graphdatascience/procedure_surface/api/{graphsage_endpoints.py => graphsage_train_endpoints.py} (97%) create mode 100644 graphdatascience/procedure_surface/api/model/graphsage_model.py create mode 100644 graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py rename graphdatascience/procedure_surface/arrow/{graphsage_arrow_endpoints.py => graphsage_train_arrow_endpoints.py} (86%) create mode 100644 graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py rename graphdatascience/procedure_surface/cypher/{graphsage_cypher_endpoints.py => graphsage_train_cypher_endpoints.py} (83%) create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py create mode 100644 graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py diff --git a/graphdatascience/model/v2/graphsage_model.py b/graphdatascience/model/v2/graphsage_model.py deleted file mode 100644 index 022e400fa..000000000 --- a/graphdatascience/model/v2/graphsage_model.py +++ /dev/null @@ -1,73 +0,0 @@ -from typing import Any - -from pandas import Series -from pydantic import BaseModel -from pydantic.alias_generators import to_camel - -from ...graph.graph_object import Graph -from ...graph.graph_type_check import graph_type_check -from .model import Model - - -class GraphSageModelV2(Model): - """ - Represents a GraphSAGE model in the model catalog. - Construct this using :func:`gds.graphSage.train()`. - """ - - def _endpoint_prefix(self) -> str: - return "gds.beta.graphSage." - - @graph_type_check - def predict_write(self, G: Graph, **config: Any) -> "Series[Any]": - """ - Generate embeddings for the given graph and write the results to the database. - - Args: - G: The graph to generate embeddings for. - **config: The config for the prediction. - - Returns: - The result of the write operation. - - """ - raise ValueError - - @graph_type_check - def predict_write_estimate(self, G: Graph, **config: Any) -> "Series[Any]": - """ - Estimate the memory needed to generate embeddings for the given graph and write the results to the database. - - Args: - G: The graph to generate embeddings for. - **config: The config for the prediction. - - Returns: - The memory needed to generate embeddings for the given graph and write the results to the database. - - """ - raise ValueError - - -class GraphSageMutateResult(BaseModel, alias_generator=to_camel): - node_count: int - node_properties_written: int - pre_processing_millis: int - compute_millis: int - mutate_millis: int - configuration: dict[str, Any] - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] - - -class GraphSageWriteResult(BaseModel, alias_generator=to_camel): - node_count: int - node_properties_written: int - pre_processing_millis: int - compute_millis: int - write_millis: int - configuration: dict[str, Any] - - def __getitem__(self, item: str) -> Any: - return self.__dict__[item] diff --git a/graphdatascience/model/v2/model.py b/graphdatascience/model/v2/model.py index 18ff2500c..7f34fbeab 100644 --- a/graphdatascience/model/v2/model.py +++ b/graphdatascience/model/v2/model.py @@ -4,7 +4,7 @@ from typing import Optional from graphdatascience.model.v2.model_api import ModelApi -from graphdatascience.model.v2.model_info import ModelDetails +from graphdatascience.model.v2.model_details import ModelDetails # Compared to v1 Model offering typed parameters for predict endpoints diff --git a/graphdatascience/model/v2/model_api.py b/graphdatascience/model/v2/model_api.py index a6c867814..d0bbc36c4 100644 --- a/graphdatascience/model/v2/model_api.py +++ b/graphdatascience/model/v2/model_api.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional -from graphdatascience.model.v2.model_info import ModelDetails +from graphdatascience.model.v2.model_details import ModelDetails class ModelApi(ABC): diff --git a/graphdatascience/model/v2/model_info.py b/graphdatascience/model/v2/model_details.py similarity index 100% rename from graphdatascience/model/v2/model_info.py rename to graphdatascience/model/v2/model_details.py diff --git a/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py new file mode 100644 index 000000000..87b188819 --- /dev/null +++ b/graphdatascience/procedure_surface/api/graphsage_predict_endpoints.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +from pandas import DataFrame + +from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult + +from ...graph.graph_object import Graph + + +class GraphSagePredictEndpoints(ABC): + """ + Abstract base class defining the API for the GraphSage algorithm. + """ + + @abstractmethod + def stream(self, G: Graph, **config: Any) -> DataFrame: + pass + + @abstractmethod + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + pass + + @abstractmethod + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + pass + + @abstractmethod + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + pass + + +class GraphSageMutateResult(BaseResult): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + mutate_millis: int + configuration: dict[str, Any] + + +class GraphSageWriteResult(BaseResult): + node_count: int + node_properties_written: int + pre_processing_millis: int + compute_millis: int + write_millis: int + configuration: dict[str, Any] diff --git a/graphdatascience/procedure_surface/api/graphsage_endpoints.py b/graphdatascience/procedure_surface/api/graphsage_train_endpoints.py similarity index 97% rename from graphdatascience/procedure_surface/api/graphsage_endpoints.py rename to graphdatascience/procedure_surface/api/graphsage_train_endpoints.py index c2c6adcb2..fd63a2718 100644 --- a/graphdatascience/procedure_surface/api/graphsage_endpoints.py +++ b/graphdatascience/procedure_surface/api/graphsage_train_endpoints.py @@ -3,13 +3,13 @@ from abc import ABC, abstractmethod from typing import Any, List, Optional -from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 from graphdatascience.procedure_surface.api.base_result import BaseResult +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 from ...graph.graph_object import Graph -class GraphSageEndpoints(ABC): +class GraphSageTrainEndpoints(ABC): """ Abstract base class defining the API for the GraphSage algorithm. """ diff --git a/graphdatascience/procedure_surface/api/model/graphsage_model.py b/graphdatascience/procedure_surface/api/model/graphsage_model.py new file mode 100644 index 000000000..4b1b532f4 --- /dev/null +++ b/graphdatascience/procedure_surface/api/model/graphsage_model.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import Optional + +from pandas import DataFrame + +from graphdatascience.model.v2.model_api import ModelApi +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) + +from ....graph.graph_object import Graph +from ....graph.graph_type_check import graph_type_check +from ....model.v2.model import Model + + +class GraphSageModelV2(Model): + """ + Represents a GraphSAGE model in the model catalog. + Construct this using :func:`gds.graphSage.train()`. + """ + + def __init__(self, name: str, model_api: ModelApi, predict_endpoints: GraphSagePredictEndpoints) -> None: + super().__init__(name, model_api) + self._predict_endpoints = predict_endpoints + + @graph_type_check + def predict_write( + self, + G: Graph, + write_property: str, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + write_concurrency: Optional[int] = None, + write_to_result_store: Optional[bool] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> GraphSageWriteResult: + """ + Generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + write_property: The property to write the embeddings to. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + write_concurrency: The concurrency for writing. + write_to_result_store: Whether to write to the result store. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The result of the write operation. + + """ + return self._predict_endpoints.write( + G, + modelName=self.name(), + writeProperty=write_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + writeConcurrency=write_concurrency, + writeToResultStore=write_to_result_store, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + def predict_stream( + self, + G: Graph, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> DataFrame: + """ + Generate embeddings for the given graph and stream the results. + + Args: + G: The graph to generate embeddings for. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The streaming results as a DataFrame. + + """ + return self._predict_endpoints.stream( + G, + modelName=self.name(), + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + def predict_mutate( + self, + G: Graph, + mutate_property: str, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> GraphSageMutateResult: + """ + Generate embeddings for the given graph and mutate the graph with the results. + + Args: + G: The graph to generate embeddings for. + mutate_property: The property to mutate with the embeddings. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The result of the mutate operation. + + """ + return self._predict_endpoints.mutate( + G, + modelName=self.name(), + mutateProperty=mutate_property, + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) + + @graph_type_check + def predict_estimate( + self, + G: Graph, + relationship_types: Optional[list[str]] = None, + node_labels: Optional[list[str]] = None, + batch_size: Optional[int] = None, + concurrency: Optional[int] = None, + log_progress: Optional[bool] = None, + username: Optional[str] = None, + sudo: Optional[bool] = None, + job_id: Optional[str] = None, + ) -> EstimationResult: + """ + Estimate the memory needed to generate embeddings for the given graph and write the results to the database. + + Args: + G: The graph to generate embeddings for. + relationship_types: The relationship types to consider. + node_labels: The node labels to consider. + batch_size: The batch size for prediction. + concurrency: The concurrency for computation. + log_progress: Whether to log progress. + username: The username for the operation. + sudo: Whether to use sudo privileges. + job_id: The job ID for the operation. + + Returns: + The memory needed to generate embeddings for the given graph and write the results to the database. + + """ + return self._predict_endpoints.estimate( + G, + modelName=self.name(), + relationshipTypes=relationship_types, + nodeLabels=node_labels, + batchSize=batch_size, + concurrency=concurrency, + logProgress=log_progress, + username=username, + sudo=sudo, + jobId=job_id, + ) diff --git a/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py new file mode 100644 index 000000000..13fa1127e --- /dev/null +++ b/graphdatascience/procedure_surface/arrow/graphsage_predict_arrow_endpoints.py @@ -0,0 +1,63 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.graph.graph_object import Graph +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) + +from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from .model_api_arrow import ModelApiArrow +from .node_property_endpoints import NodePropertyEndpoints + + +class GraphSagePredictArrowEndpoints(GraphSagePredictEndpoints): + def __init__(self, arrow_client: AuthenticatedArrowClient): + self._arrow_client = arrow_client + self._node_property_endpoints = NodePropertyEndpoints(arrow_client) + self._model_api = ModelApiArrow(arrow_client) + + def stream(self, G: Graph, **config: Any) -> DataFrame: + config = self._node_property_endpoints.create_base_config(G, **config) + + return self._node_property_endpoints.run_job_and_stream("v2/embeddings.graphSage", G, config) + + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + config = self._node_property_endpoints.create_base_config(G, **config) + + raw_result = self._node_property_endpoints.run_job_and_write( + "v2/embeddings.graphSage", + G, + config, + config.get("writeConcurrency"), + config.get("concurrency"), + ) + + return GraphSageWriteResult(**raw_result) + + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + config = self._node_property_endpoints.create_base_config(G, **config) + + mutateProperty = config.pop("mutateProperty", "") + + raw_result = self._node_property_endpoints.run_job_and_mutate( + "v2/embeddings.graphSage", + G, + config, + mutateProperty, + ) + + return GraphSageMutateResult(**raw_result) + + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + config = self._node_property_endpoints.create_estimate_config(**config) + + return self._node_property_endpoints.estimate( + "v2/embeddings.graphSage.estimate", + G, + config, + ) diff --git a/graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py b/graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py similarity index 86% rename from graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py rename to graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py index 478579c64..e3ef22ff9 100644 --- a/graphdatascience/procedure_surface/arrow/graphsage_arrow_endpoints.py +++ b/graphdatascience/procedure_surface/arrow/graphsage_train_arrow_endpoints.py @@ -1,18 +1,19 @@ from typing import Any, List, Optional -from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.arrow.graphsage_predict_arrow_endpoints import GraphSagePredictArrowEndpoints from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient from ...graph.graph_object import Graph -from ..api.graphsage_endpoints import ( - GraphSageEndpoints, +from ..api.graphsage_train_endpoints import ( + GraphSageTrainEndpoints, GraphSageTrainResult, ) from .model_api_arrow import ModelApiArrow from .node_property_endpoints import NodePropertyEndpoints -class GraphSageArrowEndpoints(GraphSageEndpoints): +class GraphSageTrainArrowEndpoints(GraphSageTrainEndpoints): def __init__(self, arrow_client: AuthenticatedArrowClient): self._arrow_client = arrow_client self._node_property_endpoints = NodePropertyEndpoints(arrow_client) @@ -80,7 +81,9 @@ def train( result = self._node_property_endpoints.run_job_and_get_summary("v2/embeddings.graphSage.train", G, config) - model = GraphSageModelV2(model_name, self._model_api) + model = GraphSageModelV2( + model_name, self._model_api, predict_endpoints=GraphSagePredictArrowEndpoints(self._arrow_client) + ) train_result = GraphSageTrainResult(**result) return model, train_result diff --git a/graphdatascience/procedure_surface/arrow/model_api_arrow.py b/graphdatascience/procedure_surface/arrow/model_api_arrow.py index b550bd466..7fa730c72 100644 --- a/graphdatascience/procedure_surface/arrow/model_api_arrow.py +++ b/graphdatascience/procedure_surface/arrow/model_api_arrow.py @@ -6,7 +6,7 @@ from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient from graphdatascience.arrow_client.v2.data_mapper_utils import deserialize from graphdatascience.model.v2.model_api import ModelApi -from graphdatascience.model.v2.model_info import ModelDetails +from graphdatascience.model.v2.model_details import ModelDetails class ModelApiArrow(ModelApi): diff --git a/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py new file mode 100644 index 000000000..9c43d2a8f --- /dev/null +++ b/graphdatascience/procedure_surface/cypher/graphsage_predict_cypher_endpoints.py @@ -0,0 +1,58 @@ +from typing import Any + +from pandas import DataFrame + +from graphdatascience.call_parameters import CallParameters +from graphdatascience.procedure_surface.api.estimation_result import EstimationResult +from graphdatascience.procedure_surface.api.graphsage_predict_endpoints import ( + GraphSageMutateResult, + GraphSagePredictEndpoints, + GraphSageWriteResult, +) +from graphdatascience.procedure_surface.utils.config_converter import ConfigConverter + +from ...graph.graph_object import Graph +from ...query_runner.query_runner import QueryRunner + + +class GraphSagePredictCypherEndpoints(GraphSagePredictEndpoints): + def __init__(self, query_runner: QueryRunner): + self._query_runner = query_runner + + def stream(self, G: Graph, **config: Any) -> DataFrame: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + return self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream", params=params) + + def write(self, G: Graph, **config: Any) -> GraphSageWriteResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.write", params=params) + + return GraphSageWriteResult(**raw_result.iloc[0].to_dict()) + + def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.mutate", params=params) + + return GraphSageMutateResult(**raw_result.iloc[0].to_dict()) + + def estimate(self, G: Graph, **config: Any) -> EstimationResult: + config = ConfigConverter.convert_to_gds_config(**config) + + params = CallParameters(graph_name=G.name(), config=config) + params.ensure_job_id_in_config() + + raw_result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.stream.estimate", params=params) + + return EstimationResult(**raw_result.iloc[0].to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py b/graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py similarity index 83% rename from graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py rename to graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py index 95ac7de03..4364e3cee 100644 --- a/graphdatascience/procedure_surface/cypher/graphsage_cypher_endpoints.py +++ b/graphdatascience/procedure_surface/cypher/graphsage_train_cypher_endpoints.py @@ -1,19 +1,20 @@ from typing import Any, List, Optional -from graphdatascience.model.v2.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.cypher.graphsage_predict_cypher_endpoints import GraphSagePredictCypherEndpoints from graphdatascience.procedure_surface.cypher.model_api_cypher import ModelApiCypher from ...call_parameters import CallParameters from ...graph.graph_object import Graph from ...query_runner.query_runner import QueryRunner -from ..api.graphsage_endpoints import ( - GraphSageEndpoints, +from ..api.graphsage_train_endpoints import ( + GraphSageTrainEndpoints, GraphSageTrainResult, ) from ..utils.config_converter import ConfigConverter -class GraphSageCypherEndpoints(GraphSageEndpoints): +class GraphSageTrainCypherEndpoints(GraphSageTrainEndpoints): def __init__(self, query_runner: QueryRunner): self._query_runner = query_runner @@ -81,6 +82,8 @@ def train( result = self._query_runner.call_procedure(endpoint="gds.beta.graphSage.train", params=params).iloc[0] - return GraphSageModelV2(name=model_name, model_api=ModelApiCypher(self._query_runner)), GraphSageTrainResult( - **result.to_dict() - ) + return GraphSageModelV2( + name=model_name, + model_api=ModelApiCypher(self._query_runner), + predict_endpoints=GraphSagePredictCypherEndpoints(self._query_runner), + ), GraphSageTrainResult(**result.to_dict()) diff --git a/graphdatascience/procedure_surface/cypher/model_api_cypher.py b/graphdatascience/procedure_surface/cypher/model_api_cypher.py index dbe33e680..01772829f 100644 --- a/graphdatascience/procedure_surface/cypher/model_api_cypher.py +++ b/graphdatascience/procedure_surface/cypher/model_api_cypher.py @@ -4,7 +4,7 @@ from graphdatascience.call_parameters import CallParameters from graphdatascience.model.v2.model_api import ModelApi -from graphdatascience.model.v2.model_info import ModelDetails +from graphdatascience.model.v2.model_details import ModelDetails from graphdatascience.query_runner.query_runner import QueryRunner diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py index 4f955a224..bf31e91e2 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_arrow_endpoints.py @@ -5,7 +5,7 @@ from graphdatascience import Graph from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.procedure_surface.arrow.graphsage_arrow_endpoints import GraphSageArrowEndpoints +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph @@ -28,11 +28,11 @@ def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, Non @pytest.fixture -def graphsage_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[GraphSageArrowEndpoints, None, None]: - yield GraphSageArrowEndpoints(arrow_client) +def graphsage_endpoints(arrow_client: AuthenticatedArrowClient) -> Generator[GraphSageTrainArrowEndpoints, None, None]: + yield GraphSageTrainArrowEndpoints(arrow_client) -def test_graphsage_train(graphsage_endpoints: GraphSageArrowEndpoints, sample_graph: Graph) -> None: +def test_graphsage_train(graphsage_endpoints: GraphSageTrainArrowEndpoints, sample_graph: Graph) -> None: """Test GraphSage train operation.""" model, result = graphsage_endpoints.train( G=sample_graph, diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py new file mode 100644 index 000000000..618a9cfaa --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_graphsage_predict_arrow_endpoints.py @@ -0,0 +1,74 @@ +import json +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints +from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph + + +@pytest.fixture +def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, None, None]: + gdl = """ + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (d: Node {feature: 4.0}), + (a)-[:REL]->(b), + (b)-[:REL]->(c), + (c)-[:REL]->(d), + (d)-[:REL]->(a) + """ + + yield create_graph(arrow_client, "g", gdl) + arrow_client.do_action("v2/graph.drop", json.dumps({"graphName": "g"}).encode("utf-8")) + + +@pytest.fixture +def gs_model(arrow_client: AuthenticatedArrowClient, sample_graph: Graph) -> Generator[GraphSageModelV2, None, None]: + model, _ = GraphSageTrainArrowEndpoints(arrow_client).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["feature"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model + + arrow_client.do_action_with_retry("v2/model.drop", json.dumps({"modelName": model.name()}).encode("utf-8")) + + +def test_stream(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_stream(sample_graph, concurrency=4) + + assert set(result.columns) == {"nodeId", "embedding"} + assert len(result) == 4 + + +def test_mutate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_mutate(sample_graph, concurrency=4, mutate_property="embedding") + + assert result.node_properties_written == 4 + + +def test_write(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + with pytest.raises(Exception, match="Write back client is not initialized"): + gs_model.predict_write(sample_graph, write_property="embedding", concurrency=4, write_concurrency=2) + + +def test_estimate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_estimate(sample_graph, concurrency=4) + + assert result.node_count == 4 + assert result.relationship_count == 4 + assert "KiB" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py index ef2e0c9fe..98021454a 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/arrow/test_model_api_arrow.py @@ -6,7 +6,7 @@ from graphdatascience import Graph from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient -from graphdatascience.procedure_surface.arrow.graphsage_arrow_endpoints import GraphSageArrowEndpoints +from graphdatascience.procedure_surface.arrow.graphsage_train_arrow_endpoints import GraphSageTrainArrowEndpoints from graphdatascience.procedure_surface.arrow.model_api_arrow import ModelApiArrow from graphdatascience.tests.integrationV2.procedure_surface.arrow.graph_creation_helper import create_graph @@ -34,7 +34,7 @@ def sample_graph(arrow_client: AuthenticatedArrowClient) -> Generator[Graph, Non @pytest.fixture def gs_model(arrow_client: AuthenticatedArrowClient, sample_graph: Graph) -> Generator[str, None, None]: - model, _ = GraphSageArrowEndpoints(arrow_client).train( + model, _ = GraphSageTrainArrowEndpoints(arrow_client).train( G=sample_graph, model_name="gs-model", feature_properties=["age"], diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py index f144c70a5..264f3f08a 100644 --- a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_cypher_endpoints.py @@ -3,7 +3,7 @@ import pytest from graphdatascience import Graph, QueryRunner -from graphdatascience.procedure_surface.cypher.graphsage_cypher_endpoints import GraphSageCypherEndpoints +from graphdatascience.procedure_surface.cypher.graphsage_train_cypher_endpoints import GraphSageTrainCypherEndpoints @pytest.fixture @@ -22,7 +22,7 @@ def sample_graph_with_features(query_runner: QueryRunner) -> Generator[Graph, No query_runner.run_cypher(""" MATCH (n) OPTIONAL MATCH (n)-[r]->(m) - WITH gds.graph.project('g', n, m, {nodeProperties: 'feature'}) AS G + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G RETURN G """) @@ -33,17 +33,19 @@ def sample_graph_with_features(query_runner: QueryRunner) -> Generator[Graph, No @pytest.fixture -def graphsage_endpoints(query_runner: QueryRunner) -> Generator[GraphSageCypherEndpoints, None, None]: - yield GraphSageCypherEndpoints(query_runner) +def graphsage_endpoints(query_runner: QueryRunner) -> Generator[GraphSageTrainCypherEndpoints, None, None]: + yield GraphSageTrainCypherEndpoints(query_runner) -def test_graphsage_train(graphsage_endpoints: GraphSageCypherEndpoints, sample_graph_with_features: Graph) -> None: +def test_graphsage_train(graphsage_endpoints: GraphSageTrainCypherEndpoints, sample_graph_with_features: Graph) -> None: """Test GraphSage train operation.""" model, train_result = graphsage_endpoints.train( G=sample_graph_with_features, model_name="testModel", feature_properties=["feature"], - embedding_dimension=64, + embedding_dimension=1, + epochs=1, # Use minimal epochs for faster testing + max_iterations=1, # Use minimal iterations for faster testing ) assert train_result.train_millis >= 0 diff --git a/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py new file mode 100644 index 000000000..bd94c2d13 --- /dev/null +++ b/graphdatascience/tests/integrationV2/procedure_surface/cypher/test_graphsage_predict_cypher_endpoints.py @@ -0,0 +1,82 @@ +from typing import Generator + +import pytest + +from graphdatascience import Graph +from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2 +from graphdatascience.procedure_surface.cypher.graphsage_train_cypher_endpoints import GraphSageTrainCypherEndpoints +from graphdatascience.query_runner.query_runner import QueryRunner + + +@pytest.fixture +def sample_graph(query_runner: QueryRunner) -> Generator[Graph, None, None]: + create_statement = """ + CREATE + (a: Node {feature: 1.0}), + (b: Node {feature: 2.0}), + (c: Node {feature: 3.0}), + (a)-[:REL]->(c), + (b)-[:REL]->(c) + """ + + query_runner.run_cypher(create_statement) + + query_runner.run_cypher(""" + MATCH (n) + OPTIONAL MATCH (n)-[r]->(m) + WITH gds.graph.project('g', n, m, {sourceNodeProperties: properties(n), targetNodeProperties: properties(m)}) AS G + RETURN G + """) + + yield Graph("g", query_runner) + + query_runner.run_cypher("CALL gds.graph.drop('g')") + query_runner.run_cypher("MATCH (n) DETACH DELETE n") + + +@pytest.fixture +def gs_model(query_runner: QueryRunner, sample_graph: Graph) -> Generator[GraphSageModelV2, None, None]: + model, _ = GraphSageTrainCypherEndpoints(query_runner).train( + G=sample_graph, + model_name="gs-model", + feature_properties=["feature"], + embedding_dimension=1, + sample_sizes=[1], + max_iterations=1, + search_depth=1, + ) + + yield model + + query_runner.run_cypher("CALL gds.model.drop('gs-model')") + + +def test_stream(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_stream(sample_graph, concurrency=4) + + assert set(result.columns) == {"nodeId", "embedding"} + assert len(result) == 3 + + +def test_mutate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_mutate(sample_graph, concurrency=4, mutate_property="embedding") + + assert result.node_properties_written == 3 + + +def test_write(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_write(sample_graph, write_property="embedding", concurrency=4, write_concurrency=2) + + assert result.node_properties_written == 3 + + +def test_estimate(gs_model: GraphSageModelV2, sample_graph: Graph) -> None: + result = gs_model.predict_estimate(sample_graph, concurrency=4) + + assert result.node_count == 3 + assert result.relationship_count == 2 + assert "KiB" in result.required_memory + assert result.bytes_min > 0 + assert result.bytes_max > 0 + assert result.heap_percentage_min > 0 + assert result.heap_percentage_max > 0