Skip to content

Commit cc8abe7

Browse files
authored
[Versioning] Import implement_for from pyvers (#3166)
1 parent 66b9a21 commit cc8abe7

File tree

3 files changed

+5
-272
lines changed

3 files changed

+5
-272
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ classifiers = [
2929
requires-python = ">=3.9"
3030
dependencies = [
3131
"torch>=2.1.0",
32+
"pyvers",
3233
"numpy",
3334
"packaging",
3435
"cloudpickle",

torchrl/_utils.py

Lines changed: 2 additions & 271 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
import warnings
1919
from collections.abc import Callable
2020
from contextlib import nullcontext
21-
from copy import copy
2221
from functools import wraps
23-
from importlib import import_module
2422
from textwrap import indent
2523
from typing import Any, cast, TypeVar
2624

2725
import numpy as np
2826
import torch
29-
from packaging.version import parse
27+
28+
from pyvers import implement_for # noqa: F401
3029
from tensordict import unravel_key
3130
from tensordict.utils import NestedKey
3231
from torch import multiprocessing as mp, Tensor
@@ -390,274 +389,6 @@ def __repr__(self):
390389
_CKPT_BACKEND = _Dynamic_CKPT_BACKEND()
391390

392391

393-
class implement_for:
394-
"""A version decorator that checks the version in the environment and implements a function with the fitting one.
395-
396-
If specified module is missing or there is no fitting implementation, call of the decorated function
397-
will lead to the explicit error.
398-
In case of intersected ranges, last fitting implementation is used.
399-
400-
This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium,
401-
numpy vs jax-numpy etc).
402-
403-
Args:
404-
module_name (str or callable): version is checked for the module with this
405-
name (e.g. "gym"). If a callable is provided, it should return the
406-
module.
407-
from_version: version from which implementation is compatible. Can be open (None).
408-
to_version: version from which implementation is no longer compatible. Can be open (None).
409-
410-
Keyword Args:
411-
class_method (bool, optional): if ``True``, the function will be written as a class method.
412-
Defaults to ``False``.
413-
compilable (bool, optional): If ``False``, the module import happens
414-
only on the first call to the wrapped function. If ``True``, the
415-
module import happens when the wrapped function is initialized. This
416-
allows the wrapped function to work well with ``torch.compile``.
417-
Defaults to ``False``.
418-
419-
Examples:
420-
>>> @implement_for("gym", "0.13", "0.14")
421-
>>> def fun(self, x):
422-
... # Older gym versions will return x + 1
423-
... return x + 1
424-
...
425-
>>> @implement_for("gym", "0.14", "0.23")
426-
>>> def fun(self, x):
427-
... # More recent gym versions will return x + 2
428-
... return x + 2
429-
...
430-
>>> @implement_for(lambda: import_module("gym"), "0.23", None)
431-
>>> def fun(self, x):
432-
... # More recent gym versions will return x + 2
433-
... return x + 2
434-
...
435-
>>> @implement_for("gymnasium", None, "1.0.0")
436-
>>> def fun(self, x):
437-
... # If gymnasium is to be used instead of gym, x+3 will be returned
438-
... return x + 3
439-
...
440-
441-
This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
442-
"""
443-
444-
# Stores pointers to fitting implementations: dict[func_name] = func_pointer
445-
_implementations = {}
446-
_setters = []
447-
_cache_modules = {}
448-
449-
def __init__(
450-
self,
451-
module_name: str | Callable,
452-
from_version: str | None = None,
453-
to_version: str | None = None,
454-
*,
455-
class_method: bool = False,
456-
compilable: bool = False,
457-
):
458-
self.module_name = module_name
459-
self.from_version = from_version
460-
self.to_version = to_version
461-
self.class_method = class_method
462-
self._compilable = compilable
463-
implement_for._setters.append(self)
464-
465-
@staticmethod
466-
def check_version(version: str, from_version: str | None, to_version: str | None):
467-
version = parse(".".join([str(v) for v in parse(version).release]))
468-
return (from_version is None or version >= parse(from_version)) and (
469-
to_version is None or version < parse(to_version)
470-
)
471-
472-
@staticmethod
473-
def get_class_that_defined_method(f):
474-
"""Returns the class of a method, if it is defined, and None otherwise."""
475-
out = f.__globals__.get(f.__qualname__.split(".")[0], None)
476-
return out
477-
478-
@classmethod
479-
def get_func_name(cls, fn):
480-
# produces a name like torchrl.module.Class.method or torchrl.module.function
481-
fn_str = str(fn).split(".")
482-
if fn_str[0].startswith("<bound method "):
483-
first = fn_str[0][len("<bound method ") :]
484-
elif fn_str[0].startswith("<function "):
485-
first = fn_str[0][len("<function ") :]
486-
else:
487-
raise RuntimeError(f"Unknown func representation {fn}")
488-
last = fn_str[1:]
489-
if last:
490-
first = [first]
491-
last[-1] = last[-1].split(" ")[0]
492-
else:
493-
last = [first.split(" ")[0]]
494-
first = []
495-
return ".".join([fn.__module__] + first + last)
496-
497-
def _get_cls(self, fn):
498-
cls = self.get_class_that_defined_method(fn)
499-
if cls is None:
500-
# class not yet defined
501-
return
502-
if cls.__class__.__name__ == "function":
503-
cls = inspect.getmodule(fn)
504-
return cls
505-
506-
def module_set(self):
507-
"""Sets the function in its module, if it exists already."""
508-
prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
509-
if prev_setter is not None:
510-
prev_setter.do_set = False
511-
type(self)._implementations[self.get_func_name(self.fn)] = self
512-
cls = self.get_class_that_defined_method(self.fn)
513-
if cls is not None:
514-
if cls.__class__.__name__ == "function":
515-
cls = inspect.getmodule(self.fn)
516-
else:
517-
# class not yet defined
518-
return
519-
try:
520-
delattr(cls, self.fn.__name__)
521-
except AttributeError:
522-
pass
523-
524-
name = self.fn.__name__
525-
if self.class_method:
526-
fn = classmethod(self.fn)
527-
else:
528-
fn = self.fn
529-
setattr(cls, name, fn)
530-
531-
@classmethod
532-
def import_module(cls, module_name: Callable | str) -> str:
533-
"""Imports module and returns its version."""
534-
if not callable(module_name):
535-
module = cls._cache_modules.get(module_name, None)
536-
if module is None:
537-
if module_name in sys.modules:
538-
sys.modules[module_name] = module = import_module(module_name)
539-
else:
540-
cls._cache_modules[module_name] = module = import_module(
541-
module_name
542-
)
543-
else:
544-
module = module_name()
545-
return module.__version__
546-
547-
_lazy_impl = collections.defaultdict(list)
548-
549-
def _delazify(self, func_name):
550-
out = None
551-
# Make a copy of the list to avoid issues when clearing during iteration
552-
lazy_calls = implement_for._lazy_impl[func_name][:]
553-
for local_call in lazy_calls:
554-
out = local_call()
555-
# Only clear for compilable decorators, since non-compilable decorators
556-
# need to keep the list to allow multiple lazy calls
557-
# Check if any of the decorators are compilable
558-
any_compilable = any(
559-
hasattr(call, "__self__") and call.__self__._compilable
560-
for call in lazy_calls
561-
)
562-
if any_compilable:
563-
implement_for._lazy_impl[func_name].clear()
564-
return out
565-
566-
def __call__(self, fn):
567-
# function names are unique
568-
self.func_name = self.get_func_name(fn)
569-
self.fn = fn
570-
implement_for._lazy_impl[self.func_name].append(self._call)
571-
572-
if self._compilable:
573-
_call_fn = self._delazify(self.func_name)
574-
575-
if self.class_method:
576-
return classmethod(_call_fn)
577-
578-
return _call_fn
579-
else:
580-
581-
@wraps(fn)
582-
def _lazy_call_fn(*args, **kwargs):
583-
# first time we call the function, we also do the replacement.
584-
# This will cause the imports to occur only during the first call to fn
585-
586-
result = self._delazify(self.func_name)(*args, **kwargs)
587-
return result
588-
589-
if self.class_method:
590-
return classmethod(_lazy_call_fn)
591-
592-
return _lazy_call_fn
593-
594-
def _call(self):
595-
596-
# If the module is missing replace the function with the mock.
597-
fn = self.fn
598-
func_name = self.func_name
599-
implementations = implement_for._implementations
600-
601-
@wraps(fn)
602-
def unsupported(*args, **kwargs):
603-
raise ModuleNotFoundError(
604-
f"Supported version of '{func_name}' has not been found."
605-
)
606-
607-
self.do_set = False
608-
# Return fitting implementation if it was encountered before.
609-
if func_name in implementations:
610-
try:
611-
# check that backends don't conflict
612-
version = self.import_module(self.module_name)
613-
if self.check_version(version, self.from_version, self.to_version):
614-
if VERBOSE:
615-
module = import_module(self.module_name)
616-
warnings.warn(
617-
f"Got multiple backends for {func_name}. "
618-
f"Using the last queried ({module} with version {version})."
619-
)
620-
self.do_set = True
621-
if not self.do_set:
622-
return implementations[func_name].fn
623-
except ModuleNotFoundError:
624-
# then it's ok, there is no conflict
625-
return implementations[func_name].fn
626-
else:
627-
try:
628-
version = self.import_module(self.module_name)
629-
if self.check_version(version, self.from_version, self.to_version):
630-
self.do_set = True
631-
except ModuleNotFoundError:
632-
return unsupported
633-
if self.do_set:
634-
self.module_set()
635-
return fn
636-
return unsupported
637-
638-
@classmethod
639-
def reset(cls, setters_dict: dict[str, implement_for] = None):
640-
"""Resets the setters in setter_dict.
641-
642-
``setter_dict`` is a copy of implementations. We just need to iterate through its
643-
values and call :meth:`module_set` for each.
644-
645-
"""
646-
if VERBOSE:
647-
logger.info("resetting implement_for")
648-
if setters_dict is None:
649-
setters_dict = copy(cls._implementations)
650-
for setter in setters_dict.values():
651-
setter.module_set()
652-
653-
def __repr__(self):
654-
return (
655-
f"{self.__class__.__name__}("
656-
f"module_name={self.module_name}({self.from_version, self.to_version}), "
657-
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)})"
658-
)
659-
660-
661392
def accept_remote_rref_invocation(func):
662393
"""Decorator that allows a method to be invoked remotely.
663394

torchrl/trainers/trainers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,7 @@ class LogScalar(TrainerHookBase):
10511051
Args:
10521052
key (NestedKey): the key where to find the value in the input batch.
10531053
Can be a string for simple keys or a tuple for nested keys.
1054+
Default is `torchrl.trainers.trainers.REWARD_KEY` (= `("next", "reward")`).
10541055
logname (str, optional): name of the metric to be logged. If None, will use
10551056
the key as the log name. Default is None.
10561057
log_pbar (bool, optional): if ``True``, the value will be logged on
@@ -1077,7 +1078,7 @@ class LogScalar(TrainerHookBase):
10771078

10781079
def __init__(
10791080
self,
1080-
key: NestedKey,
1081+
key: NestedKey = REWARD_KEY,
10811082
logname: str | None = None,
10821083
log_pbar: bool = False,
10831084
include_std: bool = True,

0 commit comments

Comments
 (0)