Skip to content

Commit 1839c7f

Browse files
committed
Pass connection info instead of docker container in conftest
easier to also replace with an address of your own
1 parent 87b992b commit 1839c7f

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

graphdatascience/tests/integrationV2/procedure_surface/arrow/conftest.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from graphdatascience import QueryRunner
1010
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
11+
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1112
from graphdatascience.tests.integrationV2.procedure_surface.conftest import (
13+
GdsSessionConnectionInfo,
1214
create_arrow_client,
1315
create_db_query_runner,
1416
start_database,
@@ -19,22 +21,22 @@
1921

2022

2123
@pytest.fixture(scope="package")
22-
def session_container(
24+
def session_connection(
2325
network: Network, password_dir: Path, logs_dir: Path, inside_ci: bool
24-
) -> Generator[DockerContainer, None, None]:
26+
) -> Generator[GdsSessionConnectionInfo, None, None]:
2527
yield from start_session(inside_ci, logs_dir, network, password_dir)
2628

2729

2830
@pytest.fixture(scope="package")
29-
def arrow_client(session_container: DockerContainer) -> AuthenticatedArrowClient:
30-
return create_arrow_client(session_container)
31+
def arrow_client(session_connection: DockerContainer) -> AuthenticatedArrowClient:
32+
return create_arrow_client(session_connection)
3133

3234

3335
@pytest.fixture(scope="package")
34-
def neo4j_container(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DockerContainer, None, None]:
36+
def neo4j_connection(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DbmsConnectionInfo, None, None]:
3537
yield from start_database(inside_ci, logs_dir, network)
3638

3739

3840
@pytest.fixture(scope="package")
39-
def query_runner(neo4j_container: DockerContainer) -> Generator[QueryRunner, None, None]:
40-
yield from create_db_query_runner(neo4j_container)
41+
def query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[QueryRunner, None, None]:
42+
yield from create_db_query_runner(neo4j_connection)

graphdatascience/tests/integrationV2/procedure_surface/conftest.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import logging
22
import os
3+
from dataclasses import dataclass
34
from datetime import datetime
45
from pathlib import Path
56
from typing import Generator
@@ -14,10 +15,18 @@
1415
from graphdatascience.arrow_client.arrow_info import ArrowInfo
1516
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
1617
from graphdatascience.query_runner.neo4j_query_runner import Neo4jQueryRunner
18+
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1719

1820
LOGGER = logging.getLogger(__name__)
1921

2022

23+
@dataclass
24+
class GdsSessionConnectionInfo:
25+
host: str
26+
arrow_port: int
27+
bolt_port: int
28+
29+
2130
@pytest.fixture(scope="package")
2231
def password_dir(tmp_path_factory: pytest.TempPathFactory) -> Generator[Path, None, None]:
2332
"""Create a temporary file and return its path."""
@@ -45,9 +54,12 @@ def latest_neo4j_version() -> str:
4554
return previous_month.strftime("%Y.%m.0")
4655

4756

48-
def start_session(inside_ci: bool, logs_dir: Path, network: Network, password_dir: Path) -> Generator[str, None, None]:
49-
if session_uri := os.environ.get("GDS_SESSION_URI"):
50-
yield session_uri
57+
def start_session(
58+
inside_ci: bool, logs_dir: Path, network: Network, password_dir: Path
59+
) -> Generator[GdsSessionConnectionInfo, None, None]:
60+
if (session_uri := os.environ.get("GDS_SESSION_URI")) is not None:
61+
uri_parts = session_uri.split(":")
62+
yield GdsSessionConnectionInfo(host=uri_parts[0], arrow_port=8491, bolt_port=int(uri_parts[1]))
5163
return
5264

5365
session_image = os.getenv(
@@ -68,7 +80,11 @@ def start_session(inside_ci: bool, logs_dir: Path, network: Network, password_di
6880
session_container = session_container.with_network(network).with_network_aliases("gds-session")
6981
with session_container as session_container:
7082
wait_for_logs(session_container, "Running GDS tasks: 0")
71-
yield f"{session_container.get_container_host_ip()}:{session_container.get_exposed_port(8491)}"
83+
yield GdsSessionConnectionInfo(
84+
host=session_container.get_container_host_ip(),
85+
arrow_port=session_container.get_exposed_port(8491),
86+
bolt_port=-1, # not used in tests
87+
)
7288
stdout, stderr = session_container.get_logs()
7389

7490
if stderr:
@@ -82,18 +98,18 @@ def start_session(inside_ci: bool, logs_dir: Path, network: Network, password_di
8298
f.write(stdout.decode("utf-8"))
8399

84100

85-
def create_arrow_client(session_url: str) -> AuthenticatedArrowClient:
101+
def create_arrow_client(session_uri: GdsSessionConnectionInfo) -> AuthenticatedArrowClient:
86102
"""Create an authenticated Arrow client connected to the session container."""
87103

88104
return AuthenticatedArrowClient.create(
89-
arrow_info=ArrowInfo(session_url, True, True, ["v1", "v2"]),
105+
arrow_info=ArrowInfo(f"{session_uri.host}:{session_uri.arrow_port}", True, True, ["v1", "v2"]),
90106
auth=UsernamePasswordAuthentication("neo4j", "password"),
91107
encrypted=False,
92108
advertised_listen_address=("gds-session", 8491),
93109
)
94110

95111

96-
def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generator[DockerContainer, None, None]:
112+
def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generator[DbmsConnectionInfo, None, None]:
97113
default_neo4j_image = (
98114
f"europe-west1-docker.pkg.dev/neo4j-aura-image-artifacts/aura/neo4j-enterprise:{latest_neo4j_version()}"
99115
)
@@ -116,7 +132,11 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat
116132
)
117133
with db_container as db_container:
118134
wait_for_logs(db_container, "Started.")
119-
yield db_container
135+
yield DbmsConnectionInfo(
136+
uri=f"{db_container.get_container_host_ip()}:{db_container.get_exposed_port(7687)}",
137+
username="neo4j",
138+
password="password",
139+
)
120140
stdout, stderr = db_container.get_logs()
121141

122142
if stderr:
@@ -130,11 +150,9 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat
130150
f.write(stdout.decode("utf-8"))
131151

132152

133-
def create_db_query_runner(neo4j_container: DockerContainer) -> Generator[Neo4jQueryRunner, None, None]:
134-
host = neo4j_container.get_container_host_ip()
135-
port = 7687
153+
def create_db_query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[Neo4jQueryRunner, None, None]:
136154
query_runner = Neo4jQueryRunner.create_for_db(
137-
f"bolt://{host}:{port}",
155+
f"bolt://{neo4j_connection.uri}",
138156
("neo4j", "password"),
139157
)
140158
yield query_runner

graphdatascience/tests/integrationV2/procedure_surface/session/conftest.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
from typing import Generator
33

44
import pytest
5-
from testcontainers.core.container import DockerContainer
65
from testcontainers.core.network import Network
76

87
from graphdatascience import QueryRunner
98
from graphdatascience.arrow_client.authenticated_flight_client import AuthenticatedArrowClient
9+
from graphdatascience.session.dbms_connection_info import DbmsConnectionInfo
1010
from graphdatascience.tests.integrationV2.procedure_surface.conftest import (
11+
GdsSessionConnectionInfo,
1112
create_arrow_client,
1213
create_db_query_runner,
1314
start_database,
@@ -16,20 +17,22 @@
1617

1718

1819
@pytest.fixture(scope="package")
19-
def session_url(network: Network, password_dir: Path, logs_dir: Path, inside_ci: bool) -> Generator[str, None, None]:
20+
def session_connection(
21+
network: Network, password_dir: Path, logs_dir: Path, inside_ci: bool
22+
) -> Generator[GdsSessionConnectionInfo, None, None]:
2023
yield from start_session(inside_ci, logs_dir, network, password_dir)
2124

2225

2326
@pytest.fixture(scope="package")
24-
def arrow_client(session_url: str) -> AuthenticatedArrowClient:
25-
return create_arrow_client(session_url)
27+
def arrow_client(session_connection: GdsSessionConnectionInfo) -> AuthenticatedArrowClient:
28+
return create_arrow_client(session_connection)
2629

2730

2831
@pytest.fixture(scope="package")
29-
def neo4j_container(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DockerContainer, None, None]:
32+
def neo4j_connection(network: Network, logs_dir: Path, inside_ci: bool) -> Generator[DbmsConnectionInfo, None, None]:
3033
yield from start_database(inside_ci, logs_dir, network)
3134

3235

3336
@pytest.fixture(scope="package")
34-
def db_query_runner(neo4j_container: DockerContainer) -> Generator[QueryRunner, None, None]:
35-
yield from create_db_query_runner(neo4j_container)
37+
def db_query_runner(neo4j_connection: DbmsConnectionInfo) -> Generator[QueryRunner, None, None]:
38+
yield from create_db_query_runner(neo4j_connection)

0 commit comments

Comments
 (0)