Skip to content

Change polling for progress logging to exponential backoff #873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* Improve error message if session is expired.
* Improve robustness of Arrow client against connection errors such as `FlightUnavailableError` and `FlightTimedOutError`.
* Return dedicated error class `SessionStatusError` if a session failed or expired.
* Reduce calls which check for progress updates. Previously every 0.5 seconds, now with exponential backoff capped at 10s.


## Other changes
Expand Down
9 changes: 8 additions & 1 deletion graphdatascience/query_runner/neo4j_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import neo4j
from pandas import DataFrame
from tenacity import wait_exponential

from ..call_parameters import CallParameters
from ..error.endpoint_suggester import generate_suggestive_error_message
Expand Down Expand Up @@ -114,7 +115,13 @@ def __init__(
self._server_version: Optional[ServerVersion] = None
self._show_progress = show_progress
self._progress_logger = QueryProgressLogger(
self.__run_cypher_simplified_for_query_progress_logger, self.server_version
run_cypher_func=self.__run_cypher_simplified_for_query_progress_logger,
server_version_func=self.server_version,
log_interval=wait_exponential(
max=10,
exp_base=1.5,
min=0.5,
),
)
self._instance_description = instance_description

Expand Down
62 changes: 38 additions & 24 deletions graphdatascience/query_runner/progress/query_progress_logger.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import warnings
from concurrent.futures import Future, ThreadPoolExecutor, wait
from typing import Any, Callable, NoReturn, Optional
from concurrent import futures
from typing import Any, Callable, NoReturn, Optional, Union

from pandas import DataFrame
from tenacity import Retrying, wait
from tqdm.auto import tqdm

from graphdatascience.retry_utils.retry_utils import retry_until_future

from ...server_version.server_version import ServerVersion
from .progress_provider import ProgressProvider, TaskWithProgress
from .query_progress_provider import CypherQueryFunction, QueryProgressProvider, ServerVersionFunction
Expand All @@ -18,16 +21,24 @@ def __init__(
self,
run_cypher_func: CypherQueryFunction,
server_version_func: ServerVersionFunction,
polling_interval: float = 0.5,
log_interval: Union[float, wait.wait_base] = 0.5,
initial_wait_time: float = 0.5,
progress_bar_options: dict[str, Any] = {},
):
self._run_cypher_func = run_cypher_func
self._server_version_func = server_version_func
self._static_progress_provider = StaticProgressProvider()
self._query_progress_provider = QueryProgressProvider(run_cypher_func, server_version_func)
self._polling_interval = polling_interval
self._progress_bar_options = progress_bar_options

self._initial_wait_time = initial_wait_time
if isinstance(log_interval, float):
self._wait_base: wait.wait_base = wait.wait_fixed(log_interval)
elif isinstance(log_interval, wait.wait_base):
self._wait_base = log_interval
else:
raise ValueError("polling interval must be a float or an instance of wait_base")

def run_with_progress_logging(
self, runnable: DataFrameProducer, job_id: str, database: Optional[str] = None
) -> DataFrame:
Expand All @@ -38,9 +49,10 @@ def run_with_progress_logging(
# Entries in the static progress store are already visible at this point.
progress_provider = self._select_progress_provider(job_id)

with ThreadPoolExecutor() as executor:
with futures.ThreadPoolExecutor() as executor:
future = executor.submit(runnable)

futures.wait([future], timeout=self._initial_wait_time) # wait for progress task to be available
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should run the progress logger also inside the threadpool.

Check every 0.5s if the main query still runs + cancel progress thread if its done.
right now we might delay the execution until 10s

self._log(future, job_id, progress_provider, database)

if future.exception():
Expand All @@ -56,29 +68,33 @@ def _select_progress_provider(self, job_id: str) -> ProgressProvider:
)

def _log(
self, future: Future[Any], job_id: str, progress_provider: ProgressProvider, database: Optional[str] = None
self,
future: futures.Future[Any],
job_id: str,
progress_provider: ProgressProvider,
database: Optional[str] = None,
) -> None:
pbar: Optional[tqdm[NoReturn]] = None
warn_if_failure = True

while wait([future], timeout=self._polling_interval).not_done:
try:
task_with_progress = progress_provider.root_task_with_progress(job_id, database)
if pbar is None:
pbar = self._init_pbar(task_with_progress)

self._update_pbar(pbar, task_with_progress)
except Exception as e:
# Do nothing if the procedure either:
# * has not started yet,
# * has already completed.
if f"No task with job id `{job_id}` was found" in str(e):
continue
else:
for attempt in Retrying(wait=self._wait_base, retry=retry_until_future(future)):
with attempt:
try:
task_with_progress = progress_provider.root_task_with_progress(job_id, database)
if pbar is None:
pbar = self._init_pbar(task_with_progress)

self._update_pbar(pbar, task_with_progress)
except Exception as e:
# Do nothing if the procedure either:
# * has not started yet,
# * has already completed.
if f"No task with job id `{job_id}` was found" in str(e):
continue

if warn_if_failure:
warnings.warn(f"Unable to get progress: {str(e)}", RuntimeWarning)
warn_if_failure = False
continue

if pbar is not None:
self._finish_pbar(future, pbar)
Expand All @@ -91,7 +107,6 @@ def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ign
total=None,
unit="",
desc=root_task_name,
maxinterval=self._polling_interval,
bar_format="{desc} [elapsed: {elapsed} {postfix}]",
**self._progress_bar_options,
)
Expand All @@ -100,7 +115,6 @@ def _init_pbar(self, task_with_progress: TaskWithProgress) -> tqdm: # type: ign
total=100,
unit="%",
desc=root_task_name,
maxinterval=self._polling_interval,
**self._progress_bar_options,
)

Expand All @@ -118,7 +132,7 @@ def _update_pbar(self, pbar: tqdm, task_with_progress: TaskWithProgress) -> None
else:
pbar.refresh()

def _finish_pbar(self, future: Future[Any], pbar: tqdm) -> None: # type: ignore
def _finish_pbar(self, future: futures.Future[Any], pbar: tqdm) -> None: # type: ignore
if future.exception():
pbar.set_postfix_str("status: FAILED", refresh=True)
return
Expand Down
10 changes: 8 additions & 2 deletions graphdatascience/query_runner/session_query_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from uuid import uuid4

from pandas import DataFrame
from tenacity import wait_exponential

from graphdatascience.query_runner.graph_constructor import GraphConstructor
from graphdatascience.query_runner.progress.query_progress_logger import QueryProgressLogger
Expand Down Expand Up @@ -41,8 +42,13 @@ def __init__(
self._resolved_protocol_version = ProtocolVersionResolver(db_query_runner).resolve()
self._show_progress = show_progress
self._progress_logger = QueryProgressLogger(
lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database),
self._gds_query_runner.server_version,
run_cypher_func=lambda query, database: self._gds_query_runner.run_cypher(query=query, database=database),
server_version_func=self._gds_query_runner.server_version,
log_interval=wait_exponential(
max=10,
exp_base=1.5,
min=0.5,
),
)

def run_cypher(
Expand Down
14 changes: 13 additions & 1 deletion graphdatascience/retry_utils/retry_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
import typing
from concurrent.futures import Future

from tenacity import RetryCallState
from tenacity import RetryCallState, retry_base


def before_log(
Expand All @@ -18,3 +19,14 @@ def log_it(retry_state: RetryCallState) -> None:
)

return log_it


class retry_until_future(retry_base):
def __init__(
self,
future: Future[typing.Any],
):
self._future = future

def __call__(self, retry_state: "RetryCallState") -> bool:
return not self._future.done()
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional

from pandas import DataFrame
from tenacity import wait

from graphdatascience import ServerVersion
from graphdatascience.query_runner.progress.progress_provider import TaskWithProgress
Expand All @@ -31,6 +32,35 @@ def fake_query() -> DataFrame:
assert df["result"][0] == 42


def test_log_interval() -> None:
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
assert "CALL gds.listProgress('foo')" in query
assert database == "database"

return DataFrame([{"progress": "n/a", "taskName": "Test task", "status": "RUNNING"}])

def fake_query() -> DataFrame:
time.sleep(0.5)
return DataFrame([{"result": 42}])

with StringIO() as pbarOutputStream:
qpl = QueryProgressLogger(
fake_run_cypher,
lambda: ServerVersion(3, 0, 0),
log_interval=wait.wait_fixed(0.1),
initial_wait_time=0,
progress_bar_options={"file": pbarOutputStream, "mininterval": 0},
)
df = qpl.run_with_progress_logging(fake_query, "foo", "database")

running_output = pbarOutputStream.getvalue().split("\r")[:-1]

assert len(running_output) > 4
assert len(running_output) < 15

assert df["result"][0] == 42


def test_skips_progress_logging_for_old_server_version() -> None:
def fake_run_cypher(query: str, database: Optional[str] = None) -> DataFrame:
print("Should not be called!")
Expand Down
2 changes: 1 addition & 1 deletion scripts/test_envs/gds_plugin_enterprise/compose.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
services:
neo4j:
image: neo4j:enterprise
image: neo4j:5-enterprise
volumes:
- ${HOME}/.gds_license:/licenses/.gds_license
environment:
Expand Down