1- from dataclasses import asdict , dataclass , field , fields
1+ from dataclasses import asdict , dataclass , field , fields , is_dataclass
22from functools import cached_property
33from json import loads
44from logging import warning
77from re import split
88from subprocess import CompletedProcess
99from subprocess import run as subprocess_run
10+ from types import TracebackType
1011from typing import Any , Callable , Literal , Optional , TypeVar , Union , cast
1112from urllib .error import HTTPError , URLError
1213from urllib .request import urlopen
1819_WARNINGS = {"DOCKER_COMPOSE_GET_CONFIG" : "get_config is experimental, see testcontainers/testcontainers-python#669" }
1920
2021
21- def _ignore_properties (cls : type [_IPT ], dict_ : any ) -> _IPT :
22+ def _ignore_properties (cls : type [_IPT ], dict_ : Any ) -> _IPT :
2223 """omits extra fields like @JsonIgnoreProperties(ignoreUnknown = true)
2324
2425 https://gist.github.com/alexanderankin/2a4549ac03554a31bef6eaaf2eaf7fd5"""
2526 if isinstance (dict_ , cls ):
2627 return dict_
28+ if not is_dataclass (cls ):
29+ raise TypeError (f"Expected a dataclass type, got { cls } " )
2730 class_fields = {f .name for f in fields (cls )}
2831 filtered = {k : v for k , v in dict_ .items () if k in class_fields }
29- return cls (** filtered )
32+ return cast ( "_IPT" , cls (** filtered ) )
3033
3134
3235@dataclass
33- class PublishedPort :
36+ class PublishedPortModel :
3437 """
3538 Class that represents the response we get from compose when inquiring status
3639 via `DockerCompose.get_running_containers()`.
3740 """
3841
3942 URL : Optional [str ] = None
40- TargetPort : Optional [str ] = None
41- PublishedPort : Optional [str ] = None
43+ TargetPort : Optional [int ] = None
44+ PublishedPort : Optional [int ] = None
4245 Protocol : Optional [str ] = None
4346
44- def normalize (self ):
47+ def normalize (self ) -> "PublishedPortModel" :
4548 url_not_usable = system () == "Windows" and self .URL == "0.0.0.0"
4649 if url_not_usable :
4750 self_dict = asdict (self )
4851 self_dict .update ({"URL" : "127.0.0.1" })
49- return PublishedPort (** self_dict )
52+ return PublishedPortModel (** self_dict )
5053 return self
5154
5255
@@ -75,19 +78,19 @@ class ComposeContainer:
7578 Service : Optional [str ] = None
7679 State : Optional [str ] = None
7780 Health : Optional [str ] = None
78- ExitCode : Optional [str ] = None
79- Publishers : list [PublishedPort ] = field (default_factory = list )
81+ ExitCode : Optional [int ] = None
82+ Publishers : list [PublishedPortModel ] = field (default_factory = list )
8083
81- def __post_init__ (self ):
84+ def __post_init__ (self ) -> None :
8285 if self .Publishers :
83- self .Publishers = [_ignore_properties (PublishedPort , p ) for p in self .Publishers ]
86+ self .Publishers = [_ignore_properties (PublishedPortModel , p ) for p in self .Publishers ]
8487
8588 def get_publisher (
8689 self ,
8790 by_port : Optional [int ] = None ,
8891 by_host : Optional [str ] = None ,
89- prefer_ip_version : Literal ["IPV4 " , "IPv6" ] = "IPv4" ,
90- ) -> PublishedPort :
92+ prefer_ip_version : Literal ["IPv4 " , "IPv6" ] = "IPv4" ,
93+ ) -> PublishedPortModel :
9194 remaining_publishers = self .Publishers
9295
9396 remaining_publishers = [r for r in remaining_publishers if self ._matches_protocol (prefer_ip_version , r )]
@@ -109,8 +112,9 @@ def get_publisher(
109112 )
110113
111114 @staticmethod
112- def _matches_protocol (prefer_ip_version , r ):
113- return (":" in r .URL ) is (prefer_ip_version == "IPv6" )
115+ def _matches_protocol (prefer_ip_version : str , r : PublishedPortModel ) -> bool :
116+ r_url = r .URL
117+ return (r_url is not None and ":" in r_url ) is (prefer_ip_version == "IPv6" )
114118
115119
116120@dataclass
@@ -164,7 +168,7 @@ class DockerCompose:
164168 image: "hello-world"
165169 """
166170
167- context : Union [str , PathLike ]
171+ context : Union [str , PathLike [ str ] ]
168172 compose_file_name : Optional [Union [str , list [str ]]] = None
169173 pull : bool = False
170174 build : bool = False
@@ -175,15 +179,17 @@ class DockerCompose:
175179 docker_command_path : Optional [str ] = None
176180 profiles : Optional [list [str ]] = None
177181
178- def __post_init__ (self ):
182+ def __post_init__ (self ) -> None :
179183 if isinstance (self .compose_file_name , str ):
180184 self .compose_file_name = [self .compose_file_name ]
181185
182186 def __enter__ (self ) -> "DockerCompose" :
183187 self .start ()
184188 return self
185189
186- def __exit__ (self , exc_type , exc_val , exc_tb ) -> None :
190+ def __exit__ (
191+ self , exc_type : Optional [type [BaseException ]], exc_val : Optional [BaseException ], exc_tb : Optional [TracebackType ]
192+ ) -> None :
187193 self .stop (not self .keep_volumes )
188194
189195 def docker_compose_command (self ) -> list [str ]:
@@ -235,7 +241,7 @@ def start(self) -> None:
235241
236242 self ._run_command (cmd = up_cmd )
237243
238- def stop (self , down = True ) -> None :
244+ def stop (self , down : bool = True ) -> None :
239245 """
240246 Stops the docker compose environment.
241247 """
@@ -295,7 +301,7 @@ def get_config(
295301 cmd_output = self ._run_command (cmd = config_cmd ).stdout
296302 return cast (dict [str , Any ], loads (cmd_output )) # noqa: TC006
297303
298- def get_containers (self , include_all = False ) -> list [ComposeContainer ]:
304+ def get_containers (self , include_all : bool = False ) -> list [ComposeContainer ]:
299305 """
300306 Fetch information about running containers via `docker compose ps --format json`.
301307 Available only in V2 of compose.
@@ -370,17 +376,18 @@ def exec_in_container(
370376 """
371377 if not service_name :
372378 service_name = self .get_container ().Service
373- exec_cmd = [* self .compose_command_property , "exec" , "-T" , service_name , * command ]
379+ assert service_name
380+ exec_cmd : list [str ] = [* self .compose_command_property , "exec" , "-T" , service_name , * command ]
374381 result = self ._run_command (cmd = exec_cmd )
375382
376- return ( result .stdout .decode ("utf-8" ), result .stderr .decode ("utf-8" ), result .returncode )
383+ return result .stdout .decode ("utf-8" ), result .stderr .decode ("utf-8" ), result .returncode
377384
378385 def _run_command (
379386 self ,
380387 cmd : Union [str , list [str ]],
381388 context : Optional [str ] = None ,
382389 ) -> CompletedProcess [bytes ]:
383- context = context or self .context
390+ context = context or str ( self .context )
384391 return subprocess_run (
385392 cmd ,
386393 capture_output = True ,
@@ -392,7 +399,7 @@ def get_service_port(
392399 self ,
393400 service_name : Optional [str ] = None ,
394401 port : Optional [int ] = None ,
395- ):
402+ ) -> Optional [ int ] :
396403 """
397404 Returns the mapped port for one of the services.
398405
@@ -408,13 +415,14 @@ def get_service_port(
408415 str:
409416 The mapped port on the host
410417 """
411- return self .get_container (service_name ).get_publisher (by_port = port ).normalize ().PublishedPort
418+ normalize : PublishedPortModel = self .get_container (service_name ).get_publisher (by_port = port ).normalize ()
419+ return normalize .PublishedPort
412420
413421 def get_service_host (
414422 self ,
415423 service_name : Optional [str ] = None ,
416424 port : Optional [int ] = None ,
417- ):
425+ ) -> Optional [ str ] :
418426 """
419427 Returns the host for one of the services.
420428
@@ -430,13 +438,17 @@ def get_service_host(
430438 str:
431439 The hostname for the service
432440 """
433- return self .get_container (service_name ).get_publisher (by_port = port ).normalize ().URL
441+ container : ComposeContainer = self .get_container (service_name )
442+ publisher : PublishedPortModel = container .get_publisher (by_port = port )
443+ normalize : PublishedPortModel = publisher .normalize ()
444+ url : Optional [str ] = normalize .URL
445+ return url
434446
435447 def get_service_host_and_port (
436448 self ,
437449 service_name : Optional [str ] = None ,
438450 port : Optional [int ] = None ,
439- ):
451+ ) -> tuple [ Optional [ str ], Optional [ int ]] :
440452 publisher = self .get_container (service_name ).get_publisher (by_port = port ).normalize ()
441453 return publisher .URL , publisher .PublishedPort
442454
0 commit comments