88
99from __future__ import annotations
1010
11+ import collections
1112import contextlib
1213import enum
1314import os
1415import sys
1516import time
1617from typing import Any
1718from typing import Generator
19+ from typing import Iterable
1820from typing import Literal
1921from typing import Sequence
2022from typing import TypedDict
23+ from typing import Union
2124import warnings
2225
2326from _pytest .config import _prepareconfig
@@ -66,7 +69,44 @@ def worker_title(title: str) -> None:
6669
6770class Marker (enum .Enum ):
6871 SHUTDOWN = 0
69- QUEUE_REPLACED = 1
72+
73+
74+ class TestQueue :
75+ """A simple queue that can be inspected and modified while the lock is held via the ``lock()`` method."""
76+
77+ Item = Union [int , Literal [Marker .SHUTDOWN ]]
78+
79+ def __init__ (self , execmodel : execnet .gateway_base .ExecModel ):
80+ self ._items : collections .deque [TestQueue .Item ] = collections .deque ()
81+ self ._lock = execmodel .RLock () # type: ignore[no-untyped-call]
82+ self ._has_items_event = execmodel .Event ()
83+
84+ def get (self ) -> Item :
85+ while True :
86+ with self .lock () as locked_items :
87+ if locked_items :
88+ return locked_items .popleft ()
89+
90+ self ._has_items_event .wait ()
91+
92+ def put (self , item : Item ) -> None :
93+ with self .lock () as locked_items :
94+ locked_items .append (item )
95+
96+ def replace (self , iterable : Iterable [Item ]) -> None :
97+ with self .lock ():
98+ self ._items = collections .deque (iterable )
99+
100+ @contextlib .contextmanager
101+ def lock (self ) -> Generator [collections .deque [Item ], None , None ]:
102+ with self ._lock :
103+ try :
104+ yield self ._items
105+ finally :
106+ if self ._items :
107+ self ._has_items_event .set ()
108+ else :
109+ self ._has_items_event .clear ()
70110
71111
72112class WorkerInteractor :
@@ -77,22 +117,10 @@ def __init__(self, config: pytest.Config, channel: execnet.Channel) -> None:
77117 self .testrunuid = workerinput ["testrunuid" ]
78118 self .log = Producer (f"worker-{ self .workerid } " , enabled = config .option .debug )
79119 self .channel = channel
80- self .torun = self ._make_queue ( )
120+ self .torun = TestQueue ( self .channel . gateway . execmodel )
81121 self .nextitem_index : int | None | Literal [Marker .SHUTDOWN ] = None
82122 config .pluginmanager .register (self )
83123
84- def _make_queue (self ) -> Any :
85- return self .channel .gateway .execmodel .queue .Queue ()
86-
87- def _get_next_item_index (self ) -> int | Literal [Marker .SHUTDOWN ]:
88- """Gets the next item from test queue. Handles the case when the queue
89- is replaced concurrently in another thread.
90- """
91- result = self .torun .get ()
92- while result is Marker .QUEUE_REPLACED :
93- result = self .torun .get ()
94- return result # type: ignore[no-any-return]
95-
96124 def sendevent (self , name : str , ** kwargs : object ) -> None :
97125 self .log ("sending" , name , kwargs )
98126 self .channel .send ((name , kwargs ))
@@ -146,30 +174,34 @@ def handle_command(
146174 self .steal (kwargs ["indices" ])
147175
148176 def steal (self , indices : Sequence [int ]) -> None :
149- indices_set = set ( indices )
150- stolen = []
177+ """
178+ Remove tests from the queue.
151179
152- old_queue , self .torun = self .torun , self ._make_queue ()
180+ Removes either all requested tests, or none, if some of these tests
181+ are not in the queue (for example, if they were processed already).
153182
154- def old_queue_get_nowait_noraise () -> int | None :
155- with contextlib .suppress (self .channel .gateway .execmodel .queue .Empty ):
156- return old_queue .get_nowait () # type: ignore[no-any-return]
157- return None
183+ :param indices: indices of the tests to remove.
184+ """
185+ requested_set = set (indices )
186+
187+ with self .torun .lock () as locked_queue :
188+ stolen = list (item for item in locked_queue if item in requested_set )
158189
159- for i in iter (old_queue_get_nowait_noraise , None ):
160- if i in indices_set :
161- stolen .append (i )
190+ # Stealing only if all requested tests are still pending
191+ if len (stolen ) == len (requested_set ):
192+ self .torun .replace (
193+ item for item in locked_queue if item not in requested_set
194+ )
162195 else :
163- self . torun . put ( i )
196+ stolen = []
164197
165198 self .sendevent ("unscheduled" , indices = stolen )
166- old_queue .put (Marker .QUEUE_REPLACED )
167199
168200 @pytest .hookimpl
169201 def pytest_runtestloop (self , session : pytest .Session ) -> bool :
170202 self .log ("entering main loop" )
171203 self .channel .setcallback (self .handle_command , endmarker = Marker .SHUTDOWN )
172- self .nextitem_index = self ._get_next_item_index ()
204+ self .nextitem_index = self .torun . get ()
173205 while self .nextitem_index is not Marker .SHUTDOWN :
174206 self .run_one_test ()
175207 if session .shouldfail or session .shouldstop :
@@ -179,7 +211,7 @@ def pytest_runtestloop(self, session: pytest.Session) -> bool:
179211 def run_one_test (self ) -> None :
180212 assert isinstance (self .nextitem_index , int )
181213 self .item_index = self .nextitem_index
182- self .nextitem_index = self ._get_next_item_index ()
214+ self .nextitem_index = self .torun . get ()
183215
184216 items = self .session .items
185217 item = items [self .item_index ]
0 commit comments