Skip to content

Commit 4f605e4

Browse files
committed
WIP
1 parent e5e8ef2 commit 4f605e4

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

graphdatascience/gnn/gnn_nc_runner.py

+17-22
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
1-
import json
21
from typing import Any, List
32

43
from ..error.illegal_attr_checker import IllegalAttrChecker
54
from ..error.uncallable_namespace import UncallableNamespace
5+
import json
66

77

88
class GNNNodeClassificationRunner(UncallableNamespace, IllegalAttrChecker):
9-
def train(
10-
self,
11-
graph_name: str,
12-
model_name: str,
13-
feature_properties: List[str],
14-
target_property: str,
15-
target_node_label: str = None,
16-
node_labels: List[str] = None,
17-
) -> "Series[Any]":
9+
def train(self, graph_name: str, model_name: str, feature_properties: List[str], target_property: str,
10+
target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
1811
configMap = {
1912
"featureProperties": feature_properties,
2013
"targetProperty": target_property,
@@ -27,18 +20,21 @@ def train(
2720
mlTrainingConfig = json.dumps(configMap)
2821
# TODO query available node labels
2922
node_labels = ["Paper"] if not node_labels else node_labels
23+
24+
# use arrow direclty here
3025
self._query_runner.run_query(
31-
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {node_properties}}})"
32-
)
26+
"CALL gds.upload.graph($config)",
27+
params={"config":
28+
{"graph_name": graph_name,
29+
"mlTrainingConfig": mlTrainingConfig,
30+
"modelName": model_name,
31+
"nodeLabels": node_labels,
32+
"nodeProperties": node_properties
33+
}
34+
})
35+
3336

34-
def predict(
35-
self,
36-
graph_name: str,
37-
model_name: str,
38-
feature_properties: List[str],
39-
target_node_label: str = None,
40-
node_labels: List[str] = None,
41-
) -> "Series[Any]":
37+
def predict(self, graph_name: str, model_name: str, feature_properties: List[str], target_node_label: str = None, node_labels: List[str] = None) -> "Series[Any]":
4238
configMap = {
4339
"featureProperties": feature_properties,
4440
"job_type": "predict",
@@ -49,5 +45,4 @@ def predict(
4945
# TODO query available node labels
5046
node_labels = ["Paper"] if not node_labels else node_labels
5147
self._query_runner.run_query(
52-
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})"
53-
)
48+
f"CALL gds.upload.graph('{graph_name}', {{mlTrainingConfig: '{mlTrainingConfig}', modelName: '{model_name}', nodeLabels: {node_labels}, nodeProperties: {feature_properties}}})")

0 commit comments

Comments
 (0)