Skip to content

Implement v1/put_node_properties #857

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

## New features

* Added the ability to upload additional node properties via the GdsArrowClient


## Bug fixes

Expand Down
94 changes: 94 additions & 0 deletions graphdatascience/query_runner/gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,87 @@ def upload_triplets(
"""
self._upload_data(graph_name, "triplet", triplet_data, batch_size, progress_callback)

def put_node_properties(
self,
graph_name: str,
database: str,
node_labels: Optional[Union[str, list[str]]] = None,
consecutive_ids: bool = False,
concurrency: Optional[int] = None,
) -> None:
"""
Starts a new node properties upload process on the GDS server.

Parameters
----------
graph_name : str
The name of the graph
database : str
The name of the database to which the graph belongs
node_labels : Optional[Union[str, List[str]]]
The name of the node labels to upload (default is None)
consecutive_ids : bool
Whether the node IDs in the input data are consecutive (default is False)
concurrency : Optional[int]
The number of threads used on the server side when uploading the properties
"""
config: dict[str, Any] = {
"name": graph_name,
"database_name": database,
"consecutive_ids": consecutive_ids,
}

if concurrency:
config["concurrency"] = concurrency
if node_labels is not None:
if isinstance(node_labels, str):
config["node_labels"] = [node_labels]
else:
config["node_labels"] = node_labels

self._send_action("PUT_NODE_PROPERTIES", config)

def upload_node_properties(
self,
graph_name: str,
node_data: Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], pandas.DataFrame],
batch_size: int = 10_000,
progress_callback: Callable[[int], None] = lambda x: None,
) -> None:
"""
Uploads node property data to the server.

Parameters
----------
graph_name : str
The name of the graph
node_data : Union[pyarrow.Table, Iterable[pyarrow.RecordBatch], DataFrame]
The node property data to upload
batch_size : int
The number of rows per batch
progress_callback : Callable[[int], None]
A callback function that is called with the number of rows uploaded after each batch
"""
self._upload_data(graph_name, "node_properties", node_data, batch_size, progress_callback)

def put_node_properties_done(self, graph_name: str) -> NodePropertiesLoadDoneResult:
"""
Notifies the server that all node property data has been sent.

Parameters
----------
graph_name : str
The name of the graph

Returns
-------
NodePropertiesLoadDoneResult
A result object containing the name of the graph and the number of properties loaded
"""
return NodePropertiesLoadDoneResult.from_json(
self._send_action("PUT_NODE_PROPERTIES_DONE", {"name": graph_name})
)

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# Remove the FlightClient as it isn't serializable
Expand Down Expand Up @@ -963,3 +1044,16 @@ class TripletLoadDoneResult:
@classmethod
def from_json(cls, json: dict[str, Any]) -> TripletLoadDoneResult:
return cls(name=json["name"], node_count=json["node_count"], relationship_count=json["relationship_count"])


@dataclass(repr=True, frozen=True)
class NodePropertiesLoadDoneResult:
name: str
node_count: int

@classmethod
def from_json(cls, json: dict[str, Any]) -> NodePropertiesLoadDoneResult:
return cls(
name=json["name"],
node_count=json["node_count"],
)
66 changes: 66 additions & 0 deletions graphdatascience/tests/unit/test_gds_arrow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
response = {"name": "g", "relationship_count": 42}
elif "TRIPLET_LOAD_DONE" in actionType:
response = {"name": "g", "node_count": 42, "relationship_count": 1337}
elif "PUT_NODE_PROPERTIES_DONE" in actionType:
response = {"name": "g", "node_count": 42}
elif "PUT_NODE_PROPERTIES" in actionType:
response = {"name": "g"}
else:
response = {}
return [json.dumps(response).encode("utf-8")]
Expand Down Expand Up @@ -102,6 +106,10 @@ def do_action(self, context: Any, action: ActionParam) -> list[bytes]:
response = {"name": "g", "relationship_count": 42}
elif "TRIPLET_LOAD_DONE" in actionType:
response = {"name": "g", "node_count": 42, "relationship_count": 1337}
elif "PUT_NODE_PROPERTIES" == actionType:
response = {"name": "g"}
elif "PUT_NODE_PROPERTIES_DONE" == actionType:
response = {"name": "g", "node_count": 42}
else:
response = {}
return [json.dumps(response).encode("utf-8")]
Expand Down Expand Up @@ -258,6 +266,64 @@ def test_triplet_load_done_action(flight_server: FlightServer, flight_client: Gd
assert_action(actions[0], "v1/TRIPLET_LOAD_DONE", {"name": "g"})


def test_put_node_properties_with_defaults(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
flight_client.put_node_properties("g", "DB")
actions = flight_server._actions
assert len(actions) == 1
assert_action(actions[0], "v1/PUT_NODE_PROPERTIES", {"name": "g", "database_name": "DB", "consecutive_ids": False})


def test_put_node_properties_with_single_label(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
flight_client.put_node_properties("g", "DB", "Label1")
actions = flight_server._actions
assert len(actions) == 1
assert_action(
actions[0],
"v1/PUT_NODE_PROPERTIES",
{"name": "g", "database_name": "DB", "consecutive_ids": False, "node_labels": ["Label1"]},
)


def test_put_node_properties_with_options(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
flight_client.put_node_properties("g", "DB", ["Label1", "Label2"], consecutive_ids=True, concurrency=42)
actions = flight_server._actions
assert len(actions) == 1
assert_action(
actions[0],
"v1/PUT_NODE_PROPERTIES",
{
"name": "g",
"database_name": "DB",
"consecutive_ids": True,
"concurrency": 42,
"node_labels": ["Label1", "Label2"],
},
)


def test_put_node_properties_with_flaky_server(
flaky_flight_server: FlakyFlightServer, flaky_flight_client: GdsArrowClient
) -> None:
flaky_flight_client.put_node_properties("g", "DB", "Label1")
actions = flaky_flight_server._actions
assert len(actions) == flaky_flight_server.expected_retries()
assert_action(
actions[0],
"v1/PUT_NODE_PROPERTIES",
{"name": "g", "database_name": "DB", "consecutive_ids": False, "node_labels": ["Label1"]},
)


def test_put_node_properties_done(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
response = flight_client.put_node_properties_done("g")
assert response.name == "g"
assert response.node_count == 42

actions = flight_server._actions
assert len(actions) == 1
assert_action(actions[0], "v1/PUT_NODE_PROPERTIES_DONE", {"name": "g"})


def test_abort_action(flight_server: FlightServer, flight_client: GdsArrowClient) -> None:
flight_client.abort("g")
actions = flight_server._actions
Expand Down