diff --git a/notion_client/client.py b/notion_client/client.py index 4c2bb2f..8ba9475 100644 --- a/notion_client/client.py +++ b/notion_client/client.py @@ -1,10 +1,10 @@ """Synchronous and asynchronous clients for Notion's API.""" import json import logging -from abc import abstractclassmethod +from abc import abstractmethod from dataclasses import dataclass from types import TracebackType -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Generic, Optional, Type, Union import httpx from httpx import Request, Response @@ -24,7 +24,7 @@ is_api_error_code, ) from notion_client.logging import make_console_logger -from notion_client.typing import SyncAsync +from notion_client.typing import ClientType, ResponseType, SyncAsync @dataclass @@ -52,7 +52,7 @@ class ClientOptions: notion_version: str = "2022-06-28" -class BaseClient: +class BaseClient(Generic[ClientType]): def __init__( self, client: Union[httpx.Client, httpx.AsyncClient], @@ -71,12 +71,12 @@ def __init__( self._clients: List[Union[httpx.Client, httpx.AsyncClient]] = [] self.client = client - self.blocks = BlocksEndpoint(self) - self.databases = DatabasesEndpoint(self) - self.users = UsersEndpoint(self) - self.pages = PagesEndpoint(self) - self.search = SearchEndpoint(self) - self.comments = CommentsEndpoint(self) + self.blocks = BlocksEndpoint[ClientType](self) + self.databases = DatabasesEndpoint[ClientType](self) + self.users = UsersEndpoint[ClientType](self) + self.pages = PagesEndpoint[ClientType](self) + self.search = SearchEndpoint[ClientType](self) + self.comments = CommentsEndpoint[ClientType](self) @property def client(self) -> Union[httpx.Client, httpx.AsyncClient]: @@ -131,15 +131,16 @@ def _parse_response(self, response: Response) -> Any: return body - @abstractclassmethod + @abstractmethod def request( self, path: str, method: str, + cast_to: Type[ResponseType], query: Optional[Dict[Any, Any]] = None, body: Optional[Dict[Any, Any]] = None, auth: Optional[str] = None, - ) -> SyncAsync[Any]: + ) -> SyncAsync[ResponseType]: # noqa pass @@ -181,17 +182,18 @@ def request( self, path: str, method: str, + cast_to: Type[ResponseType], query: Optional[Dict[Any, Any]] = None, body: Optional[Dict[Any, Any]] = None, auth: Optional[str] = None, - ) -> Any: + ) -> ResponseType: """Send an HTTP request.""" request = self._build_request(method, path, query, body, auth) try: response = self.client.send(request) except httpx.TimeoutException: raise RequestTimeoutError() - return self._parse_response(response) + return cast_to(self._parse_response(response)) class AsyncClient(BaseClient): @@ -231,14 +233,15 @@ async def request( self, path: str, method: str, + cast_to: Type[ResponseType], query: Optional[Dict[Any, Any]] = None, body: Optional[Dict[Any, Any]] = None, auth: Optional[str] = None, - ) -> Any: + ) -> ResponseType: """Send an HTTP request asynchronously.""" request = self._build_request(method, path, query, body, auth) try: response = await self.client.send(request) except httpx.TimeoutException: raise RequestTimeoutError() - return self._parse_response(response) + return cast_to(self._parse_response(response)) diff --git a/notion_client/helpers.py b/notion_client/helpers.py index 7ade264..c93866f 100644 --- a/notion_client/helpers.py +++ b/notion_client/helpers.py @@ -1,10 +1,10 @@ """Utility functions for notion-sdk-py.""" -from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, List +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Generator, Mapping, List from urllib.parse import urlparse from uuid import UUID -def pick(base: Dict[Any, Any], *keys: str) -> Dict[Any, Any]: +def pick(base: Mapping[Any, Any], *keys: str) -> Dict[Any, Any]: """Return a dict composed of key value pairs for keys passed as args.""" result = {} for key in keys: diff --git a/notion_client/typing.py b/notion_client/typing.py index 97ebc80..9b14c81 100644 --- a/notion_client/typing.py +++ b/notion_client/typing.py @@ -1,5 +1,11 @@ """Custom type definitions for notion-sdk-py.""" -from typing import Awaitable, TypeVar, Union +from typing import TYPE_CHECKING, Any, Awaitable, Mapping, TypeVar, Union + +if TYPE_CHECKING: # pragma: no cover + from notion_client.client import BaseClient T = TypeVar("T") SyncAsync = Union[T, Awaitable[T]] + +ClientType = TypeVar("ClientType", bound=BaseClient) +ResponseType = TypeVar("ResponseType", bound=Mapping[Any, Any]) \ No newline at end of file