Skip to content

Commit ae2af59

Browse files
brs96FlorentinDorazve
committed
Add all configs to CRD
Co-authored-by: Florentin Dörre <florentin.dorre@neo4j.com> Co-authored-by: Olga Razvenskaia <olga.razvenskaia@neo4j.com>
1 parent f1977a3 commit ae2af59

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

graphdatascience/gnn/gnn_nc_runner.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -12,23 +12,22 @@ def train(self, graph_name: str, model_name: str, feature_properties: List[str],
1212
"featureProperties": feature_properties,
1313
"targetProperty": target_property,
1414
"job_type": "train",
15+
"nodeProperties": feature_properties + [target_property]
1516
}
1617

17-
node_properties = feature_properties + [target_property]
1818
if target_node_label:
1919
configMap["targetNodeLabel"] = target_node_label
20+
if node_labels:
21+
configMap["nodeLabels"] = node_labels
22+
2023
mlTrainingConfig = json.dumps(configMap)
21-
# TODO query available node labels
22-
node_labels = ["Paper"] if not node_labels else node_labels
2324

2425
# token and uri will be injected by arrow_query_runner
2526
self._query_runner.run_query(
2627
f"CALL gds.upload.graph($graph_name, $config)",
2728
params={"graph_name": graph_name, "config": {
2829
"mlTrainingConfig": mlTrainingConfig,
29-
"modelName": model_name,
30-
"nodeLabels": node_labels,
31-
"nodeProperties": node_properties
30+
"modelName": model_name
3231
}}
3332
)
3433

0 commit comments

Comments
 (0)