|
26 | 26 | from typing import ( |
27 | 27 | TYPE_CHECKING, |
28 | 28 | Any, |
| 29 | + Generator, |
29 | 30 | Iterator, |
30 | 31 | Mapping, |
31 | 32 | Optional, |
|
72 | 73 | from pymongo.write_concern import WriteConcern |
73 | 74 |
|
74 | 75 | if TYPE_CHECKING: |
75 | | - from pymongo.asynchronous.collection import AsyncCollection |
| 76 | + from pymongo.asynchronous.collection import AsyncCollection, _WriteOp |
76 | 77 | from pymongo.asynchronous.mongo_client import AsyncMongoClient |
77 | 78 | from pymongo.asynchronous.pool import AsyncConnection |
78 | 79 | from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline |
@@ -214,28 +215,45 @@ def add_delete( |
214 | 215 | self.is_retryable = False |
215 | 216 | self.ops.append((_DELETE, cmd)) |
216 | 217 |
|
217 | | - def gen_ordered(self) -> Iterator[Optional[_Run]]: |
| 218 | + def gen_ordered(self, requests) -> Iterator[Optional[_Run]]: |
218 | 219 | """Generate batches of operations, batched by type of |
219 | 220 | operation, in the order **provided**. |
220 | 221 | """ |
221 | 222 | run = None |
222 | | - for idx, (op_type, operation) in enumerate(self.ops): |
| 223 | + for idx, request in enumerate(requests): |
| 224 | + try: |
| 225 | + request._add_to_bulk(self) |
| 226 | + except AttributeError: |
| 227 | + raise TypeError(f"{request!r} is not a valid request") from None |
| 228 | + (op_type, operation) = self.ops[idx] |
223 | 229 | if run is None: |
224 | 230 | run = _Run(op_type) |
225 | 231 | elif run.op_type != op_type: |
226 | 232 | yield run |
227 | 233 | run = _Run(op_type) |
228 | 234 | run.add(idx, operation) |
| 235 | + if run is None: |
| 236 | + raise InvalidOperation("No operations to execute") |
229 | 237 | yield run |
230 | 238 |
|
231 | | - def gen_unordered(self) -> Iterator[_Run]: |
| 239 | + def gen_unordered(self, requests) -> Iterator[_Run]: |
232 | 240 | """Generate batches of operations, batched by type of |
233 | 241 | operation, in arbitrary order. |
234 | 242 | """ |
235 | 243 | operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] |
236 | | - for idx, (op_type, operation) in enumerate(self.ops): |
| 244 | + for idx, request in enumerate(requests): |
| 245 | + try: |
| 246 | + request._add_to_bulk(self) |
| 247 | + except AttributeError: |
| 248 | + raise TypeError(f"{request!r} is not a valid request") from None |
| 249 | + (op_type, operation) = self.ops[idx] |
237 | 250 | operations[op_type].add(idx, operation) |
238 | | - |
| 251 | + if ( |
| 252 | + len(operations[_INSERT].ops) == 0 |
| 253 | + and len(operations[_UPDATE].ops) == 0 |
| 254 | + and len(operations[_DELETE].ops) == 0 |
| 255 | + ): |
| 256 | + raise InvalidOperation("No operations to execute") |
239 | 257 | for run in operations: |
240 | 258 | if run.ops: |
241 | 259 | yield run |
@@ -726,23 +744,22 @@ async def execute_no_results( |
726 | 744 |
|
727 | 745 | async def execute( |
728 | 746 | self, |
| 747 | + generator: Generator[_WriteOp[_DocumentType]], |
729 | 748 | write_concern: WriteConcern, |
730 | 749 | session: Optional[AsyncClientSession], |
731 | 750 | operation: str, |
732 | 751 | ) -> Any: |
733 | 752 | """Execute operations.""" |
734 | | - if not self.ops: |
735 | | - raise InvalidOperation("No operations to execute") |
736 | 753 | if self.executed: |
737 | 754 | raise InvalidOperation("Bulk operations can only be executed once.") |
738 | 755 | self.executed = True |
739 | 756 | write_concern = write_concern or self.collection.write_concern |
740 | 757 | session = _validate_session_write_concern(session, write_concern) |
741 | 758 |
|
742 | 759 | if self.ordered: |
743 | | - generator = self.gen_ordered() |
| 760 | + generator = self.gen_ordered(generator) |
744 | 761 | else: |
745 | | - generator = self.gen_unordered() |
| 762 | + generator = self.gen_unordered(generator) |
746 | 763 |
|
747 | 764 | client = self.collection.database.client |
748 | 765 | if not write_concern.acknowledged: |
|
0 commit comments