Skip to content

Commit 7cbb64b

Browse files
committed
Implement client endpoints for gnn/graph sage training
1 parent daf86bb commit 7cbb64b

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

graphdatascience/endpoints.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .algo.single_mode_algo_endpoints import SingleModeAlgoEndpoints
22
from .call_builder import IndirectAlphaCallBuilder, IndirectBetaCallBuilder
3+
from .gnn.gnn_endpoints import GnnEndpoints
34
from .graph.graph_endpoints import (
45
GraphAlphaEndpoints,
56
GraphBetaEndpoints,
@@ -32,7 +33,7 @@
3233
"""
3334

3435

35-
class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints):
36+
class DirectEndpoints(DirectSystemEndpoints, DirectUtilEndpoints, GraphEndpoints, PipelineEndpoints, ModelEndpoints, GnnEndpoints):
3637
def __init__(self, query_runner: QueryRunner, namespace: str, server_version: ServerVersion):
3738
super().__init__(query_runner, namespace, server_version)
3839

graphdatascience/gnn/__init__.py

Whitespace-only changes.

graphdatascience/gnn/gnn_endpoints.py

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from .gnn_nc_runner import GNNNodeClassificationRunner
2+
from ..caller_base import CallerBase
3+
from ..error.illegal_attr_checker import IllegalAttrChecker
4+
from ..error.uncallable_namespace import UncallableNamespace
5+
6+
class GNNRunner(UncallableNamespace, IllegalAttrChecker):
7+
@property
8+
def nodeClassification(self) -> GNNNodeClassificationRunner:
9+
return GNNNodeClassificationRunner(self._query_runner, f"{self._namespace}.nodeClassification", self._server_version)
10+
11+
class GnnEndpoints(CallerBase):
12+
@property
13+
def gnn(self) -> GNNRunner:
14+
return GNNRunner(self._query_runner, f"{self._namespace}.gnn", self._server_version)
15+
16+
17+

graphdatascience/gnn/gnn_nc_runner.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Any, List
2+
3+
from ..error.illegal_attr_checker import IllegalAttrChecker
4+
from ..error.uncallable_namespace import UncallableNamespace
5+
import json
6+
7+
8+
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9+
def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str,
10+
target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
11+
configMap = {
12+
"featureProperties": feature_properties,
13+
"targetProperty": target_property,
14+
}
15+
node_properties = feature_properties + [target_property]
16+
if target_node_label:
17+
configMap["targetNodeLabel"] = target_node_label
18+
mlTrainingConfig = json.dumps(configMap)
19+
# TODO query avaiable node labels
20+
node_labels = ["Paper"] if not node_labels else node_labels
21+
self._query_runner.run_query(f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})")

graphdatascience/ignored_server_endpoints.py

+1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
"gds.alpha.pipeline.nodeRegression.predict.stream",
4848
"gds.alpha.pipeline.nodeRegression.selectFeatures",
4949
"gds.alpha.pipeline.nodeRegression.train",
50+
"gds.gnn.nc",
5051
"gds.similarity.cosine",
5152
"gds.similarity.euclidean",
5253
"gds.similarity.euclideanDistance",

0 commit comments

Comments
 (0)