11import logging
22import os
3+ from dataclasses import dataclass
34from datetime import datetime
45from pathlib import Path
56from typing import Generator
1415from graphdatascience .arrow_client .arrow_info import ArrowInfo
1516from graphdatascience .arrow_client .authenticated_flight_client import AuthenticatedArrowClient
1617from graphdatascience .query_runner .neo4j_query_runner import Neo4jQueryRunner
18+ from graphdatascience .session .dbms_connection_info import DbmsConnectionInfo
1719
1820LOGGER = logging .getLogger (__name__ )
1921
2022
23+ @dataclass
24+ class GdsSessionConnectionInfo :
25+ host : str
26+ arrow_port : int
27+ bolt_port : int
28+
29+
2130@pytest .fixture (scope = "package" )
2231def password_dir (tmp_path_factory : pytest .TempPathFactory ) -> Generator [Path , None , None ]:
2332 """Create a temporary file and return its path."""
@@ -45,9 +54,12 @@ def latest_neo4j_version() -> str:
4554 return previous_month .strftime ("%Y.%m.0" )
4655
4756
48- def start_session (inside_ci : bool , logs_dir : Path , network : Network , password_dir : Path ) -> Generator [str , None , None ]:
49- if session_uri := os .environ .get ("GDS_SESSION_URI" ):
50- yield session_uri
57+ def start_session (
58+ inside_ci : bool , logs_dir : Path , network : Network , password_dir : Path
59+ ) -> Generator [GdsSessionConnectionInfo , None , None ]:
60+ if (session_uri := os .environ .get ("GDS_SESSION_URI" )) is not None :
61+ uri_parts = session_uri .split (":" )
62+ yield GdsSessionConnectionInfo (host = uri_parts [0 ], arrow_port = 8491 , bolt_port = int (uri_parts [1 ]))
5163 return
5264
5365 session_image = os .getenv (
@@ -68,7 +80,11 @@ def start_session(inside_ci: bool, logs_dir: Path, network: Network, password_di
6880 session_container = session_container .with_network (network ).with_network_aliases ("gds-session" )
6981 with session_container as session_container :
7082 wait_for_logs (session_container , "Running GDS tasks: 0" )
71- yield f"{ session_container .get_container_host_ip ()} :{ session_container .get_exposed_port (8491 )} "
83+ yield GdsSessionConnectionInfo (
84+ host = session_container .get_container_host_ip (),
85+ arrow_port = session_container .get_exposed_port (8491 ),
86+ bolt_port = - 1 , # not used in tests
87+ )
7288 stdout , stderr = session_container .get_logs ()
7389
7490 if stderr :
@@ -82,18 +98,18 @@ def start_session(inside_ci: bool, logs_dir: Path, network: Network, password_di
8298 f .write (stdout .decode ("utf-8" ))
8399
84100
85- def create_arrow_client (session_url : str ) -> AuthenticatedArrowClient :
101+ def create_arrow_client (session_uri : GdsSessionConnectionInfo ) -> AuthenticatedArrowClient :
86102 """Create an authenticated Arrow client connected to the session container."""
87103
88104 return AuthenticatedArrowClient .create (
89- arrow_info = ArrowInfo (session_url , True , True , ["v1" , "v2" ]),
105+ arrow_info = ArrowInfo (f" { session_uri . host } : { session_uri . arrow_port } " , True , True , ["v1" , "v2" ]),
90106 auth = UsernamePasswordAuthentication ("neo4j" , "password" ),
91107 encrypted = False ,
92108 advertised_listen_address = ("gds-session" , 8491 ),
93109 )
94110
95111
96- def start_database (inside_ci : bool , logs_dir : Path , network : Network ) -> Generator [DockerContainer , None , None ]:
112+ def start_database (inside_ci : bool , logs_dir : Path , network : Network ) -> Generator [DbmsConnectionInfo , None , None ]:
97113 default_neo4j_image = (
98114 f"europe-west1-docker.pkg.dev/neo4j-aura-image-artifacts/aura/neo4j-enterprise:{ latest_neo4j_version ()} "
99115 )
@@ -116,7 +132,11 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat
116132 )
117133 with db_container as db_container :
118134 wait_for_logs (db_container , "Started." )
119- yield db_container
135+ yield DbmsConnectionInfo (
136+ uri = f"{ db_container .get_container_host_ip ()} :{ db_container .get_exposed_port (7687 )} " ,
137+ username = "neo4j" ,
138+ password = "password" ,
139+ )
120140 stdout , stderr = db_container .get_logs ()
121141
122142 if stderr :
@@ -130,11 +150,9 @@ def start_database(inside_ci: bool, logs_dir: Path, network: Network) -> Generat
130150 f .write (stdout .decode ("utf-8" ))
131151
132152
133- def create_db_query_runner (neo4j_container : DockerContainer ) -> Generator [Neo4jQueryRunner , None , None ]:
134- host = neo4j_container .get_container_host_ip ()
135- port = 7687
153+ def create_db_query_runner (neo4j_connection : DbmsConnectionInfo ) -> Generator [Neo4jQueryRunner , None , None ]:
136154 query_runner = Neo4jQueryRunner .create_for_db (
137- f"bolt://{ host } : { port } " ,
155+ f"bolt://{ neo4j_connection . uri } " ,
138156 ("neo4j" , "password" ),
139157 )
140158 yield query_runner
0 commit comments