1- # mypy: allow-untyped-defs
2-
3- import time
4- import threading
5-
6- from . import mpcontext
7-
81"""Instrumentation for measuring high-level time spent on various tasks inside the runner.
92
103This is lower fidelity than an actual profile, but allows custom data to be considered,
2619 do_teardown()
2720"""
2821
29- class NullInstrument :
30- def set (self , stack ):
22+ from __future__ import annotations
23+
24+ import threading
25+ import time
26+ from abc import ABCMeta , abstractmethod
27+ from typing import TYPE_CHECKING , Iterable , Sequence
28+
29+ from . import mpcontext
30+
31+ if TYPE_CHECKING :
32+ import multiprocessing
33+ import sys
34+ from multiprocessing .process import BaseProcess
35+ from types import TracebackType
36+
37+ if sys .version_info >= (3 , 10 ):
38+ from typing import TypeAlias
39+ else :
40+ from typing_extensions import TypeAlias
41+
42+ if sys .version_info >= (3 , 11 ):
43+ from typing import Self
44+ else :
45+ from typing_extensions import Self
46+
47+
48+ class AbstractInstrument (metaclass = ABCMeta ):
49+ @abstractmethod
50+ def __enter__ (self ) -> AbstractInstrumentHandler :
51+ ...
52+
53+ @abstractmethod
54+ def __exit__ (
55+ self ,
56+ exc_type : type [BaseException ] | None ,
57+ exc_val : BaseException | None ,
58+ exc_tb : TracebackType | None ,
59+ ) -> None :
60+ ...
61+
62+
63+ class AbstractInstrumentHandler (metaclass = ABCMeta ):
64+ @abstractmethod
65+ def set (self , stack : Sequence [str ]) -> None :
66+ """Set the current task to stack
67+
68+ :param stack: A list of strings defining the current task.
69+ These are interpreted like a stack trace so that ["foo"] and
70+ ["foo", "bar"] both show up as descendants of "foo"
71+ """
72+ ...
73+
74+ @abstractmethod
75+ def pause (self ) -> None :
76+ """Stop recording a task on the current thread. This is useful if the thread
77+ is purely waiting on the results of other threads"""
78+ ...
79+
80+
81+ class NullInstrument (AbstractInstrument , AbstractInstrumentHandler ):
82+ def set (self , stack : Sequence [str ]) -> None :
3183 """Set the current task to stack
3284
3385 :param stack: A list of strings defining the current task.
@@ -36,37 +88,47 @@ def set(self, stack):
3688 """
3789 pass
3890
39- def pause (self ):
91+ def pause (self ) -> None :
4092 """Stop recording a task on the current thread. This is useful if the thread
4193 is purely waiting on the results of other threads"""
4294 pass
4395
44- def __enter__ (self ):
96+ def __enter__ (self ) -> Self :
4597 return self
4698
47- def __exit__ (self , * args , ** kwargs ):
99+ def __exit__ (
100+ self ,
101+ exc_type : type [BaseException ] | None ,
102+ exc_val : BaseException | None ,
103+ exc_tb : TracebackType | None ,
104+ ) -> None :
48105 return
49106
50107
51- class InstrumentWriter :
52- def __init__ (self , queue ):
108+ _InstrumentQueue : TypeAlias = "multiprocessing.Queue[tuple[str, int | None, float, Sequence[str] | None]]"
109+
110+
111+ class InstrumentWriter (AbstractInstrumentHandler ):
112+ def __init__ (
113+ self ,
114+ queue : _InstrumentQueue ,
115+ ) -> None :
53116 self .queue = queue
54117
55- def set (self , stack ) :
56- stack . insert ( 0 , threading .current_thread ().name )
118+ def set (self , stack : Sequence [ str ]) -> None :
119+ stack = [ threading .current_thread ().name , * stack ]
57120 stack = self ._check_stack (stack )
58121 self .queue .put (("set" , threading .current_thread ().ident , time .time (), stack ))
59122
60- def pause (self ):
123+ def pause (self ) -> None :
61124 self .queue .put (("pause" , threading .current_thread ().ident , time .time (), None ))
62125
63- def _check_stack (self , stack ):
64- assert isinstance (stack , (tuple , list ))
126+ def _check_stack (self , stack : Sequence [str ]) -> Sequence [str ]:
65127 return [item .replace (" " , "_" ) for item in stack ]
66128
67129
68- class Instrument :
69- def __init__ (self , file_path ) :
130+ class Instrument ( AbstractInstrument ) :
131+ def __init__ (self , file_path : str ) -> None :
70132 """Instrument that collects data from multiple threads and sums the time in each
71133 thread. The output is in the format required by flamegraph.pl to enable visualisation
72134 of the time spent in each task.
@@ -75,12 +137,10 @@ def __init__(self, file_path):
75137 at the path will be overwritten
76138 """
77139 self .path = file_path
78- self .queue = None
79- self .current = None
80- self .start_time = None
81- self .instrument_proc = None
140+ self .queue : _InstrumentQueue | None = None
141+ self .instrument_proc : BaseProcess | None = None
82142
83- def __enter__ (self ):
143+ def __enter__ (self ) -> InstrumentWriter :
84144 assert self .instrument_proc is None
85145 assert self .queue is None
86146 mp = mpcontext .get_context ()
@@ -89,16 +149,24 @@ def __enter__(self):
89149 self .instrument_proc .start ()
90150 return InstrumentWriter (self .queue )
91151
92- def __exit__ (self , * args , ** kwargs ):
152+ def __exit__ (
153+ self ,
154+ exc_type : type [BaseException ] | None ,
155+ exc_val : BaseException | None ,
156+ exc_tb : TracebackType | None ,
157+ ) -> None :
158+ assert self .instrument_proc is not None
159+ assert self .queue is not None
93160 self .queue .put (("stop" , None , time .time (), None ))
94161 self .instrument_proc .join ()
95162 self .instrument_proc = None
96163 self .queue = None
97164
98- def run (self ):
165+ def run (self ) -> None :
166+ assert self .queue is not None
99167 known_commands = {"stop" , "pause" , "set" }
100168 with open (self .path , "w" ) as f :
101- thread_data = {}
169+ thread_data : dict [ int | None , tuple [ Sequence [ str ], float ]] = {}
102170 while True :
103171 command , thread , time_stamp , stack = self .queue .get ()
104172 assert command in known_commands
@@ -107,15 +175,24 @@ def run(self):
107175 # before exiting. Otherwise for either 'set' or 'pause' we only need to dump
108176 # information from the current stack (if any) that was recording on the reporting
109177 # thread (as that stack is no longer active).
110- items = [ ]
178+ items : Iterable [ tuple [ Sequence [ str ], float ] ]
111179 if command == "stop" :
112180 items = thread_data .values ()
113181 elif thread in thread_data :
114- items .append (thread_data .pop (thread ))
182+ items = [thread_data .pop (thread )]
183+ else :
184+ items = []
115185 for output_stack , start_time in items :
116- f .write ("%s %d\n " % (";" .join (output_stack ), int (1000 * (time_stamp - start_time ))))
186+ f .write (
187+ "%s %d\n "
188+ % (
189+ ";" .join (output_stack ),
190+ int (1000 * (time_stamp - start_time )),
191+ )
192+ )
117193
118194 if command == "set" :
195+ assert stack is not None
119196 thread_data [thread ] = (stack , time_stamp )
120197 elif command == "stop" :
121198 break
0 commit comments