Skip to content

Commit 3d5e8ab

Browse files
committed
WIP KNN
1 parent 090fddd commit 3d5e8ab

File tree

14 files changed

+1372
-2
lines changed

14 files changed

+1372
-2
lines changed

graphdatascience/arrow_client/authenticated_flight_client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def do_action_with_retry(self, endpoint: str, payload: Union[bytes, dict[str, An
189189
wait=self._retry_config.wait,
190190
)
191191
def run_with_retry() -> Iterator[Result]:
192+
# TODO collect result to avoid lazy response status eval
192193
return self.do_action(endpoint, payload)
193194

194195
return run_with_retry()

graphdatascience/procedure_surface/api/similarity/__init__.py

Whitespace-only changes.

graphdatascience/procedure_surface/api/similarity/knn_endpoints.py

Lines changed: 418 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from typing import Any, Dict, Optional, Union
2+
3+
from pandas import DataFrame
4+
5+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
7+
from ...arrow_client.authenticated_flight_client import AuthenticatedArrowClient
8+
from ...arrow_client.v2.data_mapper_utils import deserialize_single
9+
from ...arrow_client.v2.job_client import JobClient
10+
from ...arrow_client.v2.mutation_client import MutationClient
11+
from ...arrow_client.v2.remote_write_back_client import RemoteWriteBackClient
12+
from ..api.estimation_result import EstimationResult
13+
from ..utils.config_converter import ConfigConverter
14+
15+
16+
# TODO find common parts with node_property_endpoints and refactor into a base class
17+
class RelationshipEndpointsHelper:
18+
"""
19+
Helper class for Arrow algorithm endpoints that work with relationships.
20+
Provides common functionality for job execution, mutation, streaming, and writing.
21+
"""
22+
23+
def __init__(
24+
self,
25+
arrow_client: AuthenticatedArrowClient,
26+
write_back_client: Optional[RemoteWriteBackClient] = None,
27+
show_progress: bool = True,
28+
):
29+
self._arrow_client = arrow_client
30+
self._write_back_client = write_back_client
31+
self._show_progress = show_progress
32+
33+
def run_job_and_get_summary(self, endpoint: str, G: GraphV2, config: Dict[str, Any]) -> Dict[str, Any]:
34+
"""Run a job and return the computation summary."""
35+
show_progress: bool = config.get("logProgress", True) and self._show_progress
36+
37+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress)
38+
return JobClient.get_summary(self._arrow_client, job_id)
39+
40+
def run_job_and_mutate(
41+
self, endpoint: str, G: GraphV2, config: Dict[str, Any], mutate_property: str, mutate_relationship_type: str
42+
) -> Dict[str, Any]:
43+
"""Run a job, mutate node properties, and return summary with mutation result."""
44+
show_progress = config.get("logProgress", True) and self._show_progress
45+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress)
46+
mutate_result = MutationClient.mutate_node_property(self._arrow_client, job_id, mutate_property)
47+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
48+
49+
# modify computation result to include mutation details
50+
computation_result["relationshipsWritten"] = mutate_result.relationships_written
51+
computation_result["mutateMillis"] = mutate_result.mutate_millis
52+
53+
if (config := computation_result.get("configuration", None)) is not None:
54+
config["mutateProperty"] = mutate_property
55+
config["mutateRelationshipType"] = mutate_relationship_type
56+
config.pop("writeConcurrency", None)
57+
config.pop("writeToResultStore", None)
58+
config.pop("writeProperty", None)
59+
config.pop("writeMillis", None)
60+
61+
return computation_result
62+
63+
def run_job_and_stream(self, endpoint: str, G: GraphV2, config: Dict[str, Any]) -> DataFrame:
64+
"""Run a job and return streamed results."""
65+
show_progress = config.get("logProgress", True) and self._show_progress
66+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress)
67+
return JobClient.stream_results(self._arrow_client, G.name(), job_id)
68+
69+
def run_job_and_write(
70+
self,
71+
endpoint: str,
72+
G: GraphV2,
73+
config: Dict[str, Any],
74+
*,
75+
relationship_type_overwrite: str,
76+
property_overwrites: Union[str, dict[str, str]],
77+
write_concurrency: Optional[int],
78+
concurrency: Optional[int],
79+
) -> Dict[str, Any]:
80+
"""Run a job, write results, and return summary with write time."""
81+
show_progress = config.get("logProgress", True) and self._show_progress
82+
job_id = JobClient.run_job_and_wait(self._arrow_client, endpoint, config, show_progress=show_progress)
83+
computation_result = JobClient.get_summary(self._arrow_client, job_id)
84+
85+
if self._write_back_client is None:
86+
raise Exception("Write back client is not initialized")
87+
88+
if isinstance(property_overwrites, str):
89+
# The remote write back procedure allows specifying a single overwrite. The key is ignored.
90+
property_overwrites = {property_overwrites: property_overwrites}
91+
92+
write_result = self._write_back_client.write(
93+
G.name(),
94+
job_id,
95+
concurrency=write_concurrency if write_concurrency is not None else concurrency,
96+
property_overwrites=property_overwrites,
97+
relationship_type_overwrite=relationship_type_overwrite,
98+
log_progress=show_progress,
99+
)
100+
101+
# modify computation result to include write details
102+
computation_result["writeMillis"] = write_result.write_millis
103+
104+
return computation_result
105+
106+
def create_base_config(self, G: GraphV2, **kwargs: Any) -> Dict[str, Any]:
107+
"""Create base configuration with common parameters."""
108+
return ConfigConverter.convert_to_gds_config(graph_name=G.name(), **kwargs)
109+
110+
def create_estimate_config(self, **kwargs: Any) -> Dict[str, Any]:
111+
"""Create configuration for estimation."""
112+
return ConfigConverter.convert_to_gds_config(**kwargs)
113+
114+
def estimate(
115+
self,
116+
estimate_endpoint: str,
117+
G: Union[GraphV2, dict[str, Any]],
118+
algo_config: Optional[dict[str, Any]] = None,
119+
) -> EstimationResult:
120+
"""Estimate memory requirements for the algorithm."""
121+
if isinstance(G, GraphV2):
122+
payload = {"graphName": G.name()}
123+
elif isinstance(G, dict):
124+
payload = G
125+
else:
126+
raise ValueError("Either graph_name or projection_config must be provided.")
127+
128+
payload.update(algo_config or {})
129+
130+
res = self._arrow_client.do_action_with_retry(estimate_endpoint, payload)
131+
132+
return EstimationResult(**deserialize_single(res))

graphdatascience/procedure_surface/arrow/similarity/__init__.py

Whitespace-only changes.
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
from typing import Any, List, Optional, Union
2+
3+
from pandas import DataFrame
4+
5+
from graphdatascience.procedure_surface.api.catalog.graph_api import GraphV2
6+
from graphdatascience.procedure_surface.api.estimation_result import EstimationResult
7+
from graphdatascience.procedure_surface.api.similarity.knn_endpoints import (
8+
KnnEndpoints,
9+
KnnMutateResult,
10+
KnnStatsResult,
11+
KnnWriteResult,
12+
)
13+
from graphdatascience.procedure_surface.arrow.relationship_endpoints_helper import RelationshipEndpointsHelper
14+
15+
16+
class KnnArrowEndpoints(KnnEndpoints):
17+
def __init__(self, endpoints_helper: RelationshipEndpointsHelper):
18+
self._endpoints_helper = endpoints_helper
19+
20+
def mutate(
21+
self,
22+
G: GraphV2,
23+
mutate_relationship_type: str,
24+
mutate_property: str,
25+
node_properties: Union[str, List[str], dict[str, str]],
26+
top_k: Optional[int] = None,
27+
similarity_cutoff: Optional[float] = None,
28+
delta_threshold: Optional[float] = None,
29+
max_iterations: Optional[int] = None,
30+
sample_rate: Optional[float] = None,
31+
perturbation_rate: Optional[float] = None,
32+
random_joins: Optional[int] = None,
33+
random_seed: Optional[int] = None,
34+
initial_sampler: Optional[Any] = None,
35+
relationship_types: Optional[List[str]] = None,
36+
node_labels: Optional[List[str]] = None,
37+
sudo: Optional[bool] = None,
38+
log_progress: bool = True,
39+
username: Optional[str] = None,
40+
concurrency: Optional[Any] = None,
41+
job_id: Optional[Any] = None,
42+
) -> KnnMutateResult:
43+
config = self._endpoints_helper.create_base_config(
44+
G,
45+
nodeProperties=node_properties,
46+
topK=top_k,
47+
similarityCutoff=similarity_cutoff,
48+
deltaThreshold=delta_threshold,
49+
maxIterations=max_iterations,
50+
sampleRate=sample_rate,
51+
perturbationRate=perturbation_rate,
52+
randomJoins=random_joins,
53+
randomSeed=random_seed,
54+
initialSampler=initial_sampler,
55+
relationshipTypes=relationship_types,
56+
nodeLabels=node_labels,
57+
sudo=sudo,
58+
logProgress=log_progress,
59+
username=username,
60+
concurrency=concurrency,
61+
jobId=job_id,
62+
)
63+
64+
result = self._endpoints_helper.run_job_and_mutate(
65+
"v2/similarity.knn", G, config, mutate_property, mutate_relationship_type
66+
)
67+
68+
return KnnMutateResult(**result)
69+
70+
def stats(
71+
self,
72+
G: GraphV2,
73+
node_properties: Union[str, List[str], dict[str, str]],
74+
top_k: Optional[int] = None,
75+
similarity_cutoff: Optional[float] = None,
76+
delta_threshold: Optional[float] = None,
77+
max_iterations: Optional[int] = None,
78+
sample_rate: Optional[float] = None,
79+
perturbation_rate: Optional[float] = None,
80+
random_joins: Optional[int] = None,
81+
random_seed: Optional[int] = None,
82+
initial_sampler: Optional[Any] = None,
83+
relationship_types: Optional[List[str]] = None,
84+
node_labels: Optional[List[str]] = None,
85+
sudo: Optional[bool] = None,
86+
log_progress: bool = True,
87+
username: Optional[str] = None,
88+
concurrency: Optional[Any] = None,
89+
job_id: Optional[Any] = None,
90+
) -> KnnStatsResult:
91+
config = self._endpoints_helper.create_base_config(
92+
G,
93+
nodeProperties=node_properties,
94+
topK=top_k,
95+
similarityCutoff=similarity_cutoff,
96+
deltaThreshold=delta_threshold,
97+
maxIterations=max_iterations,
98+
sampleRate=sample_rate,
99+
perturbationRate=perturbation_rate,
100+
randomJoins=random_joins,
101+
randomSeed=random_seed,
102+
initialSampler=initial_sampler,
103+
relationshipTypes=relationship_types,
104+
nodeLabels=node_labels,
105+
sudo=sudo,
106+
logProgress=log_progress,
107+
username=username,
108+
concurrency=concurrency,
109+
jobId=job_id,
110+
)
111+
112+
result = self._endpoints_helper.run_job_and_get_summary("v2/similarity.knn", G, config)
113+
114+
return KnnStatsResult(**result)
115+
116+
def stream(
117+
self,
118+
G: GraphV2,
119+
node_properties: Union[str, List[str], dict[str, str]],
120+
top_k: Optional[int] = None,
121+
similarity_cutoff: Optional[float] = None,
122+
delta_threshold: Optional[float] = None,
123+
max_iterations: Optional[int] = None,
124+
sample_rate: Optional[float] = None,
125+
perturbation_rate: Optional[float] = None,
126+
random_joins: Optional[int] = None,
127+
random_seed: Optional[int] = None,
128+
initial_sampler: Optional[Any] = None,
129+
relationship_types: Optional[List[str]] = None,
130+
node_labels: Optional[List[str]] = None,
131+
sudo: Optional[bool] = None,
132+
log_progress: bool = True,
133+
username: Optional[str] = None,
134+
concurrency: Optional[Any] = None,
135+
job_id: Optional[Any] = None,
136+
) -> DataFrame:
137+
config = self._endpoints_helper.create_base_config(
138+
G,
139+
nodeProperties=node_properties,
140+
topK=top_k,
141+
similarityCutoff=similarity_cutoff,
142+
deltaThreshold=delta_threshold,
143+
maxIterations=max_iterations,
144+
sampleRate=sample_rate,
145+
perturbationRate=perturbation_rate,
146+
randomJoins=random_joins,
147+
randomSeed=random_seed,
148+
initialSampler=initial_sampler,
149+
relationshipTypes=relationship_types,
150+
nodeLabels=node_labels,
151+
sudo=sudo,
152+
logProgress=log_progress,
153+
username=username,
154+
concurrency=concurrency,
155+
jobId=job_id,
156+
)
157+
158+
return self._endpoints_helper.run_job_and_stream("v2/similarity.knn", G, config)
159+
160+
def write(
161+
self,
162+
G: GraphV2,
163+
write_relationship_type: str,
164+
write_property: str,
165+
node_properties: Union[str, List[str], dict[str, str]],
166+
top_k: Optional[int] = None,
167+
similarity_cutoff: Optional[float] = None,
168+
delta_threshold: Optional[float] = None,
169+
max_iterations: Optional[int] = None,
170+
sample_rate: Optional[float] = None,
171+
perturbation_rate: Optional[float] = None,
172+
random_joins: Optional[int] = None,
173+
random_seed: Optional[int] = None,
174+
initial_sampler: Optional[Any] = None,
175+
relationship_types: Optional[List[str]] = None,
176+
node_labels: Optional[List[str]] = None,
177+
sudo: Optional[bool] = None,
178+
log_progress: bool = True,
179+
username: Optional[str] = None,
180+
concurrency: Optional[Any] = None,
181+
job_id: Optional[Any] = None,
182+
write_concurrency: Optional[int] = None,
183+
) -> KnnWriteResult:
184+
config = self._endpoints_helper.create_base_config(
185+
G,
186+
nodeProperties=node_properties,
187+
topK=top_k,
188+
similarityCutoff=similarity_cutoff,
189+
deltaThreshold=delta_threshold,
190+
maxIterations=max_iterations,
191+
sampleRate=sample_rate,
192+
perturbationRate=perturbation_rate,
193+
randomJoins=random_joins,
194+
randomSeed=random_seed,
195+
initialSampler=initial_sampler,
196+
relationshipTypes=relationship_types,
197+
nodeLabels=node_labels,
198+
sudo=sudo,
199+
logProgress=log_progress,
200+
username=username,
201+
concurrency=concurrency,
202+
jobId=job_id,
203+
)
204+
205+
result = self._endpoints_helper.run_job_and_write(
206+
"v2/similarity.knn",
207+
G,
208+
config,
209+
relationship_type_overwrite=write_relationship_type,
210+
property_overwrites=write_property,
211+
write_concurrency=write_concurrency,
212+
concurrency=None,
213+
)
214+
215+
return KnnWriteResult(**result)
216+
217+
def estimate(
218+
self,
219+
G: GraphV2,
220+
node_properties: Union[str, List[str], dict[str, str]],
221+
top_k: Optional[int] = None,
222+
similarity_cutoff: Optional[float] = None,
223+
delta_threshold: Optional[float] = None,
224+
max_iterations: Optional[int] = None,
225+
sample_rate: Optional[float] = None,
226+
perturbation_rate: Optional[float] = None,
227+
random_joins: Optional[int] = None,
228+
random_seed: Optional[int] = None,
229+
initial_sampler: Optional[Any] = None,
230+
relationship_types: Optional[List[str]] = None,
231+
node_labels: Optional[List[str]] = None,
232+
sudo: Optional[bool] = None,
233+
log_progress: bool = True,
234+
username: Optional[str] = None,
235+
concurrency: Optional[Any] = None,
236+
job_id: Optional[Any] = None,
237+
) -> EstimationResult:
238+
config = self._endpoints_helper.create_estimate_config(
239+
nodeProperties=node_properties,
240+
topK=top_k,
241+
similarityCutoff=similarity_cutoff,
242+
deltaThreshold=delta_threshold,
243+
maxIterations=max_iterations,
244+
sampleRate=sample_rate,
245+
perturbationRate=perturbation_rate,
246+
randomJoins=random_joins,
247+
randomSeed=random_seed,
248+
initialSampler=initial_sampler,
249+
relationshipTypes=relationship_types,
250+
nodeLabels=node_labels,
251+
sudo=sudo,
252+
logProgress=log_progress,
253+
username=username,
254+
concurrency=concurrency,
255+
jobId=job_id,
256+
)
257+
258+
return self._endpoints_helper.estimate("v2/similarity.knn", G, config)

0 commit comments

Comments
 (0)