1
- import json
2
1
from typing import Any , List
3
2
4
3
from ..error .illegal_attr_checker import IllegalAttrChecker
5
4
from ..error .uncallable_namespace import UncallableNamespace
5
+ import json
6
6
7
7
8
8
class GNNNodeClassificationRunner (UncallableNamespace , IllegalAttrChecker ):
9
- def train (
10
- self ,
11
- graph_name : str ,
12
- model_name : str ,
13
- feature_properties : List [str ],
14
- target_property : str ,
15
- target_node_label : str = None ,
16
- node_labels : List [str ] = None ,
17
- ) -> "Series[Any]" :
9
+ def train (self , graph_name : str , model_name : str , feature_properties : List [str ], target_property : str ,
10
+ target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
18
11
configMap = {
19
12
"featureProperties" : feature_properties ,
20
13
"targetProperty" : target_property ,
@@ -27,18 +20,21 @@ def train(
27
20
mlTrainingConfig = json .dumps (configMap )
28
21
# TODO query available node labels
29
22
node_labels = ["Paper" ] if not node_labels else node_labels
23
+
24
+ # use arrow direclty here
30
25
self ._query_runner .run_query (
31
- f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { node_properties } }})"
32
- )
26
+ "CALL gds.upload.graph($config)" ,
27
+ params = {"config" :
28
+ {"graph_name" : graph_name ,
29
+ "mlTrainingConfig" : mlTrainingConfig ,
30
+ "modelName" : model_name ,
31
+ "nodeLabels" : node_labels ,
32
+ "nodeProperties" : node_properties
33
+ }
34
+ })
35
+
33
36
34
- def predict (
35
- self ,
36
- graph_name : str ,
37
- model_name : str ,
38
- feature_properties : List [str ],
39
- target_node_label : str = None ,
40
- node_labels : List [str ] = None ,
41
- ) -> "Series[Any]" :
37
+ def predict (self , graph_name : str , model_name : str , feature_properties : List [str ], target_node_label : str = None , node_labels : List [str ] = None ) -> "Series[Any]" :
42
38
configMap = {
43
39
"featureProperties" : feature_properties ,
44
40
"job_type" : "predict" ,
@@ -49,5 +45,4 @@ def predict(
49
45
# TODO query available node labels
50
46
node_labels = ["Paper" ] if not node_labels else node_labels
51
47
self ._query_runner .run_query (
52
- f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})"
53
- )
48
+ f"CALL gds.upload.graph('{ graph_name } ', {{mlTrainingConfig: '{ mlTrainingConfig } ', modelName: '{ model_name } ', nodeLabels: { node_labels } , nodeProperties: { feature_properties } }})" )
0 commit comments