diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7d81fbd3..eb8f6197 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -32,10 +32,11 @@ jobs: token: ${{ secrets.GITHUB_TOKEN }} args: -p psqlpy --all-features -- -W clippy::all -W clippy::pedantic pytest: - name: ${{matrix.job.os}}-${{matrix.py_version}} + name: ${{matrix.job.os}}-${{matrix.py_version}}-${{ matrix.postgres_version }} strategy: matrix: py_version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + postgres_version: ["14", "15", "16", "17"] job: - os: ubuntu-latest ssl_cmd: sudo apt-get update && sudo apt-get install libssl-dev openssl @@ -43,13 +44,14 @@ jobs: steps: - uses: actions/checkout@v1 - name: Setup Postgres - uses: ./.github/actions/setup_postgres/ + id: postgres + uses: ikalnytskyi/action-setup-postgres@v7 with: username: postgres password: postgres database: psqlpy_test - ssl_on: "on" - id: postgres + ssl: true + postgres-version: ${{ matrix.postgres_version }} - uses: actions-rs/toolchain@v1 with: toolchain: stable @@ -64,4 +66,6 @@ jobs: - name: Install tox run: pip install "tox-gh>=1.2,<2" - name: Run pytest + env: + POSTGRES_CERT_FILE: "${{ steps.postgres.outputs.certificate-path }}" run: tox -v -c tox.ini diff --git a/Cargo.lock b/Cargo.lock index fee82b45..df4dc951 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -881,7 +881,7 @@ checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "postgres-derive" version = "0.4.5" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "heck", "proc-macro2", @@ -892,7 +892,7 @@ dependencies = [ [[package]] name = "postgres-openssl" version = "0.5.0" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "openssl", "tokio", @@ -903,7 +903,7 @@ dependencies = [ [[package]] name = "postgres-protocol" version = "0.6.7" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "base64", "byteorder", @@ -920,7 +920,7 @@ dependencies = [ [[package]] name = "postgres-types" version = "0.2.7" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "array-init", "bytes", @@ -1540,7 +1540,7 @@ dependencies = [ [[package]] name = "tokio-postgres" version = "0.7.11" -source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#e4e1047e701318b31c61330e428ebd8ade7ed1cb" +source = "git+https://github.com/chandr-andr/rust-postgres.git?branch=psqlpy#5780895bfa8a0b9142df225b65bc6e59f7dbee61" dependencies = [ "async-trait", "byteorder", diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 30426e5f..a9bfc4d3 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -85,7 +85,10 @@ def number_database_records() -> int: @pytest.fixture def ssl_cert_file() -> str: - return os.environ.get("POSTGRES_CERT_FILE", "./root.crt") + return os.environ.get( + "POSTGRES_CERT_FILE", + "/home/runner/work/_temp/pgdata/server.crt", + ) @pytest.fixture diff --git a/python/tests/test_value_converter.py b/python/tests/test_value_converter.py index 34361b22..ce2f05ed 100644 --- a/python/tests/test_value_converter.py +++ b/python/tests/test_value_converter.py @@ -139,14 +139,18 @@ async def test_as_class( ), ("BOOL", True, True), ("INT2", SmallInt(12), 12), + ("INT2", 12, 12), ("INT4", Integer(121231231), 121231231), + ("INT4", 121231231, 121231231), ("INT8", BigInt(99999999999999999), 99999999999999999), - ("MONEY", BigInt(99999999999999999), 99999999999999999), + ("INT8", 99999999999999999, 99999999999999999), ("MONEY", Money(99999999999999999), 99999999999999999), + ("MONEY", 99999999999999999, 99999999999999999), ("NUMERIC(5, 2)", Decimal("120.12"), Decimal("120.12")), - ("FLOAT8", 32.12329864501953, 32.12329864501953), ("FLOAT4", Float32(32.12329864501953), 32.12329864501953), + ("FLOAT4", 32.12329864501953, 32.12329864501953), ("FLOAT8", Float64(32.12329864501953), 32.12329864501953), + ("FLOAT8", 32.12329864501953, 32.12329864501953), ("DATE", now_datetime.date(), now_datetime.date()), ("TIME", now_datetime.time(), now_datetime.time()), ("TIMESTAMP", now_datetime, now_datetime), @@ -216,448 +220,24 @@ async def test_as_class( Path(((1.7, 2.8), (3.3, 2.5), (9, 9), (1.7, 2.8))), ((1.7, 2.8), (3.3, 2.5), (9.0, 9.0), (1.7, 2.8)), ), - ("LINE", Line([-2, 1, 2]), (-2.0, 1.0, 2.0)), - ("LINE", Line([1, -2, 3]), (1.0, -2.0, 3.0)), - ("LSEG", LineSegment({(1, 2), (9, 9)}), [(1.0, 2.0), (9.0, 9.0)]), - ("LSEG", LineSegment(((1.7, 2.8), (9, 9))), [(1.7, 2.8), (9.0, 9.0)]), - ( - "CIRCLE", - Circle((1.7, 2.8, 3)), - ((1.7, 2.8), 3.0), - ), - ( - "CIRCLE", - Circle([1, 2.8, 3]), - ((1.0, 2.8), 3.0), - ), - ( - "INTERVAL", - datetime.timedelta(days=100, microseconds=100), - datetime.timedelta(days=100, microseconds=100), - ), - ( - "VARCHAR ARRAY", - ["Some String", "Some String"], - ["Some String", "Some String"], - ), - ( - "TEXT ARRAY", - [Text("Some String"), Text("Some String")], - ["Some String", "Some String"], - ), - ("BOOL ARRAY", [True, False], [True, False]), - ("BOOL ARRAY", [[True], [False]], [[True], [False]]), - ("INT2 ARRAY", [SmallInt(12), SmallInt(100)], [12, 100]), - ("INT2 ARRAY", [[SmallInt(12)], [SmallInt(100)]], [[12], [100]]), - ("INT4 ARRAY", [Integer(121231231), Integer(121231231)], [121231231, 121231231]), - ( - "INT4 ARRAY", - [[Integer(121231231)], [Integer(121231231)]], - [[121231231], [121231231]], - ), - ( - "INT8 ARRAY", - [BigInt(99999999999999999), BigInt(99999999999999999)], - [99999999999999999, 99999999999999999], - ), - ( - "INT8 ARRAY", - [[BigInt(99999999999999999)], [BigInt(99999999999999999)]], - [[99999999999999999], [99999999999999999]], - ), - ( - "MONEY ARRAY", - [Money(99999999999999999), Money(99999999999999999)], - [99999999999999999, 99999999999999999], - ), - ( - "MONEY ARRAY", - [[Money(99999999999999999)], [Money(99999999999999999)]], - [[99999999999999999], [99999999999999999]], - ), - ( - "NUMERIC(5, 2) ARRAY", - [Decimal("121.23"), Decimal("188.99")], - [Decimal("121.23"), Decimal("188.99")], - ), - ( - "NUMERIC(5, 2) ARRAY", - [[Decimal("121.23")], [Decimal("188.99")]], - [[Decimal("121.23")], [Decimal("188.99")]], - ), - ( - "FLOAT8 ARRAY", - [32.12329864501953, 32.12329864501953], - [32.12329864501953, 32.12329864501953], - ), - ( - "FLOAT8 ARRAY", - [[32.12329864501953], [32.12329864501953]], - [[32.12329864501953], [32.12329864501953]], - ), - ( - "DATE ARRAY", - [now_datetime.date(), now_datetime.date()], - [now_datetime.date(), now_datetime.date()], - ), - ( - "DATE ARRAY", - [[now_datetime.date()], [now_datetime.date()]], - [[now_datetime.date()], [now_datetime.date()]], - ), - ( - "TIME ARRAY", - [now_datetime.time(), now_datetime.time()], - [now_datetime.time(), now_datetime.time()], - ), - ( - "TIME ARRAY", - [[now_datetime.time()], [now_datetime.time()]], - [[now_datetime.time()], [now_datetime.time()]], - ), - ("TIMESTAMP ARRAY", [now_datetime, now_datetime], [now_datetime, now_datetime]), - ( - "TIMESTAMP ARRAY", - [[now_datetime], [now_datetime]], - [[now_datetime], [now_datetime]], - ), - ( - "TIMESTAMPTZ ARRAY", - [now_datetime_with_tz, now_datetime_with_tz], - [now_datetime_with_tz, now_datetime_with_tz], - ), - ( - "TIMESTAMPTZ ARRAY", - [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], - [now_datetime_with_tz, now_datetime_with_tz_in_asia_jakarta], - ), - ( - "TIMESTAMPTZ ARRAY", - [[now_datetime_with_tz], [now_datetime_with_tz]], - [[now_datetime_with_tz], [now_datetime_with_tz]], - ), - ( - "UUID ARRAY", - [uuid_, uuid_], - [str(uuid_), str(uuid_)], - ), - ( - "UUID ARRAY", - [[uuid_], [uuid_]], - [[str(uuid_)], [str(uuid_)]], - ), - ( - "INET ARRAY", - [IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")], - [IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")], - ), - ( - "INET ARRAY", - [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], - [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], - ), - ( - "JSONB ARRAY", - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ), - ( - "JSONB ARRAY", - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - ), - ( - "JSONB ARRAY", - [ - JSONB([{"array": "json"}, {"one more": "test"}]), - JSONB([{"array": "json"}, {"one more": "test"}]), - ], - [ - [{"array": "json"}, {"one more": "test"}], - [{"array": "json"}, {"one more": "test"}], - ], - ), - ( - "JSONB ARRAY", - [ - JSONB([[{"array": "json"}], [{"one more": "test"}]]), - JSONB([[{"array": "json"}], [{"one more": "test"}]]), - ], - [ - [[{"array": "json"}], [{"one more": "test"}]], - [[{"array": "json"}], [{"one more": "test"}]], - ], - ), - ( - "JSON ARRAY", - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ), - ( - "JSON ARRAY", - [ - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - [ - JSON( - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ), - ], - ], - [ - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - [ - { - "test": ["something", 123, "here"], - "nested": ["JSON"], - }, - ], - ], - ), - ( - "JSON ARRAY", - [ - JSON([{"array": "json"}, {"one more": "test"}]), - JSON([{"array": "json"}, {"one more": "test"}]), - ], - [ - [{"array": "json"}, {"one more": "test"}], - [{"array": "json"}, {"one more": "test"}], - ], - ), - ( - "JSON ARRAY", - [ - JSON([[{"array": "json"}], [{"one more": "test"}]]), - JSON([[{"array": "json"}], [{"one more": "test"}]]), - ], - [ - [[{"array": "json"}], [{"one more": "test"}]], - [[{"array": "json"}], [{"one more": "test"}]], - ], - ), - ( - "POINT ARRAY", - [ - Point([1.5, 2]), - Point([2, 3]), - ], - [ - (1.5, 2.0), - (2.0, 3.0), - ], - ), - ( - "POINT ARRAY", - [ - [Point([1.5, 2])], - [Point([2, 3])], - ], - [ - [(1.5, 2.0)], - [(2.0, 3.0)], - ], - ), - ( - "BOX ARRAY", - [ - Box([3.5, 3, 9, 9]), - Box([8.5, 8, 9, 9]), - ], - [ - ((9.0, 9.0), (3.5, 3.0)), - ((9.0, 9.0), (8.5, 8.0)), - ], - ), - ( - "BOX ARRAY", - [ - [Box([3.5, 3, 9, 9])], - [Box([8.5, 8, 9, 9])], - ], - [ - [((9.0, 9.0), (3.5, 3.0))], - [((9.0, 9.0), (8.5, 8.0))], - ], - ), - ( - "PATH ARRAY", - [ - Path([(3.5, 3), (9, 9), (8, 8)]), - Path([(3.5, 3), (6, 6), (3.5, 3)]), - ], - [ - [(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)], - ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), - ], - ), - ( - "PATH ARRAY", - [ - [Path([(3.5, 3), (9, 9), (8, 8)])], - [Path([(3.5, 3), (6, 6), (3.5, 3)])], - ], - [ - [[(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)]], - [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], - ], - ), - ( - "LINE ARRAY", - [ - Line([-2, 1, 2]), - Line([1, -2, 3]), - ], - [ - (-2.0, 1.0, 2.0), - (1.0, -2.0, 3.0), - ], - ), - ( - "LINE ARRAY", - [ - [Line([-2, 1, 2])], - [Line([1, -2, 3])], - ], - [ - [(-2.0, 1.0, 2.0)], - [(1.0, -2.0, 3.0)], - ], - ), - ( - "LSEG ARRAY", - [ - LineSegment({(1, 2), (9, 9)}), - LineSegment([(5.6, 3.1), (4, 5)]), - ], - [ - [(1.0, 2.0), (9.0, 9.0)], - [(5.6, 3.1), (4.0, 5.0)], - ], - ), - ( - "LSEG ARRAY", - [ - [LineSegment({(1, 2), (9, 9)})], - [LineSegment([(5.6, 3.1), (4, 5)])], - ], - [ - [[(1.0, 2.0), (9.0, 9.0)]], - [[(5.6, 3.1), (4.0, 5.0)]], - ], - ), - ( - "CIRCLE ARRAY", - [ - Circle([1.7, 2.8, 3]), - Circle([5, 1.8, 10]), - ], - [ - ((1.7, 2.8), 3.0), - ((5.0, 1.8), 10.0), - ], - ), + ("LINE", Line([-2, 1, 2]), (-2.0, 1.0, 2.0)), + ("LINE", Line([1, -2, 3]), (1.0, -2.0, 3.0)), + ("LSEG", LineSegment({(1, 2), (9, 9)}), [(1.0, 2.0), (9.0, 9.0)]), + ("LSEG", LineSegment(((1.7, 2.8), (9, 9))), [(1.7, 2.8), (9.0, 9.0)]), ( - "CIRCLE ARRAY", - [ - [Circle([1.7, 2.8, 3])], - [Circle([5, 1.8, 10])], - ], - [ - [((1.7, 2.8), 3.0)], - [((5.0, 1.8), 10.0)], - ], + "CIRCLE", + Circle((1.7, 2.8, 3)), + ((1.7, 2.8), 3.0), ), ( - "INTERVAL ARRAY", - [ - datetime.timedelta(days=100, microseconds=100), - datetime.timedelta(days=100, microseconds=100), - ], - [ - datetime.timedelta(days=100, microseconds=100), - datetime.timedelta(days=100, microseconds=100), - ], + "CIRCLE", + Circle([1, 2.8, 3]), + ((1.0, 2.8), 3.0), + ), + ( + "INTERVAL", + datetime.timedelta(days=100, microseconds=100), + datetime.timedelta(days=100, microseconds=100), ), ], ) @@ -666,6 +246,37 @@ async def test_deserialization_simple_into_python( postgres_type: str, py_value: Any, expected_deserialized: Any, +) -> None: + """Test how types can cast from Python and to Python.""" + connection = await psql_pool.connection() + table_name = f"for_test{uuid.uuid4().hex}" + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + create_table_query = f""" + CREATE TABLE {table_name} (test_field {postgres_type}) + """ + insert_data_query = f""" + INSERT INTO {table_name} VALUES ($1) + """ + await connection.execute(querystring=create_table_query) + await connection.execute( + querystring=insert_data_query, + parameters=[py_value], + ) + + raw_result = await connection.execute( + querystring=f"SELECT test_field FROM {table_name}", + ) + + assert raw_result.result()[0]["test_field"] == expected_deserialized + + await connection.execute(f"DROP TABLE IF EXISTS {table_name}") + + +async def test_aboba( + psql_pool: ConnectionPool, + postgres_type: str = "INT2", + py_value: Any = 2, + expected_deserialized: Any = 2, ) -> None: """Test how types can cast from Python and to Python.""" connection = await psql_pool.connection() @@ -1124,32 +735,29 @@ async def test_empty_array( @pytest.mark.parametrize( ("postgres_type", "py_value", "expected_deserialized"), [ + ("VARCHAR ARRAY", [], []), ( "VARCHAR ARRAY", VarCharArray(["Some String", "Some String"]), ["Some String", "Some String"], ), - ( - "VARCHAR ARRAY", - VarCharArray([]), - [], - ), - ( - "TEXT ARRAY", - TextArray([]), - [], - ), + ("VARCHAR ARRAY", VarCharArray([]), []), + ("TEXT ARRAY", [], []), + ("TEXT ARRAY", TextArray([]), []), ( "TEXT ARRAY", TextArray([Text("Some String"), Text("Some String")]), ["Some String", "Some String"], ), + ("BOOL ARRAY", [], []), ("BOOL ARRAY", BoolArray([]), []), ("BOOL ARRAY", BoolArray([True, False]), [True, False]), ("BOOL ARRAY", BoolArray([[True], [False]]), [[True], [False]]), + ("INT2 ARRAY", [], []), ("INT2 ARRAY", Int16Array([]), []), ("INT2 ARRAY", Int16Array([SmallInt(12), SmallInt(100)]), [12, 100]), ("INT2 ARRAY", Int16Array([[SmallInt(12)], [SmallInt(100)]]), [[12], [100]]), + ("INT4 ARRAY", [], []), ( "INT4 ARRAY", Int32Array([Integer(121231231), Integer(121231231)]), @@ -1160,6 +768,7 @@ async def test_empty_array( Int32Array([[Integer(121231231)], [Integer(121231231)]]), [[121231231], [121231231]], ), + ("INT8 ARRAY", [], []), ( "INT8 ARRAY", Int64Array([BigInt(99999999999999999), BigInt(99999999999999999)]), @@ -1170,16 +779,13 @@ async def test_empty_array( Int64Array([[BigInt(99999999999999999)], [BigInt(99999999999999999)]]), [[99999999999999999], [99999999999999999]], ), + ("MONEY ARRAY", [], []), ( "MONEY ARRAY", MoneyArray([Money(99999999999999999), Money(99999999999999999)]), [99999999999999999, 99999999999999999], ), - ( - "MONEY ARRAY", - MoneyArray([[Money(99999999999999999)], [Money(99999999999999999)]]), - [[99999999999999999], [99999999999999999]], - ), + ("NUMERIC(5, 2) ARRAY", [], []), ( "NUMERIC(5, 2) ARRAY", NumericArray([Decimal("121.23"), Decimal("188.99")]), @@ -1190,6 +796,13 @@ async def test_empty_array( NumericArray([[Decimal("121.23")], [Decimal("188.99")]]), [[Decimal("121.23")], [Decimal("188.99")]], ), + ("FLOAT4 ARRAY", [], []), + ( + "FLOAT4 ARRAY", + [32.12329864501953, 32.12329864501953], + [32.12329864501953, 32.12329864501953], + ), + ("FLOAT8 ARRAY", [], []), ( "FLOAT8 ARRAY", Float64Array([32.12329864501953, 32.12329864501953]), @@ -1200,6 +813,7 @@ async def test_empty_array( Float64Array([[32.12329864501953], [32.12329864501953]]), [[32.12329864501953], [32.12329864501953]], ), + ("DATE ARRAY", [], []), ( "DATE ARRAY", DateArray([now_datetime.date(), now_datetime.date()]), @@ -1210,6 +824,7 @@ async def test_empty_array( DateArray([[now_datetime.date()], [now_datetime.date()]]), [[now_datetime.date()], [now_datetime.date()]], ), + ("TIME ARRAY", [], []), ( "TIME ARRAY", TimeArray([now_datetime.time(), now_datetime.time()]), @@ -1220,6 +835,7 @@ async def test_empty_array( TimeArray([[now_datetime.time()], [now_datetime.time()]]), [[now_datetime.time()], [now_datetime.time()]], ), + ("TIMESTAMP ARRAY", [], []), ( "TIMESTAMP ARRAY", DateTimeArray([now_datetime, now_datetime]), @@ -1230,6 +846,7 @@ async def test_empty_array( DateTimeArray([[now_datetime], [now_datetime]]), [[now_datetime], [now_datetime]], ), + ("TIMESTAMPTZ ARRAY", [], []), ( "TIMESTAMPTZ ARRAY", DateTimeTZArray([now_datetime_with_tz, now_datetime_with_tz]), @@ -1240,16 +857,13 @@ async def test_empty_array( DateTimeTZArray([[now_datetime_with_tz], [now_datetime_with_tz]]), [[now_datetime_with_tz], [now_datetime_with_tz]], ), - ( - "UUID ARRAY", - UUIDArray([uuid_, uuid_]), - [str(uuid_), str(uuid_)], - ), + ("UUID ARRAY", [], []), ( "UUID ARRAY", UUIDArray([[uuid_], [uuid_]]), [[str(uuid_)], [str(uuid_)]], ), + ("INET ARRAY", [], []), ( "INET ARRAY", IpAddressArray([IPv4Address("192.0.0.1"), IPv4Address("192.0.0.1")]), @@ -1260,6 +874,30 @@ async def test_empty_array( IpAddressArray([[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]]), [[IPv4Address("192.0.0.1")], [IPv4Address("192.0.0.1")]], ), + ("JSONB ARRAY", [], []), + ( + "JSONB ARRAY", + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), ( "JSONB ARRAY", JSONBArray( @@ -1344,6 +982,55 @@ async def test_empty_array( [[{"array": "json"}], [{"one more": "test"}]], ], ), + ("JSON ARRAY", [], []), + ( + "JSON ARRAY", + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), + ( + "JSON ARRAY", + JSONArray( + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), + [ + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + { + "test": ["something", 123, "here"], + "nested": ["JSON"], + }, + ], + ), ( "JSON ARRAY", JSONArray( @@ -1436,6 +1123,17 @@ async def test_empty_array( [[{"array": "json"}], [{"one more": "test"}]], ], ), + ( + "POINT ARRAY", + [ + Point([1.5, 2]), + Point([2, 3]), + ], + [ + (1.5, 2.0), + (2.0, 3.0), + ], + ), ( "POINT ARRAY", PointArray( @@ -1449,6 +1147,17 @@ async def test_empty_array( (2.0, 3.0), ], ), + ( + "POINT ARRAY", + [ + [Point([1.5, 2])], + [Point([2, 3])], + ], + [ + [(1.5, 2.0)], + [(2.0, 3.0)], + ], + ), ( "POINT ARRAY", PointArray( @@ -1462,6 +1171,18 @@ async def test_empty_array( [(2.0, 3.0)], ], ), + ("BOX ARRAY", [], []), + ( + "BOX ARRAY", + [ + Box([3.5, 3, 9, 9]), + Box([8.5, 8, 9, 9]), + ], + [ + ((9.0, 9.0), (3.5, 3.0)), + ((9.0, 9.0), (8.5, 8.0)), + ], + ), ( "BOX ARRAY", BoxArray( @@ -1488,6 +1209,18 @@ async def test_empty_array( [((9.0, 9.0), (8.5, 8.0))], ], ), + ("PATH ARRAY", [], []), + ( + "PATH ARRAY", + [ + Path([(3.5, 3), (9, 9), (8, 8)]), + Path([(3.5, 3), (6, 6), (3.5, 3)]), + ], + [ + [(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)], + ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), + ], + ), ( "PATH ARRAY", PathArray( @@ -1501,6 +1234,17 @@ async def test_empty_array( ((3.5, 3.0), (6.0, 6.0), (3.5, 3.0)), ], ), + ( + "PATH ARRAY", + [ + [Path([(3.5, 3), (9, 9), (8, 8)])], + [Path([(3.5, 3), (6, 6), (3.5, 3)])], + ], + [ + [[(3.5, 3.0), (9.0, 9.0), (8.0, 8.0)]], + [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], + ], + ), ( "PATH ARRAY", PathArray( @@ -1514,6 +1258,18 @@ async def test_empty_array( [((3.5, 3.0), (6.0, 6.0), (3.5, 3.0))], ], ), + ("LINE ARRAY", [], []), + ( + "LINE ARRAY", + [ + Line([-2, 1, 2]), + Line([1, -2, 3]), + ], + [ + (-2.0, 1.0, 2.0), + (1.0, -2.0, 3.0), + ], + ), ( "LINE ARRAY", LineArray( @@ -1527,6 +1283,17 @@ async def test_empty_array( (1.0, -2.0, 3.0), ], ), + ( + "LINE ARRAY", + [ + [Line([-2, 1, 2])], + [Line([1, -2, 3])], + ], + [ + [(-2.0, 1.0, 2.0)], + [(1.0, -2.0, 3.0)], + ], + ), ( "LINE ARRAY", LineArray( @@ -1540,6 +1307,18 @@ async def test_empty_array( [(1.0, -2.0, 3.0)], ], ), + ("LSEG ARRAY", [], []), + ( + "LSEG ARRAY", + [ + LineSegment({(1, 2), (9, 9)}), + LineSegment([(5.6, 3.1), (4, 5)]), + ], + [ + [(1.0, 2.0), (9.0, 9.0)], + [(5.6, 3.1), (4.0, 5.0)], + ], + ), ( "LSEG ARRAY", LsegArray( @@ -1553,6 +1332,17 @@ async def test_empty_array( [(5.6, 3.1), (4.0, 5.0)], ], ), + ( + "LSEG ARRAY", + [ + [LineSegment({(1, 2), (9, 9)})], + [LineSegment([(5.6, 3.1), (4, 5)])], + ], + [ + [[(1.0, 2.0), (9.0, 9.0)]], + [[(5.6, 3.1), (4.0, 5.0)]], + ], + ), ( "LSEG ARRAY", LsegArray( @@ -1566,6 +1356,18 @@ async def test_empty_array( [[(5.6, 3.1), (4.0, 5.0)]], ], ), + ("CIRCLE ARRAY", [], []), + ( + "CIRCLE ARRAY", + [ + Circle([1.7, 2.8, 3]), + Circle([5, 1.8, 10]), + ], + [ + ((1.7, 2.8), 3.0), + ((5.0, 1.8), 10.0), + ], + ), ( "CIRCLE ARRAY", CircleArray( @@ -1592,6 +1394,18 @@ async def test_empty_array( [((5.0, 1.8), 10.0)], ], ), + ("INTERVAL ARRAY", [], []), + ( + "INTERVAL ARRAY", + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], + [ + [datetime.timedelta(days=100, microseconds=100)], + [datetime.timedelta(days=100, microseconds=100)], + ], + ), ( "INTERVAL ARRAY", IntervalArray( diff --git a/src/driver/common_options.rs b/src/driver/common_options.rs index aebc5837..a76d37dd 100644 --- a/src/driver/common_options.rs +++ b/src/driver/common_options.rs @@ -64,7 +64,7 @@ impl TargetSessionAttrs { } #[pyclass(eq, eq_int)] -#[derive(Clone, Copy, PartialEq)] +#[derive(Clone, Copy, PartialEq, Debug)] pub enum SslMode { /// Do not use TLS. Disable, diff --git a/src/driver/connection.rs b/src/driver/connection.rs index 3c0595bb..d38b71f9 100644 --- a/src/driver/connection.rs +++ b/src/driver/connection.rs @@ -6,7 +6,7 @@ use std::{collections::HashSet, net::IpAddr, sync::Arc}; use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, runtime::tokio_runtime, @@ -25,6 +25,7 @@ pub struct Connection { db_client: Option>, db_pool: Option, pg_config: Arc, + prepare: bool, } impl Connection { @@ -33,11 +34,13 @@ impl Connection { db_client: Option>, db_pool: Option, pg_config: Arc, + prepare: bool, ) -> Self { Connection { db_client, db_pool, pg_config, + prepare, } } @@ -54,7 +57,7 @@ impl Connection { impl Default for Connection { fn default() -> Self { - Connection::new(None, None, Arc::new(Config::default())) + Connection::new(None, None, Arc::new(Config::default()), true) } } @@ -137,10 +140,14 @@ impl Connection { return self.pg_config.get_options(); } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { - let (db_client, db_pool) = pyo3::Python::with_gil(|gil| { + async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { + let (db_client, db_pool, prepare) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); - (self_.db_client.clone(), self_.db_pool.clone()) + ( + self_.db_client.clone(), + self_.db_pool.clone(), + self_.prepare, + ) }); if db_client.is_some() { @@ -155,7 +162,8 @@ impl Connection { .await??; pyo3::Python::with_gil(|gil| { let mut self_ = self_.borrow_mut(gil); - self_.db_client = Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))); + self_.db_client = + Some(Arc::new(PsqlpyConnection::PoolConn(db_connection, prepare))); }); return Ok(self_); } @@ -169,7 +177,7 @@ impl Connection { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { ( exception.is_none(gil), @@ -205,11 +213,12 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { - return db_client.execute(querystring, parameters, prepared).await; + let res = db_client.execute(querystring, parameters, prepared).await; + return res; } Err(RustPSQLDriverError::ConnectionClosedError) @@ -227,10 +236,7 @@ impl Connection { /// May return Err Result if: /// 1) Connection is closed. /// 2) Cannot execute querystring. - pub async fn execute_batch( - self_: pyo3::Py, - querystring: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn execute_batch(self_: pyo3::Py, querystring: String) -> PSQLPyResult<()> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -256,7 +262,7 @@ impl Connection { querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -282,7 +288,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -312,7 +318,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -339,7 +345,7 @@ impl Connection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); if let Some(db_client) = db_client { @@ -365,7 +371,7 @@ impl Connection { read_variant: Option, deferrable: Option, synchronous_commit: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Transaction::new( db_client.clone(), @@ -401,7 +407,7 @@ impl Connection { fetch_number: Option, scroll: Option, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), @@ -446,7 +452,7 @@ impl Connection { table_name: String, columns: Option>, schema_name: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); let mut table_name = quote_ident(&table_name); if let Some(schema_name) = schema_name { diff --git a/src/driver/connection_pool.rs b/src/driver/connection_pool.rs index 24780a6a..aa897012 100644 --- a/src/driver/connection_pool.rs +++ b/src/driver/connection_pool.rs @@ -1,10 +1,11 @@ use crate::runtime::tokio_runtime; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; +use postgres_types::Type; use pyo3::{pyclass, pyfunction, pymethods, Py, PyAny}; use std::sync::Arc; use tokio_postgres::Config; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{ common_options::{ConnRecyclingMethod, LoadBalanceHosts, SslMode, TargetSessionAttrs}, @@ -14,6 +15,23 @@ use super::{ utils::{build_connection_config, build_manager, build_tls}, }; +#[derive(Debug, Clone)] +pub struct ConnectionPoolConf { + pub ca_file: Option, + pub ssl_mode: Option, + pub prepare: bool, +} + +impl ConnectionPoolConf { + fn new(ca_file: Option, ssl_mode: Option, prepare: bool) -> Self { + Self { + ca_file, + ssl_mode, + prepare, + } + } +} + /// Make new connection pool. /// /// # Errors @@ -75,7 +93,7 @@ pub fn connect( ca_file: Option, max_db_pool_size: Option, conn_recycling_method: Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if let Some(max_db_pool_size) = max_db_pool_size { if max_db_pool_size < 2 { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( @@ -134,12 +152,9 @@ pub fn connect( let pool = db_pool_builder.build()?; - Ok(ConnectionPool { - pool: pool, - pg_config: Arc::new(pg_config), - ca_file: ca_file, - ssl_mode: ssl_mode, - }) + Ok(ConnectionPool::build( + pool, pg_config, ca_file, ssl_mode, None, + )) } #[pyclass] @@ -205,8 +220,7 @@ impl ConnectionPoolStatus { pub struct ConnectionPool { pool: Pool, pg_config: Arc, - ca_file: Option, - ssl_mode: Option, + pool_conf: ConnectionPoolConf, } impl ConnectionPool { @@ -216,14 +230,18 @@ impl ConnectionPool { pg_config: Config, ca_file: Option, ssl_mode: Option, + prepare: Option, ) -> Self { ConnectionPool { pool: pool, pg_config: Arc::new(pg_config), - ca_file: ca_file, - ssl_mode: ssl_mode, + pool_conf: ConnectionPoolConf::new(ca_file, ssl_mode, prepare.unwrap_or(true)), } } + + pub fn remove_prepared_stmt(&mut self, query: &str, types: &[Type]) { + self.pool.manager().statement_caches.remove(query, types); + } } #[pymethods] @@ -289,7 +307,7 @@ impl ConnectionPool { conn_recycling_method: Option, ssl_mode: Option, ca_file: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { connect( dsn, username, @@ -360,32 +378,42 @@ impl ConnectionPool { #[must_use] pub fn acquire(&self) -> Connection { - Connection::new(None, Some(self.pool.clone()), self.pg_config.clone()) + Connection::new( + None, + Some(self.pool.clone()), + self.pg_config.clone(), + self.pool_conf.prepare, + ) } #[must_use] #[allow(clippy::needless_pass_by_value)] pub fn listener(self_: pyo3::Py) -> Listener { - let (pg_config, ca_file, ssl_mode) = pyo3::Python::with_gil(|gil| { + let (pg_config, pool_conf) = pyo3::Python::with_gil(|gil| { let b_gil = self_.borrow(gil); - ( - b_gil.pg_config.clone(), - b_gil.ca_file.clone(), - b_gil.ssl_mode, - ) + (b_gil.pg_config.clone(), b_gil.pool_conf.clone()) }); - Listener::new(pg_config, ca_file, ssl_mode) + Listener::new( + pg_config, + pool_conf.ca_file, + pool_conf.ssl_mode, + pool_conf.prepare, + ) } /// Return new single connection. /// /// # Errors /// May return Err Result if cannot get new connection from the pool. - pub async fn connection(self_: pyo3::Py) -> RustPSQLDriverPyResult { - let (db_pool, pg_config) = pyo3::Python::with_gil(|gil| { + pub async fn connection(self_: pyo3::Py) -> PSQLPyResult { + let (db_pool, pg_config, pool_conf) = pyo3::Python::with_gil(|gil| { let slf = self_.borrow(gil); - (slf.pool.clone(), slf.pg_config.clone()) + ( + slf.pool.clone(), + slf.pg_config.clone(), + slf.pool_conf.clone(), + ) }); let db_connection = tokio_runtime() .spawn(async move { @@ -394,9 +422,13 @@ impl ConnectionPool { .await??; Ok(Connection::new( - Some(Arc::new(PsqlpyConnection::PoolConn(db_connection))), + Some(Arc::new(PsqlpyConnection::PoolConn( + db_connection, + pool_conf.prepare, + ))), None, pg_config, + pool_conf.prepare, )) } diff --git a/src/driver/connection_pool_builder.rs b/src/driver/connection_pool_builder.rs index e0610942..0cd7432b 100644 --- a/src/driver/connection_pool_builder.rs +++ b/src/driver/connection_pool_builder.rs @@ -3,7 +3,7 @@ use std::{net::IpAddr, time::Duration}; use deadpool_postgres::{Manager, ManagerConfig, Pool, RecyclingMethod}; use pyo3::{pyclass, pymethods, Py, Python}; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::{ common_options, @@ -18,6 +18,7 @@ pub struct ConnectionPoolBuilder { conn_recycling_method: Option, ca_file: Option, ssl_mode: Option, + prepare: Option, } #[pymethods] @@ -31,6 +32,7 @@ impl ConnectionPoolBuilder { conn_recycling_method: None, ca_file: None, ssl_mode: None, + prepare: None, } } @@ -38,7 +40,7 @@ impl ConnectionPoolBuilder { /// /// # Errors /// May return error if cannot build new connection pool. - fn build(&self) -> RustPSQLDriverPyResult { + fn build(&self) -> PSQLPyResult { let mgr_config: ManagerConfig; if let Some(conn_recycling_method) = self.conn_recycling_method.as_ref() { mgr_config = ManagerConfig { @@ -68,6 +70,7 @@ impl ConnectionPoolBuilder { self.config.clone(), self.ca_file.clone(), self.ssl_mode, + self.prepare, )) } @@ -84,7 +87,7 @@ impl ConnectionPoolBuilder { /// /// # Error /// If size more than 2. - fn max_pool_size(self_: Py, pool_size: usize) -> RustPSQLDriverPyResult> { + fn max_pool_size(self_: Py, pool_size: usize) -> PSQLPyResult> { if pool_size < 2 { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( "Maximum database pool size must be more than 1".into(), diff --git a/src/driver/cursor.rs b/src/driver/cursor.rs index f391d1c1..1f435ef5 100644 --- a/src/driver/cursor.rs +++ b/src/driver/cursor.rs @@ -6,7 +6,7 @@ use pyo3::{ use tokio_postgres::{config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::PSQLDriverPyQueryResult, runtime::rustdriver_future, }; @@ -23,9 +23,9 @@ trait CursorObjectTrait { querystring: &str, prepared: &Option, parameters: &Option>, - ) -> RustPSQLDriverPyResult<()>; + ) -> PSQLPyResult<()>; - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()>; + async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()>; } impl CursorObjectTrait for PsqlpyConnection { @@ -43,7 +43,7 @@ impl CursorObjectTrait for PsqlpyConnection { querystring: &str, prepared: &Option, parameters: &Option>, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let mut cursor_init_query = format!("DECLARE {cursor_name}"); if let Some(scroll) = scroll { if *scroll { @@ -70,7 +70,7 @@ impl CursorObjectTrait for PsqlpyConnection { /// /// # Errors /// May return Err Result if cannot execute querystring. - async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> RustPSQLDriverPyResult<()> { + async fn cursor_close(&self, closed: &bool, cursor_name: &str) -> PSQLPyResult<()> { if *closed { return Err(RustPSQLDriverError::CursorCloseError( "Cursor is already closed".into(), @@ -232,7 +232,7 @@ impl Cursor { slf } - async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { let (db_transaction, cursor_name, scroll, querystring, prepared, parameters) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); @@ -265,7 +265,7 @@ impl Cursor { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (db_transaction, closed, cursor_name, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { let self_ = slf.borrow(gil); @@ -307,7 +307,7 @@ impl Cursor { /// we didn't find any solution how to implement it without /// # Errors /// May return Err Result if can't execute querystring. - fn __anext__(&self) -> RustPSQLDriverPyResult> { + fn __anext__(&self) -> PSQLPyResult> { let db_transaction = self.db_transaction.clone(); let fetch_number = self.fetch_number; let cursor_name = self.cursor_name.clone(); @@ -343,7 +343,7 @@ impl Cursor { /// # Errors /// May return Err Result /// if cannot execute querystring for cursor declaration. - pub async fn start(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn start(&mut self) -> PSQLPyResult<()> { let db_transaction_arc = self.db_transaction.clone(); if let Some(db_transaction) = db_transaction_arc { @@ -370,7 +370,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn close(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn close(&mut self) -> PSQLPyResult<()> { let db_transaction_arc = self.db_transaction.clone(); if let Some(db_transaction) = db_transaction_arc { @@ -396,7 +396,7 @@ impl Cursor { pub async fn fetch<'a>( slf: Py, fetch_number: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, inner_fetch_number, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); ( @@ -437,7 +437,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_next<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_next<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -464,7 +464,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_prior<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_prior<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -491,7 +491,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_first<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_first<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -518,7 +518,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_last<'a>(slf: Py) -> RustPSQLDriverPyResult { + pub async fn fetch_last<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -548,7 +548,7 @@ impl Cursor { pub async fn fetch_absolute<'a>( slf: Py, absolute_number: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -582,7 +582,7 @@ impl Cursor { pub async fn fetch_relative<'a>( slf: Py, relative_number: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -613,9 +613,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_forward_all<'a>( - slf: Py, - ) -> RustPSQLDriverPyResult { + pub async fn fetch_forward_all<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -649,7 +647,7 @@ impl Cursor { pub async fn fetch_backward<'a>( slf: Py, backward_count: i64, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) @@ -680,9 +678,7 @@ impl Cursor { /// /// # Errors /// May return Err Result if cannot execute query. - pub async fn fetch_backward_all<'a>( - slf: Py, - ) -> RustPSQLDriverPyResult { + pub async fn fetch_backward_all<'a>(slf: Py) -> PSQLPyResult { let (db_transaction, cursor_name) = Python::with_gil(|gil| { let self_ = slf.borrow(gil); (self_.db_transaction.clone(), self_.cursor_name.clone()) diff --git a/src/driver/inner_connection.rs b/src/driver/inner_connection.rs index ae060baa..d8acc4d8 100644 --- a/src/driver/inner_connection.rs +++ b/src/driver/inner_connection.rs @@ -1,23 +1,20 @@ use bytes::Buf; use deadpool_postgres::Object; -use postgres_types::ToSql; +use postgres_types::{ToSql, Type}; use pyo3::{Py, PyAny, Python}; use std::vec; use tokio_postgres::{Client, CopyInSink, Row, Statement, ToStatement}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, - value_converter::{ - consts::QueryParameter, - funcs::{from_python::convert_parameters_and_qs, to_python::postgres_to_py}, - models::dto::PythonDTO, - }, + statement::{statement::PsqlpyStatement, statement_builder::StatementBuilder}, + value_converter::to_python::postgres_to_py, }; #[allow(clippy::module_name_repetitions)] pub enum PsqlpyConnection { - PoolConn(Object), + PoolConn(Object, bool), SingleConn(Client), } @@ -26,14 +23,39 @@ impl PsqlpyConnection { /// /// # Errors /// May return Err if cannot prepare statement. - pub async fn prepare_cached(&self, query: &str) -> RustPSQLDriverPyResult { + pub async fn prepare(&self, query: &str, prepared: bool) -> PSQLPyResult { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.prepare_cached(query).await?), + PsqlpyConnection::PoolConn(pconn, _) => { + if prepared { + return Ok(pconn.prepare_cached(query).await?); + } else { + let prepared = pconn.prepare(query).await?; + self.drop_prepared(&prepared).await?; + return Ok(prepared); + } + } PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.prepare(query).await?), } } - /// Prepare cached statement. + /// Delete prepared statement. + /// + /// # Errors + /// May return Err if cannot prepare statement. + pub async fn drop_prepared(&self, stmt: &Statement) -> PSQLPyResult<()> { + let deallocate_query = format!("DEALLOCATE PREPARE {}", stmt.name()); + match self { + PsqlpyConnection::PoolConn(pconn, _) => { + let res = Ok(pconn.batch_execute(&deallocate_query).await?); + res + } + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.batch_execute(&deallocate_query).await?) + } + } + } + + /// Execute statement with parameters. /// /// # Errors /// May return Err if cannot execute statement. @@ -41,25 +63,46 @@ impl PsqlpyConnection { &self, statement: &T, params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult> + ) -> PSQLPyResult> where T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.query(statement, params).await?), + PsqlpyConnection::PoolConn(pconn, _) => { + return Ok(pconn.query(statement, params).await?) + } PsqlpyConnection::SingleConn(sconn) => { return Ok(sconn.query(statement, params).await?) } } } - /// Prepare cached statement. + /// Execute statement with parameters. + /// + /// # Errors + /// May return Err if cannot execute statement. + pub async fn query_typed( + &self, + statement: &str, + params: &[(&(dyn ToSql + Sync), Type)], + ) -> PSQLPyResult> { + match self { + PsqlpyConnection::PoolConn(pconn, _) => { + return Ok(pconn.query_typed(statement, params).await?) + } + PsqlpyConnection::SingleConn(sconn) => { + return Ok(sconn.query_typed(statement, params).await?) + } + } + } + + /// Batch execute statement. /// /// # Errors /// May return Err if cannot execute statement. - pub async fn batch_execute(&self, query: &str) -> RustPSQLDriverPyResult<()> { + pub async fn batch_execute(&self, query: &str) -> PSQLPyResult<()> { match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.batch_execute(query).await?), + PsqlpyConnection::PoolConn(pconn, _) => return Ok(pconn.batch_execute(query).await?), PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.batch_execute(query).await?), } } @@ -67,17 +110,32 @@ impl PsqlpyConnection { /// Prepare cached statement. /// /// # Errors + /// May return Err if cannot execute copy data. + pub async fn copy_in(&self, statement: &T) -> PSQLPyResult> + where + T: ?Sized + ToStatement, + U: Buf + 'static + Send, + { + match self { + PsqlpyConnection::PoolConn(pconn, _) => return Ok(pconn.copy_in(statement).await?), + PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), + } + } + + /// Executes a statement which returns a single row, returning it. + /// + /// # Errors /// May return Err if cannot execute statement. pub async fn query_one( &self, statement: &T, params: &[&(dyn ToSql + Sync)], - ) -> RustPSQLDriverPyResult + ) -> PSQLPyResult where T: ?Sized + ToStatement, { match self { - PsqlpyConnection::PoolConn(pconn) => { + PsqlpyConnection::PoolConn(pconn, _) => { return Ok(pconn.query_one(statement, params).await?) } PsqlpyConnection::SingleConn(sconn) => { @@ -91,38 +149,31 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); let result = if prepared { self.query( - &self.prepare_cached(&qs).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - boxed_params, + &self + .prepare(&statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, + &statement.params(), ) .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query(statement.raw_query(), &statement.params()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -133,38 +184,26 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); - let result = if prepared { - self.query( - &self.prepare_cached(&qs).await.map_err(|err| { + let result = match prepared { + true => self + .query(statement.statement_query()?, &statement.params()) + .await + .map_err(|err| { RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement, error - {err}" )) })?, - boxed_params, - ) - .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? - } else { - self.query(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + false => self + .query_typed(statement.raw_query(), &statement.params_typed()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))?, }; Ok(PSQLDriverPyQueryResult::new(result)) @@ -172,41 +211,40 @@ impl PsqlpyConnection { pub async fn execute_many( &self, - mut querystring: String, + querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { - let prepared = prepared.unwrap_or(true); - - let mut params: Vec> = vec![]; + ) -> PSQLPyResult<()> { + let mut statements: Vec = vec![]; if let Some(parameters) = parameters { for vec_of_py_any in parameters { // TODO: Fix multiple qs creation - let (qs, parsed_params) = - convert_parameters_and_qs(querystring.clone(), Some(vec_of_py_any))?; - querystring = qs; - params.push(parsed_params); + let statement = + StatementBuilder::new(querystring.clone(), Some(vec_of_py_any), self, prepared) + .build() + .await?; + + statements.push(statement); } } - for param in params { - let boxed_params = ¶m - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); + for statement in statements { let querystring_result = if prepared { - let prepared_stmt = &self.prepare_cached(&querystring).await; + let prepared_stmt = &self.prepare(&statement.raw_query(), true).await; if let Err(error) = prepared_stmt { return Err(RustPSQLDriverError::ConnectionExecuteError(format!( "Cannot prepare statement in execute_many, operation rolled back {error}", ))); } - self.query(&self.prepare_cached(&querystring).await?, boxed_params) - .await + self.query( + &self.prepare(&statement.raw_query(), true).await?, + &statement.params(), + ) + .await } else { - self.query(&querystring, boxed_params).await + self.query(statement.raw_query(), &statement.params()).await }; if let Err(error) = querystring_result { @@ -224,38 +262,31 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { - let prepared = prepared.unwrap_or(true); - - let (qs, params) = convert_parameters_and_qs(querystring, parameters)?; + ) -> PSQLPyResult { + let statement = StatementBuilder::new(querystring, parameters, self, prepared) + .build() + .await?; - let boxed_params = ¶ms - .iter() - .map(|param| param as &QueryParameter) - .collect::>() - .into_boxed_slice(); + let prepared = prepared.unwrap_or(true); let result = if prepared { self.query_one( - &self.prepare_cached(&qs).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot prepare statement, error - {err}" - )) - })?, - boxed_params, + &self + .prepare(&statement.raw_query(), true) + .await + .map_err(|err| { + RustPSQLDriverError::ConnectionExecuteError(format!( + "Cannot prepare statement, error - {err}" + )) + })?, + &statement.params(), ) .await - .map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? } else { - self.query_one(&qs, boxed_params).await.map_err(|err| { - RustPSQLDriverError::ConnectionExecuteError(format!( - "Cannot execute statement, error - {err}" - )) - })? + self.query_one(statement.raw_query(), &statement.params()) + .await + .map_err(|err| RustPSQLDriverError::ConnectionExecuteError(format!("{err}")))? }; return Ok(result); @@ -266,7 +297,7 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let result = self .fetch_row_raw(querystring, parameters, prepared) .await?; @@ -279,7 +310,7 @@ impl PsqlpyConnection { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let result = self .fetch_row_raw(querystring, parameters, prepared) .await?; @@ -289,19 +320,4 @@ impl PsqlpyConnection { None => Ok(gil.None()), }); } - - /// Prepare cached statement. - /// - /// # Errors - /// May return Err if cannot execute copy data. - pub async fn copy_in(&self, statement: &T) -> RustPSQLDriverPyResult> - where - T: ?Sized + ToStatement, - U: Buf + 'static + Send, - { - match self { - PsqlpyConnection::PoolConn(pconn) => return Ok(pconn.copy_in(statement).await?), - PsqlpyConnection::SingleConn(sconn) => return Ok(sconn.copy_in(statement).await?), - } - } } diff --git a/src/driver/listener/core.rs b/src/driver/listener/core.rs index 83aa9b3e..4a9580af 100644 --- a/src/driver/listener/core.rs +++ b/src/driver/listener/core.rs @@ -18,7 +18,7 @@ use crate::{ inner_connection::PsqlpyConnection, utils::{build_tls, is_coroutine_function, ConfiguredTLS}, }, - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, runtime::{rustdriver_future, tokio_runtime}, }; @@ -42,14 +42,19 @@ pub struct Listener { impl Listener { #[must_use] - pub fn new(pg_config: Arc, ca_file: Option, ssl_mode: Option) -> Self { + pub fn new( + pg_config: Arc, + ca_file: Option, + ssl_mode: Option, + prepare: bool, + ) -> Self { Listener { pg_config: pg_config.clone(), ca_file, ssl_mode, channel_callbacks: Arc::default(), listen_abort_handler: Option::default(), - connection: Connection::new(None, None, pg_config.clone()), + connection: Connection::new(None, None, pg_config.clone(), prepare), receiver: Option::default(), listen_query: Arc::default(), is_listened: Arc::new(RwLock::new(false)), @@ -89,7 +94,7 @@ impl Listener { } #[allow(clippy::unused_async)] - async fn __aenter__<'a>(slf: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(slf: Py) -> PSQLPyResult> { Ok(slf) } @@ -99,7 +104,7 @@ impl Listener { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (client, is_exception_none, py_err) = pyo3::Python::with_gil(|gil| { let self_ = slf.borrow(gil); ( @@ -126,7 +131,7 @@ impl Listener { Err(RustPSQLDriverError::ListenerClosedError) } - fn __anext__(&self) -> RustPSQLDriverPyResult>> { + fn __anext__(&self) -> PSQLPyResult>> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::ListenerStartError( "Listener doesn't have underlying client, please call startup".into(), @@ -167,7 +172,7 @@ impl Listener { } #[getter] - fn connection(&self) -> RustPSQLDriverPyResult { + fn connection(&self) -> PSQLPyResult { if !self.is_started { return Err(RustPSQLDriverError::ListenerStartError( "Listener isn't started up".into(), @@ -177,7 +182,7 @@ impl Listener { Ok(self.connection.clone()) } - async fn startup(&mut self) -> RustPSQLDriverPyResult<()> { + async fn startup(&mut self) -> PSQLPyResult<()> { if self.is_started { return Err(RustPSQLDriverError::ListenerStartError( "Listener is already started".into(), @@ -222,6 +227,7 @@ impl Listener { Some(Arc::new(PsqlpyConnection::SingleConn(client))), None, self.pg_config.clone(), + false, ); self.is_started = true; @@ -238,11 +244,7 @@ impl Listener { } #[pyo3(signature = (channel, callback))] - async fn add_callback( - &mut self, - channel: String, - callback: Py, - ) -> RustPSQLDriverPyResult<()> { + async fn add_callback(&mut self, channel: String, callback: Py) -> PSQLPyResult<()> { if !is_coroutine_function(callback.clone())? { return Err(RustPSQLDriverError::ListenerCallbackError); } @@ -279,7 +281,7 @@ impl Listener { self.update_listen_query().await; } - fn listen(&mut self) -> RustPSQLDriverPyResult<()> { + fn listen(&mut self) -> PSQLPyResult<()> { let Some(client) = self.connection.db_client() else { return Err(RustPSQLDriverError::ListenerStartError( "Cannot start listening, underlying connection doesn't exist".into(), @@ -343,7 +345,7 @@ async fn dispatch_callback( listener_callback: &ListenerCallback, listener_notification: ListenerNotification, connection: Connection, -) -> RustPSQLDriverPyResult<()> { +) -> PSQLPyResult<()> { listener_callback .call(listener_notification.clone(), connection) .await?; @@ -355,7 +357,7 @@ async fn execute_listen( is_listened: &Arc>, listen_query: &Arc>, client: &Arc, -) -> RustPSQLDriverPyResult<()> { +) -> PSQLPyResult<()> { let mut write_is_listened = is_listened.write().await; if !write_is_listened.eq(&true) { @@ -371,7 +373,7 @@ async fn execute_listen( Ok(()) } -fn process_message(message: Option) -> RustPSQLDriverPyResult { +fn process_message(message: Option) -> PSQLPyResult { let Some(async_message) = message else { return Err(RustPSQLDriverError::ListenerError("Wow".into())); }; diff --git a/src/driver/listener/structs.rs b/src/driver/listener/structs.rs index 4d53a408..6236547e 100644 --- a/src/driver/listener/structs.rs +++ b/src/driver/listener/structs.rs @@ -6,7 +6,7 @@ use tokio_postgres::Notification; use crate::{ driver::connection::Connection, - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, runtime::tokio_runtime, }; @@ -126,7 +126,7 @@ impl ListenerCallback { &self, lister_notification: ListenerNotification, connection: Connection, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (callback, task_locals) = Python::with_gil(|py| (self.callback.clone(), self.task_locals.clone_ref(py))); diff --git a/src/driver/transaction.rs b/src/driver/transaction.rs index 2fa38ba5..60f054b7 100644 --- a/src/driver/transaction.rs +++ b/src/driver/transaction.rs @@ -9,7 +9,7 @@ use pyo3::{ use tokio_postgres::{binary_copy::BinaryCopyInWriter, config::Host, Config}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, format_helpers::quote_ident, query_result::{PSQLDriverPyQueryResult, PSQLDriverSinglePyQueryResult}, }; @@ -29,9 +29,9 @@ pub trait TransactionObjectTrait { read_variant: Option, defferable: Option, synchronous_commit: Option, - ) -> impl std::future::Future> + Send; - fn commit(&self) -> impl std::future::Future> + Send; - fn rollback(&self) -> impl std::future::Future> + Send; + ) -> impl std::future::Future> + Send; + fn commit(&self) -> impl std::future::Future> + Send; + fn rollback(&self) -> impl std::future::Future> + Send; } impl TransactionObjectTrait for PsqlpyConnection { @@ -41,7 +41,7 @@ impl TransactionObjectTrait for PsqlpyConnection { read_variant: Option, deferrable: Option, synchronous_commit: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let mut querystring = "START TRANSACTION".to_string(); if let Some(level) = isolation_level { @@ -84,7 +84,7 @@ impl TransactionObjectTrait for PsqlpyConnection { Ok(()) } - async fn commit(&self) -> RustPSQLDriverPyResult<()> { + async fn commit(&self) -> PSQLPyResult<()> { self.batch_execute("COMMIT;").await.map_err(|err| { RustPSQLDriverError::TransactionCommitError(format!( "Cannot execute COMMIT statement, error - {err}" @@ -92,7 +92,7 @@ impl TransactionObjectTrait for PsqlpyConnection { })?; Ok(()) } - async fn rollback(&self) -> RustPSQLDriverPyResult<()> { + async fn rollback(&self) -> PSQLPyResult<()> { self.batch_execute("ROLLBACK;").await.map_err(|err| { RustPSQLDriverError::TransactionRollbackError(format!( "Cannot execute ROLLBACK statement, error - {err}" @@ -144,7 +144,7 @@ impl Transaction { } } - fn check_is_transaction_ready(&self) -> RustPSQLDriverPyResult<()> { + fn check_is_transaction_ready(&self) -> PSQLPyResult<()> { if !self.is_started { return Err(RustPSQLDriverError::TransactionBeginError( "Transaction is not started, please call begin() on transaction".into(), @@ -242,7 +242,7 @@ impl Transaction { self_ } - async fn __aenter__<'a>(self_: Py) -> RustPSQLDriverPyResult> { + async fn __aenter__<'a>(self_: Py) -> PSQLPyResult> { let ( is_started, is_done, @@ -302,7 +302,7 @@ impl Transaction { _exception_type: Py, exception: Py, _traceback: Py, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_transaction_ready, is_exception_none, py_err, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -345,7 +345,7 @@ impl Transaction { /// 1) Transaction is not started /// 2) Transaction is done /// 3) Cannot execute `COMMIT` command - pub async fn commit(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn commit(&mut self) -> PSQLPyResult<()> { self.check_is_transaction_ready()?; if let Some(db_client) = &self.db_client { db_client.commit().await?; @@ -366,7 +366,7 @@ impl Transaction { /// 1) Transaction is not started /// 2) Transaction is done /// 3) Can not execute ROLLBACK command - pub async fn rollback(&mut self) -> RustPSQLDriverPyResult<()> { + pub async fn rollback(&mut self) -> PSQLPyResult<()> { self.check_is_transaction_ready()?; if let Some(db_client) = &self.db_client { db_client.rollback().await?; @@ -394,7 +394,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -419,7 +419,7 @@ impl Transaction { /// May return Err Result if: /// 1) Transaction is closed. /// 2) Cannot execute querystring. - pub async fn execute_batch(self_: Py, querystring: String) -> RustPSQLDriverPyResult<()> { + pub async fn execute_batch(self_: Py, querystring: String) -> PSQLPyResult<()> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -448,7 +448,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -481,7 +481,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -511,7 +511,7 @@ impl Transaction { querystring: String, parameters: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -539,7 +539,7 @@ impl Transaction { querystring: String, parameters: Option>>, prepared: Option, - ) -> RustPSQLDriverPyResult<()> { + ) -> PSQLPyResult<()> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); (self_.check_is_transaction_ready(), self_.db_client.clone()) @@ -564,7 +564,7 @@ impl Transaction { /// 1) Transaction is already started. /// 2) Transaction is done. /// 3) Cannot execute `BEGIN` command. - pub async fn begin(self_: Py) -> RustPSQLDriverPyResult<()> { + pub async fn begin(self_: Py) -> PSQLPyResult<()> { let ( is_started, is_done, @@ -629,10 +629,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name is exists /// 4) Can not execute SAVEPOINT command - pub async fn create_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn create_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -673,10 +670,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name doesn't exists /// 4) Can not execute RELEASE SAVEPOINT command - pub async fn release_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn release_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -717,10 +711,7 @@ impl Transaction { /// 2) Transaction is done /// 3) Specified savepoint name doesn't exist /// 4) Can not execute ROLLBACK TO SAVEPOINT command - pub async fn rollback_savepoint( - self_: Py, - savepoint_name: String, - ) -> RustPSQLDriverPyResult<()> { + pub async fn rollback_savepoint(self_: Py, savepoint_name: String) -> PSQLPyResult<()> { let (is_transaction_ready, is_savepoint_name_exists, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -765,7 +756,7 @@ impl Transaction { self_: Py, queries: Option>, prepared: Option, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let (is_transaction_ready, db_client) = pyo3::Python::with_gil(|gil| { let self_ = self_.borrow(gil); @@ -827,7 +818,7 @@ impl Transaction { fetch_number: Option, scroll: Option, prepared: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { if let Some(db_client) = &self.db_client { return Ok(Cursor::new( db_client.clone(), @@ -857,7 +848,7 @@ impl Transaction { table_name: String, columns: Option>, schema_name: Option, - ) -> RustPSQLDriverPyResult { + ) -> PSQLPyResult { let db_client = pyo3::Python::with_gil(|gil| self_.borrow(gil).db_client.clone()); let mut table_name = quote_ident(&table_name); if let Some(schema_name) = schema_name { diff --git a/src/driver/utils.rs b/src/driver/utils.rs index 3d0d59e3..15ca4123 100644 --- a/src/driver/utils.rs +++ b/src/driver/utils.rs @@ -6,7 +6,7 @@ use postgres_openssl::MakeTlsConnector; use pyo3::{types::PyAnyMethods, Py, PyAny, Python}; use tokio_postgres::{Config, NoTls}; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; use super::common_options::{self, LoadBalanceHosts, SslMode, TargetSessionAttrs}; @@ -40,7 +40,7 @@ pub fn build_connection_config( keepalives_retries: Option, load_balance_hosts: Option, ssl_mode: Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if tcp_user_timeout_nanosec.is_some() && tcp_user_timeout_sec.is_none() { return Err(RustPSQLDriverError::ConnectionPoolConfigurationError( "tcp_user_timeout_nanosec must be used with tcp_user_timeout_sec param.".into(), @@ -182,7 +182,7 @@ pub enum ConfiguredTLS { pub fn build_tls( ca_file: &Option, ssl_mode: &Option, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if let Some(ca_file) = ca_file { let mut builder = SslConnector::builder(SslMethod::tls())?; builder.set_ca_file(ca_file)?; @@ -224,7 +224,7 @@ pub fn build_manager( /// May return Err Result if cannot /// 1) import inspect /// 2) extract boolean -pub fn is_coroutine_function(function: Py) -> RustPSQLDriverPyResult { +pub fn is_coroutine_function(function: Py) -> PSQLPyResult { let is_coroutine_function: bool = Python::with_gil(|py| { let inspect = py.import("inspect")?; diff --git a/src/exceptions/rust_errors.rs b/src/exceptions/rust_errors.rs index 48af50cb..94b89fa0 100644 --- a/src/exceptions/rust_errors.rs +++ b/src/exceptions/rust_errors.rs @@ -14,7 +14,7 @@ use super::python_errors::{ TransactionRollbackError, TransactionSavepointError, UUIDValueConvertError, }; -pub type RustPSQLDriverPyResult = Result; +pub type PSQLPyResult = Result; #[derive(Error, Debug)] pub enum RustPSQLDriverError { @@ -29,9 +29,9 @@ pub enum RustPSQLDriverError { ConnectionPoolExecuteError(String), // Connection Errors - #[error("Connection error: {0}.")] + #[error("{0}")] BaseConnectionError(String), - #[error("Connection execute error: {0}.")] + #[error("{0}")] ConnectionExecuteError(String), #[error("Underlying connection is returned to the pool")] ConnectionClosedError, @@ -76,12 +76,12 @@ pub enum RustPSQLDriverError { #[error("Can't convert value from driver to python type: {0}")] RustToPyValueConversionError(String), - #[error("Can't convert value from python to rust type: {0}")] + #[error("{0}")] PyToRustValueConversionError(String), #[error("Python exception: {0}.")] RustPyError(#[from] pyo3::PyErr), - #[error("Database engine exception: {0}.")] + #[error("{0}")] RustDriverError(#[from] deadpool_postgres::tokio_postgres::Error), #[error("Database engine pool exception: {0}")] RustConnectionPoolError(#[from] deadpool_postgres::PoolError), diff --git a/src/extra_types.rs b/src/extra_types.rs index 1e8d22b4..b3411eae 100644 --- a/src/extra_types.rs +++ b/src/extra_types.rs @@ -2,6 +2,7 @@ use std::str::FromStr; use geo_types::{Line as RustLineSegment, LineString, Point as RustPoint, Rect as RustRect}; use macaddr::{MacAddr6 as RustMacAddr6, MacAddr8 as RustMacAddr8}; +use postgres_types::Type; use pyo3::{ pyclass, pymethods, types::{PyModule, PyModuleMethods}, @@ -10,16 +11,20 @@ use pyo3::{ use serde_json::Value; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ additional_types::{Circle as RustCircle, Line as RustLine}, - funcs::from_python::{ - build_flat_geo_coords, build_geo_coords, py_sequence_into_postgres_array, - }, - models::{dto::PythonDTO, serde_value::build_serde_value}, + dto::enums::PythonDTO, + from_python::{build_flat_geo_coords, build_geo_coords, py_sequence_into_postgres_array}, + models::serde_value::build_serde_value, }, }; +pub struct PythonArray; +pub struct PythonDecimal; +pub struct PythonUUID; +pub struct PythonEnum; + #[pyclass] #[derive(Clone)] pub struct PgVector(Vec); @@ -34,7 +39,7 @@ impl PgVector { impl PgVector { #[must_use] - pub fn inner_value(self) -> Vec { + pub fn inner(self) -> Vec { self.0 } } @@ -49,7 +54,7 @@ macro_rules! build_python_type { impl $st_name { #[must_use] - pub fn retrieve_value(&self) -> $rust_type { + pub fn inner(&self) -> $rust_type { self.inner_value } } @@ -135,7 +140,12 @@ macro_rules! build_json_py_type { impl $st_name { #[must_use] - pub fn inner(&self) -> &$rust_type { + pub fn inner(&self) -> $rust_type { + self.inner.clone() + } + + #[must_use] + pub fn inner_ref(&self) -> &$rust_type { &self.inner } } @@ -144,7 +154,7 @@ macro_rules! build_json_py_type { impl $st_name { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_class(value: Py) -> RustPSQLDriverPyResult { + pub fn new_class(value: &Bound<'_, PyAny>) -> PSQLPyResult { Ok(Self { inner: build_serde_value(value)?, }) @@ -180,7 +190,7 @@ macro_rules! build_macaddr_type { impl $st_name { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_class(value: &str) -> RustPSQLDriverPyResult { + pub fn new_class(value: &str) -> PSQLPyResult { Ok(Self { inner: <$rust_type>::from_str(value)?, }) @@ -223,7 +233,7 @@ macro_rules! build_geo_type { impl $st_name { #[must_use] - pub fn retrieve_value(&self) -> $rust_type { + pub fn inner(&self) -> $rust_type { self.inner.clone() } } @@ -241,7 +251,7 @@ build_geo_type!(Circle, RustCircle); impl Point { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_point(value: Py) -> RustPSQLDriverPyResult { + pub fn new_point(value: Py) -> PSQLPyResult { let point_coords = build_geo_coords(value, Some(1))?; Ok(Self { @@ -254,7 +264,7 @@ impl Point { impl Box { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_box(value: Py) -> RustPSQLDriverPyResult { + pub fn new_box(value: Py) -> PSQLPyResult { let box_coords = build_geo_coords(value, Some(2))?; Ok(Self { @@ -267,7 +277,7 @@ impl Box { impl Path { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_path(value: Py) -> RustPSQLDriverPyResult { + pub fn new_path(value: Py) -> PSQLPyResult { let path_coords = build_geo_coords(value, None)?; Ok(Self { @@ -280,7 +290,7 @@ impl Path { impl Line { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_line(value: Py) -> RustPSQLDriverPyResult { + pub fn new_line(value: Py) -> PSQLPyResult { let line_coords = build_flat_geo_coords(value, Some(3))?; Ok(Self { @@ -293,7 +303,7 @@ impl Line { impl LineSegment { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_line_segment(value: Py) -> RustPSQLDriverPyResult { + pub fn new_line_segment(value: Py) -> PSQLPyResult { let line_segment_coords = build_geo_coords(value, Some(2))?; Ok(Self { @@ -306,7 +316,7 @@ impl LineSegment { impl Circle { #[new] #[allow(clippy::missing_errors_doc)] - pub fn new_circle(value: Py) -> RustPSQLDriverPyResult { + pub fn new_circle(value: Py) -> PSQLPyResult { let circle_coords = build_flat_geo_coords(value, Some(3))?; Ok(Self { inner: RustCircle::new(circle_coords[0], circle_coords[1], circle_coords[2]), @@ -315,7 +325,7 @@ impl Circle { } macro_rules! build_array_type { - ($st_name:ident, $kind:path) => { + ($st_name:ident, $kind:path, $elem_kind:path) => { #[pyclass] #[derive(Clone)] pub struct $st_name { @@ -337,11 +347,15 @@ macro_rules! build_array_type { self.inner.clone() } + pub fn element_type() -> Type { + $elem_kind + } + /// Convert incoming sequence from python to internal `PythonDTO`. /// /// # Errors /// May return Err Result if cannot convert sequence to array. - pub fn _convert_to_python_dto(&self) -> RustPSQLDriverPyResult { + pub fn _convert_to_python_dto(&self, elem_type: &Type) -> PSQLPyResult { return Python::with_gil(|gil| { let binding = &self.inner; let bound_inner = Ok::<&pyo3::Bound<'_, pyo3::PyAny>, RustPSQLDriverError>( @@ -349,6 +363,7 @@ macro_rules! build_array_type { )?; Ok::($kind(py_sequence_into_postgres_array( bound_inner, + elem_type, )?)) }); } @@ -356,33 +371,37 @@ macro_rules! build_array_type { }; } -build_array_type!(BoolArray, PythonDTO::PyBoolArray); -build_array_type!(UUIDArray, PythonDTO::PyUuidArray); -build_array_type!(VarCharArray, PythonDTO::PyVarCharArray); -build_array_type!(TextArray, PythonDTO::PyTextArray); -build_array_type!(Int16Array, PythonDTO::PyInt16Array); -build_array_type!(Int32Array, PythonDTO::PyInt32Array); -build_array_type!(Int64Array, PythonDTO::PyInt64Array); -build_array_type!(Float32Array, PythonDTO::PyFloat32Array); -build_array_type!(Float64Array, PythonDTO::PyFloat64Array); -build_array_type!(MoneyArray, PythonDTO::PyMoneyArray); -build_array_type!(IpAddressArray, PythonDTO::PyIpAddressArray); -build_array_type!(JSONBArray, PythonDTO::PyJSONBArray); -build_array_type!(JSONArray, PythonDTO::PyJSONArray); -build_array_type!(DateArray, PythonDTO::PyDateArray); -build_array_type!(TimeArray, PythonDTO::PyTimeArray); -build_array_type!(DateTimeArray, PythonDTO::PyDateTimeArray); -build_array_type!(DateTimeTZArray, PythonDTO::PyDateTimeTZArray); -build_array_type!(MacAddr6Array, PythonDTO::PyMacAddr6Array); -build_array_type!(MacAddr8Array, PythonDTO::PyMacAddr8Array); -build_array_type!(NumericArray, PythonDTO::PyNumericArray); -build_array_type!(PointArray, PythonDTO::PyPointArray); -build_array_type!(BoxArray, PythonDTO::PyBoxArray); -build_array_type!(PathArray, PythonDTO::PyPathArray); -build_array_type!(LineArray, PythonDTO::PyLineArray); -build_array_type!(LsegArray, PythonDTO::PyLsegArray); -build_array_type!(CircleArray, PythonDTO::PyCircleArray); -build_array_type!(IntervalArray, PythonDTO::PyIntervalArray); +build_array_type!(BoolArray, PythonDTO::PyBoolArray, Type::BOOL); +build_array_type!(UUIDArray, PythonDTO::PyUuidArray, Type::UUID); +build_array_type!(VarCharArray, PythonDTO::PyVarCharArray, Type::VARCHAR); +build_array_type!(TextArray, PythonDTO::PyTextArray, Type::TEXT); +build_array_type!(Int16Array, PythonDTO::PyInt16Array, Type::INT2); +build_array_type!(Int32Array, PythonDTO::PyInt32Array, Type::INT4); +build_array_type!(Int64Array, PythonDTO::PyInt64Array, Type::INT8); +build_array_type!(Float32Array, PythonDTO::PyFloat32Array, Type::FLOAT4); +build_array_type!(Float64Array, PythonDTO::PyFloat64Array, Type::FLOAT8); +build_array_type!(MoneyArray, PythonDTO::PyMoneyArray, Type::MONEY); +build_array_type!(IpAddressArray, PythonDTO::PyIpAddressArray, Type::INET); +build_array_type!(JSONBArray, PythonDTO::PyJSONBArray, Type::JSONB); +build_array_type!(JSONArray, PythonDTO::PyJSONArray, Type::JSON); +build_array_type!(DateArray, PythonDTO::PyDateArray, Type::DATE); +build_array_type!(TimeArray, PythonDTO::PyTimeArray, Type::TIME); +build_array_type!(DateTimeArray, PythonDTO::PyDateTimeArray, Type::TIMESTAMP); +build_array_type!( + DateTimeTZArray, + PythonDTO::PyDateTimeTZArray, + Type::TIMESTAMPTZ +); +build_array_type!(MacAddr6Array, PythonDTO::PyMacAddr6Array, Type::MACADDR); +build_array_type!(MacAddr8Array, PythonDTO::PyMacAddr8Array, Type::MACADDR8); +build_array_type!(NumericArray, PythonDTO::PyNumericArray, Type::NUMERIC); +build_array_type!(PointArray, PythonDTO::PyPointArray, Type::POINT); +build_array_type!(BoxArray, PythonDTO::PyBoxArray, Type::BOX); +build_array_type!(PathArray, PythonDTO::PyPathArray, Type::PATH); +build_array_type!(LineArray, PythonDTO::PyLineArray, Type::LINE); +build_array_type!(LsegArray, PythonDTO::PyLsegArray, Type::LSEG); +build_array_type!(CircleArray, PythonDTO::PyCircleArray, Type::CIRCLE); +build_array_type!(IntervalArray, PythonDTO::PyIntervalArray, Type::INTERVAL); #[allow(clippy::module_name_repetitions)] #[allow(clippy::missing_errors_doc)] diff --git a/src/lib.rs b/src/lib.rs index e0e1fe11..6be59c75 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ pub mod format_helpers; pub mod query_result; pub mod row_factories; pub mod runtime; +pub mod statement; pub mod value_converter; use common::add_module; diff --git a/src/query_result.rs b/src/query_result.rs index da393f89..cda02a8b 100644 --- a/src/query_result.rs +++ b/src/query_result.rs @@ -1,10 +1,7 @@ use pyo3::{prelude::*, pyclass, pymethods, types::PyDict, Py, PyAny, Python, ToPyObject}; use tokio_postgres::Row; -use crate::{ - exceptions::rust_errors::RustPSQLDriverPyResult, - value_converter::funcs::to_python::postgres_to_py, -}; +use crate::{exceptions::rust_errors::PSQLPyResult, value_converter::to_python::postgres_to_py}; /// Convert postgres `Row` into Python Dict. /// @@ -18,7 +15,7 @@ fn row_to_dict<'a>( py: Python<'a>, postgres_row: &'a Row, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let python_dict = PyDict::new(py); for (column_idx, column) in postgres_row.columns().iter().enumerate() { let python_type = postgres_to_py(py, postgres_row, column, column_idx, custom_decoders)?; @@ -30,7 +27,7 @@ fn row_to_dict<'a>( #[pyclass(name = "QueryResult")] #[allow(clippy::module_name_repetitions)] pub struct PSQLDriverPyQueryResult { - inner: Vec, + pub inner: Vec, } impl PSQLDriverPyQueryResult { @@ -65,7 +62,7 @@ impl PSQLDriverPyQueryResult { &self, py: Python<'_>, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let mut result: Vec> = vec![]; for row in &self.inner { result.push(row_to_dict(py, row, &custom_decoders)?); @@ -80,11 +77,7 @@ impl PSQLDriverPyQueryResult { /// May return Err Result if can not convert /// postgres type to python or create new Python class. #[allow(clippy::needless_pass_by_value)] - pub fn as_class<'a>( - &'a self, - py: Python<'a>, - as_class: Py, - ) -> RustPSQLDriverPyResult> { + pub fn as_class<'a>(&'a self, py: Python<'a>, as_class: Py) -> PSQLPyResult> { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &None)?; @@ -108,7 +101,7 @@ impl PSQLDriverPyQueryResult { py: Python<'a>, row_factory: Py, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let mut res: Vec> = vec![]; for row in &self.inner { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, row, &custom_decoders)?; @@ -155,7 +148,7 @@ impl PSQLDriverSinglePyQueryResult { &self, py: Python<'_>, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { Ok(row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py)) } @@ -167,11 +160,7 @@ impl PSQLDriverSinglePyQueryResult { /// postgres type to python, can not create new Python class /// or there are no results. #[allow(clippy::needless_pass_by_value)] - pub fn as_class<'a>( - &'a self, - py: Python<'a>, - as_class: Py, - ) -> RustPSQLDriverPyResult> { + pub fn as_class<'a>(&'a self, py: Python<'a>, as_class: Py) -> PSQLPyResult> { let pydict: pyo3::Bound<'_, PyDict> = row_to_dict(py, &self.inner, &None)?; Ok(as_class.call(py, (), Some(&pydict))?) } @@ -189,7 +178,7 @@ impl PSQLDriverSinglePyQueryResult { py: Python<'a>, row_factory: Py, custom_decoders: Option>, - ) -> RustPSQLDriverPyResult> { + ) -> PSQLPyResult> { let pydict = row_to_dict(py, &self.inner, &custom_decoders)?.to_object(py); Ok(row_factory.call(py, (pydict,), None)?) } diff --git a/src/row_factories.rs b/src/row_factories.rs index 3a2d2de8..e867df0a 100644 --- a/src/row_factories.rs +++ b/src/row_factories.rs @@ -4,11 +4,11 @@ use pyo3::{ wrap_pyfunction, Bound, Py, PyAny, PyResult, Python, ToPyObject, }; -use crate::exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}; +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; #[pyfunction] #[allow(clippy::needless_pass_by_value)] -fn tuple_row(py: Python<'_>, dict_: Py) -> RustPSQLDriverPyResult> { +fn tuple_row(py: Python<'_>, dict_: Py) -> PSQLPyResult> { let dict_ = dict_.downcast_bound::(py).map_err(|_| { RustPSQLDriverError::RustToPyValueConversionError( "as_tuple accepts only dict as a parameter".into(), @@ -29,7 +29,7 @@ impl class_row { } #[allow(clippy::needless_pass_by_value)] - fn __call__(&self, py: Python<'_>, dict_: Py) -> RustPSQLDriverPyResult> { + fn __call__(&self, py: Python<'_>, dict_: Py) -> PSQLPyResult> { let dict_ = dict_.downcast_bound::(py).map_err(|_| { RustPSQLDriverError::RustToPyValueConversionError( "as_tuple accepts only dict as a parameter".into(), diff --git a/src/runtime.rs b/src/runtime.rs index 05889d99..ee6281de 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -1,7 +1,7 @@ use futures_util::Future; use pyo3::{IntoPyObject, Py, PyAny, Python}; -use crate::exceptions::rust_errors::RustPSQLDriverPyResult; +use crate::exceptions::rust_errors::PSQLPyResult; #[allow(clippy::missing_panics_doc)] #[allow(clippy::module_name_repetitions)] @@ -18,9 +18,9 @@ pub fn tokio_runtime() -> &'static tokio::runtime::Runtime { /// # Errors /// /// May return Err Result if future acts incorrect. -pub fn rustdriver_future(py: Python<'_>, future: F) -> RustPSQLDriverPyResult> +pub fn rustdriver_future(py: Python<'_>, future: F) -> PSQLPyResult> where - F: Future> + Send + 'static, + F: Future> + Send + 'static, T: for<'py> IntoPyObject<'py>, { let res = diff --git a/src/statement/cache.rs b/src/statement/cache.rs new file mode 100644 index 00000000..7d78898d --- /dev/null +++ b/src/statement/cache.rs @@ -0,0 +1,50 @@ +use std::collections::HashMap; + +use once_cell::sync::Lazy; +use postgres_types::Type; +use tokio::sync::RwLock; +use tokio_postgres::Statement; + +use super::{query::QueryString, utils::hash_str}; + +#[derive(Default)] +pub(crate) struct StatementsCache(HashMap); + +impl StatementsCache { + pub fn add_cache(&mut self, query: &QueryString, inner_stmt: &Statement) { + self.0 + .insert(query.hash(), StatementCacheInfo::new(query, inner_stmt)); + } + + pub fn get_cache(&self, querystring: &String) -> Option { + let qs_hash = hash_str(&querystring); + + if let Some(cache_info) = self.0.get(&qs_hash) { + return Some(cache_info.clone()); + } + + None + } +} + +#[derive(Clone)] +pub(crate) struct StatementCacheInfo { + pub(crate) query: QueryString, + pub(crate) inner_stmt: Statement, +} + +impl StatementCacheInfo { + fn new(query: &QueryString, inner_stmt: &Statement) -> Self { + return Self { + query: query.clone(), + inner_stmt: inner_stmt.clone(), + }; + } + + pub(crate) fn types(&self) -> Vec { + self.inner_stmt.params().to_vec() + } +} + +pub(crate) static STMTS_CACHE: Lazy> = + Lazy::new(|| RwLock::new(Default::default())); diff --git a/src/statement/mod.rs b/src/statement/mod.rs new file mode 100644 index 00000000..c894b9a8 --- /dev/null +++ b/src/statement/mod.rs @@ -0,0 +1,6 @@ +pub mod cache; +pub mod parameters; +pub mod query; +pub mod statement; +pub mod statement_builder; +pub mod utils; diff --git a/src/statement/parameters.rs b/src/statement/parameters.rs new file mode 100644 index 00000000..0a2d9105 --- /dev/null +++ b/src/statement/parameters.rs @@ -0,0 +1,257 @@ +use std::iter::zip; + +use postgres_types::{ToSql, Type}; +use pyo3::{ + conversion::FromPyObjectBound, + types::{PyAnyMethods, PyMapping}, + Py, PyObject, PyTypeCheck, Python, +}; + +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{ + dto::enums::PythonDTO, + from_python::{from_python_typed, from_python_untyped}, + }, +}; + +pub type QueryParameter = (dyn ToSql + Sync); + +pub(crate) struct ParametersBuilder { + parameters: Option, + types: Option>, +} + +impl ParametersBuilder { + pub fn new(parameters: &Option, types: Option>) -> Self { + Self { + parameters: parameters.clone(), + types, + } + } + + pub fn prepare( + self, + parameters_names: Option>, + ) -> PSQLPyResult { + let prepared_parameters = + Python::with_gil(|gil| self.prepare_parameters(gil, parameters_names))?; + + Ok(prepared_parameters) + } + + fn prepare_parameters( + self, + gil: Python<'_>, + parameters_names: Option>, + ) -> PSQLPyResult { + if self.parameters.is_none() { + return Ok(PreparedParameters::default()); + } + + let sequence_typed = self.as_type::>(gil); + let mapping_typed = self.downcast_as::(gil); + let mut prepared_parameters: Option = None; + + match (sequence_typed, mapping_typed) { + (Some(sequence), None) => { + prepared_parameters = + Some(SequenceParametersBuilder::new(sequence, self.types).prepare(gil)?); + } + (None, Some(mapping)) => { + if let Some(parameters_names) = parameters_names { + prepared_parameters = Some( + MappingParametersBuilder::new(mapping, self.types) + .prepare(gil, parameters_names)?, + ) + } + } + _ => {} + } + + if let Some(prepared_parameters) = prepared_parameters { + return Ok(prepared_parameters); + } + + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Parameters must be sequence or mapping".into(), + )); + } + + fn as_type FromPyObjectBound<'a, 'py>>(&self, gil: Python<'_>) -> Option { + if let Some(parameters) = &self.parameters { + let extracted_param = parameters.extract::(gil); + + if let Ok(extracted_param) = extracted_param { + return Some(extracted_param); + } + + return None; + } + + None + } + + fn downcast_as(&self, gil: Python<'_>) -> Option> { + if let Some(parameters) = &self.parameters { + let extracted_param = parameters.downcast_bound::(gil); + + if let Ok(extracted_param) = extracted_param { + return Some(extracted_param.clone().unbind()); + } + + return None; + } + + None + } +} + +pub(crate) struct MappingParametersBuilder { + map_parameters: Py, + types: Option>, +} + +impl MappingParametersBuilder { + fn new(map_parameters: Py, types: Option>) -> Self { + Self { + map_parameters, + types, + } + } + + fn prepare( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + if self.types.is_some() { + return self.prepare_typed(gil, parameters_names); + } + + self.prepare_not_typed(gil, parameters_names) + } + + fn prepare_typed( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + let converted_parameters = self + .extract_parameters(gil, parameters_names)? + .iter() + .map(|parameter| from_python_untyped(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } + + fn prepare_not_typed( + self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult { + let converted_parameters = self + .extract_parameters(gil, parameters_names)? + .iter() + .map(|parameter| from_python_untyped(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } + + fn extract_parameters( + &self, + gil: Python<'_>, + parameters_names: Vec, + ) -> PSQLPyResult> { + let mut params_as_pyobject: Vec = vec![]; + + for param_name in parameters_names { + match self.map_parameters.bind(gil).get_item(¶m_name) { + Ok(param_value) => params_as_pyobject.push(param_value.unbind()), + Err(_) => { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + format!("Cannot find parameter with name <{}>", param_name).into(), + )) + } + } + } + + Ok(params_as_pyobject) + } +} + +pub(crate) struct SequenceParametersBuilder { + seq_parameters: Vec, + types: Option>, +} + +impl SequenceParametersBuilder { + fn new(seq_parameters: Vec, types: Option>) -> Self { + Self { + seq_parameters: seq_parameters, + types, + } + } + + fn prepare(self, gil: Python<'_>) -> PSQLPyResult { + let types = self.types.clone(); + + if types.is_some() { + return self.prepare_typed(gil, types.clone().unwrap()); + } + + self.prepare_not_typed(gil) + } + + fn prepare_typed(self, gil: Python<'_>, types: Vec) -> PSQLPyResult { + let zipped_params_types = zip(self.seq_parameters, &types); + let converted_parameters = zipped_params_types + .map(|(parameter, type_)| from_python_typed(parameter.bind(gil), &type_)) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, types)) + } + + fn prepare_not_typed(self, gil: Python<'_>) -> PSQLPyResult { + let converted_parameters = self + .seq_parameters + .iter() + .map(|parameter| from_python_untyped(parameter.bind(gil))) + .collect::>>()?; + + Ok(PreparedParameters::new(converted_parameters, vec![])) // TODO: change vec![] to real types. + } +} + +#[derive(Default, Clone, Debug)] +pub struct PreparedParameters { + parameters: Vec, + types: Vec, +} + +impl PreparedParameters { + pub fn new(parameters: Vec, types: Vec) -> Self { + Self { parameters, types } + } + + pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { + let params_ref = &self.parameters; + params_ref + .iter() + .map(|param| param as &QueryParameter) + .collect::>() + .into_boxed_slice() + } + + pub fn params_typed(&self) -> Box<[(&(dyn ToSql + Sync), Type)]> { + let params_ref = &self.parameters; + let types = self.types.clone(); + let params_types = zip(params_ref, types); + params_types + .map(|(param, type_)| (param as &QueryParameter, type_)) + .collect::>() + .into_boxed_slice() + } +} diff --git a/src/statement/query.rs b/src/statement/query.rs new file mode 100644 index 00000000..108fe756 --- /dev/null +++ b/src/statement/query.rs @@ -0,0 +1,92 @@ +use std::fmt::Display; + +use regex::Regex; + +use crate::value_converter::consts::KWARGS_PARAMS_REGEXP; + +use super::utils::hash_str; + +#[derive(Clone, Debug)] +pub struct QueryString { + pub(crate) initial_qs: String, + // This field are used when kwargs passed + // from python side as parameters. + pub(crate) converted_qs: Option, +} + +impl Display for QueryString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.query()) + } +} + +impl QueryString { + pub fn new(initial_qs: &String) -> Self { + return Self { + initial_qs: initial_qs.clone(), + converted_qs: None, + }; + } + + pub(crate) fn query(&self) -> &str { + if let Some(converted_qs) = &self.converted_qs { + return converted_qs.query(); + } + + return &self.initial_qs; + } + + pub(crate) fn hash(&self) -> u64 { + hash_str(&self.initial_qs) + } + + pub(crate) fn process_qs(&mut self) { + if !self.is_kwargs_parametrized() { + return (); + } + + let mut counter = 0; + let mut parameters_names = Vec::new(); + + let re = Regex::new(KWARGS_PARAMS_REGEXP).unwrap(); + let result = re.replace_all(&self.initial_qs, |caps: ®ex::Captures| { + let parameter_idx = caps[1].to_string(); + + parameters_names.push(parameter_idx.clone()); + counter += 1; + + format!("${}", &counter) + }); + + self.converted_qs = Some(ConvertedQueryString::new(result.into(), parameters_names)); + } + + fn is_kwargs_parametrized(&self) -> bool { + Regex::new(KWARGS_PARAMS_REGEXP) + .unwrap() + .is_match(&self.initial_qs) + } +} + +#[derive(Clone, Debug)] +pub(crate) struct ConvertedQueryString { + converted_qs: String, + params_names: Vec, +} + +impl ConvertedQueryString { + fn new(converted_qs: String, params_names: Vec) -> Self { + Self { + converted_qs, + params_names, + } + } + + fn query(&self) -> &str { + &self.converted_qs + } + + pub(crate) fn params_names(&self) -> &Vec { + &self.params_names + } +} diff --git a/src/statement/statement.rs b/src/statement/statement.rs new file mode 100644 index 00000000..addaae89 --- /dev/null +++ b/src/statement/statement.rs @@ -0,0 +1,46 @@ +use postgres_types::{ToSql, Type}; +use tokio_postgres::Statement; + +use crate::exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}; + +use super::{parameters::PreparedParameters, query::QueryString}; + +#[derive(Clone, Debug)] +pub struct PsqlpyStatement { + query: QueryString, + prepared_parameters: PreparedParameters, + prepared_statement: Option, +} + +impl PsqlpyStatement { + pub(crate) fn new( + query: QueryString, + prepared_parameters: PreparedParameters, + prepared_statement: Option, + ) -> Self { + Self { + query, + prepared_parameters, + prepared_statement, + } + } + + pub fn raw_query(&self) -> &str { + self.query.query() + } + + pub fn statement_query(&self) -> PSQLPyResult<&Statement> { + match &self.prepared_statement { + Some(prepared_stmt) => return Ok(prepared_stmt), + None => return Err(RustPSQLDriverError::ConnectionExecuteError("No".into())), + } + } + + pub fn params(&self) -> Box<[&(dyn ToSql + Sync)]> { + self.prepared_parameters.params() + } + + pub fn params_typed(&self) -> Box<[(&(dyn ToSql + Sync), Type)]> { + self.prepared_parameters.params_typed() + } +} diff --git a/src/statement/statement_builder.rs b/src/statement/statement_builder.rs new file mode 100644 index 00000000..5954f88c --- /dev/null +++ b/src/statement/statement_builder.rs @@ -0,0 +1,115 @@ +use pyo3::PyObject; +use tokio::sync::RwLockWriteGuard; +use tokio_postgres::Statement; + +use crate::{driver::inner_connection::PsqlpyConnection, exceptions::rust_errors::PSQLPyResult}; + +use super::{ + cache::{StatementCacheInfo, StatementsCache, STMTS_CACHE}, + parameters::ParametersBuilder, + query::QueryString, + statement::PsqlpyStatement, +}; + +pub struct StatementBuilder<'a> { + querystring: String, + parameters: Option, + inner_conn: &'a PsqlpyConnection, + prepared: bool, +} + +impl<'a> StatementBuilder<'a> { + pub fn new( + querystring: String, + parameters: Option, + inner_conn: &'a PsqlpyConnection, + prepared: Option, + ) -> Self { + Self { + querystring, + parameters, + inner_conn, + prepared: prepared.unwrap_or(true), + } + } + + pub async fn build(self) -> PSQLPyResult { + if !self.prepared { + { + let stmt_cache_guard = STMTS_CACHE.read().await; + if let Some(cached) = stmt_cache_guard.get_cache(&self.querystring) { + return self.build_with_cached(cached); + } + } + } + + let stmt_cache_guard = STMTS_CACHE.write().await; + self.build_no_cached(stmt_cache_guard).await + } + + fn build_with_cached(self, cached: StatementCacheInfo) -> PSQLPyResult { + let raw_parameters = ParametersBuilder::new(&self.parameters, Some(cached.types())); + + let parameters_names = if let Some(converted_qs) = &cached.query.converted_qs { + Some(converted_qs.params_names().clone()) + } else { + None + }; + + let prepared_parameters = raw_parameters.prepare(parameters_names)?; + + return Ok(PsqlpyStatement::new( + cached.query, + prepared_parameters, + None, + )); + } + + async fn build_no_cached( + self, + cache_guard: RwLockWriteGuard<'_, StatementsCache>, + ) -> PSQLPyResult { + let mut querystring = QueryString::new(&self.querystring); + querystring.process_qs(); + + let prepared_stmt = self.prepare_query(&querystring, self.prepared).await?; + let parameters_builder = + ParametersBuilder::new(&self.parameters, Some(prepared_stmt.params().to_vec())); + + let parameters_names = if let Some(converted_qs) = &querystring.converted_qs { + Some(converted_qs.params_names().clone()) + } else { + None + }; + + let prepared_parameters = parameters_builder.prepare(parameters_names)?; + + match self.prepared { + true => { + return Ok(PsqlpyStatement::new( + querystring, + prepared_parameters, + Some(prepared_stmt), + )) + } + false => { + self.write_to_cache(cache_guard, &querystring, &prepared_stmt) + .await; + return Ok(PsqlpyStatement::new(querystring, prepared_parameters, None)); + } + } + } + + async fn write_to_cache( + &self, + mut cache_guard: RwLockWriteGuard<'_, StatementsCache>, + query: &QueryString, + inner_stmt: &Statement, + ) { + cache_guard.add_cache(query, inner_stmt); + } + + async fn prepare_query(&self, query: &QueryString, prepared: bool) -> PSQLPyResult { + self.inner_conn.prepare(query.query(), prepared).await + } +} diff --git a/src/statement/utils.rs b/src/statement/utils.rs new file mode 100644 index 00000000..a79f8bdd --- /dev/null +++ b/src/statement/utils.rs @@ -0,0 +1,8 @@ +use std::hash::{DefaultHasher, Hash, Hasher}; + +pub(crate) fn hash_str(string: &String) -> u64 { + let mut hasher = DefaultHasher::new(); + string.hash(&mut hasher); + + hasher.finish() +} diff --git a/src/value_converter/additional_types.rs b/src/value_converter/additional_types.rs index 5dd435a0..1159939a 100644 --- a/src/value_converter/additional_types.rs +++ b/src/value_converter/additional_types.rs @@ -13,6 +13,8 @@ use pyo3::{ use serde::{Deserialize, Serialize}; use tokio_postgres::types::{FromSql, Type}; +pub struct NonePyType; + macro_rules! build_additional_rust_type { ($st_name:ident, $rust_type:ty) => { #[derive(Debug)] diff --git a/src/value_converter/consts.rs b/src/value_converter/consts.rs index 40fa932b..82a34f0f 100644 --- a/src/value_converter/consts.rs +++ b/src/value_converter/consts.rs @@ -1,5 +1,4 @@ use once_cell::sync::Lazy; -use postgres_types::ToSql; use std::{collections::HashMap, sync::RwLock}; use pyo3::{ @@ -8,6 +7,8 @@ use pyo3::{ Bound, Py, PyResult, Python, }; +pub static KWARGS_PARAMS_REGEXP: &str = r"\$\(([^)]+)\)p"; + pub static DECIMAL_CLS: GILOnceCell> = GILOnceCell::new(); pub static TIMEDELTA_CLS: GILOnceCell> = GILOnceCell::new(); pub static KWARGS_QUERYSTRINGS: Lazy)>>> = @@ -33,5 +34,3 @@ pub fn get_timedelta_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> { }) .map(|ty| ty.bind(py)) } - -pub type QueryParameter = (dyn ToSql + Sync); diff --git a/src/value_converter/dto/converter_impls.rs b/src/value_converter/dto/converter_impls.rs new file mode 100644 index 00000000..f50529bc --- /dev/null +++ b/src/value_converter/dto/converter_impls.rs @@ -0,0 +1,220 @@ +use std::net::IpAddr; + +use chrono::{DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use pg_interval::Interval; +use postgres_types::Type; +use pyo3::{ + types::{PyAnyMethods, PyDateTime, PyDelta, PyDict}, + Bound, PyAny, +}; +use rust_decimal::Decimal; +use uuid::Uuid; + +use crate::{ + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + extra_types::{self, PythonDecimal, PythonUUID}, + value_converter::{ + additional_types::NonePyType, + from_python::{extract_datetime_from_python_object_attrs, py_sequence_into_postgres_array}, + models::serde_value::build_serde_value, + traits::{ToPythonDTO, ToPythonDTOArray}, + }, +}; + +use super::{enums::PythonDTO, funcs::array_type_to_single_type}; + +impl ToPythonDTO for NonePyType { + fn to_python_dto(_python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + Ok(PythonDTO::PyNone) + } +} + +macro_rules! construct_simple_type_converter { + ($match_type:ty, $kind:path) => { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { + Ok($kind(python_param.extract::<$match_type>()?)) + } + } + }; +} + +construct_simple_type_converter!(bool, PythonDTO::PyBool); +construct_simple_type_converter!(Vec, PythonDTO::PyBytes); +construct_simple_type_converter!(String, PythonDTO::PyString); +construct_simple_type_converter!(f32, PythonDTO::PyFloat32); +construct_simple_type_converter!(f64, PythonDTO::PyFloat64); +construct_simple_type_converter!(i16, PythonDTO::PyIntI16); +construct_simple_type_converter!(i32, PythonDTO::PyIntI32); +construct_simple_type_converter!(i64, PythonDTO::PyIntI64); +construct_simple_type_converter!(NaiveDate, PythonDTO::PyDate); +construct_simple_type_converter!(NaiveTime, PythonDTO::PyTime); + +impl ToPythonDTO for PyDateTime { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + let timestamp_tz = python_param.extract::>(); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + + let timestamp_no_tz = python_param.extract::(); + if let Ok(pydatetime_no_tz) = timestamp_no_tz { + return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); + } + + let timestamp_tz = extract_datetime_from_python_object_attrs(python_param); + if let Ok(pydatetime_tz) = timestamp_tz { + return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); + } + + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Can not convert you datetime to rust type".into(), + )); + } +} + +impl ToPythonDTO for PyDelta { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + let duration = python_param.extract::()?; + if let Some(interval) = Interval::from_duration(duration) { + return Ok(PythonDTO::PyInterval(interval)); + } + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot convert timedelta from Python to inner Rust type.".to_string(), + )); + } +} + +impl ToPythonDTO for PyDict { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + let serde_value = build_serde_value(python_param)?; + + return Ok(PythonDTO::PyJsonb(serde_value)); + } +} + +macro_rules! construct_extra_type_converter { + ($match_type:ty, $kind:path) => { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { + Ok($kind(python_param.extract::<$match_type>()?.inner())) + } + } + }; +} + +construct_extra_type_converter!(extra_types::Text, PythonDTO::PyText); +construct_extra_type_converter!(extra_types::VarChar, PythonDTO::PyVarChar); +construct_extra_type_converter!(extra_types::SmallInt, PythonDTO::PyIntI16); +construct_extra_type_converter!(extra_types::Integer, PythonDTO::PyIntI32); +construct_extra_type_converter!(extra_types::BigInt, PythonDTO::PyIntI64); +construct_extra_type_converter!(extra_types::Float32, PythonDTO::PyFloat32); +construct_extra_type_converter!(extra_types::Float64, PythonDTO::PyFloat64); +construct_extra_type_converter!(extra_types::Money, PythonDTO::PyMoney); +construct_extra_type_converter!(extra_types::JSONB, PythonDTO::PyJsonb); +construct_extra_type_converter!(extra_types::JSON, PythonDTO::PyJson); +construct_extra_type_converter!(extra_types::MacAddr6, PythonDTO::PyMacAddr6); +construct_extra_type_converter!(extra_types::MacAddr8, PythonDTO::PyMacAddr8); +construct_extra_type_converter!(extra_types::Point, PythonDTO::PyPoint); +construct_extra_type_converter!(extra_types::Box, PythonDTO::PyBox); +construct_extra_type_converter!(extra_types::Path, PythonDTO::PyPath); +construct_extra_type_converter!(extra_types::Line, PythonDTO::PyLine); +construct_extra_type_converter!(extra_types::LineSegment, PythonDTO::PyLineSegment); +construct_extra_type_converter!(extra_types::Circle, PythonDTO::PyCircle); +construct_extra_type_converter!(extra_types::PgVector, PythonDTO::PyPgVector); +construct_extra_type_converter!(extra_types::CustomType, PythonDTO::PyCustomType); + +impl ToPythonDTO for PythonDecimal { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( + python_param.str()?.extract::<&str>()?, + )?)) + } +} + +impl ToPythonDTO for PythonUUID { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + Ok(PythonDTO::PyUUID(Uuid::parse_str( + python_param.str()?.extract::<&str>()?, + )?)) + } +} + +impl ToPythonDTOArray for extra_types::PythonArray { + fn to_python_dto( + python_param: &pyo3::Bound<'_, PyAny>, + array_type: Type, + ) -> PSQLPyResult { + let elem_type = array_type_to_single_type(&array_type); + Ok(PythonDTO::PyArray( + py_sequence_into_postgres_array(python_param, &elem_type)?, + array_type, + )) + } +} + +impl ToPythonDTO for IpAddr { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + if let Ok(id_address) = python_param.extract::() { + return Ok(PythonDTO::PyIpAddress(id_address)); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "Parameter passed to IpAddr is incorrect.".to_string(), + )) + } +} + +impl ToPythonDTO for extra_types::PythonEnum { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { + if let Ok(value_attr) = python_param.getattr("value") { + if let Ok(possible_string) = value_attr.extract::() { + return Ok(PythonDTO::PyString(possible_string)); + } + } + + Err(RustPSQLDriverError::PyToRustValueConversionError( + "Cannot convert Enum to inner type".into(), + )) + } +} + +macro_rules! construct_array_type_converter { + ($match_type:ty) => { + impl ToPythonDTO for $match_type { + fn to_python_dto(python_param: &Bound<'_, PyAny>) -> PSQLPyResult { + python_param + .extract::<$match_type>()? + ._convert_to_python_dto(&Self::element_type()) + } + } + }; +} + +construct_array_type_converter!(extra_types::BoolArray); +construct_array_type_converter!(extra_types::UUIDArray); +construct_array_type_converter!(extra_types::VarCharArray); +construct_array_type_converter!(extra_types::TextArray); +construct_array_type_converter!(extra_types::Int16Array); +construct_array_type_converter!(extra_types::Int32Array); +construct_array_type_converter!(extra_types::Int64Array); +construct_array_type_converter!(extra_types::Float32Array); +construct_array_type_converter!(extra_types::Float64Array); +construct_array_type_converter!(extra_types::MoneyArray); +construct_array_type_converter!(extra_types::IpAddressArray); +construct_array_type_converter!(extra_types::JSONBArray); +construct_array_type_converter!(extra_types::JSONArray); +construct_array_type_converter!(extra_types::DateArray); +construct_array_type_converter!(extra_types::TimeArray); +construct_array_type_converter!(extra_types::DateTimeArray); +construct_array_type_converter!(extra_types::DateTimeTZArray); +construct_array_type_converter!(extra_types::MacAddr6Array); +construct_array_type_converter!(extra_types::MacAddr8Array); +construct_array_type_converter!(extra_types::NumericArray); +construct_array_type_converter!(extra_types::PointArray); +construct_array_type_converter!(extra_types::BoxArray); +construct_array_type_converter!(extra_types::PathArray); +construct_array_type_converter!(extra_types::LineArray); +construct_array_type_converter!(extra_types::LsegArray); +construct_array_type_converter!(extra_types::CircleArray); +construct_array_type_converter!(extra_types::IntervalArray); diff --git a/src/value_converter/dto/enums.rs b/src/value_converter/dto/enums.rs new file mode 100644 index 00000000..a90f1527 --- /dev/null +++ b/src/value_converter/dto/enums.rs @@ -0,0 +1,83 @@ +use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; +use geo_types::{Line as LineSegment, LineString, Point, Rect}; +use macaddr::{MacAddr6, MacAddr8}; +use pg_interval::Interval; +use postgres_types::Type; +use rust_decimal::Decimal; +use serde_json::Value; +use std::{fmt::Debug, net::IpAddr}; +use uuid::Uuid; + +use crate::value_converter::additional_types::{Circle, Line}; +use postgres_array::array::Array; + +#[derive(Debug, Clone, PartialEq)] +pub enum PythonDTO { + // Primitive + PyNone, + PyBytes(Vec), + PyBool(bool), + PyUUID(Uuid), + PyVarChar(String), + PyText(String), + PyString(String), + PyIntI16(i16), + PyIntI32(i32), + PyIntI64(i64), + PyIntU32(u32), + PyIntU64(u64), + PyFloat32(f32), + PyFloat64(f64), + PyMoney(i64), + PyDate(NaiveDate), + PyTime(NaiveTime), + PyDateTime(NaiveDateTime), + PyDateTimeTz(DateTime), + PyInterval(Interval), + PyIpAddress(IpAddr), + PyList(Vec, Type), + PyArray(Array, Type), + PyTuple(Vec, Type), + PyJsonb(Value), + PyJson(Value), + PyMacAddr6(MacAddr6), + PyMacAddr8(MacAddr8), + PyDecimal(Decimal), + PyCustomType(Vec), + PyPoint(Point), + PyBox(Rect), + PyPath(LineString), + PyLine(Line), + PyLineSegment(LineSegment), + PyCircle(Circle), + // Arrays + PyBoolArray(Array), + PyUuidArray(Array), + PyVarCharArray(Array), + PyTextArray(Array), + PyInt16Array(Array), + PyInt32Array(Array), + PyInt64Array(Array), + PyFloat32Array(Array), + PyFloat64Array(Array), + PyMoneyArray(Array), + PyIpAddressArray(Array), + PyJSONBArray(Array), + PyJSONArray(Array), + PyDateArray(Array), + PyTimeArray(Array), + PyDateTimeArray(Array), + PyDateTimeTZArray(Array), + PyMacAddr6Array(Array), + PyMacAddr8Array(Array), + PyNumericArray(Array), + PyPointArray(Array), + PyBoxArray(Array), + PyPathArray(Array), + PyLineArray(Array), + PyLsegArray(Array), + PyCircleArray(Array), + PyIntervalArray(Array), + // PgVector + PyPgVector(Vec), +} diff --git a/src/value_converter/dto/funcs.rs b/src/value_converter/dto/funcs.rs new file mode 100644 index 00000000..116db7d0 --- /dev/null +++ b/src/value_converter/dto/funcs.rs @@ -0,0 +1,33 @@ +use postgres_types::Type; + +pub fn array_type_to_single_type(array_type: &Type) -> Type { + match *array_type { + Type::BOOL_ARRAY => Type::BOOL, + Type::UUID_ARRAY => Type::UUID_ARRAY, + Type::VARCHAR_ARRAY => Type::VARCHAR, + Type::TEXT_ARRAY => Type::TEXT, + Type::INT2_ARRAY => Type::INT2, + Type::INT4_ARRAY => Type::INT4, + Type::INT8_ARRAY => Type::INT8, + Type::FLOAT4_ARRAY => Type::FLOAT4, + Type::FLOAT8_ARRAY => Type::FLOAT8, + Type::MONEY_ARRAY => Type::MONEY, + Type::INET_ARRAY => Type::INET, + Type::JSON_ARRAY => Type::JSON, + Type::JSONB_ARRAY => Type::JSONB, + Type::DATE_ARRAY => Type::DATE, + Type::TIME_ARRAY => Type::TIME, + Type::TIMESTAMP_ARRAY => Type::TIMESTAMP, + Type::TIMESTAMPTZ_ARRAY => Type::TIMESTAMPTZ, + Type::INTERVAL_ARRAY => Type::INTERVAL, + Type::MACADDR_ARRAY => Type::MACADDR, + Type::MACADDR8_ARRAY => Type::MACADDR8, + Type::POINT_ARRAY => Type::POINT, + Type::BOX_ARRAY => Type::BOX, + Type::PATH_ARRAY => Type::PATH, + Type::LINE_ARRAY => Type::LINE, + Type::LSEG_ARRAY => Type::LSEG, + Type::CIRCLE_ARRAY => Type::CIRCLE, + _ => Type::ANY, + } +} diff --git a/src/value_converter/models/dto.rs b/src/value_converter/dto/impls.rs similarity index 67% rename from src/value_converter/models/dto.rs rename to src/value_converter/dto/impls.rs index 8609a600..3450dfd0 100644 --- a/src/value_converter/models/dto.rs +++ b/src/value_converter/dto/impls.rs @@ -1,112 +1,47 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime}; -use geo_types::{Line as LineSegment, LineString, Point, Rect}; -use macaddr::{MacAddr6, MacAddr8}; use pg_interval::Interval; use postgres_types::ToSql; use rust_decimal::Decimal; use serde_json::{json, Value}; -use std::{fmt::Debug, net::IpAddr}; +use std::net::IpAddr; use uuid::Uuid; use bytes::{BufMut, BytesMut}; use postgres_protocol::types; -use pyo3::{PyObject, Python, ToPyObject}; +use pyo3::{Bound, IntoPyObject, PyAny, Python}; use tokio_postgres::types::{to_sql_checked, Type}; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - value_converter::additional_types::{ - Circle, Line, RustLineSegment, RustLineString, RustPoint, RustRect, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{ + additional_types::{Circle, Line, RustLineSegment, RustLineString, RustPoint, RustRect}, + models::serde_value::pythondto_array_to_serde, }, }; use pgvector::Vector as PgVector; -use postgres_array::{array::Array, Dimension}; -#[derive(Debug, Clone, PartialEq)] -pub enum PythonDTO { - // Primitive - PyNone, - PyBytes(Vec), - PyBool(bool), - PyUUID(Uuid), - PyVarChar(String), - PyText(String), - PyString(String), - PyIntI16(i16), - PyIntI32(i32), - PyIntI64(i64), - PyIntU32(u32), - PyIntU64(u64), - PyFloat32(f32), - PyFloat64(f64), - PyMoney(i64), - PyDate(NaiveDate), - PyTime(NaiveTime), - PyDateTime(NaiveDateTime), - PyDateTimeTz(DateTime), - PyInterval(Interval), - PyIpAddress(IpAddr), - PyList(Vec), - PyArray(Array), - PyTuple(Vec), - PyJsonb(Value), - PyJson(Value), - PyMacAddr6(MacAddr6), - PyMacAddr8(MacAddr8), - PyDecimal(Decimal), - PyCustomType(Vec), - PyPoint(Point), - PyBox(Rect), - PyPath(LineString), - PyLine(Line), - PyLineSegment(LineSegment), - PyCircle(Circle), - // Arrays - PyBoolArray(Array), - PyUuidArray(Array), - PyVarCharArray(Array), - PyTextArray(Array), - PyInt16Array(Array), - PyInt32Array(Array), - PyInt64Array(Array), - PyFloat32Array(Array), - PyFloat64Array(Array), - PyMoneyArray(Array), - PyIpAddressArray(Array), - PyJSONBArray(Array), - PyJSONArray(Array), - PyDateArray(Array), - PyTimeArray(Array), - PyDateTimeArray(Array), - PyDateTimeTZArray(Array), - PyMacAddr6Array(Array), - PyMacAddr8Array(Array), - PyNumericArray(Array), - PyPointArray(Array), - PyBoxArray(Array), - PyPathArray(Array), - PyLineArray(Array), - PyLsegArray(Array), - PyCircleArray(Array), - PyIntervalArray(Array), - // PgVector - PyPgVector(Vec), -} +use super::enums::PythonDTO; + +impl<'py> IntoPyObject<'py> for PythonDTO { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = std::convert::Infallible; -impl ToPyObject for PythonDTO { - fn to_object(&self, py: Python<'_>) -> PyObject { + fn into_pyobject(self, py: Python<'py>) -> Result { match self { - PythonDTO::PyNone => py.None(), - PythonDTO::PyBool(pybool) => pybool.to_object(py), + PythonDTO::PyNone => Ok(py.None().into_bound(py)), + PythonDTO::PyBool(pybool) => Ok(pybool.into_pyobject(py)?.to_owned().into_any()), PythonDTO::PyString(py_string) | PythonDTO::PyText(py_string) - | PythonDTO::PyVarChar(py_string) => py_string.to_object(py), - PythonDTO::PyIntI32(pyint) => pyint.to_object(py), - PythonDTO::PyIntI64(pyint) => pyint.to_object(py), - PythonDTO::PyIntU64(pyint) => pyint.to_object(py), - PythonDTO::PyFloat32(pyfloat) => pyfloat.to_object(py), - PythonDTO::PyFloat64(pyfloat) => pyfloat.to_object(py), - _ => unreachable!(), + | PythonDTO::PyVarChar(py_string) => Ok(py_string.into_pyobject(py)?.into_any()), + PythonDTO::PyIntI32(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyIntI64(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyIntU64(pyint) => Ok(pyint.into_pyobject(py)?.into_any()), + PythonDTO::PyFloat32(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), + PythonDTO::PyFloat64(pyfloat) => Ok(pyfloat.into_pyobject(py)?.into_any()), + _ => { + unreachable!() + } } } } @@ -120,7 +55,7 @@ impl PythonDTO { /// /// # Errors /// May return Err Result if there is no support for passed python type. - pub fn array_type(&self) -> RustPSQLDriverPyResult { + pub fn array_type(&self) -> PSQLPyResult { match self { PythonDTO::PyBool(_) => Ok(tokio_postgres::types::Type::BOOL_ARRAY), PythonDTO::PyUUID(_) => Ok(tokio_postgres::types::Type::UUID_ARRAY), @@ -163,7 +98,7 @@ impl PythonDTO { /// /// # Errors /// May return Err Result if cannot convert python type into rust. - pub fn to_serde_value(&self) -> RustPSQLDriverPyResult { + pub fn to_serde_value(&self) -> PSQLPyResult { match self { PythonDTO::PyNone => Ok(Value::Null), PythonDTO::PyBool(pybool) => Ok(json!(pybool)), @@ -175,7 +110,7 @@ impl PythonDTO { PythonDTO::PyIntU64(pyint) => Ok(json!(pyint)), PythonDTO::PyFloat32(pyfloat) => Ok(json!(pyfloat)), PythonDTO::PyFloat64(pyfloat) => Ok(json!(pyfloat)), - PythonDTO::PyList(pylist) => { + PythonDTO::PyList(pylist, _) => { let mut vec_serde_values: Vec = vec![]; for py_object in pylist { @@ -184,7 +119,9 @@ impl PythonDTO { Ok(json!(vec_serde_values)) } - PythonDTO::PyArray(array) => Ok(json!(pythondto_array_to_serde(Some(array.clone()))?)), + PythonDTO::PyArray(array, _) => { + Ok(json!(pythondto_array_to_serde(Some(array.clone()))?)) + } PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => Ok(py_dict.clone()), _ => Err(RustPSQLDriverError::PyToRustValueConversionError( "Cannot convert your type into Rust type".into(), @@ -305,30 +242,11 @@ impl ToSql for PythonDTO { PythonDTO::PyCircle(pycircle) => { <&Circle as ToSql>::to_sql(&pycircle, ty, out)?; } - PythonDTO::PyList(py_iterable) | PythonDTO::PyTuple(py_iterable) => { - let mut items = Vec::new(); - for inner in py_iterable { - items.push(inner); - } - if items.is_empty() { - return_is_null_true = true; - } else { - items.to_sql(&items[0].array_type()?, out)?; - } + PythonDTO::PyList(py_iterable, type_) | PythonDTO::PyTuple(py_iterable, type_) => { + return py_iterable.to_sql(type_, out); } - PythonDTO::PyArray(array) => { - if let Some(first_elem) = array.iter().nth(0) { - match first_elem.array_type() { - Ok(ok_type) => { - array.to_sql(&ok_type, out)?; - } - Err(_) => { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Cannot define array type.".into(), - ))? - } - } - } + PythonDTO::PyArray(array, type_) => { + return array.to_sql(type_, out); } PythonDTO::PyJsonb(py_dict) | PythonDTO::PyJson(py_dict) => { <&Value as ToSql>::to_sql(&py_dict, ty, out)?; @@ -431,61 +349,3 @@ impl ToSql for PythonDTO { to_sql_checked!(); } - -/// Convert Array of `PythonDTO`s to serde `Value`. -/// -/// It can convert multidimensional arrays. -fn pythondto_array_to_serde(array: Option>) -> RustPSQLDriverPyResult { - match array { - Some(array) => inner_pythondto_array_to_serde( - array.dimensions(), - array.iter().collect::>().as_slice(), - 0, - 0, - ), - None => Ok(Value::Null), - } -} - -/// Inner conversion array of `PythonDTO`s to serde `Value`. -#[allow(clippy::cast_sign_loss)] -fn inner_pythondto_array_to_serde( - dimensions: &[Dimension], - data: &[&PythonDTO], - dimension_index: usize, - mut lower_bound: usize, -) -> RustPSQLDriverPyResult { - let current_dimension = dimensions.get(dimension_index); - - if let Some(current_dimension) = current_dimension { - let possible_next_dimension = dimensions.get(dimension_index + 1); - match possible_next_dimension { - Some(next_dimension) => { - let mut final_list: Value = Value::Array(vec![]); - - for _ in 0..current_dimension.len as usize { - if dimensions.get(dimension_index + 1).is_some() { - let inner_pylist = inner_pythondto_array_to_serde( - dimensions, - &data[lower_bound..next_dimension.len as usize + lower_bound], - dimension_index + 1, - 0, - )?; - match final_list { - Value::Array(ref mut array) => array.push(inner_pylist), - _ => unreachable!(), - } - lower_bound += next_dimension.len as usize; - }; - } - - return Ok(final_list); - } - None => { - return data.iter().map(|x| x.to_serde_value()).collect(); - } - } - } - - Ok(Value::Array(vec![])) -} diff --git a/src/value_converter/dto/mod.rs b/src/value_converter/dto/mod.rs new file mode 100644 index 00000000..49985cf1 --- /dev/null +++ b/src/value_converter/dto/mod.rs @@ -0,0 +1,4 @@ +pub mod converter_impls; +pub mod enums; +pub mod funcs; +pub mod impls; diff --git a/src/value_converter/funcs/from_python.rs b/src/value_converter/from_python.rs similarity index 57% rename from src/value_converter/funcs/from_python.rs rename to src/value_converter/from_python.rs index 4fe73290..fa1d5c60 100644 --- a/src/value_converter/funcs/from_python.rs +++ b/src/value_converter/from_python.rs @@ -2,28 +2,27 @@ use chrono::{self, DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, T use chrono_tz::Tz; use geo_types::{coord, Coord}; use itertools::Itertools; -use pg_interval::Interval; use postgres_array::{Array, Dimension}; -use rust_decimal::Decimal; -use serde_json::{json, Map, Value}; +use postgres_types::Type; use std::net::IpAddr; -use uuid::Uuid; use pyo3::{ types::{ - PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyDictMethods, PyFloat, - PyInt, PyList, PyMapping, PySequence, PySet, PyString, PyTime, PyTuple, PyTypeMethods, + PyAnyMethods, PyBool, PyBytes, PyDate, PyDateTime, PyDelta, PyDict, PyFloat, PyInt, PyList, + PySequence, PySet, PyString, PyTime, PyTuple, PyTypeMethods, }, Bound, Py, PyAny, Python, }; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, extra_types::{self}, - value_converter::{ - consts::KWARGS_QUERYSTRINGS, models::dto::PythonDTO, - utils::extract_value_from_python_object_or_raise, - }, + value_converter::{dto::enums::PythonDTO, utils::extract_value_from_python_object_or_raise}, +}; + +use super::{ + additional_types::NonePyType, + traits::{ToPythonDTO, ToPythonDTOArray}, }; /// Convert single python parameter to `PythonDTO` enum. @@ -33,422 +32,402 @@ use crate::{ /// May return Err Result if python type doesn't have support yet /// or value of the type is incorrect. #[allow(clippy::too_many_lines)] -pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult { +pub fn from_python_untyped(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if parameter.is_none() { - return Ok(PythonDTO::PyNone); - } - - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyCustomType( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBool(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBytes(parameter.extract::>()?)); + return as ToPythonDTO>::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyText( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyVarChar( - parameter.extract::()?.inner(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyString(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat64(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat32( - parameter - .extract::()? - .retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyFloat64( - parameter - .extract::()? - .retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI16( - parameter - .extract::()? - .retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI32( - parameter - .extract::()? - .retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI64( - parameter.extract::()?.retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMoney( - parameter.extract::()?.retrieve_value(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyIntI32(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - let timestamp_tz = parameter.extract::>(); - if let Ok(pydatetime_tz) = timestamp_tz { - return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); - } - - let timestamp_no_tz = parameter.extract::(); - if let Ok(pydatetime_no_tz) = timestamp_no_tz { - return Ok(PythonDTO::PyDateTime(pydatetime_no_tz)); - } - - let timestamp_tz = extract_datetime_from_python_object_attrs(parameter); - if let Ok(pydatetime_tz) = timestamp_tz { - return Ok(PythonDTO::PyDateTimeTz(pydatetime_tz)); - } - - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Can not convert you datetime to rust type".into(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyDate(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return Ok(PythonDTO::PyTime(parameter.extract::()?)); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - let duration = parameter.extract::()?; - if let Some(interval) = Interval::from_duration(duration) { - return Ok(PythonDTO::PyInterval(interval)); - } - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "Cannot convert timedelta from Python to inner Rust type.".to_string(), - )); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() | parameter.is_instance_of::() { - return Ok(PythonDTO::PyArray(py_sequence_into_postgres_array( - parameter, - )?)); + return ::to_python_dto(parameter, Type::ANY); } if parameter.is_instance_of::() { - let dict = parameter.downcast::().map_err(|error| { - RustPSQLDriverError::PyToRustValueConversionError(format!( - "Can't cast to inner dict: {error}" - )) - })?; + return ::to_python_dto(parameter); + } - let mut serde_map: Map = Map::new(); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } - for dict_item in dict.items() { - let py_list = dict_item.downcast::().map_err(|error| { - RustPSQLDriverError::PyToRustValueConversionError(format!( - "Cannot cast to list: {error}" - )) - })?; + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } - let key = py_list.get_item(0)?.extract::()?; - let value = py_to_rust(&py_list.get_item(1)?)?; + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } - serde_map.insert(key, value.to_serde_value()?); - } + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } - return Ok(PythonDTO::PyJsonb(Value::Object(serde_map))); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyJsonb( - parameter.extract::()?.inner().clone(), - )); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyJson( - parameter.extract::()?.inner().clone(), - )); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMacAddr6( - parameter.extract::()?.inner(), - )); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyMacAddr8( - parameter.extract::()?.inner(), - )); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } if parameter.get_type().name()? == "UUID" { - return Ok(PythonDTO::PyUUID(Uuid::parse_str( - parameter.str()?.extract::<&str>()?, - )?)); + return ::to_python_dto(parameter); } if parameter.get_type().name()? == "decimal.Decimal" || parameter.get_type().name()? == "Decimal" { - return Ok(PythonDTO::PyDecimal(Decimal::from_str_exact( - parameter.str()?.extract::<&str>()?, - )?)); + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPoint( - parameter.extract::()?.retrieve_value(), - )); + if let Ok(converted_array) = from_python_array_typed(parameter) { + return Ok(converted_array); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyBox( - parameter.extract::()?.retrieve_value(), - )); + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPath( - parameter.extract::()?.retrieve_value(), - )); + if parameter.extract::().is_ok() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyLine( - parameter.extract::()?.retrieve_value(), - )); + if parameter.getattr("value").is_ok() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyLineSegment( - parameter - .extract::()? - .retrieve_value(), - )); + Err(RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can not covert you type {parameter} into inner one", + ))) +} + +/// Convert single python parameter to `PythonDTO` enum. +/// +/// # Errors +/// +/// May return Err Result if python type doesn't have support yet +/// or value of the type is incorrect. +#[allow(clippy::too_many_lines)] +pub fn from_python_typed( + parameter: &pyo3::Bound<'_, PyAny>, + type_: &Type, +) -> PSQLPyResult { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); } - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyCircle( - parameter.extract::()?.retrieve_value(), - )); + if parameter.is_none() { + return ::to_python_dto(parameter); + } + + if parameter.get_type().name()? == "UUID" { + return ::to_python_dto(parameter); + } + + if parameter.get_type().name()? == "decimal.Decimal" + || parameter.get_type().name()? == "Decimal" + { + return ::to_python_dto(parameter); + } + + if parameter.is_instance_of::() | parameter.is_instance_of::() { + return ::to_python_dto( + parameter, + type_.clone(), + ); } + if let Ok(converted_array) = from_python_array_typed(parameter) { + return Ok(converted_array); + } + + match *type_ { + Type::BYTEA => return as ToPythonDTO>::to_python_dto(parameter), + Type::TEXT => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::VARCHAR => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::XML => return ::to_python_dto(parameter), + Type::BOOL => return ::to_python_dto(parameter), + Type::INT2 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INT4 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INT8 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::MONEY => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::FLOAT4 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::FLOAT8 => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + return ::to_python_dto(parameter); + } + Type::INET => return ::to_python_dto(parameter), + Type::DATE => return ::to_python_dto(parameter), + Type::TIME => return ::to_python_dto(parameter), + Type::TIMESTAMP | Type::TIMESTAMPTZ => { + return ::to_python_dto(parameter) + } + Type::INTERVAL => return ::to_python_dto(parameter), + Type::JSONB => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + return ::to_python_dto(parameter); + } + Type::JSON => { + if parameter.is_instance_of::() { + return ::to_python_dto(parameter); + } + + return ::to_python_dto(parameter); + } + Type::MACADDR => return ::to_python_dto(parameter), + Type::MACADDR8 => return ::to_python_dto(parameter), + Type::POINT => return ::to_python_dto(parameter), + Type::BOX => return ::to_python_dto(parameter), + Type::PATH => return ::to_python_dto(parameter), + Type::LINE => return ::to_python_dto(parameter), + Type::LSEG => return ::to_python_dto(parameter), + Type::CIRCLE => return ::to_python_dto(parameter), + _ => {} + } + + if let Ok(converted_value) = from_python_untyped(parameter) { + return Ok(converted_value); + } + + Err(RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can not covert you type {parameter} into {type_}", + ))) +} + +fn from_python_array_typed(parameter: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult { if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); + return ::to_python_dto(parameter); } if parameter.is_instance_of::() { - return parameter - .extract::()? - ._convert_to_python_dto(); - } - - if parameter.is_instance_of::() { - return Ok(PythonDTO::PyPgVector( - parameter.extract::()?.inner_value(), - )); - } - - if let Ok(id_address) = parameter.extract::() { - return Ok(PythonDTO::PyIpAddress(id_address)); - } - - // It's used for Enum. - // If StrEnum is used on Python side, - // we simply stop at the `is_instance_of::``. - if let Ok(value_attr) = parameter.getattr("value") { - if let Ok(possible_string) = value_attr.extract::() { - return Ok(PythonDTO::PyString(possible_string)); - } + return ::to_python_dto(parameter); } Err(RustPSQLDriverError::PyToRustValueConversionError(format!( - "Can not covert you type {parameter} into inner one", + "Cannot convert parameter in extra types Array", ))) } @@ -462,7 +441,7 @@ pub fn py_to_rust(parameter: &pyo3::Bound<'_, PyAny>) -> RustPSQLDriverPyResult< /// - The retrieved values are invalid for constructing a date, time, or datetime (e.g., invalid month or day) /// - The timezone information (`tzinfo`) is not available or cannot be parsed /// - The resulting datetime is ambiguous or invalid (e.g., due to DST transitions) -fn extract_datetime_from_python_object_attrs( +pub fn extract_datetime_from_python_object_attrs( parameter: &pyo3::Bound<'_, PyAny>, ) -> Result, RustPSQLDriverError> { let year = extract_value_from_python_object_or_raise::(parameter, "year")?; @@ -514,7 +493,8 @@ fn extract_datetime_from_python_object_attrs( #[allow(clippy::cast_possible_wrap)] pub fn py_sequence_into_postgres_array( parameter: &Bound, -) -> RustPSQLDriverPyResult> { + type_: &Type, +) -> PSQLPyResult> { let mut py_seq = parameter .downcast::() .map_err(|_| { @@ -559,7 +539,7 @@ pub fn py_sequence_into_postgres_array( } } - let array_data = py_sequence_into_flat_vec(parameter)?; + let array_data = py_sequence_into_flat_vec(parameter, type_)?; match postgres_array::Array::from_parts_no_panic(array_data, dimensions) { Ok(result_array) => Ok(result_array), Err(err) => Err(RustPSQLDriverError::PyToRustValueConversionError(format!( @@ -574,7 +554,8 @@ pub fn py_sequence_into_postgres_array( /// May return Err Result if cannot convert element into Rust one. pub fn py_sequence_into_flat_vec( parameter: &Bound, -) -> RustPSQLDriverPyResult> { + type_: &Type, +) -> PSQLPyResult> { let py_seq = parameter.downcast::().map_err(|_| { RustPSQLDriverError::PyToRustValueConversionError( "PostgreSQL ARRAY type can be made only from python Sequence".into(), @@ -589,17 +570,17 @@ pub fn py_sequence_into_flat_vec( // Check for the string because it's sequence too, // and in the most cases it should be array type, not new dimension. if ok_seq_elem.is_instance_of::() { - final_vec.push(py_to_rust(&ok_seq_elem)?); + final_vec.push(from_python_typed(&ok_seq_elem, type_)?); continue; } let possible_next_seq = ok_seq_elem.downcast::(); if let Ok(next_seq) = possible_next_seq { - let mut next_vec = py_sequence_into_flat_vec(next_seq)?; + let mut next_vec = py_sequence_into_flat_vec(next_seq, type_)?; final_vec.append(&mut next_vec); } else { - final_vec.push(py_to_rust(&ok_seq_elem)?); + final_vec.push(from_python_typed(&ok_seq_elem, type_)?); continue; } } @@ -607,123 +588,6 @@ pub fn py_sequence_into_flat_vec( Ok(final_vec) } -/// Convert parameters come from python. -/// -/// Parameters for `execute()` method can be either -/// a list or a tuple or a set. -/// -/// We parse every parameter from python object and return -/// Vector of out `PythonDTO`. -/// -/// # Errors -/// -/// May return Err Result if can't convert python object. -#[allow(clippy::needless_pass_by_value)] -pub fn convert_parameters_and_qs( - querystring: String, - parameters: Option>, -) -> RustPSQLDriverPyResult<(String, Vec)> { - let Some(parameters) = parameters else { - return Ok((querystring, vec![])); - }; - - let res = Python::with_gil(|gil| { - let params = parameters.extract::>>(gil).map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - "Cannot convert you parameters argument into Rust type, please use List/Tuple" - .into(), - ) - }); - if let Ok(params) = params { - return Ok((querystring, convert_seq_parameters(params)?)); - } - - let kw_params = parameters.downcast_bound::(gil); - if let Ok(kw_params) = kw_params { - return convert_kwargs_parameters(kw_params, &querystring); - } - - Err(RustPSQLDriverError::PyToRustValueConversionError( - "Parameters must be sequence or mapping".into(), - )) - })?; - - Ok(res) -} - -pub fn convert_kwargs_parameters<'a>( - kw_params: &Bound<'_, PyMapping>, - querystring: &'a str, -) -> RustPSQLDriverPyResult<(String, Vec)> { - let mut result_vec: Vec = vec![]; - let (changed_string, params_names) = parse_kwargs_qs(querystring); - - for param_name in params_names { - match kw_params.get_item(¶m_name) { - Ok(param) => result_vec.push(py_to_rust(¶m)?), - Err(_) => { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - format!("Cannot find parameter with name <{param_name}> in parameters").into(), - )) - } - } - } - - Ok((changed_string, result_vec)) -} - -pub fn convert_seq_parameters( - seq_params: Vec>, -) -> RustPSQLDriverPyResult> { - let mut result_vec: Vec = vec![]; - Python::with_gil(|gil| { - for parameter in seq_params { - result_vec.push(py_to_rust(parameter.bind(gil))?); - } - Ok::<(), RustPSQLDriverError>(()) - })?; - - Ok(result_vec) -} - -/// Convert python List of Dict type or just Dict into serde `Value`. -/// -/// # Errors -/// May return error if cannot convert Python type into Rust one. -#[allow(clippy::needless_pass_by_value)] -pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { - Python::with_gil(|gil| { - let bind_value = value.bind(gil); - if bind_value.is_instance_of::() { - let mut result_vec: Vec = vec![]; - - let params = bind_value.extract::>>()?; - - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = py_to_rust(inner_bind)?; - result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner)?; - result_vec.push(serde_value); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), - )); - } - } - Ok(json!(result_vec)) - } else if bind_value.is_instance_of::() { - return py_to_rust(bind_value)?.to_serde_value(); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must be dict value.".to_string(), - )); - } - }) -} - /// Convert two python parameters(x and y) to Coord from `geo_type`. /// Also it checks that passed values is int or float. /// @@ -731,7 +595,7 @@ pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { /// /// May return error if cannot convert Python type into Rust one. /// May return error if parameters type isn't correct. -fn convert_py_to_rust_coord_values(parameters: Vec>) -> RustPSQLDriverPyResult> { +fn convert_py_to_rust_coord_values(parameters: Vec>) -> PSQLPyResult> { Python::with_gil(|gil| { let mut coord_values_vec: Vec = vec![]; @@ -746,7 +610,7 @@ fn convert_py_to_rust_coord_values(parameters: Vec>) -> RustPSQLDriver )); } - let python_dto = py_to_rust(parameter_bind)?; + let python_dto = from_python_untyped(parameter_bind)?; match python_dto { PythonDTO::PyIntI16(pyint) => coord_values_vec.push(f64::from(pyint)), PythonDTO::PyIntI32(pyint) => coord_values_vec.push(f64::from(pyint)), @@ -785,7 +649,7 @@ fn convert_py_to_rust_coord_values(parameters: Vec>) -> RustPSQLDriver pub fn build_geo_coords( py_parameters: Py, allowed_length_option: Option, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let mut result_vec: Vec = vec![]; result_vec = Python::with_gil(|gil| { @@ -859,7 +723,7 @@ pub fn build_geo_coords( pub fn build_flat_geo_coords( py_parameters: Py, allowed_length_option: Option, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { Python::with_gil(|gil| { let allowed_length = allowed_length_option.unwrap_or_default(); @@ -893,7 +757,7 @@ pub fn build_flat_geo_coords( /// /// May return error if cannot convert Python type into Rust one. /// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult>> { +fn py_sequence_to_rust(bind_parameters: &Bound) -> PSQLPyResult>> { let mut coord_values_sequence_vec: Vec> = vec![]; if bind_parameters.is_instance_of::() { @@ -923,35 +787,3 @@ fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult Ok::>, RustPSQLDriverError>(coord_values_sequence_vec) } - -fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { - let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap(); - - { - let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); - let qs = kq_read.get(querystring); - - if let Some(qs) = qs { - return qs.clone(); - } - }; - - let mut counter = 0; - let mut sequence = Vec::new(); - - let result = re.replace_all(querystring, |caps: ®ex::Captures| { - let account_id = caps[1].to_string(); - - sequence.push(account_id.clone()); - counter += 1; - - format!("${}", &counter) - }); - - let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap(); - kq_write.insert( - querystring.to_string(), - (result.clone().into(), sequence.clone()), - ); - (result.into(), sequence) -} diff --git a/src/value_converter/funcs/mod.rs b/src/value_converter/funcs/mod.rs deleted file mode 100644 index 4db4cd38..00000000 --- a/src/value_converter/funcs/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod from_python; -pub mod to_python; diff --git a/src/value_converter/mod.rs b/src/value_converter/mod.rs index e8cbc82b..41c42284 100644 --- a/src/value_converter/mod.rs +++ b/src/value_converter/mod.rs @@ -1,5 +1,8 @@ pub mod additional_types; pub mod consts; -pub mod funcs; +pub mod dto; +pub mod from_python; pub mod models; +pub mod to_python; +pub mod traits; pub mod utils; diff --git a/src/value_converter/models/decimal.rs b/src/value_converter/models/decimal.rs index 13d009cc..44a898a1 100644 --- a/src/value_converter/models/decimal.rs +++ b/src/value_converter/models/decimal.rs @@ -1,5 +1,5 @@ use postgres_types::{FromSql, Type}; -use pyo3::{types::PyAnyMethods, PyObject, Python, ToPyObject}; +use pyo3::{types::PyAnyMethods, Bound, IntoPyObject, PyAny, PyObject, Python, ToPyObject}; use rust_decimal::Decimal; use crate::value_converter::consts::get_decimal_cls; diff --git a/src/value_converter/models/mod.rs b/src/value_converter/models/mod.rs index 92d26e49..b36f3bff 100644 --- a/src/value_converter/models/mod.rs +++ b/src/value_converter/models/mod.rs @@ -1,5 +1,4 @@ pub mod decimal; -pub mod dto; pub mod interval; pub mod serde_value; pub mod uuid; diff --git a/src/value_converter/models/serde_value.rs b/src/value_converter/models/serde_value.rs index b39f7737..392e3fd0 100644 --- a/src/value_converter/models/serde_value.rs +++ b/src/value_converter/models/serde_value.rs @@ -1,16 +1,19 @@ +use postgres_array::{Array, Dimension}; use postgres_types::FromSql; -use serde_json::{json, Value}; -use uuid::Uuid; +use serde_json::{json, Map, Value}; use pyo3::{ - types::{PyAnyMethods, PyDict, PyList}, + types::{PyAnyMethods, PyDict, PyDictMethods, PyList, PyTuple}, Bound, FromPyObject, Py, PyAny, PyObject, PyResult, Python, ToPyObject, }; use tokio_postgres::types::Type; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, - value_converter::funcs::{from_python::py_to_rust, to_python::build_python_from_serde_value}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, + value_converter::{ + dto::enums::PythonDTO, from_python::from_python_untyped, + to_python::build_python_from_serde_value, + }, }; /// Struct for Value. @@ -22,7 +25,7 @@ pub struct InternalSerdeValue(Value); impl<'a> FromPyObject<'a> for InternalSerdeValue { fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { - let serde_value = build_serde_value(ob.clone().unbind())?; + let serde_value = build_serde_value(ob)?; Ok(InternalSerdeValue(serde_value)) } @@ -50,36 +53,64 @@ impl<'a> FromSql<'a> for InternalSerdeValue { } } +fn serde_value_from_list(gil: Python<'_>, bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { + let mut result_vec: Vec = vec![]; + + let params = bind_value.extract::>>()?; + + for inner in params { + let inner_bind = inner.bind(gil); + if inner_bind.is_instance_of::() { + let python_dto = from_python_untyped(inner_bind)?; + result_vec.push(python_dto.to_serde_value()?); + } else if inner_bind.is_instance_of::() { + let serde_value = build_serde_value(inner.bind(gil))?; + result_vec.push(serde_value); + } else { + return Err(RustPSQLDriverError::PyToRustValueConversionError( + "PyJSON must have dicts.".to_string(), + )); + } + } + Ok(json!(result_vec)) +} + +fn serde_value_from_dict(bind_value: &Bound<'_, PyAny>) -> PSQLPyResult { + let dict = bind_value.downcast::().map_err(|error| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Can't cast to inner dict: {error}" + )) + })?; + + let mut serde_map: Map = Map::new(); + + for dict_item in dict.items() { + let py_list = dict_item.downcast::().map_err(|error| { + RustPSQLDriverError::PyToRustValueConversionError(format!( + "Cannot cast to list: {error}" + )) + })?; + + let key = py_list.get_item(0)?.extract::()?; + let value = from_python_untyped(&py_list.get_item(1)?)?; + + serde_map.insert(key, value.to_serde_value()?); + } + + return Ok(Value::Object(serde_map)); +} + /// Convert python List of Dict type or just Dict into serde `Value`. /// /// # Errors /// May return error if cannot convert Python type into Rust one. #[allow(clippy::needless_pass_by_value)] -pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { +pub fn build_serde_value(value: &Bound<'_, PyAny>) -> PSQLPyResult { Python::with_gil(|gil| { - let bind_value = value.bind(gil); - if bind_value.is_instance_of::() { - let mut result_vec: Vec = vec![]; - - let params = bind_value.extract::>>()?; - - for inner in params { - let inner_bind = inner.bind(gil); - if inner_bind.is_instance_of::() { - let python_dto = py_to_rust(inner_bind)?; - result_vec.push(python_dto.to_serde_value()?); - } else if inner_bind.is_instance_of::() { - let serde_value = build_serde_value(inner)?; - result_vec.push(serde_value); - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError( - "PyJSON must have dicts.".to_string(), - )); - } - } - Ok(json!(result_vec)) - } else if bind_value.is_instance_of::() { - return py_to_rust(bind_value)?.to_serde_value(); + if value.is_instance_of::() { + return serde_value_from_list(gil, value); + } else if value.is_instance_of::() { + return serde_value_from_dict(value); } else { return Err(RustPSQLDriverError::PyToRustValueConversionError( "PyJSON must be dict value.".to_string(), @@ -87,3 +118,61 @@ pub fn build_serde_value(value: Py) -> RustPSQLDriverPyResult { } }) } + +/// Convert Array of `PythonDTO`s to serde `Value`. +/// +/// It can convert multidimensional arrays. +pub fn pythondto_array_to_serde(array: Option>) -> PSQLPyResult { + match array { + Some(array) => inner_pythondto_array_to_serde( + array.dimensions(), + array.iter().collect::>().as_slice(), + 0, + 0, + ), + None => Ok(Value::Null), + } +} + +/// Inner conversion array of `PythonDTO`s to serde `Value`. +#[allow(clippy::cast_sign_loss)] +fn inner_pythondto_array_to_serde( + dimensions: &[Dimension], + data: &[&PythonDTO], + dimension_index: usize, + mut lower_bound: usize, +) -> PSQLPyResult { + let current_dimension = dimensions.get(dimension_index); + + if let Some(current_dimension) = current_dimension { + let possible_next_dimension = dimensions.get(dimension_index + 1); + match possible_next_dimension { + Some(next_dimension) => { + let mut final_list: Value = Value::Array(vec![]); + + for _ in 0..current_dimension.len as usize { + if dimensions.get(dimension_index + 1).is_some() { + let inner_pylist = inner_pythondto_array_to_serde( + dimensions, + &data[lower_bound..next_dimension.len as usize + lower_bound], + dimension_index + 1, + 0, + )?; + match final_list { + Value::Array(ref mut array) => array.push(inner_pylist), + _ => unreachable!(), + } + lower_bound += next_dimension.len as usize; + }; + } + + return Ok(final_list); + } + None => { + return data.iter().map(|x| x.to_serde_value()).collect(); + } + } + } + + Ok(Value::Array(vec![])) +} diff --git a/src/value_converter/funcs/to_python.rs b/src/value_converter/to_python.rs similarity index 88% rename from src/value_converter/funcs/to_python.rs rename to src/value_converter/to_python.rs index e65a0085..c0801bac 100644 --- a/src/value_converter/funcs/to_python.rs +++ b/src/value_converter/to_python.rs @@ -17,13 +17,12 @@ use pyo3::{ }; use crate::{ - exceptions::rust_errors::{RustPSQLDriverError, RustPSQLDriverPyResult}, + exceptions::rust_errors::{PSQLPyResult, RustPSQLDriverError}, value_converter::{ additional_types::{ Circle, Line, RustLineSegment, RustLineString, RustMacAddr6, RustMacAddr8, RustPoint, RustRect, }, - consts::KWARGS_QUERYSTRINGS, models::{ decimal::InnerDecimal, interval::InnerInterval, serde_value::InternalSerdeValue, uuid::InternalUuid, @@ -35,10 +34,7 @@ use pgvector::Vector as PgVector; /// Convert serde `Value` into Python object. /// # Errors /// May return Err Result if cannot add new value to Python Dict. -pub fn build_python_from_serde_value( - py: Python<'_>, - value: Value, -) -> RustPSQLDriverPyResult> { +pub fn build_python_from_serde_value(py: Python<'_>, value: Value) -> PSQLPyResult> { match value { Value::Array(massive) => { let mut result_vec: Vec> = vec![]; @@ -76,43 +72,11 @@ pub fn build_python_from_serde_value( } } -fn parse_kwargs_qs(querystring: &str) -> (String, Vec) { - let re = regex::Regex::new(r"\$\(([^)]+)\)p").unwrap(); - - { - let kq_read = KWARGS_QUERYSTRINGS.read().unwrap(); - let qs = kq_read.get(querystring); - - if let Some(qs) = qs { - return qs.clone(); - } - }; - - let mut counter = 0; - let mut sequence = Vec::new(); - - let result = re.replace_all(querystring, |caps: ®ex::Captures| { - let account_id = caps[1].to_string(); - - sequence.push(account_id.clone()); - counter += 1; - - format!("${}", &counter) - }); - - let mut kq_write = KWARGS_QUERYSTRINGS.write().unwrap(); - kq_write.insert( - querystring.to_string(), - (result.clone().into(), sequence.clone()), - ); - (result.into(), sequence) -} - fn composite_field_postgres_to_py<'a, T: FromSql<'a>>( type_: &Type, buf: &mut &'a [u8], is_simple: bool, -) -> RustPSQLDriverPyResult { +) -> PSQLPyResult { if is_simple { return T::from_sql_nullable(type_, Some(buf)).map_err(|err| { RustPSQLDriverError::RustToPyValueConversionError(format!( @@ -196,7 +160,7 @@ fn postgres_bytes_to_py( type_: &Type, buf: &mut &[u8], is_simple: bool, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { match *type_ { // ---------- Bytes Types ---------- // Convert BYTEA type into Vector, then into PyBytes @@ -524,7 +488,7 @@ pub fn other_postgres_bytes_to_py( type_: &Type, buf: &mut &[u8], is_simple: bool, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { if type_.name() == "vector" { let vector = composite_field_postgres_to_py::>(type_, buf, is_simple)?; match vector { @@ -550,7 +514,7 @@ pub fn composite_postgres_to_py( fields: &Vec, buf: &mut &[u8], custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let result_py_dict: Bound<'_, PyDict> = PyDict::new_bound(py); let num_fields = postgres_types::private::read_be_i32(buf).map_err(|err| { @@ -619,7 +583,7 @@ pub fn raw_bytes_data_process( column_name: &str, column_type: &Type, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { if let Some(custom_decoders) = custom_decoders { let py_encoder_func = custom_decoders .bind(py) @@ -658,7 +622,7 @@ pub fn postgres_to_py( column: &Column, column_i: usize, custom_decoders: &Option>, -) -> RustPSQLDriverPyResult> { +) -> PSQLPyResult> { let raw_bytes_data = row.col_buffer(column_i); if let Some(mut raw_bytes_data) = raw_bytes_data { return raw_bytes_data_process( @@ -671,41 +635,3 @@ pub fn postgres_to_py( } Ok(py.None()) } - -/// Convert Python sequence to Rust vector. -/// Also it checks that sequence has set/list/tuple type. -/// -/// # Errors -/// -/// May return error if cannot convert Python type into Rust one. -/// May return error if parameters type isn't correct. -fn py_sequence_to_rust(bind_parameters: &Bound) -> RustPSQLDriverPyResult>> { - let mut coord_values_sequence_vec: Vec> = vec![]; - - if bind_parameters.is_instance_of::() { - let bind_pyset_parameters = bind_parameters.downcast::().unwrap(); - - for one_parameter in bind_pyset_parameters { - let extracted_parameter = one_parameter.extract::>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - coord_values_sequence_vec.push(extracted_parameter); - } - } else if bind_parameters.is_instance_of::() - | bind_parameters.is_instance_of::() - { - coord_values_sequence_vec = bind_parameters.extract::>>().map_err(|_| { - RustPSQLDriverError::PyToRustValueConversionError( - format!("Error on sequence type extraction, please use correct list/tuple/set, {bind_parameters}") - ) - })?; - } else { - return Err(RustPSQLDriverError::PyToRustValueConversionError(format!( - "Invalid sequence type, please use list/tuple/set, {bind_parameters}" - ))); - }; - - Ok::>, RustPSQLDriverError>(coord_values_sequence_vec) -} diff --git a/src/value_converter/traits.rs b/src/value_converter/traits.rs new file mode 100644 index 00000000..d9d3512e --- /dev/null +++ b/src/value_converter/traits.rs @@ -0,0 +1,17 @@ +use postgres_types::Type; +use pyo3::PyAny; + +use crate::exceptions::rust_errors::PSQLPyResult; + +use super::dto::enums::PythonDTO; + +pub trait ToPythonDTO { + fn to_python_dto(python_param: &pyo3::Bound<'_, PyAny>) -> PSQLPyResult; +} + +pub trait ToPythonDTOArray { + fn to_python_dto( + python_param: &pyo3::Bound<'_, PyAny>, + array_type_: Type, + ) -> PSQLPyResult; +}