Skip to content

Commit 51be4ea

Browse files
nipung90meta-codesync[bot]
authored andcommitted
Add more exception handling and support for class constructors (#3517)
Summary: Pull Request resolved: #3517 This diff adds a few features to the static logger 1) Better exception handling to make sure any failures are silent and don't affect the actual training job. 2) Support for handling class constructors. This includes changing the function name to include the name of the class for constructors and moving the collection of inputs to after the function call succeeds. 3) Moved the static logger enablement stricts to a configerator file instead of using justknobs metadata. The justknobs will only be used as a kill switch now Reviewed By: kausv, saumishr Differential Revision: D76294357 fbshipit-source-id: 69e5db646e8557d7864647901773f52bbffb0e12
1 parent 14473ce commit 51be4ea

File tree

3 files changed

+128
-33
lines changed

3 files changed

+128
-33
lines changed

torchrec/distributed/logger.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
# mypy: allow-untyped-defs
99
import functools
1010
import inspect
11-
from typing import Any, Callable, TypeVar
11+
import logging
12+
from typing import Any, Callable, Dict, TypeVar
1213

1314
import torchrec.distributed.torchrec_logger as torchrec_logger
1415
from torchrec.distributed.torchrec_logging_handlers import TORCHREC_LOGGER_NAME
@@ -34,16 +35,25 @@ def decorator(func: Callable[_P, _T]): # pyre-ignore
3435
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
3536
msg_dict = torchrec_logger._get_msg_dict(func.__name__, **kwargs)
3637
try:
37-
## Add function input to log message
38-
msg_dict["input"] = _get_input_from_func(func, *args, **kwargs)
3938
# exceptions
4039
result = func(*args, **kwargs)
4140
except BaseException as error:
4241
msg_dict["error"] = f"{error}"
42+
## Add function input to log message
43+
msg_dict["input"] = _get_input_from_func(
44+
func, msg_dict, *args, **kwargs
45+
)
4346
_torchrec_logger.error(msg_dict)
4447
raise
45-
msg_dict["output"] = str(result)
46-
_torchrec_logger.debug(msg_dict)
48+
## Add function input to log message
49+
try:
50+
msg_dict["input"] = _get_input_from_func(
51+
func, msg_dict, *args, **kwargs
52+
)
53+
msg_dict["output"] = str(result)
54+
_torchrec_logger.debug(msg_dict)
55+
except Exception as error:
56+
logging.info(f"Torchrec logger: Failed in static logger: {error}")
4757
return result
4858

4959
return wrapper
@@ -52,15 +62,29 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
5262

5363

5464
def _get_input_from_func(
55-
func: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs
65+
func: Callable[_P, _T],
66+
msg_dict: Dict[str, Any],
67+
*args: _P.args,
68+
**kwargs: _P.kwargs,
5669
) -> str:
57-
signature = inspect.signature(func)
58-
bound_args = signature.bind_partial(*args, **kwargs)
59-
bound_args.apply_defaults()
60-
input_vars = {param.name: param.default for param in signature.parameters.values()}
61-
for key, value in bound_args.arguments.items():
62-
if isinstance(value, (int, float)):
63-
input_vars[key] = value
64-
else:
65-
input_vars[key] = str(value)
66-
return str(input_vars)
70+
try:
71+
signature = inspect.signature(func)
72+
bound_args = signature.bind_partial(*args, **kwargs)
73+
bound_args.apply_defaults()
74+
input_vars = {
75+
param.name: param.default for param in signature.parameters.values()
76+
}
77+
for key, value in bound_args.arguments.items():
78+
if key == "self" and func.__name__ == "__init__":
79+
# Add class name to function name if the function is a constructor
80+
msg_dict["func_name"] = (
81+
f"{value.__class__.__name__}.{msg_dict['func_name']}"
82+
)
83+
if isinstance(value, (int, float)):
84+
input_vars[key] = value
85+
else:
86+
input_vars[key] = str(value)
87+
return str(input_vars)
88+
except Exception as error:
89+
logging.error(f"Torchrec Logger: Error in _get_input_from_func: {error}")
90+
return "Error in _get_input_from_func: " + str(error)

torchrec/distributed/tests/test_logger.py

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,12 @@ def setUp(self) -> None:
2323
"torchrec.distributed.torchrec_logger._get_msg_dict"
2424
)
2525
self.mock_get_msg_dict = self.get_msg_dict_patcher.start()
26-
self.mock_get_msg_dict.return_value = {}
26+
27+
# Return a dictionary with func_name that can be modified by _get_input_from_func
28+
def mock_get_msg_dict_impl(func_name: str, **kwargs: Any) -> dict[str, Any]:
29+
return {"func_name": func_name}
30+
31+
self.mock_get_msg_dict.side_effect = mock_get_msg_dict_impl
2732

2833
# Mock _torchrec_logger
2934
self.logger_patcher = mock.patch("torchrec.distributed.logger._torchrec_logger")
@@ -40,7 +45,8 @@ def test_get_input_from_func_no_args(self) -> None:
4045
def test_func() -> None:
4146
pass
4247

43-
result = _get_input_from_func(test_func)
48+
msg_dict = {"func_name": "test_func"}
49+
result = _get_input_from_func(test_func, msg_dict)
4450
self.assertEqual(result, "{}")
4551

4652
def test_get_input_from_func_with_args(self) -> None:
@@ -49,7 +55,8 @@ def test_get_input_from_func_with_args(self) -> None:
4955
def test_func(_a: int, _b: str) -> None:
5056
pass
5157

52-
result = _get_input_from_func(test_func, 42, "hello")
58+
msg_dict = {"func_name": "test_func"}
59+
result = _get_input_from_func(test_func, msg_dict, 42, "hello")
5360
self.assertEqual(result, "{'_a': 42, '_b': 'hello'}")
5461

5562
def test_get_input_from_func_with_kwargs(self) -> None:
@@ -58,7 +65,8 @@ def test_get_input_from_func_with_kwargs(self) -> None:
5865
def test_func(_a: int = 0, _b: str = "default") -> None:
5966
pass
6067

61-
result = _get_input_from_func(test_func, _b="world")
68+
msg_dict = {"func_name": "test_func"}
69+
result = _get_input_from_func(test_func, msg_dict, _b="world")
6270
self.assertEqual(result, "{'_a': 0, '_b': 'world'}")
6371

6472
def test_get_input_from_func_with_args_and_kwargs(self) -> None:
@@ -69,11 +77,19 @@ def test_func(
6977
) -> None:
7078
pass
7179

72-
result = _get_input_from_func(test_func, 42, "hello", "extra", key="value")
73-
self.assertEqual(
74-
result,
75-
"{'_a': 42, '_b': 'hello', '_args': \"('extra',)\", '_kwargs': \"{'key': 'value'}\"}",
80+
msg_dict = {"func_name": "test_func"}
81+
result = _get_input_from_func(
82+
test_func, msg_dict, 42, "hello", "extra", key="value"
7683
)
84+
self.assertIn("_a", result)
85+
self.assertIn("42", result)
86+
self.assertIn("_b", result)
87+
self.assertIn("hello", result)
88+
self.assertIn("_args", result)
89+
self.assertIn("extra", result)
90+
self.assertIn("_kwargs", result)
91+
self.assertIn("key", result)
92+
self.assertIn("value", result)
7793

7894
def test_torchrec_method_logger_success(self) -> None:
7995
"""Test _torchrec_method_logger with a successful function execution when logging is enabled."""
@@ -142,6 +158,57 @@ def test_torchrec_method_logger_with_wrapper_kwargs(self) -> None:
142158
msg_dict = self.mock_logger.debug.call_args[0][0]
143159
self.assertEqual(msg_dict["output"], "result")
144160

161+
def test_torchrec_method_logger_constructor_with_args(self) -> None:
162+
"""Test _torchrec_method_logger with a class constructor that has positional arguments."""
163+
164+
class TestClass:
165+
@_torchrec_method_logger()
166+
def __init__(self, _a: int, _b: str) -> None:
167+
pass
168+
169+
# Create an instance which will call __init__
170+
_ = TestClass(42, "hello")
171+
172+
# Verify that the logger was called
173+
self.mock_logger.debug.assert_called_once()
174+
msg_dict = self.mock_logger.debug.call_args[0][0]
175+
# Verify that class name was prepended to function name
176+
self.assertEqual(msg_dict["func_name"], "TestClass.__init__")
177+
# Verify the input contains the arguments
178+
self.assertIn("_a", msg_dict["input"])
179+
self.assertIn("42", msg_dict["input"])
180+
self.assertIn("_b", msg_dict["input"])
181+
self.assertIn("hello", msg_dict["input"])
182+
183+
def test_torchrec_method_logger_constructor_with_args_and_kwargs(self) -> None:
184+
"""Test _torchrec_method_logger with a class constructor that has both positional and keyword arguments."""
185+
186+
class TestClass:
187+
@_torchrec_method_logger()
188+
def __init__(
189+
self, _a: int, _b: str = "default", *_args: Any, **_kwargs: Any
190+
) -> None:
191+
pass
192+
193+
# Create an instance which will call __init__
194+
_ = TestClass(42, "hello", "extra", key="value")
195+
196+
# Verify that the logger was called
197+
self.mock_logger.debug.assert_called_once()
198+
msg_dict = self.mock_logger.debug.call_args[0][0]
199+
# Verify that class name was prepended to function name
200+
self.assertEqual(msg_dict["func_name"], "TestClass.__init__")
201+
# Verify the input contains the arguments
202+
self.assertIn("_a", msg_dict["input"])
203+
self.assertIn("42", msg_dict["input"])
204+
self.assertIn("_b", msg_dict["input"])
205+
self.assertIn("hello", msg_dict["input"])
206+
self.assertIn("_args", msg_dict["input"])
207+
self.assertIn("extra", msg_dict["input"])
208+
self.assertIn("_kwargs", msg_dict["input"])
209+
self.assertIn("key", msg_dict["input"])
210+
self.assertIn("value", msg_dict["input"])
211+
145212

146213
if __name__ == "__main__":
147214
unittest.main()

torchrec/distributed/torchrec_logger.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,16 @@ def _get_logging_handler(
4242

4343

4444
def _get_msg_dict(func_name: str, **kwargs: Any) -> dict[str, Any]:
45-
msg_dict = {
46-
"func_name": f"{func_name}",
47-
}
48-
if dist.is_initialized():
49-
group = kwargs.get("group") or kwargs.get("process_group")
50-
msg_dict["group"] = f"{group}"
51-
msg_dict["world_size"] = f"{dist.get_world_size(group)}"
52-
msg_dict["rank"] = f"{dist.get_rank(group)}"
53-
return msg_dict
45+
try:
46+
msg_dict = {
47+
"func_name": f"{func_name}",
48+
}
49+
if dist.is_initialized():
50+
group = kwargs.get("group") or kwargs.get("process_group")
51+
msg_dict["group"] = f"{group}"
52+
msg_dict["world_size"] = f"{dist.get_world_size(group)}"
53+
msg_dict["rank"] = f"{dist.get_rank(group)}"
54+
return msg_dict
55+
except Exception as error:
56+
logging.error(f"Torchrec Logger: Error in _get_msg_dict: {error}")
57+
return {"_get_msg_dict_error": str(error)}

0 commit comments

Comments
 (0)