Skip to content

Add graphsage to v2 #938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions graphdatascience/model/v2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

from abc import ABC
from typing import Optional

from graphdatascience.model.v2.model_api import ModelApi
from graphdatascience.model.v2.model_details import ModelDetails


# Compared to v1 Model offering typed parameters for predict endpoints
class Model(ABC):
def __init__(self, name: str, model_api: ModelApi):
self._name = name
self._model_api = model_api

# TODO estimate mode, predict modes on here?
# implement Cypher and Arrow info_provider and stuff

def name(self) -> str:
"""
Get the name of the model.

Returns:
The name of the model.

"""
return self._name

def details(self) -> ModelDetails:
return self._model_api.get(self._name)

def exists(self) -> bool:
"""
Check whether the model exists.

Returns:
True if the model exists, False otherwise.

"""
return self._model_api.exists(self._name)

def drop(self, failIfMissing: bool = False) -> Optional[ModelDetails]:
"""
Drop the model.

Args:
failIfMissing: If True, an error is thrown if the model does not exist. If False, no error is thrown.

Returns:
The result of the drop operation.

"""
return self._model_api.drop(self._name, failIfMissing)

def __str__(self) -> str:
return f"{self.__class__.__name__}(name={self.name()}, type={self.details().type})"

def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.details().model_dump()})"
51 changes: 51 additions & 0 deletions graphdatascience/model/v2/model_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from abc import ABC, abstractmethod
from typing import Optional

from graphdatascience.model.v2.model_details import ModelDetails


class ModelApi(ABC):
"""
Abstract base class defining the API for model operations.
This class is intended to be subclassed by specific model implementations.
"""

@abstractmethod
def exists(self, model: str) -> bool:
"""
Check if a specific model exists.

Args:
model: The name of the model.

Returns:
True if the model exists, False otherwise.
"""
pass

@abstractmethod
def get(self, model: str) -> ModelDetails:
"""
Get the details of a specific model.

Args:
model: The name of the model.

Returns:
The details of the model.
"""
pass

@abstractmethod
def drop(self, model: str, fail_if_missing: bool) -> Optional[ModelDetails]:
"""
Drop a specific model.

Args:
model: The name of the model.
fail_if_missing: If True, an error is thrown if the model does not exist. If False, no error is thrown.

Returns:
The result of the drop operation.
"""
pass
20 changes: 20 additions & 0 deletions graphdatascience/model/v2/model_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import datetime
from typing import Any

from pydantic import BaseModel, Field
from pydantic.alias_generators import to_camel


class ModelDetails(BaseModel, alias_generator=to_camel):
name: str = Field(alias="modelName")
type: str = Field(alias="modelType")
train_config: dict[str, Any]
graph_schema: dict[str, Any]
loaded: bool
stored: bool
published: bool
model_info: dict[str, Any] # TODO better typing in actual model?
creation_time: datetime.datetime

def __getitem__(self, item: str) -> Any:
return getattr(self, item)
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any

from pandas import DataFrame

from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult

from ...graph.graph_object import Graph


class GraphSagePredictEndpoints(ABC):
"""
Abstract base class defining the API for the GraphSage algorithm.
"""

@abstractmethod
def stream(self, G: Graph, **config: Any) -> DataFrame:
pass

@abstractmethod
def write(self, G: Graph, **config: Any) -> GraphSageWriteResult:
pass

@abstractmethod
def mutate(self, G: Graph, **config: Any) -> GraphSageMutateResult:
pass

@abstractmethod
def estimate(self, G: Graph, **config: Any) -> EstimationResult:
pass


class GraphSageMutateResult(BaseResult):
node_count: int
node_properties_written: int
pre_processing_millis: int
compute_millis: int
mutate_millis: int
configuration: dict[str, Any]


class GraphSageWriteResult(BaseResult):
node_count: int
node_properties_written: int
pre_processing_millis: int
compute_millis: int
write_millis: int
configuration: dict[str, Any]
118 changes: 118 additions & 0 deletions graphdatascience/procedure_surface/api/graphsage_train_endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, List, Optional

from graphdatascience.procedure_surface.api.base_result import BaseResult
from graphdatascience.procedure_surface.api.model.graphsage_model import GraphSageModelV2

from ...graph.graph_object import Graph


class GraphSageTrainEndpoints(ABC):
"""
Abstract base class defining the API for the GraphSage algorithm.
"""

@abstractmethod
def train(
self,
G: Graph,
model_name: str,
feature_properties: List[str],
activation_function: Optional[Any] = None,
negative_sample_weight: Optional[int] = None,
embedding_dimension: Optional[int] = None,
tolerance: Optional[float] = None,
learning_rate: Optional[float] = None,
max_iterations: Optional[int] = None,
sample_sizes: Optional[List[int]] = None,
aggregator: Optional[Any] = None,
penalty_l2: Optional[float] = None,
search_depth: Optional[int] = None,
epochs: Optional[int] = None,
projected_feature_dimension: Optional[int] = None,
batch_sampling_ratio: Optional[float] = None,
store_model_to_disk: Optional[bool] = None,
relationship_types: Optional[List[str]] = None,
node_labels: Optional[List[str]] = None,
username: Optional[str] = None,
log_progress: Optional[bool] = None,
sudo: Optional[bool] = None,
concurrency: Optional[Any] = None,
job_id: Optional[Any] = None,
batch_size: Optional[int] = None,
relationship_weight_property: Optional[str] = None,
random_seed: Optional[Any] = None,
) -> tuple[GraphSageModelV2, GraphSageTrainResult]:
"""
Trains a GraphSage model on the given graph.

Parameters
----------
G : Graph
The graph to run the algorithm on
model_name : str
Name under which the model will be stored
feature_properties : List[str]
The names of the node properties to use as input features
activation_function : Optional[Any], default=None
The activation function to apply after each layer
negative_sample_weight : Optional[int], default=None
Weight of negative samples in the loss function
embedding_dimension : Optional[int], default=None
The dimension of the generated embeddings
tolerance : Optional[float], default=None
Tolerance for early stopping based on loss improvement
learning_rate : Optional[float], default=None
Learning rate for the training optimization
max_iterations : Optional[int], default=None
Maximum number of training iterations
sample_sizes : Optional[List[int]], default=None
Number of neighbors to sample at each layer
aggregator : Optional[Any], default=None
The aggregator function for neighborhood aggregation
penalty_l2 : Optional[float], default=None
L2 regularization penalty
search_depth : Optional[int], default=None
Maximum search depth for neighbor sampling
epochs : Optional[int], default=None
Number of training epochs
projected_feature_dimension : Optional[int], default=None
Dimension to project input features to before training
batch_sampling_ratio : Optional[float], default=None
Ratio of nodes to sample for each training batch
store_model_to_disk : Optional[bool], default=None
Whether to persist the model to disk
relationship_types : Optional[List[str]], default=None
The relationship types used to select relationships for this algorithm run
node_labels : Optional[List[str]], default=None
The node labels used to select nodes for this algorithm run
username : Optional[str] = None
The username to attribute the procedure run to
log_progress : Optional[bool], default=None
Whether to log progress
sudo : Optional[bool], default=None
Override memory estimation limits
concurrency : Optional[Any], default=None
The number of concurrent threads
job_id : Optional[Any], default=None
An identifier for the job
batch_size : Optional[int], default=None
Batch size for training
relationship_weight_property : Optional[str], default=None
The property name that contains weight
random_seed : Optional[Any], default=None
Random seed for reproducible results

Returns
-------
GraphSageModelV2
Trained model
"""


class GraphSageTrainResult(BaseResult):
model_info: dict[str, Any]
configuration: dict[str, Any]
train_millis: int
Loading