Skip to content

Commit 4518919

Browse files
committed
Convert model classes to use keyword arguments
1 parent c84b047 commit 4518919

File tree

6 files changed

+203
-108
lines changed

6 files changed

+203
-108
lines changed

HISTORY.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ History
1010
* BREAKING: The ``raw`` attribute on the model classes has been replaced
1111
with a ``to_dict()`` method. This can be used to get a representation of
1212
the object that is suitable for serialization.
13-
* BREAKING: The record classes now require all arguments other than ``locales``
14-
to be keyword arguments.
13+
* BREAKING: The model and record classes now require all arguments other than
14+
``locales`` to be keyword arguments.
1515
* BREAKING: ``geoip2.mixins`` has been made internal. This normally would not
1616
have been used by external code.
1717
* IMPORTANT: Python 3.9 or greater is required. If you are using an older

geoip2/database.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,9 @@ def _model_for(
252252
ip_address: IPAddress,
253253
) -> Union[Country, Enterprise, City]:
254254
(record, prefix_len) = self._get(types, ip_address)
255-
traits = record.setdefault("traits", {})
256-
traits["ip_address"] = ip_address
257-
traits["prefix_len"] = prefix_len
258-
return model_class(record, locales=self._locales)
255+
return model_class(
256+
self._locales, ip_address=ip_address, prefix_len=prefix_len, **record
257+
)
259258

260259
def _flat_model_for(
261260
self,
@@ -266,9 +265,7 @@ def _flat_model_for(
266265
ip_address: IPAddress,
267266
) -> Union[ConnectionType, ISP, AnonymousIP, Domain, ASN]:
268267
(record, prefix_len) = self._get(types, ip_address)
269-
record["ip_address"] = ip_address
270-
record["prefix_len"] = prefix_len
271-
return model_class(record)
268+
return model_class(ip_address=ip_address, prefix_len=prefix_len, **record)
272269

273270
def metadata(
274271
self,

geoip2/models.py

Lines changed: 155 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
1212
"""
1313

14-
# pylint: disable=too-many-instance-attributes,too-few-public-methods
14+
# pylint: disable=too-many-instance-attributes,too-few-public-methods,too-many-arguments
1515
import ipaddress
1616
from abc import ABCMeta
17-
from typing import Any, cast, Dict, Optional, Sequence, Union
17+
from typing import Dict, List, Optional, Sequence, Union
1818

1919
import geoip2.records
2020
from geoip2._internal import Model
21+
from geoip2.types import IPAddress
2122

2223

2324
class Country(Model):
@@ -76,30 +77,44 @@ class Country(Model):
7677
traits: geoip2.records.Traits
7778

7879
def __init__(
79-
self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None
80+
self,
81+
locales: Optional[Sequence[str]],
82+
*,
83+
continent: Optional[Dict] = None,
84+
country: Optional[Dict] = None,
85+
ip_address: Optional[IPAddress] = None,
86+
maxmind: Optional[Dict] = None,
87+
prefix_len: Optional[int] = None,
88+
registered_country: Optional[Dict] = None,
89+
represented_country: Optional[Dict] = None,
90+
traits: Optional[Dict] = None,
91+
**_,
8092
) -> None:
81-
if locales is None:
82-
locales = ["en"]
8393
self._locales = locales
84-
self.continent = geoip2.records.Continent(
85-
locales, **raw_response.get("continent", {})
86-
)
87-
self.country = geoip2.records.Country(
88-
locales, **raw_response.get("country", {})
89-
)
94+
self.continent = geoip2.records.Continent(locales, **(continent or {}))
95+
self.country = geoip2.records.Country(locales, **(country or {}))
9096
self.registered_country = geoip2.records.Country(
91-
locales, **raw_response.get("registered_country", {})
97+
locales, **(registered_country or {})
9298
)
9399
self.represented_country = geoip2.records.RepresentedCountry(
94-
locales, **raw_response.get("represented_country", {})
100+
locales, **(represented_country or {})
95101
)
96102

97-
self.maxmind = geoip2.records.MaxMind(**raw_response.get("maxmind", {}))
103+
self.maxmind = geoip2.records.MaxMind(**(maxmind or {}))
104+
105+
traits = traits or {}
106+
if ip_address is not None:
107+
traits["ip_address"] = ip_address
108+
if prefix_len is not None:
109+
traits["prefix_len"] = prefix_len
98110

99-
self.traits = geoip2.records.Traits(**raw_response.get("traits", {}))
111+
self.traits = geoip2.records.Traits(**traits)
100112

101113
def __repr__(self) -> str:
102-
return f"{self.__module__}.{self.__class__.__name__}({self.to_dict()}, {self._locales})"
114+
return (
115+
f"{self.__module__}.{self.__class__.__name__}({self._locales}, "
116+
f"{', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})"
117+
)
103118

104119

105120
class City(Country):
@@ -179,15 +194,38 @@ class City(Country):
179194
subdivisions: geoip2.records.Subdivisions
180195

181196
def __init__(
182-
self, raw_response: Dict[str, Any], locales: Optional[Sequence[str]] = None
197+
self,
198+
locales: Optional[Sequence[str]],
199+
*,
200+
city: Optional[Dict] = None,
201+
continent: Optional[Dict] = None,
202+
country: Optional[Dict] = None,
203+
location: Optional[Dict] = None,
204+
ip_address: Optional[IPAddress] = None,
205+
maxmind: Optional[Dict] = None,
206+
postal: Optional[Dict] = None,
207+
prefix_len: Optional[int] = None,
208+
registered_country: Optional[Dict] = None,
209+
represented_country: Optional[Dict] = None,
210+
subdivisions: Optional[List[Dict]] = None,
211+
traits: Optional[Dict] = None,
212+
**_,
183213
) -> None:
184-
super().__init__(raw_response, locales)
185-
self.city = geoip2.records.City(locales, **raw_response.get("city", {}))
186-
self.location = geoip2.records.Location(**raw_response.get("location", {}))
187-
self.postal = geoip2.records.Postal(**raw_response.get("postal", {}))
188-
self.subdivisions = geoip2.records.Subdivisions(
189-
locales, *raw_response.get("subdivisions", [])
214+
super().__init__(
215+
locales,
216+
continent=continent,
217+
country=country,
218+
ip_address=ip_address,
219+
maxmind=maxmind,
220+
prefix_len=prefix_len,
221+
registered_country=registered_country,
222+
represented_country=represented_country,
223+
traits=traits,
190224
)
225+
self.city = geoip2.records.City(locales, **(city or {}))
226+
self.location = geoip2.records.Location(**(location or {}))
227+
self.postal = geoip2.records.Postal(**(postal or {}))
228+
self.subdivisions = geoip2.records.Subdivisions(locales, *(subdivisions or []))
191229

192230

193231
class Insights(City):
@@ -325,20 +363,28 @@ class SimpleModel(Model, metaclass=ABCMeta):
325363
_network: Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]
326364
_prefix_len: int
327365

328-
def __init__(self, raw: Dict[str, Union[bool, str, int]]) -> None:
329-
if network := raw.get("network"):
366+
def __init__(
367+
self,
368+
ip_address: Optional[str],
369+
network: Optional[str],
370+
prefix_len: Optional[int],
371+
) -> None:
372+
if network:
330373
self._network = ipaddress.ip_network(network, False)
331374
self._prefix_len = self._network.prefixlen
332375
else:
333376
# This case is for MMDB lookups where performance is paramount.
334377
# This is why we don't generate the network unless .network is
335378
# used.
336379
self._network = None
337-
self._prefix_len = cast(int, raw.get("prefix_len"))
338-
self.ip_address = cast(str, raw.get("ip_address"))
380+
self._prefix_len = prefix_len
381+
self.ip_address = ip_address
339382

340383
def __repr__(self) -> str:
341-
return f"{self.__module__}.{self.__class__.__name__}({self.to_dict()})"
384+
return (
385+
f"{self.__module__}.{self.__class__.__name__}"
386+
f"({', '.join(f'{k}={repr(v)}' for k, v in self.to_dict().items())})"
387+
)
342388

343389
@property
344390
def network(self) -> Optional[Union[ipaddress.IPv4Network, ipaddress.IPv6Network]]:
@@ -427,14 +473,27 @@ class AnonymousIP(SimpleModel):
427473
is_residential_proxy: bool
428474
is_tor_exit_node: bool
429475

430-
def __init__(self, raw: Dict[str, bool]) -> None:
431-
super().__init__(raw) # type: ignore
432-
self.is_anonymous = raw.get("is_anonymous", False)
433-
self.is_anonymous_vpn = raw.get("is_anonymous_vpn", False)
434-
self.is_hosting_provider = raw.get("is_hosting_provider", False)
435-
self.is_public_proxy = raw.get("is_public_proxy", False)
436-
self.is_residential_proxy = raw.get("is_residential_proxy", False)
437-
self.is_tor_exit_node = raw.get("is_tor_exit_node", False)
476+
def __init__(
477+
self,
478+
*,
479+
is_anonymous: bool = False,
480+
is_anonymous_vpn: bool = False,
481+
is_hosting_provider: bool = False,
482+
is_public_proxy: bool = False,
483+
is_residential_proxy: bool = False,
484+
is_tor_exit_node: bool = False,
485+
ip_address: Optional[str] = None,
486+
network: Optional[str] = None,
487+
prefix_len: Optional[int] = None,
488+
**_,
489+
) -> None:
490+
super().__init__(ip_address, network, prefix_len)
491+
self.is_anonymous = is_anonymous
492+
self.is_anonymous_vpn = is_anonymous_vpn
493+
self.is_hosting_provider = is_hosting_provider
494+
self.is_public_proxy = is_public_proxy
495+
self.is_residential_proxy = is_residential_proxy
496+
self.is_tor_exit_node = is_tor_exit_node
438497

439498

440499
class ASN(SimpleModel):
@@ -474,14 +533,19 @@ class ASN(SimpleModel):
474533
autonomous_system_organization: Optional[str]
475534

476535
# pylint:disable=too-many-arguments,too-many-positional-arguments
477-
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
478-
super().__init__(raw)
479-
self.autonomous_system_number = cast(
480-
Optional[int], raw.get("autonomous_system_number")
481-
)
482-
self.autonomous_system_organization = cast(
483-
Optional[str], raw.get("autonomous_system_organization")
484-
)
536+
def __init__(
537+
self,
538+
*,
539+
autonomous_system_number: Optional[int] = None,
540+
autonomous_system_organization: Optional[str] = None,
541+
ip_address: Optional[str] = None,
542+
network: Optional[str] = None,
543+
prefix_len: Optional[int] = None,
544+
**_,
545+
) -> None:
546+
super().__init__(ip_address, network, prefix_len)
547+
self.autonomous_system_number = autonomous_system_number
548+
self.autonomous_system_organization = autonomous_system_organization
485549

486550

487551
class ConnectionType(SimpleModel):
@@ -520,9 +584,17 @@ class ConnectionType(SimpleModel):
520584

521585
connection_type: Optional[str]
522586

523-
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
524-
super().__init__(raw)
525-
self.connection_type = cast(Optional[str], raw.get("connection_type"))
587+
def __init__(
588+
self,
589+
*,
590+
connection_type: Optional[str] = None,
591+
ip_address: Optional[str] = None,
592+
network: Optional[str] = None,
593+
prefix_len: Optional[int] = None,
594+
**_,
595+
) -> None:
596+
super().__init__(ip_address, network, prefix_len)
597+
self.connection_type = connection_type
526598

527599

528600
class Domain(SimpleModel):
@@ -554,9 +626,17 @@ class Domain(SimpleModel):
554626

555627
domain: Optional[str]
556628

557-
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
558-
super().__init__(raw)
559-
self.domain = cast(Optional[str], raw.get("domain"))
629+
def __init__(
630+
self,
631+
*,
632+
domain: Optional[str] = None,
633+
ip_address: Optional[str] = None,
634+
network: Optional[str] = None,
635+
prefix_len: Optional[int] = None,
636+
**_,
637+
) -> None:
638+
super().__init__(ip_address, network, prefix_len)
639+
self.domain = domain
560640

561641

562642
class ISP(ASN):
@@ -626,9 +706,28 @@ class ISP(ASN):
626706
organization: Optional[str]
627707

628708
# pylint:disable=too-many-arguments,too-many-positional-arguments
629-
def __init__(self, raw: Dict[str, Union[str, int]]) -> None:
630-
super().__init__(raw)
631-
self.isp = cast(Optional[str], raw.get("isp"))
632-
self.mobile_country_code = cast(Optional[str], raw.get("mobile_country_code"))
633-
self.mobile_network_code = cast(Optional[str], raw.get("mobile_network_code"))
634-
self.organization = cast(Optional[str], raw.get("organization"))
709+
def __init__(
710+
self,
711+
*,
712+
autonomous_system_number: Optional[int] = None,
713+
autonomous_system_organization: Optional[str] = None,
714+
isp: Optional[str] = None,
715+
mobile_country_code: Optional[str] = None,
716+
mobile_network_code: Optional[str] = None,
717+
organization: Optional[str] = None,
718+
ip_address: Optional[str] = None,
719+
network: Optional[str] = None,
720+
prefix_len: Optional[int] = None,
721+
**_,
722+
) -> None:
723+
super().__init__(
724+
autonomous_system_number=autonomous_system_number,
725+
autonomous_system_organization=autonomous_system_organization,
726+
ip_address=ip_address,
727+
network=network,
728+
prefix_len=prefix_len,
729+
)
730+
self.isp = isp
731+
self.mobile_country_code = mobile_country_code
732+
self.mobile_network_code = mobile_network_code
733+
self.organization = organization

geoip2/webservice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ async def _response_for(
350350
if status != 200:
351351
raise self._exception_for_error(status, content_type, body, uri)
352352
decoded_body = self._handle_success(body, uri)
353-
return model_class(decoded_body, locales=self._locales)
353+
return model_class(self._locales, **decoded_body)
354354

355355
async def close(self):
356356
"""Close underlying session
@@ -499,7 +499,7 @@ def _response_for(
499499
if status != 200:
500500
raise self._exception_for_error(status, content_type, body, uri)
501501
decoded_body = self._handle_success(body, uri)
502-
return model_class(decoded_body, locales=self._locales)
502+
return model_class(self._locales, **decoded_body)
503503

504504
def close(self):
505505
"""Close underlying session

tests/database_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_connection_type(self) -> None:
161161

162162
self.assertRegex(
163163
str(record),
164-
r"ConnectionType\(\{.*Cellular.*\}\)",
164+
r"ConnectionType\(.*Cellular.*\)",
165165
"ConnectionType str representation is reasonable",
166166
)
167167

@@ -197,7 +197,7 @@ def test_domain(self) -> None:
197197

198198
self.assertRegex(
199199
str(record),
200-
r"Domain\(\{.*maxmind.com.*\}\)",
200+
r"Domain\(.*maxmind.com.*\)",
201201
"Domain str representation is reasonable",
202202
)
203203

@@ -247,7 +247,7 @@ def test_isp(self) -> None:
247247

248248
self.assertRegex(
249249
str(record),
250-
r"ISP\(\{.*Telstra.*\}\)",
250+
r"ISP\(.*Telstra.*\)",
251251
"ISP str representation is reasonable",
252252
)
253253

0 commit comments

Comments
 (0)