@@ -23,7 +23,12 @@ def setUp(self) -> None:
2323 "torchrec.distributed.torchrec_logger._get_msg_dict"
2424 )
2525 self .mock_get_msg_dict = self .get_msg_dict_patcher .start ()
26- self .mock_get_msg_dict .return_value = {}
26+
27+ # Return a dictionary with func_name that can be modified by _get_input_from_func
28+ def mock_get_msg_dict_impl (func_name : str , ** kwargs : Any ) -> dict [str , Any ]:
29+ return {"func_name" : func_name }
30+
31+ self .mock_get_msg_dict .side_effect = mock_get_msg_dict_impl
2732
2833 # Mock _torchrec_logger
2934 self .logger_patcher = mock .patch ("torchrec.distributed.logger._torchrec_logger" )
@@ -40,7 +45,8 @@ def test_get_input_from_func_no_args(self) -> None:
4045 def test_func () -> None :
4146 pass
4247
43- result = _get_input_from_func (test_func )
48+ msg_dict = {"func_name" : "test_func" }
49+ result = _get_input_from_func (test_func , msg_dict )
4450 self .assertEqual (result , "{}" )
4551
4652 def test_get_input_from_func_with_args (self ) -> None :
@@ -49,7 +55,8 @@ def test_get_input_from_func_with_args(self) -> None:
4955 def test_func (_a : int , _b : str ) -> None :
5056 pass
5157
52- result = _get_input_from_func (test_func , 42 , "hello" )
58+ msg_dict = {"func_name" : "test_func" }
59+ result = _get_input_from_func (test_func , msg_dict , 42 , "hello" )
5360 self .assertEqual (result , "{'_a': 42, '_b': 'hello'}" )
5461
5562 def test_get_input_from_func_with_kwargs (self ) -> None :
@@ -58,7 +65,8 @@ def test_get_input_from_func_with_kwargs(self) -> None:
5865 def test_func (_a : int = 0 , _b : str = "default" ) -> None :
5966 pass
6067
61- result = _get_input_from_func (test_func , _b = "world" )
68+ msg_dict = {"func_name" : "test_func" }
69+ result = _get_input_from_func (test_func , msg_dict , _b = "world" )
6270 self .assertEqual (result , "{'_a': 0, '_b': 'world'}" )
6371
6472 def test_get_input_from_func_with_args_and_kwargs (self ) -> None :
@@ -69,11 +77,19 @@ def test_func(
6977 ) -> None :
7078 pass
7179
72- result = _get_input_from_func (test_func , 42 , "hello" , "extra" , key = "value" )
73- self .assertEqual (
74- result ,
75- "{'_a': 42, '_b': 'hello', '_args': \" ('extra',)\" , '_kwargs': \" {'key': 'value'}\" }" ,
80+ msg_dict = {"func_name" : "test_func" }
81+ result = _get_input_from_func (
82+ test_func , msg_dict , 42 , "hello" , "extra" , key = "value"
7683 )
84+ self .assertIn ("_a" , result )
85+ self .assertIn ("42" , result )
86+ self .assertIn ("_b" , result )
87+ self .assertIn ("hello" , result )
88+ self .assertIn ("_args" , result )
89+ self .assertIn ("extra" , result )
90+ self .assertIn ("_kwargs" , result )
91+ self .assertIn ("key" , result )
92+ self .assertIn ("value" , result )
7793
7894 def test_torchrec_method_logger_success (self ) -> None :
7995 """Test _torchrec_method_logger with a successful function execution when logging is enabled."""
@@ -142,6 +158,57 @@ def test_torchrec_method_logger_with_wrapper_kwargs(self) -> None:
142158 msg_dict = self .mock_logger .debug .call_args [0 ][0 ]
143159 self .assertEqual (msg_dict ["output" ], "result" )
144160
161+ def test_torchrec_method_logger_constructor_with_args (self ) -> None :
162+ """Test _torchrec_method_logger with a class constructor that has positional arguments."""
163+
164+ class TestClass :
165+ @_torchrec_method_logger ()
166+ def __init__ (self , _a : int , _b : str ) -> None :
167+ pass
168+
169+ # Create an instance which will call __init__
170+ _ = TestClass (42 , "hello" )
171+
172+ # Verify that the logger was called
173+ self .mock_logger .debug .assert_called_once ()
174+ msg_dict = self .mock_logger .debug .call_args [0 ][0 ]
175+ # Verify that class name was prepended to function name
176+ self .assertEqual (msg_dict ["func_name" ], "TestClass.__init__" )
177+ # Verify the input contains the arguments
178+ self .assertIn ("_a" , msg_dict ["input" ])
179+ self .assertIn ("42" , msg_dict ["input" ])
180+ self .assertIn ("_b" , msg_dict ["input" ])
181+ self .assertIn ("hello" , msg_dict ["input" ])
182+
183+ def test_torchrec_method_logger_constructor_with_args_and_kwargs (self ) -> None :
184+ """Test _torchrec_method_logger with a class constructor that has both positional and keyword arguments."""
185+
186+ class TestClass :
187+ @_torchrec_method_logger ()
188+ def __init__ (
189+ self , _a : int , _b : str = "default" , * _args : Any , ** _kwargs : Any
190+ ) -> None :
191+ pass
192+
193+ # Create an instance which will call __init__
194+ _ = TestClass (42 , "hello" , "extra" , key = "value" )
195+
196+ # Verify that the logger was called
197+ self .mock_logger .debug .assert_called_once ()
198+ msg_dict = self .mock_logger .debug .call_args [0 ][0 ]
199+ # Verify that class name was prepended to function name
200+ self .assertEqual (msg_dict ["func_name" ], "TestClass.__init__" )
201+ # Verify the input contains the arguments
202+ self .assertIn ("_a" , msg_dict ["input" ])
203+ self .assertIn ("42" , msg_dict ["input" ])
204+ self .assertIn ("_b" , msg_dict ["input" ])
205+ self .assertIn ("hello" , msg_dict ["input" ])
206+ self .assertIn ("_args" , msg_dict ["input" ])
207+ self .assertIn ("extra" , msg_dict ["input" ])
208+ self .assertIn ("_kwargs" , msg_dict ["input" ])
209+ self .assertIn ("key" , msg_dict ["input" ])
210+ self .assertIn ("value" , msg_dict ["input" ])
211+
145212
146213if __name__ == "__main__" :
147214 unittest .main ()
0 commit comments