Skip to content

Commit 9083821

Browse files
authored
PYTHON-3454 Specifying a generic type for a collection does not correctly enforce type safety when inserting data (#1081)
1 parent f08776c commit 9083821

File tree

10 files changed

+58
-28
lines changed

10 files changed

+58
-28
lines changed

.github/workflows/test-python.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ jobs:
6767
# Test overshadowed codec_options.py file
6868
mypy --install-types --non-interactive bson/codec_options.py
6969
mypy --install-types --non-interactive --disable-error-code var-annotated --disable-error-code attr-defined --disable-error-code union-attr --disable-error-code assignment --disable-error-code no-redef --disable-error-code index --allow-redefinition --allow-untyped-globals --exclude "test/mypy_fails/*.*" test
70+
python -m pip install -U typing_extensions
71+
mypy --install-types --non-interactive test/test_mypy.py
7072
7173
linkcheck:
7274
name: Check Links

doc/examples/type_hints.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ Note that when using :class:`~bson.son.SON`, the key and value types must be giv
9292
Typed Collection
9393
----------------
9494

95-
You can use :py:class:`~typing.TypedDict` (Python 3.8+) when using a well-defined schema for the data in a :class:`~pymongo.collection.Collection`:
95+
You can use :py:class:`~typing.TypedDict` (Python 3.8+) when using a well-defined schema for the data in a
96+
:class:`~pymongo.collection.Collection`. Note that all `schema validation`_ for inserts and updates is done on the server.
97+
These methods automatically add an "_id" field.
9698

9799
.. doctest::
98100

@@ -105,10 +107,12 @@ You can use :py:class:`~typing.TypedDict` (Python 3.8+) when using a well-define
105107
...
106108
>>> client: MongoClient = MongoClient()
107109
>>> collection: Collection[Movie] = client.test.test
108-
>>> inserted = collection.insert_one({"name": "Jurassic Park", "year": 1993 })
110+
>>> inserted = collection.insert_one(Movie(name="Jurassic Park", year=1993))
109111
>>> result = collection.find_one({"name": "Jurassic Park"})
110112
>>> assert result is not None
111113
>>> assert result["year"] == 1993
114+
>>> # This will not be type checked, despite being present, because it is added by PyMongo.
115+
>>> assert type(result["_id"]) == ObjectId
112116

113117
Typed Database
114118
--------------
@@ -243,3 +247,4 @@ Another example is trying to set a value on a :class:`~bson.raw_bson.RawBSONDocu
243247
.. _limitations in mypy: https://github.com/python/mypy/issues/3737
244248
.. _mypy config: https://mypy.readthedocs.io/en/stable/config_file.html
245249
.. _test_mypy module: https://github.com/mongodb/mongo-python-driver/blob/master/test/test_mypy.py
250+
.. _schema validation: https://www.mongodb.com/docs/manual/core/schema-validation/#when-to-use-schema-validation

pymongo/client_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def _max_time_expired_error(exc):
435435

436436
# From the transactions spec, all the retryable writes errors plus
437437
# WriteConcernFailed.
438-
_UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset(
438+
_UNKNOWN_COMMIT_ERROR_CODES: frozenset = _RETRYABLE_ERROR_CODES | frozenset(
439439
[
440440
64, # WriteConcernFailed
441441
50, # MaxTimeMSExpired

pymongo/collection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
InsertOneResult,
7272
UpdateResult,
7373
)
74-
from pymongo.typings import _CollationIn, _DocumentIn, _DocumentType, _Pipeline
74+
from pymongo.typings import _CollationIn, _DocumentType, _Pipeline
7575
from pymongo.write_concern import WriteConcern
7676

7777
_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1}
@@ -566,7 +566,7 @@ def _insert_command(session, sock_info, retryable_write):
566566

567567
def insert_one(
568568
self,
569-
document: _DocumentIn,
569+
document: Union[_DocumentType, RawBSONDocument],
570570
bypass_document_validation: bool = False,
571571
session: Optional["ClientSession"] = None,
572572
comment: Optional[Any] = None,
@@ -614,7 +614,7 @@ def insert_one(
614614
"""
615615
common.validate_is_document_type("document", document)
616616
if not (isinstance(document, RawBSONDocument) or "_id" in document):
617-
document["_id"] = ObjectId()
617+
document["_id"] = ObjectId() # type: ignore[index]
618618

619619
write_concern = self._write_concern_for(session)
620620
return InsertOneResult(
@@ -633,7 +633,7 @@ def insert_one(
633633
@_csot.apply
634634
def insert_many(
635635
self,
636-
documents: Iterable[_DocumentIn],
636+
documents: Iterable[Union[_DocumentType, RawBSONDocument]],
637637
ordered: bool = True,
638638
bypass_document_validation: bool = False,
639639
session: Optional["ClientSession"] = None,
@@ -697,7 +697,7 @@ def gen():
697697
common.validate_is_document_type("document", document)
698698
if not isinstance(document, RawBSONDocument):
699699
if "_id" not in document:
700-
document["_id"] = ObjectId()
700+
document["_id"] = ObjectId() # type: ignore[index]
701701
inserted_ids.append(document["_id"])
702702
yield (message._INSERT, document)
703703

pymongo/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
# From the SDAM spec, the "not primary" error codes are combined with the
4545
# "node is recovering" error codes (of which the "node is shutting down"
4646
# errors are a subset).
47-
_NOT_PRIMARY_CODES = (
47+
_NOT_PRIMARY_CODES: frozenset = (
4848
frozenset(
4949
[
5050
10058, # LegacyNotPrimary <=3.2 "not primary" error code
@@ -58,7 +58,7 @@
5858
| _SHUTDOWN_CODES
5959
)
6060
# From the retryable writes spec.
61-
_RETRYABLE_ERROR_CODES = _NOT_PRIMARY_CODES | frozenset(
61+
_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset(
6262
[
6363
7, # HostNotFound
6464
6, # HostUnreachable

pymongo/monitoring.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def register(listener: _EventListener) -> None:
528528
# Note - to avoid bugs from forgetting which if these is all lowercase and
529529
# which are camelCase, and at the same time avoid having to add a test for
530530
# every command, use all lowercase here and test against command_name.lower().
531-
_SENSITIVE_COMMANDS = set(
531+
_SENSITIVE_COMMANDS: set = set(
532532
[
533533
"authenticate",
534534
"saslstart",

test/__init__.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,10 @@
4343
HAVE_IPADDRESS = True
4444
except ImportError:
4545
HAVE_IPADDRESS = False
46-
4746
from contextlib import contextmanager
4847
from functools import wraps
4948
from test.version import Version
50-
from typing import Callable, Dict, Generator, no_type_check
49+
from typing import Any, Callable, Dict, Generator, no_type_check
5150
from unittest import SkipTest
5251
from urllib.parse import quote_plus
5352

@@ -331,7 +330,9 @@ def hello(self):
331330

332331
def _connect(self, host, port, **kwargs):
333332
kwargs.update(self.default_client_options)
334-
client = pymongo.MongoClient(host, port, serverSelectionTimeoutMS=5000, **kwargs)
333+
client: MongoClient = pymongo.MongoClient(
334+
host, port, serverSelectionTimeoutMS=5000, **kwargs
335+
)
335336
try:
336337
try:
337338
client.admin.command(HelloCompat.LEGACY_CMD) # Can we connect?
@@ -356,7 +357,7 @@ def _init_client(self):
356357
if self.client is not None:
357358
# Return early when connected to dataLake as mongohoused does not
358359
# support the getCmdLineOpts command and is tested without TLS.
359-
build_info = self.client.admin.command("buildInfo")
360+
build_info: Any = self.client.admin.command("buildInfo")
360361
if "dataLake" in build_info:
361362
self.is_data_lake = True
362363
self.auth_enabled = True
@@ -521,14 +522,16 @@ def has_secondaries(self):
521522
@property
522523
def storage_engine(self):
523524
try:
524-
return self.server_status.get("storageEngine", {}).get("name")
525+
return self.server_status.get("storageEngine", {}).get( # type:ignore[union-attr]
526+
"name"
527+
)
525528
except AttributeError:
526529
# Raised if self.server_status is None.
527530
return None
528531

529532
def _check_user_provided(self):
530533
"""Return True if db_user/db_password is already an admin user."""
531-
client = pymongo.MongoClient(
534+
client: MongoClient = pymongo.MongoClient(
532535
host,
533536
port,
534537
username=db_user,
@@ -694,7 +697,7 @@ def supports_secondary_read_pref(self):
694697
if self.has_secondaries:
695698
return True
696699
if self.is_mongos:
697-
shard = self.client.config.shards.find_one()["host"]
700+
shard = self.client.config.shards.find_one()["host"] # type:ignore[index]
698701
num_members = shard.count(",") + 1
699702
return num_members > 1
700703
return False
@@ -1015,12 +1018,12 @@ def fork(
10151018
"""
10161019

10171020
def _print_threads(*args: object) -> None:
1018-
if _print_threads.called:
1021+
if _print_threads.called: # type:ignore[attr-defined]
10191022
return
1020-
_print_threads.called = True
1023+
_print_threads.called = True # type:ignore[attr-defined]
10211024
print_thread_tracebacks()
10221025

1023-
_print_threads.called = False
1026+
_print_threads.called = False # type:ignore[attr-defined]
10241027

10251028
def _target() -> None:
10261029
signal.signal(signal.SIGUSR1, _print_threads)

test/test_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ def test_insert_many_invalid(self):
785785
db.test.insert_many(1) # type: ignore[arg-type]
786786

787787
with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"):
788-
db.test.insert_many(RawBSONDocument(encode({"_id": 2}))) # type: ignore[arg-type]
788+
db.test.insert_many(RawBSONDocument(encode({"_id": 2})))
789789

790790
def test_delete_one(self):
791791
self.db.test.drop()

test/test_mypy.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,20 @@
1414

1515
"""Test that each file in mypy_fails/ actually fails mypy, and test some
1616
sample client code that uses PyMongo typings."""
17-
1817
import os
1918
import tempfile
2019
import unittest
2120
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List
2221

2322
try:
24-
from typing import TypedDict # type: ignore[attr-defined]
23+
from typing_extensions import TypedDict
2524

26-
# Not available in Python 3.7
2725
class Movie(TypedDict): # type: ignore[misc]
2826
name: str
2927
year: int
3028

3129
except ImportError:
32-
TypeDict = None
30+
TypedDict = None
3331

3432

3533
try:
@@ -304,6 +302,28 @@ def test_typeddict_document_type(self) -> None:
304302
assert retreived["year"] == 1
305303
assert retreived["name"] == "a"
306304

305+
@only_type_check
306+
def test_typeddict_document_type_insertion(self) -> None:
307+
client: MongoClient[Movie] = MongoClient()
308+
coll = client.test.test
309+
mov = {"name": "THX-1138", "year": 1971}
310+
movie = Movie(name="THX-1138", year=1971)
311+
coll.insert_one(mov) # type: ignore[arg-type]
312+
coll.insert_one({"name": "THX-1138", "year": 1971}) # This will work because it is in-line.
313+
coll.insert_one(movie)
314+
coll.insert_many([mov]) # type: ignore[list-item]
315+
coll.insert_many([movie])
316+
bad_mov = {"name": "THX-1138", "year": "WRONG TYPE"}
317+
bad_movie = Movie(name="THX-1138", year="WRONG TYPE") # type: ignore[typeddict-item]
318+
coll.insert_one(bad_mov) # type:ignore[arg-type]
319+
coll.insert_one({"name": "THX-1138", "year": "WRONG TYPE"}) # type: ignore[typeddict-item]
320+
coll.insert_one(bad_movie)
321+
coll.insert_many([bad_mov]) # type: ignore[list-item]
322+
coll.insert_many(
323+
[{"name": "THX-1138", "year": "WRONG TYPE"}] # type: ignore[typeddict-item]
324+
)
325+
coll.insert_many([bad_movie])
326+
307327
@only_type_check
308328
def test_raw_bson_document_type(self) -> None:
309329
client = MongoClient(document_class=RawBSONDocument)

test/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,7 @@ def ensure_all_connected(client: MongoClient) -> None:
601601
Depending on the use-case, the caller may need to clear any event listeners
602602
that are configured on the client.
603603
"""
604-
hello = client.admin.command(HelloCompat.LEGACY_CMD)
604+
hello: dict = client.admin.command(HelloCompat.LEGACY_CMD)
605605
if "setName" not in hello:
606606
raise ConfigurationError("cluster is not a replica set")
607607

@@ -612,7 +612,7 @@ def ensure_all_connected(client: MongoClient) -> None:
612612
def discover():
613613
i = 0
614614
while i < 100 and connected_host_list != target_host_list:
615-
hello = client.admin.command(
615+
hello: dict = client.admin.command(
616616
HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY
617617
)
618618
connected_host_list.update([hello["me"]])

0 commit comments

Comments
 (0)