1+ from __future__ import annotations
2+
3+ from collections .abc import Sequence
4+ from itertools import cycle
5+
6+ import pytest
7+
8+ from xdist .remote import Producer
9+ from xdist .workermanage import parse_tx_spec_config
10+ from xdist .workermanage import WorkerController
11+
12+
13+ class SingleCollectScheduling :
14+ """Implement scheduling with a single test collection phase.
15+
16+ This differs from LoadScheduling by:
17+ 1. Only collecting tests on the first node
18+ 2. Skipping collection on other nodes
19+ 3. Not checking for collection equality
20+
21+ This can significantly improve startup time by avoiding redundant collection
22+ and collection verification across multiple worker processes.
23+ """
24+
25+ def __init__ (self , config : pytest .Config , log : Producer | None = None ) -> None :
26+ self .numnodes = len (parse_tx_spec_config (config ))
27+ self .node2pending : dict [WorkerController , list [int ]] = {}
28+ self .pending : list [int ] = []
29+ self .collection : list [str ] | None = None
30+ self .first_node : WorkerController | None = None
31+ if log is None :
32+ self .log = Producer ("singlecollect" )
33+ else :
34+ self .log = log .singlecollect
35+ self .config = config
36+ self .maxschedchunk = self .config .getoption ("maxschedchunk" )
37+ self .collection_done = False
38+
39+ @property
40+ def nodes (self ) -> list [WorkerController ]:
41+ """A list of all nodes in the scheduler."""
42+ return list (self .node2pending .keys ())
43+
44+ @property
45+ def collection_is_completed (self ) -> bool :
46+ """Return True once we have collected tests from the first node."""
47+ return self .collection_done
48+
49+ @property
50+ def tests_finished (self ) -> bool :
51+ """Return True if all tests have been executed by the nodes."""
52+ if not self .collection_is_completed :
53+ return False
54+ if self .pending :
55+ return False
56+ for pending in self .node2pending .values ():
57+ if len (pending ) >= 2 :
58+ return False
59+ return True
60+
61+ @property
62+ def has_pending (self ) -> bool :
63+ """Return True if there are pending test items."""
64+ if self .pending :
65+ return True
66+ for pending in self .node2pending .values ():
67+ if pending :
68+ return True
69+ return False
70+
71+ def add_node (self , node : WorkerController ) -> None :
72+ """Add a new node to the scheduler."""
73+ assert node not in self .node2pending
74+ self .node2pending [node ] = []
75+
76+ # Remember the first node as our collector
77+ if self .first_node is None :
78+ self .first_node = node
79+ self .log (f"Using { node .gateway .id } as collection node" )
80+
81+ def add_node_collection (
82+ self , node : WorkerController , collection : Sequence [str ]
83+ ) -> None :
84+ """Only use collection from the first node."""
85+ # We only care about collection from the first node
86+ if node == self .first_node :
87+ self .log (f"Received collection from first node { node .gateway .id } " )
88+ self .collection = list (collection )
89+ self .collection_done = True
90+ else :
91+ # Skip collection verification for other nodes
92+ self .log (f"Ignoring collection from node { node .gateway .id } " )
93+
94+ def mark_test_complete (
95+ self , node : WorkerController , item_index : int , duration : float = 0
96+ ) -> None :
97+ """Mark test item as completed by node."""
98+ self .node2pending [node ].remove (item_index )
99+ self .check_schedule (node , duration = duration )
100+
101+ def mark_test_pending (self , item : str ) -> None :
102+ assert self .collection is not None
103+ self .pending .insert (
104+ 0 ,
105+ self .collection .index (item ),
106+ )
107+ for node in self .node2pending :
108+ self .check_schedule (node )
109+
110+ def remove_pending_tests_from_node (
111+ self ,
112+ node : WorkerController ,
113+ indices : Sequence [int ],
114+ ) -> None :
115+ raise NotImplementedError ()
116+
117+ def check_schedule (self , node : WorkerController , duration : float = 0 ) -> None :
118+ """Maybe schedule new items on the node."""
119+ if node .shutting_down :
120+ return
121+
122+ if self .pending :
123+ # how many nodes do we have?
124+ num_nodes = len (self .node2pending )
125+ # if our node goes below a heuristic minimum, fill it out to
126+ # heuristic maximum
127+ items_per_node_min = max (2 , len (self .pending ) // num_nodes // 4 )
128+ items_per_node_max = max (2 , len (self .pending ) // num_nodes // 2 )
129+ node_pending = self .node2pending [node ]
130+ if len (node_pending ) < items_per_node_min :
131+ if duration >= 0.1 and len (node_pending ) >= 2 :
132+ # seems the node is doing long-running tests
133+ # and has enough items to continue
134+ # so let's rather wait with sending new items
135+ return
136+ num_send = items_per_node_max - len (node_pending )
137+ # keep at least 2 tests pending even if --maxschedchunk=1
138+ maxschedchunk = max (2 - len (node_pending ), self .maxschedchunk )
139+ self ._send_tests (node , min (num_send , maxschedchunk ))
140+ else :
141+ node .shutdown ()
142+
143+ self .log ("num items waiting for node:" , len (self .pending ))
144+
145+ def remove_node (self , node : WorkerController ) -> str | None :
146+ """Remove a node from the scheduler."""
147+ pending = self .node2pending .pop (node )
148+
149+ # If this is the first node (collector), reset it
150+ if node == self .first_node :
151+ self .first_node = None
152+
153+ if not pending :
154+ return None
155+
156+ # Reassign pending items if the node had any
157+ assert self .collection is not None
158+ crashitem = self .collection [pending .pop (0 )]
159+ self .pending .extend (pending )
160+ for node in self .node2pending :
161+ self .check_schedule (node )
162+ return crashitem
163+
164+ def schedule (self ) -> None :
165+ """Initiate distribution of the test collection."""
166+ assert self .collection_is_completed
167+
168+ # Initial distribution already happened, reschedule on all nodes
169+ if self .pending :
170+ for node in self .nodes :
171+ self .check_schedule (node )
172+ return
173+
174+ # Initialize the index of pending items
175+ assert self .collection is not None
176+ self .pending [:] = range (len (self .collection ))
177+ if not self .collection :
178+ return
179+
180+ if self .maxschedchunk is None :
181+ self .maxschedchunk = len (self .collection )
182+
183+ # Send a batch of tests to run. If we don't have at least two
184+ # tests per node, we have to send them all so that we can send
185+ # shutdown signals and get all nodes working.
186+ if len (self .pending ) < 2 * len (self .nodes ):
187+ # Distribute tests round-robin
188+ nodes = cycle (self .nodes )
189+ for _ in range (len (self .pending )):
190+ self ._send_tests (next (nodes ), 1 )
191+ else :
192+ # how many items per node do we have about?
193+ items_per_node = len (self .collection ) // len (self .node2pending )
194+ # take a fraction of tests for initial distribution
195+ node_chunksize = min (items_per_node // 4 , self .maxschedchunk )
196+ node_chunksize = max (node_chunksize , 2 )
197+ # and initialize each node with a chunk of tests
198+ for node in self .nodes :
199+ self ._send_tests (node , node_chunksize )
200+
201+ if not self .pending :
202+ # initial distribution sent all tests, start node shutdown
203+ for node in self .nodes :
204+ node .shutdown ()
205+
206+ def _send_tests (self , node : WorkerController , num : int ) -> None :
207+ tests_per_node = self .pending [:num ]
208+ if tests_per_node :
209+ del self .pending [:num ]
210+ self .node2pending [node ].extend (tests_per_node )
211+ node .send_runtest_some (tests_per_node )
0 commit comments