|
18 | 18 | import warnings |
19 | 19 | from collections.abc import Callable |
20 | 20 | from contextlib import nullcontext |
21 | | -from copy import copy |
22 | 21 | from functools import wraps |
23 | | -from importlib import import_module |
24 | 22 | from textwrap import indent |
25 | 23 | from typing import Any, cast, TypeVar |
26 | 24 |
|
27 | 25 | import numpy as np |
28 | 26 | import torch |
29 | | -from packaging.version import parse |
| 27 | + |
| 28 | +from pyvers import implement_for # noqa: F401 |
30 | 29 | from tensordict import unravel_key |
31 | 30 | from tensordict.utils import NestedKey |
32 | 31 | from torch import multiprocessing as mp, Tensor |
@@ -390,274 +389,6 @@ def __repr__(self): |
390 | 389 | _CKPT_BACKEND = _Dynamic_CKPT_BACKEND() |
391 | 390 |
|
392 | 391 |
|
393 | | -class implement_for: |
394 | | - """A version decorator that checks the version in the environment and implements a function with the fitting one. |
395 | | -
|
396 | | - If specified module is missing or there is no fitting implementation, call of the decorated function |
397 | | - will lead to the explicit error. |
398 | | - In case of intersected ranges, last fitting implementation is used. |
399 | | -
|
400 | | - This wrapper also works to implement different backends for a same function (eg. gym vs gymnasium, |
401 | | - numpy vs jax-numpy etc). |
402 | | -
|
403 | | - Args: |
404 | | - module_name (str or callable): version is checked for the module with this |
405 | | - name (e.g. "gym"). If a callable is provided, it should return the |
406 | | - module. |
407 | | - from_version: version from which implementation is compatible. Can be open (None). |
408 | | - to_version: version from which implementation is no longer compatible. Can be open (None). |
409 | | -
|
410 | | - Keyword Args: |
411 | | - class_method (bool, optional): if ``True``, the function will be written as a class method. |
412 | | - Defaults to ``False``. |
413 | | - compilable (bool, optional): If ``False``, the module import happens |
414 | | - only on the first call to the wrapped function. If ``True``, the |
415 | | - module import happens when the wrapped function is initialized. This |
416 | | - allows the wrapped function to work well with ``torch.compile``. |
417 | | - Defaults to ``False``. |
418 | | -
|
419 | | - Examples: |
420 | | - >>> @implement_for("gym", "0.13", "0.14") |
421 | | - >>> def fun(self, x): |
422 | | - ... # Older gym versions will return x + 1 |
423 | | - ... return x + 1 |
424 | | - ... |
425 | | - >>> @implement_for("gym", "0.14", "0.23") |
426 | | - >>> def fun(self, x): |
427 | | - ... # More recent gym versions will return x + 2 |
428 | | - ... return x + 2 |
429 | | - ... |
430 | | - >>> @implement_for(lambda: import_module("gym"), "0.23", None) |
431 | | - >>> def fun(self, x): |
432 | | - ... # More recent gym versions will return x + 2 |
433 | | - ... return x + 2 |
434 | | - ... |
435 | | - >>> @implement_for("gymnasium", None, "1.0.0") |
436 | | - >>> def fun(self, x): |
437 | | - ... # If gymnasium is to be used instead of gym, x+3 will be returned |
438 | | - ... return x + 3 |
439 | | - ... |
440 | | -
|
441 | | - This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+. |
442 | | - """ |
443 | | - |
444 | | - # Stores pointers to fitting implementations: dict[func_name] = func_pointer |
445 | | - _implementations = {} |
446 | | - _setters = [] |
447 | | - _cache_modules = {} |
448 | | - |
449 | | - def __init__( |
450 | | - self, |
451 | | - module_name: str | Callable, |
452 | | - from_version: str | None = None, |
453 | | - to_version: str | None = None, |
454 | | - *, |
455 | | - class_method: bool = False, |
456 | | - compilable: bool = False, |
457 | | - ): |
458 | | - self.module_name = module_name |
459 | | - self.from_version = from_version |
460 | | - self.to_version = to_version |
461 | | - self.class_method = class_method |
462 | | - self._compilable = compilable |
463 | | - implement_for._setters.append(self) |
464 | | - |
465 | | - @staticmethod |
466 | | - def check_version(version: str, from_version: str | None, to_version: str | None): |
467 | | - version = parse(".".join([str(v) for v in parse(version).release])) |
468 | | - return (from_version is None or version >= parse(from_version)) and ( |
469 | | - to_version is None or version < parse(to_version) |
470 | | - ) |
471 | | - |
472 | | - @staticmethod |
473 | | - def get_class_that_defined_method(f): |
474 | | - """Returns the class of a method, if it is defined, and None otherwise.""" |
475 | | - out = f.__globals__.get(f.__qualname__.split(".")[0], None) |
476 | | - return out |
477 | | - |
478 | | - @classmethod |
479 | | - def get_func_name(cls, fn): |
480 | | - # produces a name like torchrl.module.Class.method or torchrl.module.function |
481 | | - fn_str = str(fn).split(".") |
482 | | - if fn_str[0].startswith("<bound method "): |
483 | | - first = fn_str[0][len("<bound method ") :] |
484 | | - elif fn_str[0].startswith("<function "): |
485 | | - first = fn_str[0][len("<function ") :] |
486 | | - else: |
487 | | - raise RuntimeError(f"Unknown func representation {fn}") |
488 | | - last = fn_str[1:] |
489 | | - if last: |
490 | | - first = [first] |
491 | | - last[-1] = last[-1].split(" ")[0] |
492 | | - else: |
493 | | - last = [first.split(" ")[0]] |
494 | | - first = [] |
495 | | - return ".".join([fn.__module__] + first + last) |
496 | | - |
497 | | - def _get_cls(self, fn): |
498 | | - cls = self.get_class_that_defined_method(fn) |
499 | | - if cls is None: |
500 | | - # class not yet defined |
501 | | - return |
502 | | - if cls.__class__.__name__ == "function": |
503 | | - cls = inspect.getmodule(fn) |
504 | | - return cls |
505 | | - |
506 | | - def module_set(self): |
507 | | - """Sets the function in its module, if it exists already.""" |
508 | | - prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None) |
509 | | - if prev_setter is not None: |
510 | | - prev_setter.do_set = False |
511 | | - type(self)._implementations[self.get_func_name(self.fn)] = self |
512 | | - cls = self.get_class_that_defined_method(self.fn) |
513 | | - if cls is not None: |
514 | | - if cls.__class__.__name__ == "function": |
515 | | - cls = inspect.getmodule(self.fn) |
516 | | - else: |
517 | | - # class not yet defined |
518 | | - return |
519 | | - try: |
520 | | - delattr(cls, self.fn.__name__) |
521 | | - except AttributeError: |
522 | | - pass |
523 | | - |
524 | | - name = self.fn.__name__ |
525 | | - if self.class_method: |
526 | | - fn = classmethod(self.fn) |
527 | | - else: |
528 | | - fn = self.fn |
529 | | - setattr(cls, name, fn) |
530 | | - |
531 | | - @classmethod |
532 | | - def import_module(cls, module_name: Callable | str) -> str: |
533 | | - """Imports module and returns its version.""" |
534 | | - if not callable(module_name): |
535 | | - module = cls._cache_modules.get(module_name, None) |
536 | | - if module is None: |
537 | | - if module_name in sys.modules: |
538 | | - sys.modules[module_name] = module = import_module(module_name) |
539 | | - else: |
540 | | - cls._cache_modules[module_name] = module = import_module( |
541 | | - module_name |
542 | | - ) |
543 | | - else: |
544 | | - module = module_name() |
545 | | - return module.__version__ |
546 | | - |
547 | | - _lazy_impl = collections.defaultdict(list) |
548 | | - |
549 | | - def _delazify(self, func_name): |
550 | | - out = None |
551 | | - # Make a copy of the list to avoid issues when clearing during iteration |
552 | | - lazy_calls = implement_for._lazy_impl[func_name][:] |
553 | | - for local_call in lazy_calls: |
554 | | - out = local_call() |
555 | | - # Only clear for compilable decorators, since non-compilable decorators |
556 | | - # need to keep the list to allow multiple lazy calls |
557 | | - # Check if any of the decorators are compilable |
558 | | - any_compilable = any( |
559 | | - hasattr(call, "__self__") and call.__self__._compilable |
560 | | - for call in lazy_calls |
561 | | - ) |
562 | | - if any_compilable: |
563 | | - implement_for._lazy_impl[func_name].clear() |
564 | | - return out |
565 | | - |
566 | | - def __call__(self, fn): |
567 | | - # function names are unique |
568 | | - self.func_name = self.get_func_name(fn) |
569 | | - self.fn = fn |
570 | | - implement_for._lazy_impl[self.func_name].append(self._call) |
571 | | - |
572 | | - if self._compilable: |
573 | | - _call_fn = self._delazify(self.func_name) |
574 | | - |
575 | | - if self.class_method: |
576 | | - return classmethod(_call_fn) |
577 | | - |
578 | | - return _call_fn |
579 | | - else: |
580 | | - |
581 | | - @wraps(fn) |
582 | | - def _lazy_call_fn(*args, **kwargs): |
583 | | - # first time we call the function, we also do the replacement. |
584 | | - # This will cause the imports to occur only during the first call to fn |
585 | | - |
586 | | - result = self._delazify(self.func_name)(*args, **kwargs) |
587 | | - return result |
588 | | - |
589 | | - if self.class_method: |
590 | | - return classmethod(_lazy_call_fn) |
591 | | - |
592 | | - return _lazy_call_fn |
593 | | - |
594 | | - def _call(self): |
595 | | - |
596 | | - # If the module is missing replace the function with the mock. |
597 | | - fn = self.fn |
598 | | - func_name = self.func_name |
599 | | - implementations = implement_for._implementations |
600 | | - |
601 | | - @wraps(fn) |
602 | | - def unsupported(*args, **kwargs): |
603 | | - raise ModuleNotFoundError( |
604 | | - f"Supported version of '{func_name}' has not been found." |
605 | | - ) |
606 | | - |
607 | | - self.do_set = False |
608 | | - # Return fitting implementation if it was encountered before. |
609 | | - if func_name in implementations: |
610 | | - try: |
611 | | - # check that backends don't conflict |
612 | | - version = self.import_module(self.module_name) |
613 | | - if self.check_version(version, self.from_version, self.to_version): |
614 | | - if VERBOSE: |
615 | | - module = import_module(self.module_name) |
616 | | - warnings.warn( |
617 | | - f"Got multiple backends for {func_name}. " |
618 | | - f"Using the last queried ({module} with version {version})." |
619 | | - ) |
620 | | - self.do_set = True |
621 | | - if not self.do_set: |
622 | | - return implementations[func_name].fn |
623 | | - except ModuleNotFoundError: |
624 | | - # then it's ok, there is no conflict |
625 | | - return implementations[func_name].fn |
626 | | - else: |
627 | | - try: |
628 | | - version = self.import_module(self.module_name) |
629 | | - if self.check_version(version, self.from_version, self.to_version): |
630 | | - self.do_set = True |
631 | | - except ModuleNotFoundError: |
632 | | - return unsupported |
633 | | - if self.do_set: |
634 | | - self.module_set() |
635 | | - return fn |
636 | | - return unsupported |
637 | | - |
638 | | - @classmethod |
639 | | - def reset(cls, setters_dict: dict[str, implement_for] = None): |
640 | | - """Resets the setters in setter_dict. |
641 | | -
|
642 | | - ``setter_dict`` is a copy of implementations. We just need to iterate through its |
643 | | - values and call :meth:`module_set` for each. |
644 | | -
|
645 | | - """ |
646 | | - if VERBOSE: |
647 | | - logger.info("resetting implement_for") |
648 | | - if setters_dict is None: |
649 | | - setters_dict = copy(cls._implementations) |
650 | | - for setter in setters_dict.values(): |
651 | | - setter.module_set() |
652 | | - |
653 | | - def __repr__(self): |
654 | | - return ( |
655 | | - f"{self.__class__.__name__}(" |
656 | | - f"module_name={self.module_name}({self.from_version, self.to_version}), " |
657 | | - f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)})" |
658 | | - ) |
659 | | - |
660 | | - |
661 | 392 | def accept_remote_rref_invocation(func): |
662 | 393 | """Decorator that allows a method to be invoked remotely. |
663 | 394 |
|
|
0 commit comments