File tree Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Expand file tree Collapse file tree 3 files changed +22
-2
lines changed Original file line number Diff line number Diff line change 5050 auto_unwrap_transformed_env ,
5151 compile_with_warmup ,
5252 implement_for ,
53+ logger ,
5354 set_auto_unwrap_transformed_env ,
5455 timeit ,
5556)
5657
58+ torchrl_logger = logger
59+
5760# Filter warnings in subprocesses: True by default given the multiple optional
5861# deps of the library. This can be turned on via `torchrl.filter_warnings_subprocess = False`.
5962filter_warnings_subprocess = True
@@ -108,4 +111,6 @@ def _inv(self):
108111 "implement_for" ,
109112 "set_auto_unwrap_transformed_env" ,
110113 "timeit" ,
114+ "logger" ,
115+ "torchrl_logger" ,
111116]
Original file line number Diff line number Diff line change @@ -1122,3 +1122,17 @@ def auto_unwrap_transformed_env(allow_none=False):
11221122 elif _AUTO_UNWRAP is None :
11231123 return _DEFAULT_AUTO_UNWRAP
11241124 return strtobool (_AUTO_UNWRAP ) if isinstance (_AUTO_UNWRAP , str ) else _AUTO_UNWRAP
1125+
1126+
1127+ def safe_is_current_stream_capturing ():
1128+ """A safe proxy to torch.cuda.is_current_stream_capturing."""
1129+ if not torch .cuda .is_available ():
1130+ return False
1131+ try :
1132+ return torch .cuda .is_current_stream_capturing ()
1133+ except Exception as error :
1134+ warnings .warn (
1135+ f"torch.cuda.is_current_stream_capturing() exited unexpectedly with the error message { error = } . "
1136+ f"Returning False by default."
1137+ )
1138+ return False
Original file line number Diff line number Diff line change 1515from torch .distributions import constraints
1616from torch .distributions .transforms import _InverseTransform
1717
18+ from torchrl ._utils import safe_is_current_stream_capturing
1819from torchrl .modules .distributions .truncated_normal import (
1920 TruncatedNormal as _TruncatedNormal ,
2021)
@@ -358,7 +359,7 @@ def __init__(
358359 event_dims = min (1 , loc .ndim )
359360
360361 err_msg = "TanhNormal high values must be strictly greater than low values"
361- if not is_compiling () and not torch . cuda . is_current_stream_capturing ():
362+ if not is_compiling () and not safe_is_current_stream_capturing ():
362363 if isinstance (high , torch .Tensor ) or isinstance (low , torch .Tensor ):
363364 if not (high > low ).all ():
364365 raise RuntimeError (err_msg )
@@ -377,7 +378,7 @@ def __init__(
377378 low = torch .as_tensor (low , device = loc .device )
378379 elif low .device != loc .device :
379380 low = low .to (loc .device )
380- if not is_compiling () and not torch . cuda . is_current_stream_capturing ():
381+ if not is_compiling () and not safe_is_current_stream_capturing ():
381382 self .non_trivial_max = (high != 1.0 ).any ()
382383 self .non_trivial_min = (low != - 1.0 ).any ()
383384 else :
You can’t perform that action at this time.
0 commit comments