From 50fa2b0bee5313855b324d0864e414fb977edf54 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Fri, 31 Oct 2025 15:11:28 +0800 Subject: [PATCH 01/12] migrate test_utils, test_schema_check for xpu and add checkpoint.py for test_utils --- test/xpu/checkpoint.py | 1608 +++++++++++++++++++++++++++++++++ test/xpu/test_schema_check.py | 514 +++++++++++ test/xpu/test_utils.py | 1022 +++++++++++++++++++++ 3 files changed, 3144 insertions(+) create mode 100644 test/xpu/checkpoint.py create mode 100644 test/xpu/test_schema_check.py create mode 100644 test/xpu/test_utils.py diff --git a/test/xpu/checkpoint.py b/test/xpu/checkpoint.py new file mode 100644 index 0000000000..6fc9af6a4f --- /dev/null +++ b/test/xpu/checkpoint.py @@ -0,0 +1,1608 @@ +# mypy: allow-untyped-defs +import contextlib +import platform +import uuid +import warnings +import weakref +from collections import defaultdict +from typing import * # noqa: F403 +import enum +from weakref import ReferenceType + +import torch +import torch.fx.traceback as fx_traceback +from torch.utils._pytree import tree_map +from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode +from torch.utils._python_dispatch import TorchDispatchMode + +__all__ = [ + "checkpoint", + "checkpoint_sequential", + "CheckpointError", + "CheckpointFunction", + "check_backward_validity", + "detach_variable", + "get_device_states", + "set_device_states", + "noop_context_fn", + "set_checkpoint_early_stop", + "DefaultDeviceType", + "set_checkpoint_debug_enabled", + "CheckpointPolicy", + "SelectiveCheckpointContext", + "create_selective_checkpoint_contexts", + "SAC_IGNORED_OPS", +] + +_DEFAULT_DETERMINISM_MODE = "default" + +_checkpoint_debug_enabled: Optional[bool] = None + + +@contextlib.contextmanager +def set_checkpoint_debug_enabled(enabled: Optional[bool]): + """ + Context manager that sets whether checkpoint should print additional debug + information when running. See the ``debug`` flag for + :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that + when set, this context manager overrides the value of ``debug`` passed to + checkpoint. To defer to the local setting, pass ``None`` to this context. + + Args: + enabled (bool): Whether checkpoint should print debug information. + Default is 'None'. + """ + global _checkpoint_debug_enabled + try: + prev = _checkpoint_debug_enabled + _checkpoint_debug_enabled = enabled + yield + finally: + _checkpoint_debug_enabled = prev + + +def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: + if isinstance(inputs, tuple): + out = [] + for inp in inputs: + if not isinstance(inp, torch.Tensor): + out.append(inp) + continue + + x = inp.detach() + x.requires_grad = inp.requires_grad + out.append(x) + return tuple(out) + else: + raise RuntimeError( + "Only tuple of tensors is supported. Got Unsupported input type: ", + type(inputs).__name__, + ) + + +def check_backward_validity(inputs: Iterable[Any]) -> None: + if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): + warnings.warn( + "None of the inputs have requires_grad=True. Gradients will be None", stacklevel=2 + ) + + +def _get_device_module(device="cuda"): + if device == "meta": + return torch.device("meta") + device_module = getattr(torch, device) + return device_module + + +class DefaultDeviceType: + r""" + A class that manages the default device type for checkpointing. + + If no non-CPU tensors are present, the default device type will + be used. The default value is 'cuda'. The device type is used in + the checkpointing process when determining which device states + to save and restore for recomputation. + """ + + _default_device_type = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" + + @staticmethod + def set_device_type(device: str = "cuda"): + """ + Set the default device type for checkpointing. + + Args: + device (str): The device type to be set as default. Default is 'cuda'. + """ + DefaultDeviceType._default_device_type = device + + @staticmethod + def get_device_type() -> str: + """ + Get the current default device type for checkpointing. + + Returns: + str: The current default device type. + """ + return DefaultDeviceType._default_device_type + + +def _infer_device_type(*args): + device_types = [] + + def add_device_types(arg): + nonlocal device_types + if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": + device_types.append(arg.device.type) + tree_map(add_device_types, args) + + device_types_set = set(device_types) + if len(device_types_set) > 1: + warnings.warn( + "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " + "Device state will only be saved for devices of a single device type, and the remaining " + "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " + "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " + "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" + f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}", stacklevel=2 + ) + if len(device_types) == 0: + return DefaultDeviceType.get_device_type() + elif "cuda" in device_types_set: + return "cuda" + else: + return device_types[0] + + +# We can't know if the run_fn will internally move some args to different devices, +# which would require logic to preserve rng states for those devices as well. +# We could paranoically stash and restore ALL the rng states for all visible devices, +# but that seems very wasteful for most cases. Compromise: Stash the RNG state for +# the device of all Tensor args. +# +# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? +def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: + # This will not error out if "arg" is a CPU tensor or a non-tensor type because + # the conditionals short-circuit. + fwd_device_ids = [] + + def add_device_ids(arg): + nonlocal fwd_device_ids + if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: + fwd_device_ids.append(arg.get_device()) + tree_map(add_device_ids, args) + + fwd_device_states = [] + device_module = _get_device_module(_infer_device_type(*args)) + for device_id in fwd_device_ids: + with device_module.device(device_id): + fwd_device_states.append(device_module.get_rng_state()) + + return fwd_device_ids, fwd_device_states + + +def set_device_states(devices, states, *, device_type=None) -> None: + """Sets random number generator states for the specified devices. + + Args: + devices: Device ids to set states for. + states: States to set. + device_type: ``device_type`` of the devices to set states for. Default + is the device returned by a call to ``DefaultDeviceType.get_device_type()``, + which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. + """ + if device_type is None: + device_type = DefaultDeviceType.get_device_type() + if device_type == "meta": + return + device_module = _get_device_module(device_type) + for device, state in zip(devices, states): + with device_module.device(device): + device_module.set_rng_state(state) + + +def _get_autocast_kwargs(device_type="cuda"): + if torch.amp.is_autocast_available(device_type): + device_autocast_kwargs = { + "enabled": torch.is_autocast_enabled(device_type), + "dtype": torch.get_autocast_dtype(device_type), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + else: + device_autocast_kwargs = None + + cpu_autocast_kwargs = { + "enabled": torch.is_autocast_enabled('cpu'), + "dtype": torch.get_autocast_dtype('cpu'), + "cache_enabled": torch.is_autocast_cache_enabled(), + } + + return device_autocast_kwargs, cpu_autocast_kwargs + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, run_function, preserve_rng_state, *args): + check_backward_validity(args) + ctx.run_function = run_function + ctx.preserve_rng_state = preserve_rng_state + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + ctx.device_type = _infer_device_type(*args) + ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( + ctx.device_type + ) + if preserve_rng_state: + ctx.fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function.) + ctx.had_device_in_fwd = False + device_module = _get_device_module(ctx.device_type) + if getattr(device_module, "_initialized", False): + ctx.had_device_in_fwd = True + ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for i, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(i) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + ctx.save_for_backward(*tensor_inputs) + + with torch.no_grad(): + outputs = run_function(*args) + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "When use_reentrant=True, torch.utils.checkpoint is incompatible" + " with .grad() or passing an `inputs` parameter to .backward()." + " To resolve this error, you can either set use_reentrant=False," + " or call .backward() without passing the `inputs` argument." + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + for i, idx in enumerate(tensor_indices): + inputs[idx] = tensors[i] + + # Stash the surrounding rng state, and mimic the state that was + # present at this time during forward. Restore the surrounding state + # when we're done. + rng_devices = [] + if ctx.preserve_rng_state and ctx.had_device_in_fwd: + rng_devices = ctx.fwd_devices + with torch.random.fork_rng( + devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type + ): + if ctx.preserve_rng_state: + torch.set_rng_state(ctx.fwd_cpu_state) + if ctx.had_device_in_fwd: + set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) + detached_inputs = detach_variable(tuple(inputs)) + + device_autocast_ctx = torch.amp.autocast( + device_type=ctx.device_type, **ctx.device_autocast_kwargs + ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] + outputs = ctx.run_function(*detached_inputs) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for i in range(len(outputs)): + if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: + outputs_with_grad.append(outputs[i]) + args_with_grad.append(args[i]) + if len(outputs_with_grad) == 0: + raise RuntimeError( + "none of output has requires_grad=True," + " this checkpoint() is not necessary" + ) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + + return (None, None) + grads + + +def noop_context_fn(): + return contextlib.nullcontext(), contextlib.nullcontext() + +# Note: [torch.compile and checkpoint] +# TorchDynamo does not step inside utils.checkpoint function. The flow +# looks likes this +# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by +# speculatively checking if the forward function is safe to trace. +# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher +# order op. As a result, TorchDynamo does not look inside utils.checkpoint. +# 3) If not, then TorchDynamo falls back to eager by performing a graph +# break. And here, the following disable wrapper ensures that +# TorchDynamo does not trigger again on the frames created by +# utils.checkpoint innards. +@torch._disable_dynamo +def checkpoint( + function, + *args, + use_reentrant: Optional[bool] = None, + context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, + determinism_check: str = _DEFAULT_DETERMINISM_MODE, + debug: bool = False, + early_stop: bool = True, + **kwargs +): + r"""Checkpoint a model or part of the model. + + Activation checkpointing is a technique that trades compute for memory. + Instead of keeping tensors needed for backward alive until they are used in + gradient computation during backward, forward computation in checkpointed + regions omits saving tensors for backward and recomputes them during the + backward pass. Activation checkpointing can be applied to any part of a + model. + + There are currently two checkpointing implementations available, determined + by the :attr:`use_reentrant` parameter. It is recommended that you use + ``use_reentrant=False``. Please refer the note below for a discussion of + their differences. + + .. warning:: + + If the :attr:`function` invocation during the backward pass differs + from the forward pass, e.g., due to a global variable, the checkpointed + version may not be equivalent, potentially causing an + error being raised or leading to silently incorrect gradients. + + .. warning:: + + The ``use_reentrant`` parameter should be passed explicitly. In version + 2.9 we will raise an exception if ``use_reentrant`` is not passed. + If you are using the ``use_reentrant=True`` variant, please refer to the + note below for important considerations and potential limitations. + + .. note:: + + The reentrant variant of checkpoint (``use_reentrant=True``) and + the non-reentrant variant of checkpoint (``use_reentrant=False``) + differ in the following ways: + + * Non-reentrant checkpoint stops recomputation as soon as all needed + intermediate activations have been recomputed. This feature is enabled + by default, but can be disabled with :func:`set_checkpoint_early_stop`. + Reentrant checkpoint always recomputes :attr:`function` in its + entirety during the backward pass. + + * The reentrant variant does not record the autograd graph during the + forward pass, as it runs with the forward pass under + :func:`torch.no_grad`. The non-reentrant version does record the + autograd graph, allowing one to perform backward on the graph within + checkpointed regions. + + * The reentrant checkpoint only supports the + :func:`torch.autograd.backward` API for the backward pass without its + `inputs` argument, while the non-reentrant version supports all ways + of performing the backward pass. + + * At least one input and output must have ``requires_grad=True`` for the + reentrant variant. If this condition is unmet, the checkpointed part + of the model will not have gradients. The non-reentrant version does + not have this requirement. + + * The reentrant version does not consider tensors in nested structures + (e.g., custom objects, lists, dicts, etc) as participating in + autograd, while the non-reentrant version does. + + * The reentrant checkpoint does not support checkpointed regions with + detached tensors from the computational graph, whereas the + non-reentrant version does. For the reentrant variant, if the + checkpointed segment contains tensors detached using ``detach()`` or + with :func:`torch.no_grad`, the backward pass will raise an error. + This is because ``checkpoint`` makes all the outputs require gradients + and this causes issues when a tensor is defined to have no gradient in + the model. To avoid this, detach the tensors outside of the + ``checkpoint`` function. + + Args: + function: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + args: tuple containing inputs to the :attr:`function` + + Keyword args: + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. Note that under torch.compile, + this flag doesn't take effect and we always preserve RNG state. + Default: ``True`` + use_reentrant(bool): + specify whether to use the activation checkpoint variant that + requires reentrant autograd. This parameter should be passed + explicitly. In version 2.9 we will raise an exception if + ``use_reentrant`` is not passed. If ``use_reentrant=False``, + ``checkpoint`` will use an implementation that does not require + reentrant autograd. This allows ``checkpoint`` to support additional + functionality, such as working as expected with + ``torch.autograd.grad`` and support for keyword arguments input into + the checkpointed function. + context_fn(Callable, optional): A callable returning a tuple of two + context managers. The function and its recomputation will be run + under the first and second context managers respectively. + This argument is only supported if ``use_reentrant=False``. + determinism_check(str, optional): A string specifying the determinism + check to perform. By default it is set to ``"default"`` which + compares the shapes, dtypes, and devices of the recomputed tensors + against those the saved tensors. To turn off this check, specify + ``"none"``. Currently these are the only two supported values. + Please open an issue if you would like to see more determinism + checks. This argument is only supported if ``use_reentrant=False``, + if ``use_reentrant=True``, the determinism check is always disabled. + debug(bool, optional): If ``True``, error messages will also include + a trace of the operators ran during the original forward computation + as well as the recomputation. This argument is only supported if + ``use_reentrant=False``. + early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops + recomputation as soon as it has computed all needed Tensors. This + argument is ignored if ``use_reentrant=True``. Can be overridden + globally using :func:`set_checkpoint_early_stop` context manager. + Default: ``True``. + + Returns: + Output of running :attr:`function` on :attr:`*args` + """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint: the use_reentrant parameter should be " + "passed explicitly. Starting in PyTorch 2.9, calling checkpoint " + "without use_reentrant will raise an exception. use_reentrant=False is " + "recommended, but if you need to preserve the current default " + "behavior, you can pass use_reentrant=True. Refer to docs for more " + "details on the differences between the two variants.", + stacklevel=2 + ) + use_reentrant = True + + # Hack to mix *args with **kwargs in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs and use_reentrant: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + if use_reentrant: + if context_fn is not noop_context_fn or debug is not False: + raise ValueError( + "Passing `context_fn` or `debug` is only supported when " + "use_reentrant=False." + ) + return CheckpointFunction.apply(function, preserve, *args) + else: + gen = _checkpoint_without_reentrant_generator( + function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs + ) + # Runs pre-forward logic + next(gen) + ret = function(*args, **kwargs) + # Runs post-forward logic + try: + next(gen) + except StopIteration: + return ret + + +def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): + r"""Checkpoint a sequential model to save memory. + + Sequential models execute a list of modules/functions in order + (sequentially). Therefore, we can divide such a model in various segments + and checkpoint each segment. All segments except the last will not store + the intermediate activations. The inputs of each checkpointed segment will + be saved for re-running the segment in the backward pass. + + .. warning:: + The ``use_reentrant`` parameter should be passed explicitly. In version + 2.9 we will raise an exception if ``use_reentrant`` is not passed. + If you are using the ``use_reentrant=True` variant, please see + :func:`~torch.utils.checkpoint.checkpoint` for + the important considerations and limitations of this variant. It is + recommended that you use ``use_reentrant=False``. + + .. warning: + Since PyTorch 1.4, it allows only one Tensor as the input and + intermediate outputs, just like :class:`torch.nn.Sequential`. + + Args: + functions: A :class:`torch.nn.Sequential` or the list of modules or + functions (comprising the model) to run sequentially. + segments: Number of chunks to create in the model + input: A Tensor that is input to :attr:`functions` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + use_reentrant(bool): + specify whether to use the activation checkpoint variant that + requires reentrant autograd. This parameter should be passed + explicitly. In version 2.5 we will raise an exception if + ``use_reentrant`` is not passed. If ``use_reentrant=False``, + ``checkpoint`` will use an implementation that does not require + reentrant autograd. This allows ``checkpoint`` to support additional + functionality, such as working as expected with + ``torch.autograd.grad`` and support for keyword arguments input into + the checkpointed function. + + Returns: + Output of running :attr:`functions` sequentially on :attr:`*inputs` + + Example: + >>> # xdoctest: +SKIP("stub") + >>> model = nn.Sequential(...) + >>> input_var = checkpoint_sequential(model, chunks, input_var) + """ + if use_reentrant is None: + warnings.warn( + "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " + "parameter should be passed explicitly. " + "In version 2.9 we will raise an exception if use_reentrant " + "is not passed. use_reentrant=False is " + "recommended, but if you need to preserve the current default " + "behavior, you can pass use_reentrant=True. Refer to docs for more " + "details on the differences between the two variants.", stacklevel=2 + ) + use_reentrant = True + + # Hack for keyword-only parameter in a python 2.7-compliant way + preserve = kwargs.pop("preserve_rng_state", True) + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def run_function(start, end, functions): + def forward(input): + for j in range(start, end + 1): + input = functions[j](input) + return input + + return forward + + if isinstance(functions, torch.nn.Sequential): + functions = list(functions.children()) + + segment_size = len(functions) // segments + # the last chunk has to be non-volatile + end = -1 + for start in range(0, segment_size * (segments - 1), segment_size): + end = start + segment_size - 1 + input = checkpoint( + run_function(start, end, functions), + input, + use_reentrant=use_reentrant, + preserve_rng_state=preserve, + ) + return run_function(end + 1, len(functions) - 1, functions)(input) + + +def _internal_assert(cond): + if not cond: + raise AssertionError( + "Something went unexpectedly wrong in activation checkpoint. " + "Please report this bug by filing an issue to PyTorch." + ) + + +# NOTE [ Nestable Checkpoint ] +# +# The semantics of nested checkpoint can be defined by two basic rules. +# Following the two rules leads to an important implication that is central +# to motivating the design. +# +# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden +# from any outer layers of checkpoint. +# +# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its +# parent checkpoint. +# +# Implication: To recompute any given saved tensor, we need to recompute all of +# the checkpoints wrapping it. +# +# Why is this implied? To unpack a saved tensor X during backward we need to +# recompute the inner-most checkpoint (#1), and in order to recompute that +# checkpoint I need to have its inputs, which are managed by that checkpoint's +# parent (#2), which thus also needs to be recomputed first. Continue this line +# of reasoning and we realize that in order to unpack X, all checkpoints that +# were active at the time X was saved need to be recomputed. (unless we have +# already done so in that backward for some other saved tensor). +# +# In practice, we use a noop autograd Function to save inputs as saved tensors. +# During unpack calling ctx.saved_tensor triggers the parent checkpoint to +# recompute. +# +# Rule 3. We should start recomputation as if there are no checkpoints currently +# active. Checkpoints encountered during recomputation are still +# respected. +# +# When we start recomputation, we push the saved variable hook meant for +# recomputation on the stack. See examples in Rule 6 for more context. +# +# * * * * +# +# Beyond the basic semantics specific to nested checkpoint, we impose several +# more constraints that may apply to checkpointing in general. +# +# Rule 4. Lifetime of recomputed tensors +# +# Recomputed tensors are considered specific to particular invocations +# of backward and are always cleared immediately as they are unpacked +# Particularly, we require this to happen even if retain_graph=True. +# +# [ Implementation details of Rule 4 ] +# +# If we were okay with recomputed tensors staying alive after backward is run +# with retain_graph=True, we would store recomputed variables as the values of a +# WeakKeyDictionary and pack strong references to the keys, so that as we +# backward, those packed keys would be cleared as long as retain_graph=False. +# Clearing the packed key clears the corresponding entry in the WKD. +# +# If we wish recomputed variables to be immediately cleared as we unpack them in +# the retain_graph=True case, we cannot rely on the packed keys to be cleared by +# backward automatically. Instead of packing the strong reference to the key +# directly, we pack a container object, which we manually clear as we unpack. +# +# An important detail is that if a second backward happens, the second +# recomputation needs to reset the container with a newly created key. +# +# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we +# know we need. +# +# [ Implementation details of Rule 5 ] +# +# During recomputation, raise an exception if the number of recomputed tensors +# matches the number of tensors that we expected to recompute. We wrap the +# recomputation call with a try-catch to catch this specific exception. See +# Rule #6 below for some examples. +# +# Rule 6. We support doing backward inside checkpoint context +# +# [ retain_graph is True] +# +# def fn(x): +# y = x.sin() +# z = y.cos() +# gx, = torch.autograd.grad(z, x, retains_grad=True) +# return gx, z +# +# out = checkpoint(fn)(inp) +# out.backward() +# +# Because z is saved by cos while checkpoint is enabled, it would not be +# actually saved, and so the .grad() call inside must trigger a recomputation. +# +# During recomputation the "inner pack hook" has two responsibilities: +# +# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors +# 2) Pack the actual tensor (detached) so that one may perform backward on the +# recomputed graph. The tensors saved to this graph will live until the end +# of recomputation, or die earlier if someone performs backward with +# retain_graph=False. +# +# More generally performing backward on the recomputed graph occurs in the +# following cases: +# - If backward is performed inside forward, +# - During the original forward IF early-stop is disabled +# - During the original backward +# - If there are multiple .grad()/.backward() calls, we would perform backward +# on the recomputed graph even if early-stop is enabled (see the example below) +# +# [ retain_graph is False ] +# +# The example below shows what happens if during recomputation we find that some +# of the tensors we are trying to recompute have already been cleared. +# +# Spoiler: we don't do anything special, we just skip over them! +# +# def fn(x): +# y = x.sin() # (1) +# z = y.cos() # (2) +# gx, = torch.autograd.grad(z, x) # (3) +# return x.cos() * gx # (4) +# +# out = checkpoint(fn)(inp) +# out.backward() # (5) +# +# 1, 2. Don't save x and y since we are inside a checkpoint. +# 3. Trigger a recompute of fn since x and y weren't saved. +# And depending on whether early stop is enabled, either stop at (2) or +# continue running the function. +# Because we are running backward with retain_graph=False, we clear x and y's +# holders. +# 4. Don't save x since we are inside a checkpoint. +# 5. Calling backward triggers another recompute of fn. During recompute, we see +# that x and y have already been cleared in the original graph as indicated +# by holder=None. We skip over them. We still save x at (4) (since its holder +# is still alive.) + +_enable_checkpoint_early_stop: Optional[bool] = None + + +@contextlib.contextmanager +def set_checkpoint_early_stop(enable: bool): + """Context manager that sets whether checkpoint should stop recomputation early. + + By default, non-reentrant checkpoint stops recomputation as soon as it + has computed all needed Tensors. This context manager can be used to disable + that feature if it is problematic for your specific application. + + This context manager only needs to be active when forward is run. It does + not need to be active during backward. + + Example:: + + >>> # xdoctest: +SKIP(failing) + >>> message = "saved tensors default hooks are disabled" + >>> with set_checkpoint_early_stop(False): + ... # Any checkpoint under this context manager will respect this + ... # context manager, even if its backward is performed outside. + ... out = checkpoint(fn, inputs) + ... + >>> out.backward() + """ + global _enable_checkpoint_early_stop + try: + prev = _enable_checkpoint_early_stop + _enable_checkpoint_early_stop = enable + yield + finally: + _enable_checkpoint_early_stop = prev + + +class _Handle: + pass + + +class _Holder: + def __init__(self): + self.handles: Dict[int, Optional[_Handle]] = {} + + +class _NoopSaveInputs(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(*args): + return torch.empty((0,)) + + @staticmethod + def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: + # Only tensors can be saved with ctx.save_for_backward, everything else + # is captured by get_args, which is saved directly on ctx + tensor_indices, tensors = zip( + *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] + ) + idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} + # args but with tensors replaced with None as placeholders + args = [None if isinstance(o, torch.Tensor) else o for o in inputs] + + def get_args(saved_tensors): + # restore the placeholders with the original tensors grabbed from + # ctx.saved_tensors (which may be saved on a parent checkpoint if + # this checkpoint is nested, and that would trigger a recursive + # unpack!) + ret = [ + saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o + for i, o in enumerate(args) + ] + # grab the tail since we also saved the dummy to avoid having to explicitly + # handle the case where there are no tensor inputs + return ret[1:] + + ctx.get_args = get_args + ctx.save_for_backward(*tensors) + + @staticmethod + def backward(ctx, *grad_outputs): + raise AssertionError("Did not expect to backward on this graph") + + +class _CheckpointFrame: + def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): + self.recompute_fn = recompute_fn + self.input_saver = None + self.weak_holders: List[ReferenceType] = [] + # We store this as a weakkeydictionary so that in the case of a partial + # backward, the entries in the dict are cleared alongside the Holder + # which will be removed when the SavedVariable is cleared. + self.recomputed: DefaultDict[ + int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] + ] = defaultdict(weakref.WeakKeyDictionary) + # We need both recomp_counter and recomputed since they can diverge + # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 + self.recomp_counter: DefaultDict[int, int] = defaultdict(int) + self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) + + # See Rule 5 + self.early_stop = early_stop + + # Debugging + self.metadata_fn = metadata_fn + self.unpack_error_cb = unpack_error_cb + self.x_metadatas = [] + self.forward_completed = False + self.ignore_saved_mismatch = False + + def check_recomputed_tensors_match(self, gid): + if self.ignore_saved_mismatch: + # TODO: we can probably make this check stricter by checking that + # the metadata of the first tensors still match. + return + # NOTE [ Error handling for checkpoint ] + # + # At a high level, we need to check that the tensors saved + # during original forward matches tensors saved during recompute + # This means handling 3 cases: + # + # 1. During recompute, more tensors were saved. + # + # Usually this is hidden due to the StopRecomputationError + # but if early stop is not enabled, or we would have errored + # anyway because there aren't enough weak_holders. But we + # do want to have a nice error. See the _recomputation_hook + # for details. + if not len(self.weak_holders) == self.recomp_counter[gid]: + # 2. During recompute, fewer tensors were saved + # + # We know that every time we save something do original forward + # we append to weak_holder, and every time we save a tensor + # during recompute we increment recompute_counter. + raise CheckpointError( + "torch.utils.checkpoint: A different number of tensors was saved " + "during the original forward and recomputation.\n" + f"Number of tensors saved during forward: {len(self.weak_holders)}\n" + f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n" + f"{_debug_tip_msg}" + ) + + # 3. During recompute, the same tensors were saved, but they + # have different metadata + nb_meta_different = [] + for idx, weak_holder in enumerate(self.weak_holders): + holder = weak_holder() + if holder is None: + continue + # We've seen all holders since we iterate over them in order + # For every holder that is still alive now, it must've been + # alive when we saw it during recompute, therefore, the + # gid must be set. + _internal_assert(gid in holder.handles) + # We know this is the first unpack, so it couldn't have been set + # to None yet. + _internal_assert(holder.handles[gid] is not None) + # We always set these together in the recomputation hook + _internal_assert(holder.handles[gid] in self.recomputed[gid]) + # see pack hook, x_metadata is 1:1 with weak_holders. + x_meta = self.x_metadatas[idx] + recomputed_x = self.recomputed[gid][holder.handles[gid]] + if x_meta != self.metadata_fn(recomputed_x): + nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) + + if len(nb_meta_different) > 0: + mismatched_tensors = "" + for idx, x_meta, recomputed_meta in nb_meta_different: + mismatched_tensors += ( + f"tensor at position {idx}:\n" + f"saved metadata: {x_meta}\n" + f"recomputed metadata: {recomputed_meta}\n" + ) + raise CheckpointError( + "torch.utils.checkpoint: Recomputed values for the following tensors " + "have different metadata than during the forward pass.\n" + f"{mismatched_tensors}.\n" + f"{_debug_tip_msg}" + ) + + +_debug_tip_msg = """ +Tip: To see a more detailed error message, either pass `debug=True` to +`torch.utils.checkpoint.checkpoint(...)` or wrap the code block +with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to +enable checkpoint‑debug mode globally. +""" + + +_checkpoint_error_template = """ \ +An error happened while unpacking tensors; dumping logs of latest computation +because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. +Scroll all the way down for guidance on how to navigate these logs. + ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 1. Stack traces of the operators that ran in the original forward | ++------------------------------------------------------------------------------+ + +{forward_traces} ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 2. Stack traces of the operators that ran during recomputation | ++------------------------------------------------------------------------------+ + +{recompute_traces} ++~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ +| 3. Log of operators in the original forward and recomputation | ++------------------------------------------------------------------------------+ +(Scroll up to correlate stack traces with each operation listed below. This + helps identify their source in the code.) + +IMPORTANT: Differences in "detach" calls between the original forward and the + recomputation are expected. They are introduced by the checkpointing + mechanism and can be ignored. + +Operations executed during the original forward: + +{forward_ops} + +Operations executed during recomputation: + +{recompute_ops} + ++------------------------------------------------------------------------------+ + ERROR: Detected non-determinism while running activation checkpointing + + You are seeing this error because you passed `debug=True` to checkpoint and + tensors to be saved during the original forward and differ between those saved + during recomputation. This can happen if different operators were ran in the + original forward and in the recomputation. + + To identify where the mismatch may be coming from, you can do the following: + + 1) Compare the operators ran during original forward and recomputation to + see where they differ. These operators are printed above in the order they + were executed. + + 2) Review the stack trace for each operator to locate its invocation source. + Each operator's stack trace is printed in their execution order. + + Note that the logs can be quite long. Here's how they are structured: + (Tip: you can Ctrl-f for these headers) + + 1. Stack traces of the operators that ran in the original forward + 2. Stack traces of the operators that ran during recomputation + 3. Log of operators in the original forward and recomputation + 4. Error message <--- You are here +-------------------------------------------------------------------------------- +""" + +class CheckpointError(RuntimeError): + pass + + +def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: + # This function returns the context_fn and error_cb to be used by the + # checkpointing mechanism. error_cb is invoked when an error is detected + # during unpack. + + # record_context_cpp is not support on non-linux non-x86_64 platforms + cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' + + class CaptureLogs: + def __init__(self): + self.logs = None + self.tbs = None + + def get_context_manager(self): + @contextlib.contextmanager + def logging_mode(): + with LoggingTensorMode(), \ + capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: + # pyrefly: ignore [bad-assignment] + self.logs, self.tbs = logs_and_tb + yield logs_and_tb + return logging_mode() + + capture_logs_fwd = CaptureLogs() + capture_logs_recompute = CaptureLogs() + + def unpack_error_cb(e: CheckpointError): + def get_str_tb(label, capture_logs): + out = "" + total_len = len(capture_logs.logs) + for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): + out += f"{log} ({i + 1} of {total_len} in {label})\n\n" + found_torch_dispatch = False + for line in tb: + # Start printing stack trace only after __torch_dispatch__ is found + is_torch_dispatch = line['name'] == '__torch_dispatch__' + if not found_torch_dispatch and not is_torch_dispatch: + continue + elif is_torch_dispatch: + found_torch_dispatch = True + continue + out += f"{line['filename']}:{line['line']}:{line['name']}\n" + out += "\n\n" + return out + if capture_logs_fwd.logs is None: + raise AssertionError("capture_logs_fwd.logs is None") + if capture_logs_recompute.logs is None: + raise AssertionError("capture_logs_recompute.logs is None") + raise CheckpointError( + _checkpoint_error_template.format( + forward_traces=get_str_tb("original", capture_logs_fwd), + recompute_traces=get_str_tb("recompute", capture_logs_recompute), + forward_ops="\n".join(capture_logs_fwd.logs), + recompute_ops="\n".join(capture_logs_recompute.logs) + ) + ) from e + + def context_fn(): + return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() + + return context_fn, unpack_error_cb + +def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: + # These properties are fast to check, easy to understand + return { + "shape": x.shape, + "dtype": x.dtype, + "device": x.device + } + +_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { + _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, + "none": lambda _: None, +} + +# See Rule 5 +class _StopRecomputationError(Exception): + pass + + +class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self, target_frame_ref: ReferenceType, gid: int): + def pack_hook(x): + x = x.detach() if x.requires_grad else x + target_frame = target_frame_ref() + if target_frame is None: + raise AssertionError("Internal error: target_frame reference is None") + recomp_idx = target_frame.recomp_counter[gid] + target_frame.recomp_counter[gid] += 1 + + if recomp_idx >= len(target_frame.weak_holders): + if target_frame.early_stop: + raise AssertionError("Unexpected state: target_frame.early_stop is set") + if not target_frame.forward_completed: + # We run into this case when early stop is not enabled and do + # grad within checkpoint. + # We need to set this flag, so we don't error out later when + # we check if the number of tensors saved during forward and + # recomputation match. + target_frame.ignore_saved_mismatch = True + return x + raise CheckpointError( + "torch.utils.checkpoint: trying to save more tensors during " + "recomputation than during the original forward pass.\n" + f"{_debug_tip_msg}" + ) + + holder = target_frame.weak_holders[recomp_idx]() + + # This holder may have been cleared because someone may have called + # backward within forward. If so, we don't need to save. + if holder is not None: + _internal_assert(holder.handles.get(gid, None) is None) + holder.handles[gid] = _Handle() + target_frame.recomputed[gid][holder.handles[gid]] = x + + if target_frame.early_stop and target_frame.recomp_counter[gid] == len( + target_frame.weak_holders + ): + raise _StopRecomputationError + # See Rule 6: [ retain_graph is True ] above + return x + + def unpack_hook(x): + # See Rule 6: [ retain_graph is True ] above for an example of when + # the graph created during recomputation could be backwarded. + return x + + super().__init__(pack_hook, unpack_hook) + + +# torch._disable_dynamo creates a reference cycle with decorated function +# This function is used to ensure that the decorated function does not have +# a closure, so that other objects aren't also kept alive. +# https://github.com/pytorch/pytorch/issues/154642 +# Note: does not work when fn is compiled +@torch._disable_dynamo +def _run_fn_with_dynamo_disabled(fn, *args, **kwargs): + return fn(*args, **kwargs) + + +class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): + def __init__(self, frame): + def pack_hook(x): + # See Rule 4 above + holder = _Holder() + frame.weak_holders.append(weakref.ref(holder)) + # Save metadata to detect non-determinism + if frame.metadata_fn is not None: + with torch.no_grad(): + frame.x_metadatas.append(frame.metadata_fn(x)) + return holder + + def unpack_hook(holder): + gid = torch._C._current_graph_task_id() + if gid == -1: + # generate a temporary id if we trigger unpack outside of a backward call + gid = int(uuid.uuid4()) + + if not frame.is_recomputed[gid]: + ctx = frame.input_saver.grad_fn + args = ctx.get_args(ctx.saved_tensors) + + try: + with _recomputation_hook( + weakref.ref(frame), gid + ), torch.autograd.enable_grad(): + # See Note: [compiled autograd and checkpoint unpack hook] + _run_fn_with_dynamo_disabled(frame.recompute_fn, *args) + except _StopRecomputationError: + pass + frame.is_recomputed[gid] = True + frame.check_recomputed_tensors_match(gid) + + _internal_assert(gid in holder.handles) + + if holder.handles[gid] is None: + raise CheckpointError( + "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " + "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " + "so only once. Otherwise please open an issue with details on your use case." + ) + _internal_assert(holder.handles[gid] in frame.recomputed[gid]) + ret = frame.recomputed[gid][holder.handles[gid]] + holder.handles[gid] = None + return ret + + if frame.unpack_error_cb is not None: + def unpack_hook_with_error_cb(holder): + try: + return unpack_hook(holder) + except CheckpointError as e: + frame.unpack_error_cb(e) + super().__init__(pack_hook, unpack_hook_with_error_cb) + else: + super().__init__(pack_hook, unpack_hook) + + +def _is_compiling(func, args, kwargs): + # Check if we are under AOTAutograd tracing + # Checking that a functional mode is active should always do what we want + return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None + + +class _VersionWrapper: + # Check that cached tensors are not mutated. + def __init__(self, val): + self.val: Union[torch.Tensor, Any] = val + self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + + def get_val(self, allow_cache_entry_mutation): + if self.version is not None and not allow_cache_entry_mutation: + if self.val._version != self.version: + # Can we give user a stack trace of where the mutation happened? + raise RuntimeError( + "Tensor cached during selective activation checkpoint has been mutated" + ) + return self.val + + +def _maybe_detach(x, any_ret_has_alias_info): + # We detach for two separate reasons: + # - For view ops, we need to ensure that when the tensor is returned from + # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr + # - Avoid reference cycles + # For case 1, it is not enough to check whether x has differentiable dtype + # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. + # when the tensor is a view. + if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): + with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + # Ensure that view performed beneath autograd properly propagates + # version counter. TODO: Use reentrant_dispatch instead of + # manually manipulating dispatch keys. Using reentrant_dispatch + # would respect inference_mode, though that is not relevant for + # this case. + x = x.detach() + return x + + +class SelectiveCheckpointContext: + """ + Context passed to policy function during selective checkpointing. + + This class is used to pass relevant metadata to the policy function during + selective checkpointing. The metadata includes whether the current invocation + of the policy function is during recomputation or not. + + Example: + >>> # xdoctest: +SKIP(stub) + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> print(ctx.is_recompute) + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + def __init__(self, *, is_recompute): + self.is_recompute = is_recompute + + +class CheckpointPolicy(enum.Enum): + """ + Enum for specifying the policy for checkpointing during backpropagation. + + The following policies are supported: + + - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward + pass and will not be recomputed during the backward pass + - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the + forward pass and will be recomputed during the backward pass + + Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden + by other subsystems like `torch.compile`. + + .. note:: + A policy function that always returns ``PREFER_RECOMPUTE`` is + equivalent to vanilla checkpointing. + + A policy function that returns ``PREFER_SAVE`` every op is + NOT equivalent to not using checkpointing. Using such a policy would + save additional tensors not limited to ones that are actually needed for + gradient computation. + """ + MUST_SAVE = 0 + PREFER_SAVE = 1 + MUST_RECOMPUTE = 2 + PREFER_RECOMPUTE = 3 + + +def _policy_from_bool(b): + # For backward compatibility + return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE + + +SAC_IGNORED_OPS = { + # AC inserts different number of detach during forward and recompute. + torch.ops.aten.detach.default, + # AC's determinism check invokes additional metadata ops during forward. + # With subclasses involved, these metadata ops become dispatchable, this + # can result in incorrectness if these ops are selected cached. + torch.ops.prim.device.default, +} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type] + + +class _CachingTorchDispatchMode(TorchDispatchMode): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage): + self.policy_fn = policy_fn + self.storage = storage + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), + func, *args, **kwargs) + if isinstance(policy, bool): + policy = _policy_from_bool(policy) + + is_compiling = _is_compiling(func, args, kwargs) + + if is_compiling: + # Overwrite each node's "recompute" tag to add in the user annotation. + fx_traceback.current_meta["recompute"] = policy + + out = func(*args, **kwargs) + + # HOPs don't support func._schema + # HOPs don't alias -> this is always true today and will be always true for a long time + # TODO HOPs don't mutate -> this is always true today but will not be true forever + if isinstance(func, torch._ops.HigherOrderOperator): + any_ret_has_alias_info = False + else: + any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + return out + +class _CachedTorchDispatchMode(TorchDispatchMode): + # Used together with _CachedTorchDispatchMode to implement SAC. + def __init__(self, policy_fn, storage, allow_cache_entry_mutation): + self.policy_fn = policy_fn + self.storage = storage + self.allow_cache_entry_mutation = allow_cache_entry_mutation + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + if func in SAC_IGNORED_OPS: + return func(*args, **kwargs) + + kwargs = {} if kwargs is None else kwargs + policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), + func, *args, **kwargs) + if isinstance(policy, bool): + policy = _policy_from_bool(policy) + + is_compiling = _is_compiling(func, args, kwargs) + + if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + storage = self.storage.get(func) + if storage is None: + raise RuntimeError(f"{func} encountered during backward, but not found in storage") + if len(storage) == 0: + raise RuntimeError( + "Trying to backward an extra time. You are only allowed to backward once " + "on any region computed under selective activation checkpoint." + ) + out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) + else: + out = func(*args, **kwargs) + return out + + +def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): + """ + Helper to avoid recomputing certain ops during activation checkpointing. + + Use this with `torch.utils.checkpoint.checkpoint` to control which + operations are recomputed during the backward pass. + + Args: + policy_fn_or_list (Callable or List): + - If a policy function is provided, it should accept a + :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and + kwargs to the op, and return a :class:`CheckpointPolicy` enum value + indicating whether the execution of the op should be recomputed or not. + - If a list of operations is provided, it is equivalent to a policy + returning `CheckpointPolicy.MUST_SAVE` for the specified + operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other + operations. + allow_cache_entry_mutation (bool, optional): By default, an error is + raised if any tensors cached by selective activation checkpoint are + mutated in order to ensure correctness. If set to `True`, this check + is disabled. + Returns: + A tuple of two context managers. + + Example: + >>> # xdoctest: +REQUIRES(LINUX) + >>> import functools + >>> + >>> x = torch.rand(10, 10, requires_grad=True) + >>> y = torch.rand(10, 10, requires_grad=True) + >>> + >>> ops_to_save = [ + >>> torch.ops.aten.mm.default, + >>> ] + >>> + >>> def policy_fn(ctx, op, *args, **kwargs): + >>> if op in ops_to_save: + >>> return CheckpointPolicy.MUST_SAVE + >>> else: + >>> return CheckpointPolicy.PREFER_RECOMPUTE + >>> + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) + >>> + >>> # or equivalently + >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) + >>> + >>> def fn(x, y): + >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y + >>> + >>> out = torch.utils.checkpoint.checkpoint( + >>> fn, x, y, + >>> use_reentrant=False, + >>> context_fn=context_fn, + >>> ) + """ + # NB: If grad_mode is disabled, checkpoint would not run forward under + # context_fn anyway, so proceed as usual. + if isinstance(policy_fn_or_list, list): + for op in policy_fn_or_list: + if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + _extra_msg = ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + raise ValueError( + f"Expected op in `op_list` to be an OpOverload but got: {op} " + f"of type {type(op)}. {_extra_msg}" + ) + + def policy_fn(ctx, op, *args, **kwargs): + if op in policy_fn_or_list: + return CheckpointPolicy.MUST_SAVE + else: + return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): + policy_fn = policy_fn_or_list + else: + raise TypeError("policy_fn_or_list must be either a function or a list of ops.") + + storage: Dict[Any, List[Any]] = defaultdict(list) + return ( + _CachingTorchDispatchMode(policy_fn, storage), + _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), + ) + +# NB: this helper wraps fn before calling checkpoint_impl. kwargs and +# saving/restoring of global state is handled here. + +def _checkpoint_without_reentrant_generator( + fn, + preserve_rng_state=True, + context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, + determinism_check: str = _DEFAULT_DETERMINISM_MODE, + debug: bool = False, + early_stop: bool = True, + *args, + **kwargs +): + """Checkpointing without reentrant autograd. + + Args: + fn: describes what to run in the forward pass of the model or + part of the model. It should also know how to handle the inputs + passed as the tuple. For example, in LSTM, if user passes + ``(activation, hidden)``, :attr:`function` should correctly use the + first input as ``activation`` and the second input as ``hidden`` + preserve_rng_state(bool, optional): Omit stashing and restoring + the RNG state during each checkpoint. + Default: ``True`` + context_fn(Callable, optional): A callable returning a tuple of two + context managers. The function and its recomputation will be run + under the first and second context managers respectively. + determinism_check(str, optional): A string specifying the determinism + check to perform. By default it is set to ``"default"`` which + compares the shapes, dtypes, and devices of the recomputed tensors + against those the saved tensors. To turn off this check, specify + ``"none"``. Currently these are the only two supported values. + Please open an issue if you would like to see more determinism + checks. + debug(bool, optional): If ``True``, error messages will also include + a trace of the operators ran during the original forward computation + as well as the recomputation. + early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops + recomputation as soon as it has computed all needed Tensors. Can be + overridden globally using :func:`set_checkpoint_early_stop` context + manager. Default: ``True``. + *args: Arguments to pass in to the given ``function``. + **kwargs: Keyword arguments to pass into the given ``function``. + """ + unpack_error_cb = None + + if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: + if context_fn != noop_context_fn: + raise ValueError( + "debug=True is incompatible with non-default context_fn" + ) + context_fn, unpack_error_cb = _get_debug_context_and_cb() + + if determinism_check in _allowed_determinism_checks_to_fns: + metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] + else: + raise ValueError( + f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " + f"but got {determinism_check}" + ) + + device_type = _infer_device_type(*args) + device_module = _get_device_module(device_type) + forward_context, recompute_context = context_fn() + if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: + if ( + not isinstance(forward_context, TorchDispatchMode) + or not isinstance(recompute_context, TorchDispatchMode) + ): + raise AssertionError( + "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + "must generate a tuple of two `TorchDispatchMode`s." + ) + # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. + device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) + + if preserve_rng_state: + fwd_cpu_state = torch.get_rng_state() + # Don't eagerly initialize the cuda context by accident. + # (If the user intends that the context is initialized later, within their + # run_function, we SHOULD actually stash the cuda state here. Unfortunately, + # we have no way to anticipate this will happen before we run the function. + # If they do so, we raise an error.) + had_device_in_fwd = False + if getattr(device_module, "_initialized", False): + had_device_in_fwd = True + fwd_devices, fwd_device_states = get_device_states(*args) + + def recompute_fn(*inputs): + kwargs, *args = inputs + # This will be called later during recomputation. This wrapping enables + # the necessary global state to be captured. + rng_devices = [] + if preserve_rng_state and had_device_in_fwd: + rng_devices = fwd_devices + with torch.random.fork_rng( + devices=rng_devices, enabled=preserve_rng_state, device_type=device_type + ): + if preserve_rng_state: + torch.set_rng_state(fwd_cpu_state) + if had_device_in_fwd: + set_device_states(fwd_devices, fwd_device_states, device_type=device_type) + + device_autocast_ctx = torch.amp.autocast( + device_type=device_type, **device_autocast_kwargs + ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() + with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] + fn(*args, **kwargs) + + new_frame = _CheckpointFrame( + recompute_fn, + _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop, + unpack_error_cb, + metadata_fn + ) + dummy = torch.empty((0,), requires_grad=True) + new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) + + # When ambient grad_mode is False + if new_frame.input_saver.grad_fn is None: + yield + return + + with _checkpoint_hook(new_frame), forward_context: + yield + new_frame.forward_completed = True + + if getattr(device_module, "_initialized", False) and \ + preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] + # Device was not initialized before running the forward, so we didn't + # stash the device state. + raise RuntimeError( + "PyTorch's device state was initialized in the forward pass " + "of a Checkpoint, which is not allowed. Please open an issue " + "if you need this feature." + ) + + return + +# Note: [compiled autograd and checkpoint unpack hook] +# When tracing via compiled autograd, this hook will be visible to the +# compiler if the forward of this checkpointed region ran in eager. +# If the forward had ran under compile, it would have been wrapped in a +# higher order op. See Note: [torch.compile and checkpoint]. +# +# Since we run the recomputation hook under a enable_grad context, +# AOTDispatch will trace a joint graph for this hook, and may +# save different activations than in eager. This conflicts with the +# strict activation count checks in `frame.check_recomputed_tensors_match`. +# So, we disable this hook to force it to recompute eager checkpointed regions +# in eager. This could be removed if we can disable the partitioner for this +# graph segment. diff --git a/test/xpu/test_schema_check.py b/test/xpu/test_schema_check.py new file mode 100644 index 0000000000..6d6410f073 --- /dev/null +++ b/test/xpu/test_schema_check.py @@ -0,0 +1,514 @@ +# Owner(s): ["oncall: jit"] +# ruff: noqa: F841 + +import os +import sys +import torch +from torch.utils._pytree import tree_map +import unittest + +from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO +from torch.fx.operator_schemas import normalize_function +from torch._subclasses.schema_check_mode import SchemaCheckMode +from torch.utils._python_dispatch import TorchDispatchMode +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.jit_utils import JitTestCase +from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests +from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf +pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +sys.path.append(pytorch_test_dir) + + + +def secretly_aliasing(x): + return x.view(-1) + +def secretly_mutating(x): + x.mul_(2) + return x * 3 + +def output_is_input(x): + return x + +custom_lib = torch.library.Library("bad_schemas", "DEF") # noqa: TOR901 +custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") +custom_lib.define("secretly_mutating(Tensor x) -> Tensor") +custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") + +custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") # noqa: TOR901 +custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) +custom_lib_cpu.impl("secretly_mutating", secretly_mutating) +custom_lib_cpu.impl("output_is_input", output_is_input) + +custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") # noqa: TOR901 +custom_lib_meta.impl("secretly_aliasing", secretly_aliasing) +custom_lib_meta.impl("secretly_mutating", secretly_mutating) +custom_lib_meta.impl("output_is_input", output_is_input) + +# This TorchDispatchTensor Subclass is used to simulate an incorrect schema +# which is then used to test that SchemaCheckMode behaves as expected + +class IncorrectAliasTensor(torch.Tensor): + ALIAS_ARG_OUT = {"aten::add"} + ALIAS_OUT_OUT = {"aten::aminmax"} + MUTATE_ARGS_OUT = {"aten::sub"} + + elem: torch.Tensor + + __slots__ = ['elem'] + + @staticmethod + def __new__(cls, elem, *args, **kwargs): + # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any + # memory for the class in question, but it should still + # advertise the same device as before + r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] + cls, elem.size(), + strides=elem.stride(), storage_offset=elem.storage_offset(), + # TODO: clone storage aliasing + dtype=elem.dtype, layout=elem.layout, + device=elem.device, requires_grad=kwargs.get("requires_grad", False) + ) + # ...the real tensor is held as an element on the tensor. + r.elem = elem.detach() if r.requires_grad else elem + return r + + def __repr__(self): + return super().__repr__(tensor_contents=f"{self.elem}") + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def unwrap(e): + return e.elem if isinstance(e, cls) else e + + def wrap(e): + return cls(e) if isinstance(e, torch.Tensor) else e + unwrapped_args = tree_map(unwrap, args) + out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) + if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: + args[0].elem = out + if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT: + args[0].elem = torch.rand(args[0].elem.shape) + if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT: + incorrect_out = list(out) + incorrect_out[0] = incorrect_out[1] + return tree_map(wrap, tuple(incorrect_out)) + + return tree_map(wrap, out) + +# Tests various schema checking functionalities. +class TestSchemaCheck(JitTestCase): + def setUp(self): + if TEST_WITH_TORCHDYNAMO: + self.skipTest("SchemaCheckMode is ignored by dynamo") + super().setUp() + + # Tests that SchemaCheckMode records operator order with grad + def test_schema_check_mode_operator_order(self): + with SchemaCheckMode() as schema_check: + x = torch.rand((3, 3), requires_grad=True) + x.relu().sin() + self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) + + # Tests that SchemaCheckMode records operator order without grad + def test_schema_check_mode_operator_order_without_grad(self): + with SchemaCheckMode() as schema_check: + x = torch.rand((3, 3), requires_grad=False) + x.relu().sin() + self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) + + # Tests that SchemaCheckMode records mutations and aliases with none expected + def test_schema_check_mode_mutated_aliasing_none(self): + # NB: previously requires_grad=True, but this induces a detach for + # saved variable + x = torch.rand((3, 3)) + with SchemaCheckMode() as schema_check: + actual = x.relu().sin() + self.assertEqual([], schema_check.mutated) + self.assertEqual([], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with mutation expected + def test_schema_check_mode_mutated_aliasing_mutation(self): + actual = torch.rand((3, 3), requires_grad=False) + with SchemaCheckMode() as schema_check: + actual.sinh_() + self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) + self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with resize_ + def test_schema_check_mode_mutated_aliasing_resize_(self): + actual = torch.rand((3, 3), requires_grad=False) + with SchemaCheckMode() as schema_check: + actual.resize_(9) + self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) + self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) + + # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs + def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): + actual = torch.rand((3, 3)) + y = actual + with SchemaCheckMode() as schema_check: + actual.add_(y) + self.assertEqual( + [ + ('aten::add_', 'input'), + ('aten::add_', 'other') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::add_', 'input', 'output_0'), + ('aten::add_', 'other', 'output_0') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and alias with as_strided + def test_schema_check_mode_mutated_aliasing_as_strided(self): + x = torch.rand((3, 6, 4)) + with SchemaCheckMode() as schema_check: + x.as_strided_([3, 6, 4], [9, 1, 1]) + self.assertEqual( + [ + ('aten::as_strided_', 'input') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::as_strided_', 'input', 'output_0') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and aliases with multiple outputs + def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): + x = torch.arange(9.) + m_actual = torch.arange(9.) + e_actual = torch.zeros([9], dtype=torch.int32) + with SchemaCheckMode() as schema_check: + torch.frexp(x, out=(m_actual, e_actual)) + self.assertEqual( + [ + ('aten::frexp', 'mantissa'), + ('aten::frexp', 'exponent') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::frexp', 'mantissa', 'output_0'), + ('aten::frexp', 'exponent', 'output_1') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs + def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): + x = torch.rand((3, 3)) + actual = torch.zeros(3) + with SchemaCheckMode() as schema_check: + torch.aminmax(x, dim=0, out=[actual, actual]) + self.assertEqual( + [ + ('aten::aminmax', 'min'), + ('aten::aminmax', 'max') + ], + schema_check.mutated + ) + self.assertEqual( + [ + ('aten::aminmax', 'min', 'output_0'), + ('aten::aminmax', 'min', 'output_1'), + ('aten::aminmax', 'max', 'output_0'), + ('aten::aminmax', 'max', 'output_1') + ], + schema_check.aliasing + ) + + # Tests that SchemaCheckMode wraps torch.Tensor + def test_schema_check_mode_functionality(self): + x = torch.rand((3, 3), requires_grad=True) + expected = x.relu().sin() + with SchemaCheckMode(): + actual = x.relu().sin() + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overridden + def test_schema_check_mode_functionality_default_replaced(self): + x = torch.rand((3, 3), requires_grad=True) + expected = x.add(x, alpha=2) + with SchemaCheckMode(): + actual = x.add(x, alpha=2) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument + def test_schema_check_mode_functionality_list_input(self): + a = torch.rand((3, 3)) + b = torch.rand((3, 3)) + c = torch.rand((3, 3)) + expected = torch.linalg.multi_dot([a, b, c]) + with SchemaCheckMode(): + actual = torch.linalg.multi_dot([a, b, c]) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation + def test_schema_check_mode_functionality_wildcard_after(self): + x = torch.rand((3, 3)) + expected = x.chunk(6) + with SchemaCheckMode(): + actual = x.chunk(6) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input + @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.") + def test_schema_check_mode_functionality_kwarg_tensor(self): + x = torch.rand((3, 5)) + w = torch.rand(4) + expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) + with SchemaCheckMode(): + actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op + def test_schema_check_mode_functionality_mutable_inputs(self): + expected = torch.rand((3, 3), requires_grad=False) + actual = torch.clone(expected) + expected.sinh_() + with SchemaCheckMode(): + actual.sinh_() + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias + def test_schema_check_mode_functionality_aliasing_inputs(self): + expected = torch.rand((3, 3)) + x = expected + actual = torch.clone(expected) + y = actual + expected.add_(x) + with SchemaCheckMode(): + actual.add_(y) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs + def test_schema_check_mode_functionality_with_multiple_outputs(self): + x = torch.arange(9.) + m_expected, e_expected = torch.frexp(x) + m_actual = torch.arange(9.) + e_actual = torch.zeros([9], dtype=torch.int32) + with SchemaCheckMode(): + torch.frexp(x, out=(m_actual, e_actual)) + self.assertEqual(m_expected, m_actual) + self.assertEqual(e_expected, e_actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with aliasing outputs due to aliasing inputs + def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): + x = torch.rand((3, 3)) + actual = torch.zeros(3) + with SchemaCheckMode(): + torch.aminmax(x, dim=0, out=[actual, actual]) + self.assertEqual(torch.amax(x, dim=0), actual) + + # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input + def test_schema_check_mode_functionality_device_input(self): + with SchemaCheckMode(): + x = torch.rand((3, 3), device="cpu", dtype=torch.double) + y = x + x + self.assertEqual(x + x, y) + + # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case + def test_schema_check_mode_functionality_training_op(self): + x = torch.rand((3, 3), requires_grad=True) + batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + expected = batch(x) + with SchemaCheckMode(): + actual = batch(x) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case + def test_schema_check_mode_functionality_nested_training_op(self): + actual = torch.rand((3, 3)) + batch = torch.nn.BatchNorm1d(3, track_running_stats=True) + expected = torch.clone(actual) + expected.sinh_() + expected.tanh_() + expected.relu_() + expected = batch(expected) + + with SchemaCheckMode(): + actual.sinh_() + actual.tanh_() + actual.relu_() + actual = batch(actual) + self.assertEqual(expected, actual) + + # Tests that SchemaCheckMode wraps Torch.tensor with empty list input + def test_schema_check_mode_empty_list_input(self): + expected = torch.atleast_1d([]) + with SchemaCheckMode(): + actual = torch.atleast_1d([]) + self.assertEqual(expected, actual) + + # Tests that an exception is raised for a mismatching mutation + def test_mutation_check_fail(self): + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + x = torch.rand((3, 3)) + y = torch.rand((3, 3)) + with SchemaCheckMode(): + IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) + + # # Tests that an exception is raised for a mismatching mutation over multiple ops + def test_mutation_check_fail_multiple_operators(self): + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + x = torch.rand((3, 3)) + y = torch.rand((3, 3)) + with SchemaCheckMode(): + IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) + + # Tests that an exception is raised for a mismatching alias + def test_alias_check_fail_simple(self): + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + x = torch.rand((3, 3), requires_grad=True) + y = torch.rand((3, 3)) + with SchemaCheckMode(): + IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) + + # Tests that an exception is raised for a mismatching alias over multiple ops + def test_alias_check_fail_multiple_operators(self): + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + x = torch.rand((3, 3), requires_grad=True) + y = torch.zeros((3, 3), requires_grad=True) + with SchemaCheckMode(): + IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) + + # Tests that an exception is raised for a centered mismatching alias over multiple ops + def test_alias_check_fail_multiple_operators_centered(self): + with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + x = torch.rand((3, 3), requires_grad=True) + y = torch.zeros((3, 3), requires_grad=True) + with SchemaCheckMode(): + IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() + + # Tests that an exception is raised for a centered mismatching alias over multiple ops + def test_alias_check_fail_outputs_unexpectedly_aliasing(self): + with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): + x = torch.rand((3, 3)) + with SchemaCheckMode() as s: + IncorrectAliasTensor(x).aminmax(dim=0) + + # When this file was written, python op registration didn't exist. + # It's probably worth re-writing the entire file to use it, + # but instead I just added extra tests. + def test_alias_check_fail_custom_ops_secretly_aliasing(self): + def f(x): + return torch.ops.bad_schemas.secretly_aliasing(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): + with SchemaCheckMode() as s: + out = f(x) + + def test_alias_check_fail_custom_ops_secretly_mutating(self): + def f(x): + return torch.ops.bad_schemas.secretly_mutating(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): + with SchemaCheckMode() as s: + out = f(x) + + def test_alias_check_fail_custom_ops_output_is_input(self): + def f(x): + return torch.ops.bad_schemas.output_is_input(x) + + x = torch.rand((3, 3)) + with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): + with SchemaCheckMode() as s: + out = f(x) + + # Tests that is_alias_of returns as expected + def test_is_alias_of_basic(self): + x = torch.rand((3, 3), requires_grad=True) + y = torch.rand((3, 3), requires_grad=True) + y = x.add(x, alpha=2) + self.assertTrue(torch._C._is_alias_of(x, x)) + self.assertFalse(torch._C._is_alias_of(x, y)) + + # Tests that is_alias_of returns as expected with empty containers + def test_is_alias_of_empty_container(self): + x = [] + y = torch.rand((3, 3), requires_grad=True) + self.assertFalse(torch._C._is_alias_of(x, x)) + self.assertFalse(torch._C._is_alias_of(x, y)) + + # Tests that overlaps returns as expected + def test_overlaps_basic(self): + x = torch.rand((3, 3), requires_grad=True) + y = torch.rand((3, 3), requires_grad=True) + z = [x, y] + self.assertTrue(torch._C._overlaps(x, x)) + self.assertFalse(torch._C._overlaps(x, y)) + self.assertTrue(torch._C._overlaps(z, x)) + self.assertTrue(torch._C._overlaps(z, y)) + + # Tests that overlaps returns correctly with empty containers + def test_overlaps_empty_container(self): + x = [] + y = [torch.rand((3, 3), requires_grad=True)] + # Empty containers return false + self.assertFalse(torch._C._overlaps(y, x)) + self.assertTrue(torch._C._overlaps(y, y)) + + # Tests that SchemaInfo Bindings work as expected + def test_schema_info_bind_basic(self): + class SchemaInfoBindTestMode(TorchDispatchMode): + def __init__(self, test_self): + self.test_self = test_self + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + named_arg_list = normalize_function( + func, + args, + kwargs, + normalize_to_only_use_kwargs=True + ).kwargs + schema_info_value_test = torch._C._SchemaInfo(func._schema) + schema_info_values_test = torch._C._SchemaInfo(func._schema) + self.test_self.assertFalse(schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertFalse(schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + for i in named_arg_list: + schema_info_value_test.add_argument_value(i, named_arg_list[i]) + schema_info_values_test.add_argument_values(named_arg_list) + self.test_self.assertTrue(schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertTrue(schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + + return func(*args, **kwargs) + x = torch.rand((3, 3)) + with SchemaInfoBindTestMode(self) as schemaInfoCheck: + x.add(x) + +class TestSchemaCheckModeOpInfo(JitTestCase): + @ops(op_db, dtypes=OpDTypes.supported) + @slowTestIf(IS_WINDOWS) + def test_schema_correctness(self, device, dtype, op): + # Currently torch.equal isn't supported with torch.complex32 + # There's also errors with complex64 and complex128 + if (dtype == torch.complex32): + return + for sample in op.sample_inputs(device, dtype, requires_grad=False): + with SchemaCheckMode(): + op(sample.input, *sample.args, **sample.kwargs) + +instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda", "xpu"), allow_xpu=True) + +if __name__ == '__main__': + run_tests() diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py new file mode 100644 index 0000000000..4c4876b283 --- /dev/null +++ b/test/xpu/test_utils.py @@ -0,0 +1,1022 @@ +# mypy: allow-untyped-defs +# Owner(s): ["module: unknown"] + +import os +import random +import shutil +import subprocess +import sys +import tempfile +import textwrap +import traceback +import unittest +import warnings +from typing import Any + +import torch +import torch.cuda +import torch.nn as nn +import torch.utils.cpp_extension +import torch.utils.data +from torch._utils import try_import +from torch._utils_internal import deprecated +from torch.testing._internal.common_cuda import TEST_MULTIGPU +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + onlyCPU, + ops, +) +from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] + IS_FBCODE, + IS_SANDCASTLE, + IS_WINDOWS, + load_tests, +) +from torch.testing._internal.inductor_utils import HAS_GPU, get_gpu_type +from torch.utils._device import set_device +from torch.utils._pytree import tree_all_only, tree_any +from torch.utils._traceback import ( + CapturedTraceback, + format_traceback_short, + report_compile_source_on_error, +) +from checkpoint import ( + _infer_device_type, + checkpoint, + checkpoint_sequential, + get_device_states, +) +from torch.utils.data import DataLoader + + +# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for +# sharding on sandcastle. This line silences flake warnings +load_tests = load_tests # noqa: PLW0127 + +HAS_CUDA = torch.cuda.is_available() +HAS_XPU = torch.xpu.is_available() +HAS_GPU = HAS_CUDA or HAS_XPU +device_type = get_gpu_type() + + +from torch.testing._internal.common_utils import run_tests, TestCase + + +# mypy: disable-error-code="name-defined" + + +class RandomDatasetMock(torch.utils.data.Dataset): + def __getitem__(self, index): + return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) + + def __len__(self): + return 1000 + + +class TestCheckpoint(TestCase): + # This runs checkpoint_sequential on each of the nets in + # module_lists_to_compare, and compares them against the uncheckpointed model. + # To compare, it checks outputs as well as input gradients and parameter gradients + def _check_checkpoint_sequential( + self, + model, + module_lists_to_compare, + num_chunks, + input, + use_reentrant, + ): + # not checkpointed + out = model(input) + out_not_checkpointed = out.detach().clone() + model.zero_grad() + out.sum().backward() + grad_not_checkpointed = { + name: param.grad.detach().clone() + for name, param in model.named_parameters() + } + input_grad_not_checkpointed = input.grad.detach().clone() + for model_to_compare in module_lists_to_compare: + # checkpointed model by passing list of modules + detached = input.detach() + detached.requires_grad = True + + # pass list of modules to checkpoint + out = checkpoint_sequential( + model_to_compare, num_chunks, detached, use_reentrant=use_reentrant + ) + out_checkpointed = out.detach().clone() + model.zero_grad() + out.sum().backward() + grad_checkpointed = { + name: param.grad.detach().clone() + for name, param in model.named_parameters() + } + input_grad_checkpointed = detached.grad.detach().clone() + # compare outputs as well as the gradients of input and parameters + self.assertEqual(out_checkpointed, out_not_checkpointed) + self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) + for name in grad_checkpointed: + self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) + + # Test whether checkpoint is being triggered or not. For this, we check + # the number of times forward pass happens + def test_checkpoint_trigger(self): + class Net(nn.Module): + def __init__(self) -> None: + super().__init__() + self.counter = 0 + + def forward(self, input_var): + self.counter += 1 + # For reentrant, need to have autograd actually + # pack a tensor to trigger recomp + ret = input_var * torch.tensor(2.0) + return ret + + # checkpointed + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + modules = [Net() for _ in range(10)] + for m in modules: + self.assertEqual(m.counter, 0) + input_var = torch.randn(3, 4, requires_grad=True) + out = checkpoint_sequential( + modules, 2, input_var, use_reentrant=use_reentrant + ) + for m in modules: + self.assertEqual(m.counter, 1) + out.sum().backward() + for m in modules[: (len(modules) // 2)]: + self.assertEqual(m.counter, 2) + for m in modules[(len(modules) // 2) :]: + self.assertEqual(m.counter, 1) + + def test_checkpoint_valid(self): + model = nn.Sequential( + nn.Linear(100, 50), + nn.ReLU(), + nn.Linear(50, 20), + nn.ReLU(), + nn.Linear(20, 5), + nn.ReLU(), + ) + + input_var = torch.randn(1, 100, requires_grad=True) + + # checkpointed + chunks = 2 + modules = list(model.children()) + out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True) + with self.assertRaisesRegex( + RuntimeError, "torch.utils.checkpoint is incompatible" + ): + torch.autograd.grad( + outputs=[out], + grad_outputs=[torch.ones(1, 5)], + inputs=[input_var], + create_graph=True, + ) + # works with use_reentrant=False, and grads are the same + out = model(input_var) + grads_no_checkpoint = torch.autograd.grad( + outputs=[out], + grad_outputs=[torch.ones(1, 5)], + inputs=[input_var], + create_graph=True, + ) + out_checkpoint = checkpoint_sequential( + modules, chunks, input_var, use_reentrant=False + ) + # check outputs are the same + self.assertEqual(out_checkpoint, out) + grads_checkpoint = torch.autograd.grad( + outputs=[out_checkpoint], + grad_outputs=[torch.ones(1, 5)], + inputs=[input_var], + create_graph=True, + ) + self.assertEqual(grads_no_checkpoint, grads_checkpoint) + + def test_checkpoint(self): + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + model = nn.Sequential( + nn.Linear(100, 50), + nn.ReLU(), + nn.Linear(50, 20), + nn.ReLU(), + nn.Linear(20, 5), + nn.ReLU(), + ) + + # Compare uncheckpointed model with its checkpointed counterparts + # In addition to running checkpoint_sequential on the nn.Sequential + # instance, we also run the function on the list of functions within + # the module. + self._check_checkpoint_sequential( + model, + [list(model.children()), model], + 2, + torch.randn(1, 100, requires_grad=True), + use_reentrant=use_reentrant, + ) + + def test_checkpoint_module_list(self): + class ModuleListNet(nn.Module): + def __init__(self) -> None: + super().__init__() + module_list = [ + nn.Linear(100, 50), + nn.ReLU(), + nn.Linear(50, 20), + nn.ReLU(), + nn.Linear(20, 5), + nn.ReLU(), + ] + self.module_list = nn.ModuleList(module_list) + + def forward(self, input): + for layer in self.module_list: + input = layer(input) + return input + + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + model = ModuleListNet() + + # Compare uncheckpointed model with its checkpointed counterparts. + self._check_checkpoint_sequential( + model, + [list(model.module_list.children()), model.module_list], + 2, + torch.randn(1, 100, requires_grad=True), + use_reentrant=use_reentrant, + ) + + def test_checkpoint_sequential_deprecated_multiple_args(self): + class Two(nn.Module): + def forward(self, a, b): + return a, b + + model = nn.Sequential(Two()) + a = torch.randn(1, 100, requires_grad=True) + b = torch.randn(1, 100, requires_grad=True) + + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] + + def test_checkpoint_sequential_deprecated_no_args(self): + class Noop(nn.Module): + def forward(self): + pass + + model = nn.Sequential(Noop()) + for use_reentrant in [True, False]: + with self.subTest(use_reentrant=use_reentrant): + with self.assertRaises(TypeError): + checkpoint_sequential(model, 1) # type: ignore[call-arg] + + def test_checkpoint_rng_cpu(self): + for _ in range(5): + inp = torch.randn(20000, device="cpu").requires_grad_() + phase1 = torch.nn.Dropout() + phase2 = torch.nn.Dropout() + + def run_fn(input): + return phase2(input) + + state = torch.get_rng_state() + + out = phase1(inp) + out = checkpoint(run_fn, out, use_reentrant=True) + out.sum().backward() + grad_with_checkpointing = inp.grad + + torch.set_rng_state(state) + + inp.grad = None + + out = phase1(inp) + out = run_fn(out) + out.sum().backward() + grad_no_checkpointing = inp.grad + + self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) + + @unittest.skipIf(not HAS_GPU, "No GPU") + def test_checkpoint_rng_gpu(self): + for _ in range(5): + inp = torch.randn(20000, device=device_type).requires_grad_() + phase1 = torch.nn.Dropout() + phase2 = torch.nn.Dropout() + + def run_fn(input): + return phase2(input) + + state = torch.get_device_module(device_type).get_rng_state() + + out = phase1(inp) + out = checkpoint(run_fn, out, use_reentrant=True) + out.sum().backward() + grad_with_checkpointing = inp.grad + + torch.get_device_module(device_type).set_rng_state(state) + + inp.grad = None + + out = phase1(inp) + out = run_fn(out) + out.sum().backward() + grad_no_checkpointing = inp.grad + + self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) + + @unittest.skipIf(not HAS_GPU, "No GPU") + def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): + inp = torch.randn(2, device=device_type).requires_grad_() + layer = torch.nn.Dropout() + + def run_fn(input): + return layer(input) + + out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False) + out.sum().backward() + # This should run without error + + def test_checkpoint_non_tensor(self): + def run_fn(tensor1, tensor2): + if tensor2 is None: + return tensor1 + return tensor1 + tensor2 + + input_var = torch.randn(1, 100, requires_grad=True) + out = checkpoint(run_fn, input_var, None, use_reentrant=True) + out.sum().backward() + + def test_checkpoint_non_tensor_inputs_outputs(self): + def foo(t1, t2, scale, t3): + t4 = t1 + t2 * t3 + t5 = t1 * t2 + t3 + t4 *= scale + t5 *= scale + return scale, t4, None, True, t5, "bar", t1 + + t1 = torch.rand(10, requires_grad=True) + t2 = torch.rand(10, requires_grad=True) + t3 = torch.rand(10) + scale = random.randint(0, 10) + res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) + self.assertEqual(scale, res[0]) + self.assertEqual((t1 + t2 * t3) * scale, res[1]) + self.assertEqual(None, res[2]) + self.assertEqual(True, res[3]) + self.assertEqual((t1 * t2 + t3) * scale, res[4]) + self.assertEqual("bar", res[5]) + self.assertEqual(t1, res[6]) + + # Validate running backward. + res[1].sum().backward(retain_graph=True) + res[4].sum().backward(retain_graph=True) + res[6].sum().backward() + with self.assertRaisesRegex( + RuntimeError, "Trying to backward through the graph a second time" + ): + res[6].sum().backward() + t1_grad = t1.grad + t2_grad = t2.grad + + # Reset grads, run without checkpoint and validate we receive same grads. + t1.grad = None + t2.grad = None + res = foo(t1, t2, scale, t3) + torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) + self.assertEqual(t1.grad, t1_grad) + self.assertEqual(t2.grad, t2_grad) + + def test_checkpoint_no_tensors(self): + def foo(t1, t2, scale, t3): + t4 = t1 + t2 * t3 + t5 = t1 * t2 + t3 + t4 *= scale + t5 *= scale + return scale, t4, None, True, t5, "bar", t1 + + t1 = random.random() + t2 = random.random() + t3 = random.random() + scale = random.randint(0, 10) + res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) + self.assertEqual(scale, res[0]) + self.assertEqual((t1 + t2 * t3) * scale, res[1]) + self.assertEqual(None, res[2]) + self.assertEqual(True, res[3]) + self.assertEqual((t1 * t2 + t3) * scale, res[4]) + self.assertEqual("bar", res[5]) + self.assertEqual(t1, res[6]) + + def test_checkpoint_partial_grad(self): + def run_fn(tensor1, tensor2): + # tensor 2 is used for other application logic + return tensor1, tensor2 + + input_var = torch.randn(1, 4, requires_grad=True) + input_var2 = torch.randn(1, 4, requires_grad=False) + out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True) + out[0].sum().backward() + + def run_fn2(tensor1, tensor2): + return tensor1 + + input_var = torch.randn(1, 4, requires_grad=False) + input_var2 = torch.randn(1, 4, requires_grad=True) + with self.assertRaisesRegex( + RuntimeError, + r"none of output has requires_grad=True, this checkpoint\(\) is not necessary", + ): + out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) + out.sum().backward() + + @unittest.skipIf(not HAS_GPU, "Test requires GPU") + def test_checkpointing_without_reentrant_early_free(self): + # I don't know how to check if the temporary saved variable buffer + # get de-allocated directly. So using GPU memory usage as a proxy + + def _do_test(fn, should_free): + stats: list[int] = [] + + def track(x, idx): + # Track that at each step of the backward, some Tensor were + # de-allocated (which correspond to the checkpoint storage being + # emptied at each step) + def hook(_unused): + self.assertEqual(len(stats), idx) + torch.get_device_module(device_type).synchronize() + stats.append(torch.get_device_module(device_type).memory_allocated()) + if idx > 0: + if should_free: + self.assertLess(stats[idx], stats[idx - 1]) + else: + self.assertEqual(stats[idx], stats[idx - 1]) + + x.register_hook(hook) + + def test_fn(x): + # The main property of this function is that it contains multiple + # operations that save gradients in a chain. + x = x**2 + track(x, 2) + x = x**2 + track(x, 1) + x = x**2 + track(x, 0) + x = x**2 + return x.sum() + + fn(test_fn) + + return stats + + x = torch.zeros(10, device=device_type, requires_grad=True) + x.grad = torch.zeros_like(x) + + # In a regular backward, buffers get eagerly freed + non_retain_stats = _do_test(lambda fn: fn(x).backward(), True) + + # In a retain_grad backward, buffers get preserved + _unused_retain_stats = _do_test( + lambda fn: fn(x).backward(retain_graph=True), False + ) + + # In a regular backward with checkpoint, buffers get eagerly freed + checkpoint_non_retain_stats = _do_test( + lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True + ) + + # In a retain_grad backward with checkpoint, buffers get eagerly freed + checkpoint_retain_stats = _do_test( + lambda fn: checkpoint(fn, x, use_reentrant=False).backward( + retain_graph=True + ), + True, + ) + + self.assertEqual(non_retain_stats, checkpoint_non_retain_stats) + self.assertEqual(non_retain_stats, checkpoint_retain_stats) + + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_get_device_states_recursive(self): + inp = { + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:1")], + } + device_ids, device_states = get_device_states(inp) + self.assertEqual(2, len(device_ids)) + self.assertEqual(2, len(device_states)) + self.assertEqual(0, device_ids[0]) + self.assertEqual(1, device_ids[1]) + self.assertTrue(isinstance(device_states[0], torch.Tensor)) + self.assertTrue(isinstance(device_states[1], torch.Tensor)) + + def test_infer_device_state_recursive_meta(self): + inp = {"foo": torch.rand(10, device="meta")} + device_type = _infer_device_type(inp) + self.assertEqual("meta", device_type) + + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_infer_device_state_recursive_multi_gpu(self): + # Check that no warning is issued for either gpu:0, gpu:1 or + # gpu:0, gpu:0 cases since they are both the same device type + global g_device_type + inp = { + "foo": torch.rand(10, device=f"{g_device_type}:0"), + "bar": [torch.rand(10, device=f"{g_device_type}:1")], + } + with warnings.catch_warnings(): + warnings.simplefilter("error") + _device_type = _infer_device_type(inp) + self.assertEqual(g_device_type, _device_type) + inp = { + "foo": torch.rand(10, device=f"{g_device_type}:0"), + "bar": [torch.rand(10, device=f"{g_device_type}:0")], + } + with warnings.catch_warnings(): + warnings.simplefilter("error") + _device_type = _infer_device_type(inp) + self.assertEqual(g_device_type, _device_type) + # Check that a warning is issued for gpu:0, meta and that it includes + # device type information + inp = { + "foo": torch.rand(10, device=f"{g_device_type}:0"), + "bar": [torch.rand(10, device="meta")], + } + with warnings.catch_warnings(record=True) as w: + _device_type = _infer_device_type(inp) + self.assertEqual(g_device_type, _device_type) + self.assertEqual(len(w), 1) + warning_msg = str(w[-1].message) + self.assertTrue( + "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" + in warning_msg + ) + self.assertTrue(f"Device types: ['{device_type}', 'meta']" in warning_msg) + self.assertTrue(f"first device type: {device_type}" in warning_msg) + + +class TestDataLoaderUtils(TestCase): + MAX_TIMEOUT_IN_SECOND = 300 + + def test_random_seed(self): + def run(): + dataloader = torch.utils.data.DataLoader( + RandomDatasetMock(), + batch_size=2, + num_workers=4, + shuffle=True, + timeout=self.MAX_TIMEOUT_IN_SECOND, + ) + return next(iter(dataloader)) + + torch.manual_seed(2018) + x1 = run() + torch.manual_seed(2018) + x2 = run() + self.assertEqual(x1, x2) + + def test_single_keep(self): + # torch.rand(5, 3, 3, 2) is a Tensor here; technically not a valid input because + # not a Dataset subclass, but needs to stay working so add ignore's + # for type checking with mypy + dataloader: DataLoader = DataLoader( + torch.rand(5, 3, 3, 2), # type: ignore[arg-type] + batch_size=3, + num_workers=0, + drop_last=False, + ) + dataiter = iter(dataloader) + self.assertEqual(len(list(dataiter)), 2) + + def test_single_drop(self): + dataloader: DataLoader = DataLoader( + torch.rand(5, 3, 3, 2), # type: ignore[arg-type] + batch_size=3, + num_workers=0, + drop_last=True, + ) + dataiter = iter(dataloader) + self.assertEqual(len(list(dataiter)), 1) + + @unittest.skip( + "FIXME: Intermittent GPU out-of-memory error on Windows and time-out under ASAN" + ) + def test_multi_keep(self): + dataloader: DataLoader = DataLoader( + torch.rand(5, 3, 3, 2), # type: ignore[arg-type] + batch_size=3, + num_workers=2, + drop_last=False, + timeout=self.MAX_TIMEOUT_IN_SECOND, + ) + dataiter = iter(dataloader) + self.assertEqual(len(list(dataiter)), 2) + + def test_multi_drop(self): + dataloader: DataLoader = DataLoader( + torch.rand(5, 3, 3, 2), # type: ignore[arg-type] + batch_size=3, + num_workers=2, + drop_last=True, + timeout=self.MAX_TIMEOUT_IN_SECOND, + ) + dataiter = iter(dataloader) + self.assertEqual(len(list(dataiter)), 1) + + +test_dir = os.path.abspath(os.path.dirname(str(__file__))) + + +from torch.utils.collect_env import get_pretty_env_info + + +@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") +class TestCollectEnv(TestCase): + def test_smoke(self): + info_output = get_pretty_env_info() + self.assertTrue(info_output.count("\n") >= 17) + + +class TestHipify(TestCase): + def test_import_hipify(self): + from torch.utils.hipify import hipify_python # noqa: F401 + + +class TestHipifyTrie(TestCase): + def setUp(self): + from torch.utils.hipify import hipify_python + + self.trie = hipify_python.Trie() + + def test_add_and_search_trie(self): + self.trie.add("banana") + self.assertTrue(self.trie.search("banana")) + self.assertFalse(self.trie.search("ban")) + self.assertFalse(self.trie.search("dog")) + + def test_add_multiple_and_search_trie(self): + words_to_add = ["banana", "apple", "orange"] + for word in words_to_add: + self.trie.add(word) + + for word in words_to_add: + self.assertTrue(self.trie.search(word)) + + for word in ["ban", "dog", "okay", "app"]: + self.assertFalse(self.trie.search(word)) + + def test_quote_escape(self): + orig_chars = ["*", "[", ".", "+", "a", "z", "-"] + quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"] + for i in range(len(orig_chars)): + self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i]) + + @unittest.skipIf(HAS_XPU, "XPU not supported hipify") + def test_export_trie_to_regex(self): + words_to_add = [ + "__CUDACC__", + "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", + "CUDA_ERROR_ARRAY_IS_MAPPED", + "CUDA_ERROR_NOT_MAPPED", + "CUDA_ERROR_INVALID_SOURCE", + ] + for word in words_to_add: + self.trie.add(word) + regex = self.trie.export_to_regex() + expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)" + self.assertEqual(regex, expected_regex) + + def test_prefix_words_export_trie_to_regex(self): + # test case where some nodes have both children and are also leaf nodes. + words_to_add = ["apple", "app", "ban", "banana"] + for word in words_to_add: + self.trie.add(word) + regex = self.trie.export_to_regex() + expected_regex = r"(?:app(?:le)?|ban(?:ana)?)" + self.assertEqual(regex, expected_regex) + + @unittest.skipIf(HAS_XPU, "XPU not supported hipify") + def test_single_export_trie_to_regex(self): + words_to_add = ["cudaErrorInvalidMemcpyDirection"] + for word in words_to_add: + self.trie.add(word) + regex = self.trie.export_to_regex() + expected_regex = "cudaErrorInvalidMemcpyDirection" + self.assertEqual(regex, expected_regex) + + def test_char_export_trie_to_regex(self): + self.trie.add("a") + self.assertEqual(self.trie.export_to_regex(), "a") + self.trie.add("b") + self.assertEqual(self.trie.export_to_regex(), "[ab]") + + def test_special_char_export_trie_to_regex(self): + self.trie.add(r"c*") + self.assertEqual(self.trie.export_to_regex(), r"c\*") + + +class TestAssert(TestCase): + def test_assert_true(self): + # verify assertions work as expected + # bool argument + torch._assert(True, "foo") + with self.assertRaisesRegex(AssertionError, "bar"): + torch._assert(False, "bar") + # tensor argument + torch._assert(torch.tensor([True], dtype=torch.bool), "foo") + with self.assertRaisesRegex(AssertionError, "bar"): + torch._assert(torch.tensor([False], dtype=torch.bool), "bar") + + def test_assert_scriptable(self): + class M(torch.nn.Module): + def forward(self, x): + torch._assert(x.sum() > 0, "foo") + return x + + m = M() + # scriptable + ms = torch.jit.script(m) + # data can be passed without errors + x = torch.randn(4, 4).fill_(1.0) + ms(x) + with self.assertRaisesRegex(torch.jit.Error, "foo"): + ms(torch.tensor([False], dtype=torch.bool)) + + +@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only") +class TestStandaloneCPPJIT(TestCase): + def test_load_standalone(self): + build_dir = tempfile.mkdtemp() + try: + src_path = os.path.join(build_dir, "main.cpp") + src = textwrap.dedent( + """\ + #include + #include + int main() { + auto x = torch::eye(3); + std::cout << x << std::endl; + } + """ + ) + with open(src_path, "w") as f: + f.write(src) + + exec_path = torch.utils.cpp_extension.load( + "standalone_load_test", + src_path, + build_directory=build_dir, + is_python_module=False, + is_standalone=True, + ) + + ext = ".exe" if IS_WINDOWS else "" + self.assertEqual( + exec_path, os.path.join(build_dir, f"standalone_load_test{ext}") + ) + + for shell in [True, False]: + r = subprocess.run( + [exec_path], + shell=shell, + stdout=subprocess.PIPE, + ) + self.assertEqual(r.returncode, 0) + self.assertEqual( + # Windows prints "\r\n" for newlines. + textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), + textwrap.dedent( + """\ + 1 0 0 + 0 1 0 + 0 0 1 + [ CPUFloatType{3,3} ] + """ + ), + ) + + finally: + shutil.rmtree(build_dir) + + +class TestRenderUtils(TestCase): + def test_basic(self): + self.assertExpectedInline( + torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}), + """torch.sum(tensor([...], size=(100,)), dim=0)""", + ) + self.assertExpectedInline( + torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}), + """torch.sum(tensor([...], size=(100, 100)), dim=0)""", + ) + + +class TestDeviceUtils(TestCase): + def test_basic(self): + with torch.device("meta") as dev: + x = torch.empty(3, 3) + self.assertEqual(x.device.type, "meta") + self.assertEqual(dev, torch.device("meta")) + + def test_decorator(self): + @set_device("meta") + def f(): + return torch.empty(3, 3) + + self.assertEqual(f().device.type, "meta") + + def test_decorator_generator(self): + @set_device("meta") + def f(): + yield torch.empty(3, 3) + yield torch.empty(3, 3) + + r1, r2 = list(f()) + self.assertEqual(r1.device.type, "meta") + self.assertEqual(r2.device.type, "meta") + + def test_nn_module(self): + with torch.device("meta"): + m = nn.Linear(40, 50) + self.assertEqual(m.weight.device.type, "meta") + + def test_set_default_device(self): + try: + torch.set_default_device("meta") + r = torch.empty(2, 2) + finally: + torch.set_default_device(None) + + self.assertEqual(r.device.type, "meta") + + def test_get_default_device(self): + torch.set_default_device("meta") + self.assertEqual(torch.get_default_device().type, "meta") + torch.set_default_device(None) + + @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") + def test_get_default_device_more(self): + try: + torch.set_default_device(device_type) + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) + + torch.set_default_device(device_type) + torch.get_device_module(device_type).set_device(f"{device_type}:1") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) + + torch.set_default_device(f"{device_type}:1") + self.assertEqual(torch.get_default_device(), torch.tensor([]).device) + torch.set_default_device(None) + + torch.set_default_device(f"{device_type}:1") + with torch.device(f"{device_type}:0"): + self.assertEqual(torch.get_default_device(), torch.device(f"{device_type}", 0)) + + torch.set_default_device("cpu") + self.assertEqual(torch.get_default_device(), torch.device("cpu")) + with torch.device(f"{device_type}:0"): + self.assertEqual(torch.get_default_device(), torch.device(f"{device_type}", 0)) + + self.assertEqual(torch.get_default_device(), torch.device("cpu")) + finally: + # Reset the device at the end. + torch.set_default_device(None) + + @onlyCPU + @ops(op_db) + def test_device_mode_ops(self, device, dtype, op): + func = op.get_op() + samples = op.sample_inputs(device, dtype, requires_grad=False) + for sample in samples: + # Only test samples which don't have Tensor inputs. However, + # we don't test the factory property on OpInfo as it is very, + # very incomplete + if tree_any( + lambda x: isinstance(x, torch.Tensor), + (sample.input, sample.args, sample.kwargs), + ): + continue + # Many OpInfos will explicitly pass in a device. DeviceContext + # will respect device if it is explicitly specified. To test + # DeviceContext, we have to remove the device kwarg in this case. + # NB: Can't pass None to sample_inputs, the function can't + # handle it. + kwargs = sample.kwargs.copy() + kwargs.pop("device", None) + with torch.device("meta"): + r = func(sample.input, *sample.args, **kwargs) + + def is_meta_device(x: torch.Tensor) -> bool: + return x.device.type == "meta" + + self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) + + +instantiate_device_type_tests(TestDeviceUtils, globals()) + + +class TestCppExtensionUtils(TestCase): + def test_cpp_compiler_is_ok(self): + self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++")) + + def test_cc_compiler_is_ok(self): + self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc")) + + +class TestTraceback(TestCase): + def test_basic(self): + source = """\ +def f(x): + def g(x): + raise RuntimeError # HEYA + + x = x * 3 + return g(x) + 1 +""" + + out: dict[str, Any] = {} + scope = {"__compile_source__": source} + exec(source, scope, out) + + try: + with report_compile_source_on_error(): + out["f"](1) + except RuntimeError as e: + self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__))) + + def test_format_traceback_short(self): + try: + raise RuntimeError + except RuntimeError as e: + self.assertRegex( + format_traceback_short(e.__traceback__), + r".*test_utils.py:\d+ in test_format_traceback_short", + ) + + def test_captured_traceback(self): + self.assertIn( + "test_captured_traceback", "".join(CapturedTraceback.extract().format()) + ) + + def test_captured_traceback_format_all(self): + rs = CapturedTraceback.format_all( + [CapturedTraceback.extract(), CapturedTraceback.extract()] + ) + self.assertEqual(len(rs), 2) + self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) + + def test_captured_traceback_format_all_cached(self): + tb = CapturedTraceback.extract() + tb.format() # cached + rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()]) + self.assertEqual(len(rs), 2) + self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) + + +class TestTryImport(TestCase): + def test_import_imported(self): + self.assertIn("os", sys.modules) + os_module = try_import("os") + self.assertIs(os_module, os) + + def test_import_existing(self): + self.assertNotIn("imaplib", sys.modules) + imaplib_module = try_import("imaplib") + self.assertIsNotNone(imaplib_module) + self.assertFalse(hasattr(imaplib_module, "not_attribute")) + self.assertTrue(hasattr(imaplib_module, "IMAP4")) + + def test_import_missing(self): + missing_module = try_import("missing_module") + self.assertIsNone(missing_module) + + +@deprecated() +def _deprecated_api(x, y=15): + return x + y + + +class TestDeprecate(TestCase): + def test_deprecated(self): + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, 2) # noqa: F821 + with self.assertWarnsRegex(Warning, "is DEPRECATED"): + deprecated_api(1, y=2) # noqa: F821 + _deprecated_api(1, 2) + _deprecated_api(1, y=2) + + +if __name__ == "__main__": + run_tests() From 6bb31a9569ebed90884adcbbfdcf59e43011495f Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Fri, 31 Oct 2025 15:23:40 +0800 Subject: [PATCH 02/12] replace the g_device_type to device_type --- test/xpu/test_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index 4c4876b283..d67107aa56 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -529,32 +529,32 @@ def test_infer_device_state_recursive_meta(self): def test_infer_device_state_recursive_multi_gpu(self): # Check that no warning is issued for either gpu:0, gpu:1 or # gpu:0, gpu:0 cases since they are both the same device type - global g_device_type + global device_type inp = { - "foo": torch.rand(10, device=f"{g_device_type}:0"), - "bar": [torch.rand(10, device=f"{g_device_type}:1")], + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:1")], } with warnings.catch_warnings(): warnings.simplefilter("error") _device_type = _infer_device_type(inp) - self.assertEqual(g_device_type, _device_type) + self.assertEqual(device_type, _device_type) inp = { - "foo": torch.rand(10, device=f"{g_device_type}:0"), - "bar": [torch.rand(10, device=f"{g_device_type}:0")], + "foo": torch.rand(10, device=f"{device_type}:0"), + "bar": [torch.rand(10, device=f"{device_type}:0")], } with warnings.catch_warnings(): warnings.simplefilter("error") _device_type = _infer_device_type(inp) - self.assertEqual(g_device_type, _device_type) + self.assertEqual(device_type, _device_type) # Check that a warning is issued for gpu:0, meta and that it includes # device type information inp = { - "foo": torch.rand(10, device=f"{g_device_type}:0"), + "foo": torch.rand(10, device=f"{device_type}:0"), "bar": [torch.rand(10, device="meta")], } with warnings.catch_warnings(record=True) as w: _device_type = _infer_device_type(inp) - self.assertEqual(g_device_type, _device_type) + self.assertEqual(device_type, _device_type) self.assertEqual(len(w), 1) warning_msg = str(w[-1].message) self.assertTrue( From c4524cdb46d326d44935d86fda681d67f6372571 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Fri, 31 Oct 2025 15:31:58 +0800 Subject: [PATCH 03/12] Remove has_gpu --- test/xpu/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index d67107aa56..6ca46cc838 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -33,7 +33,7 @@ IS_WINDOWS, load_tests, ) -from torch.testing._internal.inductor_utils import HAS_GPU, get_gpu_type +from torch.testing._internal.inductor_utils import get_gpu_type from torch.utils._device import set_device from torch.utils._pytree import tree_all_only, tree_any from torch.utils._traceback import ( From 23d40e6bc413f665557916356c7591e815068004 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Mon, 3 Nov 2025 10:47:34 +0800 Subject: [PATCH 04/12] Update format with autopep8 --- test/xpu/checkpoint.py | 21 +++++++++++++++++---- test/xpu/test_utils.py | 10 +++------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/test/xpu/checkpoint.py b/test/xpu/checkpoint.py index 6fc9af6a4f..fbadbb648a 100644 --- a/test/xpu/checkpoint.py +++ b/test/xpu/checkpoint.py @@ -299,7 +299,8 @@ def backward(ctx, *args): device_autocast_ctx = torch.amp.autocast( device_type=ctx.device_type, **ctx.device_autocast_kwargs ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() - with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined] + # type: ignore[attr-defined] + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): @@ -340,6 +341,8 @@ def noop_context_fn(): # break. And here, the following disable wrapper ensures that # TorchDynamo does not trigger again on the frames created by # utils.checkpoint innards. + + @torch._disable_dynamo def checkpoint( function, @@ -986,6 +989,7 @@ def check_recomputed_tensors_match(self, gid): -------------------------------------------------------------------------------- """ + class CheckpointError(RuntimeError): pass @@ -1007,7 +1011,7 @@ def get_context_manager(self): @contextlib.contextmanager def logging_mode(): with LoggingTensorMode(), \ - capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: + capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: # pyrefly: ignore [bad-assignment] self.logs, self.tbs = logs_and_tb yield logs_and_tb @@ -1052,6 +1056,7 @@ def context_fn(): return context_fn, unpack_error_cb + def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: # These properties are fast to check, easy to understand return { @@ -1060,12 +1065,15 @@ def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: "device": x.device } + _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, "none": lambda _: None, } # See Rule 5 + + class _StopRecomputationError(Exception): pass @@ -1251,6 +1259,7 @@ class SelectiveCheckpointContext: >>> context_fn=context_fn, >>> ) """ + def __init__(self, *, is_recompute): self.is_recompute = is_recompute @@ -1332,9 +1341,11 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: - self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out)) + self.storage[func].append(tree_map(lambda x: _VersionWrapper( + _maybe_detach(x, any_ret_has_alias_info)), out)) return out + class _CachedTorchDispatchMode(TorchDispatchMode): # Used together with _CachedTorchDispatchMode to implement SAC. def __init__(self, policy_fn, storage, allow_cache_entry_mutation): @@ -1457,6 +1468,7 @@ def policy_fn(ctx, op, *args, **kwargs): # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here. + def _checkpoint_without_reentrant_generator( fn, preserve_rng_state=True, @@ -1560,7 +1572,8 @@ def recompute_fn(*inputs): device_autocast_ctx = torch.amp.autocast( device_type=device_type, **device_autocast_kwargs ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() - with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined] + # type: ignore[attr-defined] + with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: fn(*args, **kwargs) new_frame = _CheckpointFrame( diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index 6ca46cc838..03db402568 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs # Owner(s): ["module: unknown"] +from torch.utils.collect_env import get_pretty_env_info +from torch.testing._internal.common_utils import run_tests, TestCase import os import random import shutil @@ -60,9 +62,6 @@ device_type = get_gpu_type() -from torch.testing._internal.common_utils import run_tests, TestCase - - # mypy: disable-error-code="name-defined" @@ -149,7 +148,7 @@ def forward(self, input_var): out.sum().backward() for m in modules[: (len(modules) // 2)]: self.assertEqual(m.counter, 2) - for m in modules[(len(modules) // 2) :]: + for m in modules[(len(modules) // 2):]: self.assertEqual(m.counter, 1) def test_checkpoint_valid(self): @@ -637,9 +636,6 @@ def test_multi_drop(self): test_dir = os.path.abspath(os.path.dirname(str(__file__))) -from torch.utils.collect_env import get_pretty_env_info - - @unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") class TestCollectEnv(TestCase): def test_smoke(self): From 72a975645ad4d1329a100e5c14038e5ec8a409d8 Mon Sep 17 00:00:00 2001 From: erxin Date: Tue, 4 Nov 2025 10:24:01 +0000 Subject: [PATCH 05/12] Fix lint issue --- test/xpu/checkpoint.py | 219 +++++++++++++++++++++++----------- test/xpu/test_schema_check.py | 218 +++++++++++++++++++-------------- test/xpu/test_utils.py | 32 +++-- 3 files changed, 299 insertions(+), 170 deletions(-) diff --git a/test/xpu/checkpoint.py b/test/xpu/checkpoint.py index fbadbb648a..705a5b64e8 100644 --- a/test/xpu/checkpoint.py +++ b/test/xpu/checkpoint.py @@ -1,19 +1,19 @@ # mypy: allow-untyped-defs import contextlib +import enum import platform import uuid import warnings import weakref from collections import defaultdict from typing import * # noqa: F403 -import enum from weakref import ReferenceType import torch import torch.fx.traceback as fx_traceback -from torch.utils._pytree import tree_map from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map __all__ = [ "checkpoint", @@ -83,7 +83,8 @@ def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: def check_backward_validity(inputs: Iterable[Any]) -> None: if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): warnings.warn( - "None of the inputs have requires_grad=True. Gradients will be None", stacklevel=2 + "None of the inputs have requires_grad=True. Gradients will be None", + stacklevel=2, ) @@ -104,7 +105,9 @@ class DefaultDeviceType: to save and restore for recomputation. """ - _default_device_type = acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" + _default_device_type = ( + acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" + ) @staticmethod def set_device_type(device: str = "cuda"): @@ -134,6 +137,7 @@ def add_device_types(arg): nonlocal device_types if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": device_types.append(arg.device.type) + tree_map(add_device_types, args) device_types_set = set(device_types) @@ -144,7 +148,8 @@ def add_device_types(arg): "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" - f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}", stacklevel=2 + f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}", + stacklevel=2, ) if len(device_types) == 0: return DefaultDeviceType.get_device_type() @@ -170,6 +175,7 @@ def add_device_ids(arg): nonlocal fwd_device_ids if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: fwd_device_ids.append(arg.get_device()) + tree_map(add_device_ids, args) fwd_device_states = [] @@ -212,8 +218,8 @@ def _get_autocast_kwargs(device_type="cuda"): device_autocast_kwargs = None cpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled('cpu'), - "dtype": torch.get_autocast_dtype('cpu'), + "enabled": torch.is_autocast_enabled("cpu"), + "dtype": torch.get_autocast_dtype("cpu"), "cache_enabled": torch.is_autocast_cache_enabled(), } @@ -288,19 +294,31 @@ def backward(ctx, *args): if ctx.preserve_rng_state and ctx.had_device_in_fwd: rng_devices = ctx.fwd_devices with torch.random.fork_rng( - devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type + devices=rng_devices, + enabled=ctx.preserve_rng_state, + device_type=ctx.device_type, ): if ctx.preserve_rng_state: torch.set_rng_state(ctx.fwd_cpu_state) if ctx.had_device_in_fwd: - set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type) + set_device_states( + ctx.fwd_devices, + ctx.fwd_device_states, + device_type=ctx.device_type, + ) detached_inputs = detach_variable(tuple(inputs)) - device_autocast_ctx = torch.amp.autocast( - device_type=ctx.device_type, **ctx.device_autocast_kwargs - ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext() + device_autocast_ctx = ( + torch.amp.autocast( + device_type=ctx.device_type, **ctx.device_autocast_kwargs + ) + if torch.amp.is_autocast_available(ctx.device_type) + else contextlib.nullcontext() + ) # type: ignore[attr-defined] - with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): + with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast( + "cpu", **ctx.cpu_autocast_kwargs + ): outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): @@ -330,6 +348,7 @@ def backward(ctx, *args): def noop_context_fn(): return contextlib.nullcontext(), contextlib.nullcontext() + # Note: [torch.compile and checkpoint] # TorchDynamo does not step inside utils.checkpoint function. The flow # looks likes this @@ -352,7 +371,7 @@ def checkpoint( determinism_check: str = _DEFAULT_DETERMINISM_MODE, debug: bool = False, early_stop: bool = True, - **kwargs + **kwargs, ): r"""Checkpoint a model or part of the model. @@ -480,7 +499,7 @@ def checkpoint( "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " "details on the differences between the two variants.", - stacklevel=2 + stacklevel=2, ) use_reentrant = True @@ -500,7 +519,14 @@ def checkpoint( return CheckpointFunction.apply(function, preserve, *args) else: gen = _checkpoint_without_reentrant_generator( - function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs + function, + preserve, + context_fn, + determinism_check, + debug, + early_stop, + *args, + **kwargs, ) # Runs pre-forward logic next(gen) @@ -568,7 +594,8 @@ def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwar "is not passed. use_reentrant=False is " "recommended, but if you need to preserve the current default " "behavior, you can pass use_reentrant=True. Refer to docs for more " - "details on the differences between the two variants.", stacklevel=2 + "details on the differences between the two variants.", + stacklevel=2, ) use_reentrant = True @@ -994,13 +1021,15 @@ class CheckpointError(RuntimeError): pass -def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]: +def _get_debug_context_and_cb() -> ( + Tuple[Callable[[], Any], Callable[[CheckpointError], None]] +): # This function returns the context_fn and error_cb to be used by the # checkpointing mechanism. error_cb is invoked when an error is detected # during unpack. # record_context_cpp is not support on non-linux non-x86_64 platforms - cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux' + cpp_tb = platform.machine() == "x86_64" and platform.system() == "Linux" class CaptureLogs: def __init__(self): @@ -1010,11 +1039,13 @@ def __init__(self): def get_context_manager(self): @contextlib.contextmanager def logging_mode(): - with LoggingTensorMode(), \ - capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb: + with LoggingTensorMode(), capture_logs( + True, python_tb=True, script_tb=True, cpp_tb=cpp_tb + ) as logs_and_tb: # pyrefly: ignore [bad-assignment] self.logs, self.tbs = logs_and_tb yield logs_and_tb + return logging_mode() capture_logs_fwd = CaptureLogs() @@ -1029,7 +1060,7 @@ def get_str_tb(label, capture_logs): found_torch_dispatch = False for line in tb: # Start printing stack trace only after __torch_dispatch__ is found - is_torch_dispatch = line['name'] == '__torch_dispatch__' + is_torch_dispatch = line["name"] == "__torch_dispatch__" if not found_torch_dispatch and not is_torch_dispatch: continue elif is_torch_dispatch: @@ -1038,6 +1069,7 @@ def get_str_tb(label, capture_logs): out += f"{line['filename']}:{line['line']}:{line['name']}\n" out += "\n\n" return out + if capture_logs_fwd.logs is None: raise AssertionError("capture_logs_fwd.logs is None") if capture_logs_recompute.logs is None: @@ -1047,23 +1079,22 @@ def get_str_tb(label, capture_logs): forward_traces=get_str_tb("original", capture_logs_fwd), recompute_traces=get_str_tb("recompute", capture_logs_recompute), forward_ops="\n".join(capture_logs_fwd.logs), - recompute_ops="\n".join(capture_logs_recompute.logs) + recompute_ops="\n".join(capture_logs_recompute.logs), ) ) from e def context_fn(): - return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager() + return ( + capture_logs_fwd.get_context_manager(), + capture_logs_recompute.get_context_manager(), + ) return context_fn, unpack_error_cb def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: # These properties are fast to check, easy to understand - return { - "shape": x.shape, - "dtype": x.dtype, - "device": x.device - } + return {"shape": x.shape, "dtype": x.dtype, "device": x.device} _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { @@ -1090,7 +1121,9 @@ def pack_hook(x): if recomp_idx >= len(target_frame.weak_holders): if target_frame.early_stop: - raise AssertionError("Unexpected state: target_frame.early_stop is set") + raise AssertionError( + "Unexpected state: target_frame.early_stop is set" + ) if not target_frame.forward_completed: # We run into this case when early stop is not enabled and do # grad within checkpoint. @@ -1186,11 +1219,13 @@ def unpack_hook(holder): return ret if frame.unpack_error_cb is not None: + def unpack_hook_with_error_cb(holder): try: return unpack_hook(holder) except CheckpointError as e: frame.unpack_error_cb(e) + super().__init__(pack_hook, unpack_hook_with_error_cb) else: super().__init__(pack_hook, unpack_hook) @@ -1206,7 +1241,9 @@ class _VersionWrapper: # Check that cached tensors are not mutated. def __init__(self, val): self.val: Union[torch.Tensor, Any] = val - self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None + self.version: Optional[int] = ( + val._version if isinstance(val, torch.Tensor) else None + ) def get_val(self, allow_cache_entry_mutation): if self.version is not None and not allow_cache_entry_mutation: @@ -1226,8 +1263,12 @@ def _maybe_detach(x, any_ret_has_alias_info): # For case 1, it is not enough to check whether x has differentiable dtype # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. # when the tensor is a view. - if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info): - with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False): + if isinstance(x, torch.Tensor) and ( + x.is_floating_point() or x.is_complex() or any_ret_has_alias_info + ): + with torch._C._SetExcludeDispatchKeyGuard( + torch._C.DispatchKey.ADInplaceOrView, False + ): # Ensure that view performed beneath autograd properly propagates # version counter. TODO: Use reentrant_dispatch instead of # manually manipulating dispatch keys. Using reentrant_dispatch @@ -1287,6 +1328,7 @@ class CheckpointPolicy(enum.Enum): save additional tensors not limited to ones that are actually needed for gradient computation. """ + MUST_SAVE = 0 PREFER_SAVE = 1 MUST_RECOMPUTE = 2 @@ -1305,7 +1347,9 @@ def _policy_from_bool(b): # With subclasses involved, these metadata ops become dispatchable, this # can result in incorrectness if these ops are selected cached. torch.ops.prim.device.default, -} | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns) # type: ignore[has-type] +} | set( + torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns +) # type: ignore[has-type] class _CachingTorchDispatchMode(TorchDispatchMode): @@ -1319,8 +1363,9 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) kwargs = {} if kwargs is None else kwargs - policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False), - func, *args, **kwargs) + policy = self.policy_fn( + SelectiveCheckpointContext(is_recompute=False), func, *args, **kwargs + ) if isinstance(policy, bool): policy = _policy_from_bool(policy) @@ -1338,11 +1383,20 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if isinstance(func, torch._ops.HigherOrderOperator): any_ret_has_alias_info = False else: - any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns) + any_ret_has_alias_info = any( + ret.alias_info is not None for ret in func._schema.returns + ) - if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: - self.storage[func].append(tree_map(lambda x: _VersionWrapper( - _maybe_detach(x, any_ret_has_alias_info)), out)) + if ( + policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) + or is_compiling + ): + self.storage[func].append( + tree_map( + lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), + out, + ) + ) return out @@ -1358,29 +1412,39 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) kwargs = {} if kwargs is None else kwargs - policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True), - func, *args, **kwargs) + policy = self.policy_fn( + SelectiveCheckpointContext(is_recompute=True), func, *args, **kwargs + ) if isinstance(policy, bool): policy = _policy_from_bool(policy) is_compiling = _is_compiling(func, args, kwargs) - if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling: + if ( + policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) + or is_compiling + ): storage = self.storage.get(func) if storage is None: - raise RuntimeError(f"{func} encountered during backward, but not found in storage") + raise RuntimeError( + f"{func} encountered during backward, but not found in storage" + ) if len(storage) == 0: raise RuntimeError( "Trying to backward an extra time. You are only allowed to backward once " "on any region computed under selective activation checkpoint." ) - out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0)) + out = tree_map( + lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0) + ) else: out = func(*args, **kwargs) return out -def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False): +def create_selective_checkpoint_contexts( + policy_fn_or_list, allow_cache_entry_mutation=False +): """ Helper to avoid recomputing certain ops during activation checkpointing. @@ -1439,11 +1503,17 @@ def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mu # context_fn anyway, so proceed as usual. if isinstance(policy_fn_or_list, list): for op in policy_fn_or_list: - if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + if not isinstance( + op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ): _extra_msg = ( - "Please update the OpOverloadPacket to a specific OpOverload." - "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." - ) if isinstance(op, torch._ops.OpOverloadPacket) else "" + ( + "Please update the OpOverloadPacket to a specific OpOverload." + "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." + ) + if isinstance(op, torch._ops.OpOverloadPacket) + else "" + ) raise ValueError( f"Expected op in `op_list` to be an OpOverload but got: {op} " f"of type {type(op)}. {_extra_msg}" @@ -1454,6 +1524,7 @@ def policy_fn(ctx, op, *args, **kwargs): return CheckpointPolicy.MUST_SAVE else: return CheckpointPolicy.PREFER_RECOMPUTE + elif callable(policy_fn_or_list): policy_fn = policy_fn_or_list else: @@ -1465,6 +1536,7 @@ def policy_fn(ctx, op, *args, **kwargs): _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), ) + # NB: this helper wraps fn before calling checkpoint_impl. kwargs and # saving/restoring of global state is handled here. @@ -1477,7 +1549,7 @@ def _checkpoint_without_reentrant_generator( debug: bool = False, early_stop: bool = True, *args, - **kwargs + **kwargs, ): """Checkpointing without reentrant autograd. @@ -1514,9 +1586,7 @@ def _checkpoint_without_reentrant_generator( if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: if context_fn != noop_context_fn: - raise ValueError( - "debug=True is incompatible with non-default context_fn" - ) + raise ValueError("debug=True is incompatible with non-default context_fn") context_fn, unpack_error_cb = _get_debug_context_and_cb() if determinism_check in _allowed_determinism_checks_to_fns: @@ -1531,16 +1601,17 @@ def _checkpoint_without_reentrant_generator( device_module = _get_device_module(device_type) forward_context, recompute_context = context_fn() if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: - if ( - not isinstance(forward_context, TorchDispatchMode) - or not isinstance(recompute_context, TorchDispatchMode) + if not isinstance(forward_context, TorchDispatchMode) or not isinstance( + recompute_context, TorchDispatchMode ): raise AssertionError( "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " "must generate a tuple of two `TorchDispatchMode`s." ) # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type) + device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs( + device_type=device_type + ) if preserve_rng_state: fwd_cpu_state = torch.get_rng_state() @@ -1567,20 +1638,28 @@ def recompute_fn(*inputs): if preserve_rng_state: torch.set_rng_state(fwd_cpu_state) if had_device_in_fwd: - set_device_states(fwd_devices, fwd_device_states, device_type=device_type) - - device_autocast_ctx = torch.amp.autocast( - device_type=device_type, **device_autocast_kwargs - ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext() + set_device_states( + fwd_devices, fwd_device_states, device_type=device_type + ) + + device_autocast_ctx = ( + torch.amp.autocast(device_type=device_type, **device_autocast_kwargs) + if torch.amp.is_autocast_available(device_type) + else contextlib.nullcontext() + ) # type: ignore[attr-defined] - with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: + with device_autocast_ctx, torch.amp.autocast( + "cpu", **cpu_autocast_kwargs + ), recompute_context: fn(*args, **kwargs) new_frame = _CheckpointFrame( recompute_fn, - _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop, + _enable_checkpoint_early_stop + if _enable_checkpoint_early_stop is not None + else early_stop, unpack_error_cb, - metadata_fn + metadata_fn, ) dummy = torch.empty((0,), requires_grad=True) new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) @@ -1594,8 +1673,11 @@ def recompute_fn(*inputs): yield new_frame.forward_completed = True - if getattr(device_module, "_initialized", False) and \ - preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined] + if ( + getattr(device_module, "_initialized", False) + and preserve_rng_state + and not had_device_in_fwd + ): # type: ignore[possibly-undefined] # Device was not initialized before running the forward, so we didn't # stash the device state. raise RuntimeError( @@ -1606,6 +1688,7 @@ def recompute_fn(*inputs): return + # Note: [compiled autograd and checkpoint unpack hook] # When tracing via compiled autograd, this hook will be visible to the # compiler if the forward of this checkpointed region ran in eager. diff --git a/test/xpu/test_schema_check.py b/test/xpu/test_schema_check.py index 6d6410f073..d305c9daf4 100644 --- a/test/xpu/test_schema_check.py +++ b/test/xpu/test_schema_check.py @@ -3,33 +3,44 @@ import os import sys -import torch -from torch.utils._pytree import tree_map import unittest -from torch.testing._internal.common_utils import run_tests, TEST_WITH_TORCHDYNAMO -from torch.fx.operator_schemas import normalize_function +import torch from torch._subclasses.schema_check_mode import SchemaCheckMode -from torch.utils._python_dispatch import TorchDispatchMode +from torch.fx.operator_schemas import normalize_function +from torch.testing._internal.common_device_type import ( + instantiate_device_type_tests, + OpDTypes, + ops, +) from torch.testing._internal.common_methods_invocations import op_db +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + run_tests, + slowTestIf, + TEST_WITH_TORCHDYNAMO, +) from torch.testing._internal.jit_utils import JitTestCase -from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests -from torch.testing._internal.common_utils import IS_WINDOWS, slowTestIf +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map + pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) - def secretly_aliasing(x): return x.view(-1) + def secretly_mutating(x): x.mul_(2) return x * 3 + def output_is_input(x): return x + custom_lib = torch.library.Library("bad_schemas", "DEF") # noqa: TOR901 custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") custom_lib.define("secretly_mutating(Tensor x) -> Tensor") @@ -48,6 +59,7 @@ def output_is_input(x): # This TorchDispatchTensor Subclass is used to simulate an incorrect schema # which is then used to test that SchemaCheckMode behaves as expected + class IncorrectAliasTensor(torch.Tensor): ALIAS_ARG_OUT = {"aten::add"} ALIAS_OUT_OUT = {"aten::aminmax"} @@ -55,7 +67,7 @@ class IncorrectAliasTensor(torch.Tensor): elem: torch.Tensor - __slots__ = ['elem'] + __slots__ = ["elem"] @staticmethod def __new__(cls, elem, *args, **kwargs): @@ -63,11 +75,15 @@ def __new__(cls, elem, *args, **kwargs): # memory for the class in question, but it should still # advertise the same device as before r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] - cls, elem.size(), - strides=elem.stride(), storage_offset=elem.storage_offset(), + cls, + elem.size(), + strides=elem.stride(), + storage_offset=elem.storage_offset(), # TODO: clone storage aliasing - dtype=elem.dtype, layout=elem.layout, - device=elem.device, requires_grad=kwargs.get("requires_grad", False) + dtype=elem.dtype, + layout=elem.layout, + device=elem.device, + requires_grad=kwargs.get("requires_grad", False), ) # ...the real tensor is held as an element on the tensor. r.elem = elem.detach() if r.requires_grad else elem @@ -83,6 +99,7 @@ def unwrap(e): def wrap(e): return cls(e) if isinstance(e, torch.Tensor) else e + unwrapped_args = tree_map(unwrap, args) out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: @@ -96,6 +113,7 @@ def wrap(e): return tree_map(wrap, out) + # Tests various schema checking functionalities. class TestSchemaCheck(JitTestCase): def setUp(self): @@ -108,7 +126,9 @@ def test_schema_check_mode_operator_order(self): with SchemaCheckMode() as schema_check: x = torch.rand((3, 3), requires_grad=True) x.relu().sin() - self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) + self.assertEqual( + ["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops + ) # Tests that SchemaCheckMode records operator order without grad def test_schema_check_mode_operator_order_without_grad(self): @@ -132,16 +152,18 @@ def test_schema_check_mode_mutated_aliasing_mutation(self): actual = torch.rand((3, 3), requires_grad=False) with SchemaCheckMode() as schema_check: actual.sinh_() - self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) - self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) + self.assertEqual([("aten::sinh_", "input")], schema_check.mutated) + self.assertEqual([("aten::sinh_", "input", "output_0")], schema_check.aliasing) # Tests that SchemaCheckMode records mutations and aliases with resize_ def test_schema_check_mode_mutated_aliasing_resize_(self): actual = torch.rand((3, 3), requires_grad=False) with SchemaCheckMode() as schema_check: actual.resize_(9) - self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) - self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) + self.assertEqual([("aten::resize_", "input")], schema_check.mutated) + self.assertEqual( + [("aten::resize_", "input", "output_0")], schema_check.aliasing + ) # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): @@ -150,18 +172,11 @@ def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): with SchemaCheckMode() as schema_check: actual.add_(y) self.assertEqual( - [ - ('aten::add_', 'input'), - ('aten::add_', 'other') - ], - schema_check.mutated + [("aten::add_", "input"), ("aten::add_", "other")], schema_check.mutated ) self.assertEqual( - [ - ('aten::add_', 'input', 'output_0'), - ('aten::add_', 'other', 'output_0') - ], - schema_check.aliasing + [("aten::add_", "input", "output_0"), ("aten::add_", "other", "output_0")], + schema_check.aliasing, ) # Tests that SchemaCheckMode records mutations and alias with as_strided @@ -169,39 +184,28 @@ def test_schema_check_mode_mutated_aliasing_as_strided(self): x = torch.rand((3, 6, 4)) with SchemaCheckMode() as schema_check: x.as_strided_([3, 6, 4], [9, 1, 1]) + self.assertEqual([("aten::as_strided_", "input")], schema_check.mutated) self.assertEqual( - [ - ('aten::as_strided_', 'input') - ], - schema_check.mutated - ) - self.assertEqual( - [ - ('aten::as_strided_', 'input', 'output_0') - ], - schema_check.aliasing + [("aten::as_strided_", "input", "output_0")], schema_check.aliasing ) # Tests that SchemaCheckMode records mutations and aliases with multiple outputs def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): - x = torch.arange(9.) - m_actual = torch.arange(9.) + x = torch.arange(9.0) + m_actual = torch.arange(9.0) e_actual = torch.zeros([9], dtype=torch.int32) with SchemaCheckMode() as schema_check: torch.frexp(x, out=(m_actual, e_actual)) self.assertEqual( - [ - ('aten::frexp', 'mantissa'), - ('aten::frexp', 'exponent') - ], - schema_check.mutated + [("aten::frexp", "mantissa"), ("aten::frexp", "exponent")], + schema_check.mutated, ) self.assertEqual( [ - ('aten::frexp', 'mantissa', 'output_0'), - ('aten::frexp', 'exponent', 'output_1') + ("aten::frexp", "mantissa", "output_0"), + ("aten::frexp", "exponent", "output_1"), ], - schema_check.aliasing + schema_check.aliasing, ) # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs @@ -211,20 +215,16 @@ def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): with SchemaCheckMode() as schema_check: torch.aminmax(x, dim=0, out=[actual, actual]) self.assertEqual( - [ - ('aten::aminmax', 'min'), - ('aten::aminmax', 'max') - ], - schema_check.mutated + [("aten::aminmax", "min"), ("aten::aminmax", "max")], schema_check.mutated ) self.assertEqual( [ - ('aten::aminmax', 'min', 'output_0'), - ('aten::aminmax', 'min', 'output_1'), - ('aten::aminmax', 'max', 'output_0'), - ('aten::aminmax', 'max', 'output_1') + ("aten::aminmax", "min", "output_0"), + ("aten::aminmax", "min", "output_1"), + ("aten::aminmax", "max", "output_0"), + ("aten::aminmax", "max", "output_1"), ], - schema_check.aliasing + schema_check.aliasing, ) # Tests that SchemaCheckMode wraps torch.Tensor @@ -293,9 +293,9 @@ def test_schema_check_mode_functionality_aliasing_inputs(self): # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs def test_schema_check_mode_functionality_with_multiple_outputs(self): - x = torch.arange(9.) + x = torch.arange(9.0) m_expected, e_expected = torch.frexp(x) - m_actual = torch.arange(9.) + m_actual = torch.arange(9.0) e_actual = torch.zeros([9], dtype=torch.int32) with SchemaCheckMode(): torch.frexp(x, out=(m_actual, e_actual)) @@ -352,7 +352,9 @@ def test_schema_check_mode_empty_list_input(self): # Tests that an exception is raised for a mismatching mutation def test_mutation_check_fail(self): - with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + with self.assertRaisesRegex( + RuntimeError, "Argument input is not defined as mutable but was mutated" + ): x = torch.rand((3, 3)) y = torch.rand((3, 3)) with SchemaCheckMode(): @@ -360,7 +362,9 @@ def test_mutation_check_fail(self): # # Tests that an exception is raised for a mismatching mutation over multiple ops def test_mutation_check_fail_multiple_operators(self): - with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): + with self.assertRaisesRegex( + RuntimeError, "Argument input is not defined as mutable but was mutated" + ): x = torch.rand((3, 3)) y = torch.rand((3, 3)) with SchemaCheckMode(): @@ -368,7 +372,10 @@ def test_mutation_check_fail_multiple_operators(self): # Tests that an exception is raised for a mismatching alias def test_alias_check_fail_simple(self): - with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + with self.assertRaisesRegex( + RuntimeError, + "Argument input is not defined to alias output but was aliasing", + ): x = torch.rand((3, 3), requires_grad=True) y = torch.rand((3, 3)) with SchemaCheckMode(): @@ -376,19 +383,29 @@ def test_alias_check_fail_simple(self): # Tests that an exception is raised for a mismatching alias over multiple ops def test_alias_check_fail_multiple_operators(self): - with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + with self.assertRaisesRegex( + RuntimeError, + "Argument input is not defined to alias output but was aliasing", + ): x = torch.rand((3, 3), requires_grad=True) y = torch.zeros((3, 3), requires_grad=True) with SchemaCheckMode(): - IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) + IncorrectAliasTensor(x).sin().relu().add( + IncorrectAliasTensor(y), alpha=2 + ) # Tests that an exception is raised for a centered mismatching alias over multiple ops def test_alias_check_fail_multiple_operators_centered(self): - with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): + with self.assertRaisesRegex( + RuntimeError, + "Argument input is not defined to alias output but was aliasing", + ): x = torch.rand((3, 3), requires_grad=True) y = torch.zeros((3, 3), requires_grad=True) with SchemaCheckMode(): - IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() + IncorrectAliasTensor(x).sin().add( + IncorrectAliasTensor(y), alpha=2 + ).relu() # Tests that an exception is raised for a centered mismatching alias over multiple ops def test_alias_check_fail_outputs_unexpectedly_aliasing(self): @@ -405,7 +422,9 @@ def f(x): return torch.ops.bad_schemas.secretly_aliasing(x) x = torch.rand((3, 3)) - with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): + with self.assertRaisesRegex( + RuntimeError, "not defined to alias output but was aliasing" + ): with SchemaCheckMode() as s: out = f(x) @@ -414,7 +433,9 @@ def f(x): return torch.ops.bad_schemas.secretly_mutating(x) x = torch.rand((3, 3)) - with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): + with self.assertRaisesRegex( + RuntimeError, "not defined as mutable but was mutated" + ): with SchemaCheckMode() as s: out = f(x) @@ -423,7 +444,9 @@ def f(x): return torch.ops.bad_schemas.output_is_input(x) x = torch.rand((3, 3)) - with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): + with self.assertRaisesRegex( + RuntimeError, "are not allowed to directly return inputs" + ): with SchemaCheckMode() as s: out = f(x) @@ -468,47 +491,64 @@ def __init__(self, test_self): def __torch_dispatch__(self, func, types, args=(), kwargs=None): named_arg_list = normalize_function( - func, - args, - kwargs, - normalize_to_only_use_kwargs=True + func, args, kwargs, normalize_to_only_use_kwargs=True ).kwargs schema_info_value_test = torch._C._SchemaInfo(func._schema) schema_info_values_test = torch._C._SchemaInfo(func._schema) - self.test_self.assertFalse(schema_info_value_test.may_alias( - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) - self.test_self.assertFalse(schema_info_values_test.may_alias( - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertFalse( + schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1), + ) + ) + self.test_self.assertFalse( + schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1), + ) + ) for i in named_arg_list: schema_info_value_test.add_argument_value(i, named_arg_list[i]) schema_info_values_test.add_argument_values(named_arg_list) - self.test_self.assertTrue(schema_info_value_test.may_alias( - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) - self.test_self.assertTrue(schema_info_values_test.may_alias( - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), - torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) + self.test_self.assertTrue( + schema_info_value_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1), + ) + ) + self.test_self.assertTrue( + schema_info_values_test.may_alias( + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), + torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1), + ) + ) return func(*args, **kwargs) + x = torch.rand((3, 3)) with SchemaInfoBindTestMode(self) as schemaInfoCheck: x.add(x) + class TestSchemaCheckModeOpInfo(JitTestCase): @ops(op_db, dtypes=OpDTypes.supported) @slowTestIf(IS_WINDOWS) def test_schema_correctness(self, device, dtype, op): # Currently torch.equal isn't supported with torch.complex32 # There's also errors with complex64 and complex128 - if (dtype == torch.complex32): + if dtype == torch.complex32: return for sample in op.sample_inputs(device, dtype, requires_grad=False): with SchemaCheckMode(): op(sample.input, *sample.args, **sample.kwargs) -instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda", "xpu"), allow_xpu=True) -if __name__ == '__main__': +instantiate_device_type_tests( + TestSchemaCheckModeOpInfo, + globals(), + only_for=("cpu", "cuda", "xpu"), + allow_xpu=True, +) + +if __name__ == "__main__": run_tests() diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index 03db402568..b6efc8db8a 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs # Owner(s): ["module: unknown"] -from torch.utils.collect_env import get_pretty_env_info -from torch.testing._internal.common_utils import run_tests, TestCase import os import random import shutil @@ -20,6 +18,12 @@ import torch.nn as nn import torch.utils.cpp_extension import torch.utils.data +from checkpoint import ( + _infer_device_type, + checkpoint, + checkpoint_sequential, + get_device_states, +) from torch._utils import try_import from torch._utils_internal import deprecated from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -34,6 +38,8 @@ IS_SANDCASTLE, IS_WINDOWS, load_tests, + run_tests, + TestCase, ) from torch.testing._internal.inductor_utils import get_gpu_type from torch.utils._device import set_device @@ -43,15 +49,9 @@ format_traceback_short, report_compile_source_on_error, ) -from checkpoint import ( - _infer_device_type, - checkpoint, - checkpoint_sequential, - get_device_states, -) +from torch.utils.collect_env import get_pretty_env_info from torch.utils.data import DataLoader - # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests # noqa: PLW0127 @@ -148,7 +148,7 @@ def forward(self, input_var): out.sum().backward() for m in modules[: (len(modules) // 2)]: self.assertEqual(m.counter, 2) - for m in modules[(len(modules) // 2):]: + for m in modules[(len(modules) // 2) :]: self.assertEqual(m.counter, 1) def test_checkpoint_valid(self): @@ -453,7 +453,9 @@ def track(x, idx): def hook(_unused): self.assertEqual(len(stats), idx) torch.get_device_module(device_type).synchronize() - stats.append(torch.get_device_module(device_type).memory_allocated()) + stats.append( + torch.get_device_module(device_type).memory_allocated() + ) if idx > 0: if should_free: self.assertLess(stats[idx], stats[idx - 1]) @@ -878,12 +880,16 @@ def test_get_default_device_more(self): torch.set_default_device(f"{device_type}:1") with torch.device(f"{device_type}:0"): - self.assertEqual(torch.get_default_device(), torch.device(f"{device_type}", 0)) + self.assertEqual( + torch.get_default_device(), torch.device(f"{device_type}", 0) + ) torch.set_default_device("cpu") self.assertEqual(torch.get_default_device(), torch.device("cpu")) with torch.device(f"{device_type}:0"): - self.assertEqual(torch.get_default_device(), torch.device(f"{device_type}", 0)) + self.assertEqual( + torch.get_default_device(), torch.device(f"{device_type}", 0) + ) self.assertEqual(torch.get_default_device(), torch.device("cpu")) finally: From c0dc3649b3a274f57d330b4da7bf616b773390ee Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Wed, 5 Nov 2025 14:33:55 +0800 Subject: [PATCH 06/12] Add test cases into skip list common file to trigger the test in CI --- test/xpu/skip_list_common.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 866a75d36a..16fa5cc212 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -96,6 +96,8 @@ "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_5_xpu", "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_2_xpu", ), + "test_utils.py": None, + "test_schema_check.py": None, "test_complex_xpu.py": None, "test_modules_xpu.py": ( # oneDNN issues From 4a1c7700da8d935e968992b1cf28637bfa5d98ed Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Thu, 6 Nov 2025 11:07:02 +0800 Subject: [PATCH 07/12] Add skip failed cases for test schema correctness which not relative to the changes --- test/xpu/skip_list_common.py | 38 +++++++++++++++++++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 16fa5cc212..5f45d8efee 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -97,7 +97,43 @@ "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_2_xpu", ), "test_utils.py": None, - "test_schema_check.py": None, + "test_schema_check.py": ( + # Skipped failed cases for third_party.torch-xpu-ops.test.xpu.test_schema_check.TestSchemaCheckModeOpInfoXPU + "test_schema_correctness_max_reduction_no_dim_xpu_uint16", + "test_schema_correctness_max_reduction_no_dim_xpu_uint32", + "test_schema_correctness_max_reduction_no_dim_xpu_uint64", + "test_schema_correctness_max_reduction_with_dim_xpu_uint16", + "test_schema_correctness_max_reduction_with_dim_xpu_uint32", + "test_schema_correctness_max_reduction_with_dim_xpu_uint64", + "test_schema_correctness_min_reduction_no_dim_xpu_uint16", + "test_schema_correctness_min_reduction_no_dim_xpu_uint32", + "test_schema_correctness_min_reduction_no_dim_xpu_uint64", + "test_schema_correctness_min_reduction_with_dim_xpu_uint16", + "test_schema_correctness_min_reduction_with_dim_xpu_uint32", + "test_schema_correctness_min_reduction_with_dim_xpu_uint64", + "test_schema_correctness_amax_xpu_uint16", + "test_schema_correctness_amax_xpu_uint32", + "test_schema_correctness_amax_xpu_uint64", + "test_schema_correctness_amin_xpu_uint16", + "test_schema_correctness_amin_xpu_uint32", + "test_schema_correctness_amin_xpu_uint64", + "test_schema_correctness_aminmax_xpu_uint16", + "test_schema_correctness_aminmax_xpu_uint32", + "test_schema_correctness_aminmax_xpu_uint64", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_bfloat16", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_complex128", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_complex64", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float16", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float32", + "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float64", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_bfloat16", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_complex128", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_complex64", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float16", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float32", + "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float64", + "test_schema_correctness_torch_ops_aten__flash_attention_forward_xpu_float16", + ), "test_complex_xpu.py": None, "test_modules_xpu.py": ( # oneDNN issues From d97f7072dfc3dc2bc8eca97f2d617b3acee023c0 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Thu, 6 Nov 2025 11:07:44 +0800 Subject: [PATCH 08/12] Fix intend --- test/xpu/skip_list_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 5f45d8efee..85bf466caa 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -133,7 +133,7 @@ "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float32", "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float64", "test_schema_correctness_torch_ops_aten__flash_attention_forward_xpu_float16", - ), + ), "test_complex_xpu.py": None, "test_modules_xpu.py": ( # oneDNN issues From 5fe36a89d760e717cb016ed061275220bc5edea1 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Mon, 10 Nov 2025 11:47:30 +0800 Subject: [PATCH 09/12] Add relative directory import for checkpoint --- test/xpu/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index b6efc8db8a..e3108ebd3e 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -18,7 +18,7 @@ import torch.nn as nn import torch.utils.cpp_extension import torch.utils.data -from checkpoint import ( +from .checkpoint import ( _infer_device_type, checkpoint, checkpoint_sequential, From 7ea55db6272c948ebc6574bacdfcec7b41e02a39 Mon Sep 17 00:00:00 2001 From: shangerxin Date: Mon, 10 Nov 2025 03:51:26 +0000 Subject: [PATCH 10/12] Fix lint issue --- test/xpu/test_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py index e3108ebd3e..420504954e 100644 --- a/test/xpu/test_utils.py +++ b/test/xpu/test_utils.py @@ -18,12 +18,6 @@ import torch.nn as nn import torch.utils.cpp_extension import torch.utils.data -from .checkpoint import ( - _infer_device_type, - checkpoint, - checkpoint_sequential, - get_device_states, -) from torch._utils import try_import from torch._utils_internal import deprecated from torch.testing._internal.common_cuda import TEST_MULTIGPU @@ -52,6 +46,13 @@ from torch.utils.collect_env import get_pretty_env_info from torch.utils.data import DataLoader +from .checkpoint import ( + _infer_device_type, + checkpoint, + checkpoint_sequential, + get_device_states, +) + # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # sharding on sandcastle. This line silences flake warnings load_tests = load_tests # noqa: PLW0127 From 4d295ac1995446b3e5e3f5c4d22173b76f18e43e Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Thu, 13 Nov 2025 14:30:23 +0800 Subject: [PATCH 11/12] Remove the cases from skip list which are already added into the issue --- test/xpu/skip_list_common.py | 38 +----------------------------------- 1 file changed, 1 insertion(+), 37 deletions(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 85bf466caa..16fa5cc212 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -97,43 +97,7 @@ "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_2_xpu", ), "test_utils.py": None, - "test_schema_check.py": ( - # Skipped failed cases for third_party.torch-xpu-ops.test.xpu.test_schema_check.TestSchemaCheckModeOpInfoXPU - "test_schema_correctness_max_reduction_no_dim_xpu_uint16", - "test_schema_correctness_max_reduction_no_dim_xpu_uint32", - "test_schema_correctness_max_reduction_no_dim_xpu_uint64", - "test_schema_correctness_max_reduction_with_dim_xpu_uint16", - "test_schema_correctness_max_reduction_with_dim_xpu_uint32", - "test_schema_correctness_max_reduction_with_dim_xpu_uint64", - "test_schema_correctness_min_reduction_no_dim_xpu_uint16", - "test_schema_correctness_min_reduction_no_dim_xpu_uint32", - "test_schema_correctness_min_reduction_no_dim_xpu_uint64", - "test_schema_correctness_min_reduction_with_dim_xpu_uint16", - "test_schema_correctness_min_reduction_with_dim_xpu_uint32", - "test_schema_correctness_min_reduction_with_dim_xpu_uint64", - "test_schema_correctness_amax_xpu_uint16", - "test_schema_correctness_amax_xpu_uint32", - "test_schema_correctness_amax_xpu_uint64", - "test_schema_correctness_amin_xpu_uint16", - "test_schema_correctness_amin_xpu_uint32", - "test_schema_correctness_amin_xpu_uint64", - "test_schema_correctness_aminmax_xpu_uint16", - "test_schema_correctness_aminmax_xpu_uint32", - "test_schema_correctness_aminmax_xpu_uint64", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_bfloat16", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_complex128", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_complex64", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float16", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float32", - "test_schema_correctness_nn_functional_conv_transpose2d_xpu_float64", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_bfloat16", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_complex128", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_complex64", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float16", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float32", - "test_schema_correctness_nn_functional_conv_transpose3d_xpu_float64", - "test_schema_correctness_torch_ops_aten__flash_attention_forward_xpu_float16", - ), + "test_schema_check.py": None, "test_complex_xpu.py": None, "test_modules_xpu.py": ( # oneDNN issues From 147e81affa2333355e1e97764d37d71ef8057885 Mon Sep 17 00:00:00 2001 From: Erxin Shang Date: Tue, 18 Nov 2025 11:06:33 +0800 Subject: [PATCH 12/12] Remove test_utils relative changes out of xpu which are already meregd into torch --- test/xpu/checkpoint.py | 1704 ---------------------------------- test/xpu/skip_list_common.py | 1 - test/xpu/test_utils.py | 1025 -------------------- 3 files changed, 2730 deletions(-) delete mode 100644 test/xpu/checkpoint.py delete mode 100644 test/xpu/test_utils.py diff --git a/test/xpu/checkpoint.py b/test/xpu/checkpoint.py deleted file mode 100644 index 705a5b64e8..0000000000 --- a/test/xpu/checkpoint.py +++ /dev/null @@ -1,1704 +0,0 @@ -# mypy: allow-untyped-defs -import contextlib -import enum -import platform -import uuid -import warnings -import weakref -from collections import defaultdict -from typing import * # noqa: F403 -from weakref import ReferenceType - -import torch -import torch.fx.traceback as fx_traceback -from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode -from torch.utils._python_dispatch import TorchDispatchMode -from torch.utils._pytree import tree_map - -__all__ = [ - "checkpoint", - "checkpoint_sequential", - "CheckpointError", - "CheckpointFunction", - "check_backward_validity", - "detach_variable", - "get_device_states", - "set_device_states", - "noop_context_fn", - "set_checkpoint_early_stop", - "DefaultDeviceType", - "set_checkpoint_debug_enabled", - "CheckpointPolicy", - "SelectiveCheckpointContext", - "create_selective_checkpoint_contexts", - "SAC_IGNORED_OPS", -] - -_DEFAULT_DETERMINISM_MODE = "default" - -_checkpoint_debug_enabled: Optional[bool] = None - - -@contextlib.contextmanager -def set_checkpoint_debug_enabled(enabled: Optional[bool]): - """ - Context manager that sets whether checkpoint should print additional debug - information when running. See the ``debug`` flag for - :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that - when set, this context manager overrides the value of ``debug`` passed to - checkpoint. To defer to the local setting, pass ``None`` to this context. - - Args: - enabled (bool): Whether checkpoint should print debug information. - Default is 'None'. - """ - global _checkpoint_debug_enabled - try: - prev = _checkpoint_debug_enabled - _checkpoint_debug_enabled = enabled - yield - finally: - _checkpoint_debug_enabled = prev - - -def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: - if isinstance(inputs, tuple): - out = [] - for inp in inputs: - if not isinstance(inp, torch.Tensor): - out.append(inp) - continue - - x = inp.detach() - x.requires_grad = inp.requires_grad - out.append(x) - return tuple(out) - else: - raise RuntimeError( - "Only tuple of tensors is supported. Got Unsupported input type: ", - type(inputs).__name__, - ) - - -def check_backward_validity(inputs: Iterable[Any]) -> None: - if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)): - warnings.warn( - "None of the inputs have requires_grad=True. Gradients will be None", - stacklevel=2, - ) - - -def _get_device_module(device="cuda"): - if device == "meta": - return torch.device("meta") - device_module = getattr(torch, device) - return device_module - - -class DefaultDeviceType: - r""" - A class that manages the default device type for checkpointing. - - If no non-CPU tensors are present, the default device type will - be used. The default value is 'cuda'. The device type is used in - the checkpointing process when determining which device states - to save and restore for recomputation. - """ - - _default_device_type = ( - acc.type if (acc := torch.accelerator.current_accelerator(True)) else "cpu" - ) - - @staticmethod - def set_device_type(device: str = "cuda"): - """ - Set the default device type for checkpointing. - - Args: - device (str): The device type to be set as default. Default is 'cuda'. - """ - DefaultDeviceType._default_device_type = device - - @staticmethod - def get_device_type() -> str: - """ - Get the current default device type for checkpointing. - - Returns: - str: The current default device type. - """ - return DefaultDeviceType._default_device_type - - -def _infer_device_type(*args): - device_types = [] - - def add_device_types(arg): - nonlocal device_types - if isinstance(arg, torch.Tensor) and arg.device.type != "cpu": - device_types.append(arg.device.type) - - tree_map(add_device_types, args) - - device_types_set = set(device_types) - if len(device_types_set) > 1: - warnings.warn( - "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. " - "Device state will only be saved for devices of a single device type, and the remaining " - "devices will be ignored. Consequently, if any checkpointed functions involve randomness, " - "this may result in incorrect gradients. (Note that if CUDA devices are among the devices " - "detected, it will be prioritized; otherwise, the first device encountered will be selected.)" - f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}", - stacklevel=2, - ) - if len(device_types) == 0: - return DefaultDeviceType.get_device_type() - elif "cuda" in device_types_set: - return "cuda" - else: - return device_types[0] - - -# We can't know if the run_fn will internally move some args to different devices, -# which would require logic to preserve rng states for those devices as well. -# We could paranoically stash and restore ALL the rng states for all visible devices, -# but that seems very wasteful for most cases. Compromise: Stash the RNG state for -# the device of all Tensor args. -# -# To consider: maybe get_device_states and set_device_states should reside in torch/random.py? -def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]: - # This will not error out if "arg" is a CPU tensor or a non-tensor type because - # the conditionals short-circuit. - fwd_device_ids = [] - - def add_device_ids(arg): - nonlocal fwd_device_ids - if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}: - fwd_device_ids.append(arg.get_device()) - - tree_map(add_device_ids, args) - - fwd_device_states = [] - device_module = _get_device_module(_infer_device_type(*args)) - for device_id in fwd_device_ids: - with device_module.device(device_id): - fwd_device_states.append(device_module.get_rng_state()) - - return fwd_device_ids, fwd_device_states - - -def set_device_states(devices, states, *, device_type=None) -> None: - """Sets random number generator states for the specified devices. - - Args: - devices: Device ids to set states for. - states: States to set. - device_type: ``device_type`` of the devices to set states for. Default - is the device returned by a call to ``DefaultDeviceType.get_device_type()``, - which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``. - """ - if device_type is None: - device_type = DefaultDeviceType.get_device_type() - if device_type == "meta": - return - device_module = _get_device_module(device_type) - for device, state in zip(devices, states): - with device_module.device(device): - device_module.set_rng_state(state) - - -def _get_autocast_kwargs(device_type="cuda"): - if torch.amp.is_autocast_available(device_type): - device_autocast_kwargs = { - "enabled": torch.is_autocast_enabled(device_type), - "dtype": torch.get_autocast_dtype(device_type), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - else: - device_autocast_kwargs = None - - cpu_autocast_kwargs = { - "enabled": torch.is_autocast_enabled("cpu"), - "dtype": torch.get_autocast_dtype("cpu"), - "cache_enabled": torch.is_autocast_cache_enabled(), - } - - return device_autocast_kwargs, cpu_autocast_kwargs - - -class CheckpointFunction(torch.autograd.Function): - @staticmethod - # pyrefly: ignore [bad-override] - def forward(ctx, run_function, preserve_rng_state, *args): - check_backward_validity(args) - ctx.run_function = run_function - ctx.preserve_rng_state = preserve_rng_state - # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - ctx.device_type = _infer_device_type(*args) - ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs( - ctx.device_type - ) - if preserve_rng_state: - ctx.fwd_cpu_state = torch.get_rng_state() - # Don't eagerly initialize the cuda context by accident. - # (If the user intends that the context is initialized later, within their - # run_function, we SHOULD actually stash the cuda state here. Unfortunately, - # we have no way to anticipate this will happen before we run the function.) - ctx.had_device_in_fwd = False - device_module = _get_device_module(ctx.device_type) - if getattr(device_module, "_initialized", False): - ctx.had_device_in_fwd = True - ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args) - - # Save non-tensor inputs in ctx, keep a placeholder None for tensors - # to be filled out during the backward. - ctx.inputs = [] - ctx.tensor_indices = [] - tensor_inputs = [] - for i, arg in enumerate(args): - if torch.is_tensor(arg): - tensor_inputs.append(arg) - ctx.tensor_indices.append(i) - ctx.inputs.append(None) - else: - ctx.inputs.append(arg) - - ctx.save_for_backward(*tensor_inputs) - - with torch.no_grad(): - outputs = run_function(*args) - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError( - "When use_reentrant=True, torch.utils.checkpoint is incompatible" - " with .grad() or passing an `inputs` parameter to .backward()." - " To resolve this error, you can either set use_reentrant=False," - " or call .backward() without passing the `inputs` argument." - ) - # Copy the list to avoid modifying original list. - inputs = list(ctx.inputs) - tensor_indices = ctx.tensor_indices - tensors = ctx.saved_tensors - - # Fill in inputs with appropriate saved tensors. - for i, idx in enumerate(tensor_indices): - inputs[idx] = tensors[i] - - # Stash the surrounding rng state, and mimic the state that was - # present at this time during forward. Restore the surrounding state - # when we're done. - rng_devices = [] - if ctx.preserve_rng_state and ctx.had_device_in_fwd: - rng_devices = ctx.fwd_devices - with torch.random.fork_rng( - devices=rng_devices, - enabled=ctx.preserve_rng_state, - device_type=ctx.device_type, - ): - if ctx.preserve_rng_state: - torch.set_rng_state(ctx.fwd_cpu_state) - if ctx.had_device_in_fwd: - set_device_states( - ctx.fwd_devices, - ctx.fwd_device_states, - device_type=ctx.device_type, - ) - detached_inputs = detach_variable(tuple(inputs)) - - device_autocast_ctx = ( - torch.amp.autocast( - device_type=ctx.device_type, **ctx.device_autocast_kwargs - ) - if torch.amp.is_autocast_available(ctx.device_type) - else contextlib.nullcontext() - ) - # type: ignore[attr-defined] - with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast( - "cpu", **ctx.cpu_autocast_kwargs - ): - outputs = ctx.run_function(*detached_inputs) - - if isinstance(outputs, torch.Tensor): - outputs = (outputs,) - - # run backward() with only tensor that requires grad - outputs_with_grad = [] - args_with_grad = [] - for i in range(len(outputs)): - if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: - outputs_with_grad.append(outputs[i]) - args_with_grad.append(args[i]) - if len(outputs_with_grad) == 0: - raise RuntimeError( - "none of output has requires_grad=True," - " this checkpoint() is not necessary" - ) - torch.autograd.backward(outputs_with_grad, args_with_grad) - grads = tuple( - inp.grad if isinstance(inp, torch.Tensor) else None - for inp in detached_inputs - ) - - return (None, None) + grads - - -def noop_context_fn(): - return contextlib.nullcontext(), contextlib.nullcontext() - - -# Note: [torch.compile and checkpoint] -# TorchDynamo does not step inside utils.checkpoint function. The flow -# looks likes this -# 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by -# speculatively checking if the forward function is safe to trace. -# 2) If yes, then Dynamo-generated Fx graph has the wrapped higher -# order op. As a result, TorchDynamo does not look inside utils.checkpoint. -# 3) If not, then TorchDynamo falls back to eager by performing a graph -# break. And here, the following disable wrapper ensures that -# TorchDynamo does not trigger again on the frames created by -# utils.checkpoint innards. - - -@torch._disable_dynamo -def checkpoint( - function, - *args, - use_reentrant: Optional[bool] = None, - context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, - determinism_check: str = _DEFAULT_DETERMINISM_MODE, - debug: bool = False, - early_stop: bool = True, - **kwargs, -): - r"""Checkpoint a model or part of the model. - - Activation checkpointing is a technique that trades compute for memory. - Instead of keeping tensors needed for backward alive until they are used in - gradient computation during backward, forward computation in checkpointed - regions omits saving tensors for backward and recomputes them during the - backward pass. Activation checkpointing can be applied to any part of a - model. - - There are currently two checkpointing implementations available, determined - by the :attr:`use_reentrant` parameter. It is recommended that you use - ``use_reentrant=False``. Please refer the note below for a discussion of - their differences. - - .. warning:: - - If the :attr:`function` invocation during the backward pass differs - from the forward pass, e.g., due to a global variable, the checkpointed - version may not be equivalent, potentially causing an - error being raised or leading to silently incorrect gradients. - - .. warning:: - - The ``use_reentrant`` parameter should be passed explicitly. In version - 2.9 we will raise an exception if ``use_reentrant`` is not passed. - If you are using the ``use_reentrant=True`` variant, please refer to the - note below for important considerations and potential limitations. - - .. note:: - - The reentrant variant of checkpoint (``use_reentrant=True``) and - the non-reentrant variant of checkpoint (``use_reentrant=False``) - differ in the following ways: - - * Non-reentrant checkpoint stops recomputation as soon as all needed - intermediate activations have been recomputed. This feature is enabled - by default, but can be disabled with :func:`set_checkpoint_early_stop`. - Reentrant checkpoint always recomputes :attr:`function` in its - entirety during the backward pass. - - * The reentrant variant does not record the autograd graph during the - forward pass, as it runs with the forward pass under - :func:`torch.no_grad`. The non-reentrant version does record the - autograd graph, allowing one to perform backward on the graph within - checkpointed regions. - - * The reentrant checkpoint only supports the - :func:`torch.autograd.backward` API for the backward pass without its - `inputs` argument, while the non-reentrant version supports all ways - of performing the backward pass. - - * At least one input and output must have ``requires_grad=True`` for the - reentrant variant. If this condition is unmet, the checkpointed part - of the model will not have gradients. The non-reentrant version does - not have this requirement. - - * The reentrant version does not consider tensors in nested structures - (e.g., custom objects, lists, dicts, etc) as participating in - autograd, while the non-reentrant version does. - - * The reentrant checkpoint does not support checkpointed regions with - detached tensors from the computational graph, whereas the - non-reentrant version does. For the reentrant variant, if the - checkpointed segment contains tensors detached using ``detach()`` or - with :func:`torch.no_grad`, the backward pass will raise an error. - This is because ``checkpoint`` makes all the outputs require gradients - and this causes issues when a tensor is defined to have no gradient in - the model. To avoid this, detach the tensors outside of the - ``checkpoint`` function. - - Args: - function: describes what to run in the forward pass of the model or - part of the model. It should also know how to handle the inputs - passed as the tuple. For example, in LSTM, if user passes - ``(activation, hidden)``, :attr:`function` should correctly use the - first input as ``activation`` and the second input as ``hidden`` - args: tuple containing inputs to the :attr:`function` - - Keyword args: - preserve_rng_state(bool, optional): Omit stashing and restoring - the RNG state during each checkpoint. Note that under torch.compile, - this flag doesn't take effect and we always preserve RNG state. - Default: ``True`` - use_reentrant(bool): - specify whether to use the activation checkpoint variant that - requires reentrant autograd. This parameter should be passed - explicitly. In version 2.9 we will raise an exception if - ``use_reentrant`` is not passed. If ``use_reentrant=False``, - ``checkpoint`` will use an implementation that does not require - reentrant autograd. This allows ``checkpoint`` to support additional - functionality, such as working as expected with - ``torch.autograd.grad`` and support for keyword arguments input into - the checkpointed function. - context_fn(Callable, optional): A callable returning a tuple of two - context managers. The function and its recomputation will be run - under the first and second context managers respectively. - This argument is only supported if ``use_reentrant=False``. - determinism_check(str, optional): A string specifying the determinism - check to perform. By default it is set to ``"default"`` which - compares the shapes, dtypes, and devices of the recomputed tensors - against those the saved tensors. To turn off this check, specify - ``"none"``. Currently these are the only two supported values. - Please open an issue if you would like to see more determinism - checks. This argument is only supported if ``use_reentrant=False``, - if ``use_reentrant=True``, the determinism check is always disabled. - debug(bool, optional): If ``True``, error messages will also include - a trace of the operators ran during the original forward computation - as well as the recomputation. This argument is only supported if - ``use_reentrant=False``. - early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops - recomputation as soon as it has computed all needed Tensors. This - argument is ignored if ``use_reentrant=True``. Can be overridden - globally using :func:`set_checkpoint_early_stop` context manager. - Default: ``True``. - - Returns: - Output of running :attr:`function` on :attr:`*args` - """ - if use_reentrant is None: - warnings.warn( - "torch.utils.checkpoint: the use_reentrant parameter should be " - "passed explicitly. Starting in PyTorch 2.9, calling checkpoint " - "without use_reentrant will raise an exception. use_reentrant=False is " - "recommended, but if you need to preserve the current default " - "behavior, you can pass use_reentrant=True. Refer to docs for more " - "details on the differences between the two variants.", - stacklevel=2, - ) - use_reentrant = True - - # Hack to mix *args with **kwargs in a python 2.7-compliant way - preserve = kwargs.pop("preserve_rng_state", True) - if kwargs and use_reentrant: - raise ValueError( - "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) - ) - - if use_reentrant: - if context_fn is not noop_context_fn or debug is not False: - raise ValueError( - "Passing `context_fn` or `debug` is only supported when " - "use_reentrant=False." - ) - return CheckpointFunction.apply(function, preserve, *args) - else: - gen = _checkpoint_without_reentrant_generator( - function, - preserve, - context_fn, - determinism_check, - debug, - early_stop, - *args, - **kwargs, - ) - # Runs pre-forward logic - next(gen) - ret = function(*args, **kwargs) - # Runs post-forward logic - try: - next(gen) - except StopIteration: - return ret - - -def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs): - r"""Checkpoint a sequential model to save memory. - - Sequential models execute a list of modules/functions in order - (sequentially). Therefore, we can divide such a model in various segments - and checkpoint each segment. All segments except the last will not store - the intermediate activations. The inputs of each checkpointed segment will - be saved for re-running the segment in the backward pass. - - .. warning:: - The ``use_reentrant`` parameter should be passed explicitly. In version - 2.9 we will raise an exception if ``use_reentrant`` is not passed. - If you are using the ``use_reentrant=True` variant, please see - :func:`~torch.utils.checkpoint.checkpoint` for - the important considerations and limitations of this variant. It is - recommended that you use ``use_reentrant=False``. - - .. warning: - Since PyTorch 1.4, it allows only one Tensor as the input and - intermediate outputs, just like :class:`torch.nn.Sequential`. - - Args: - functions: A :class:`torch.nn.Sequential` or the list of modules or - functions (comprising the model) to run sequentially. - segments: Number of chunks to create in the model - input: A Tensor that is input to :attr:`functions` - preserve_rng_state(bool, optional): Omit stashing and restoring - the RNG state during each checkpoint. - Default: ``True`` - use_reentrant(bool): - specify whether to use the activation checkpoint variant that - requires reentrant autograd. This parameter should be passed - explicitly. In version 2.5 we will raise an exception if - ``use_reentrant`` is not passed. If ``use_reentrant=False``, - ``checkpoint`` will use an implementation that does not require - reentrant autograd. This allows ``checkpoint`` to support additional - functionality, such as working as expected with - ``torch.autograd.grad`` and support for keyword arguments input into - the checkpointed function. - - Returns: - Output of running :attr:`functions` sequentially on :attr:`*inputs` - - Example: - >>> # xdoctest: +SKIP("stub") - >>> model = nn.Sequential(...) - >>> input_var = checkpoint_sequential(model, chunks, input_var) - """ - if use_reentrant is None: - warnings.warn( - "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant " - "parameter should be passed explicitly. " - "In version 2.9 we will raise an exception if use_reentrant " - "is not passed. use_reentrant=False is " - "recommended, but if you need to preserve the current default " - "behavior, you can pass use_reentrant=True. Refer to docs for more " - "details on the differences between the two variants.", - stacklevel=2, - ) - use_reentrant = True - - # Hack for keyword-only parameter in a python 2.7-compliant way - preserve = kwargs.pop("preserve_rng_state", True) - if kwargs: - raise ValueError( - "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) - ) - - def run_function(start, end, functions): - def forward(input): - for j in range(start, end + 1): - input = functions[j](input) - return input - - return forward - - if isinstance(functions, torch.nn.Sequential): - functions = list(functions.children()) - - segment_size = len(functions) // segments - # the last chunk has to be non-volatile - end = -1 - for start in range(0, segment_size * (segments - 1), segment_size): - end = start + segment_size - 1 - input = checkpoint( - run_function(start, end, functions), - input, - use_reentrant=use_reentrant, - preserve_rng_state=preserve, - ) - return run_function(end + 1, len(functions) - 1, functions)(input) - - -def _internal_assert(cond): - if not cond: - raise AssertionError( - "Something went unexpectedly wrong in activation checkpoint. " - "Please report this bug by filing an issue to PyTorch." - ) - - -# NOTE [ Nestable Checkpoint ] -# -# The semantics of nested checkpoint can be defined by two basic rules. -# Following the two rules leads to an important implication that is central -# to motivating the design. -# -# Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden -# from any outer layers of checkpoint. -# -# Rule 2. The inputs of inner checkpoints are treated as tensors saved to its -# parent checkpoint. -# -# Implication: To recompute any given saved tensor, we need to recompute all of -# the checkpoints wrapping it. -# -# Why is this implied? To unpack a saved tensor X during backward we need to -# recompute the inner-most checkpoint (#1), and in order to recompute that -# checkpoint I need to have its inputs, which are managed by that checkpoint's -# parent (#2), which thus also needs to be recomputed first. Continue this line -# of reasoning and we realize that in order to unpack X, all checkpoints that -# were active at the time X was saved need to be recomputed. (unless we have -# already done so in that backward for some other saved tensor). -# -# In practice, we use a noop autograd Function to save inputs as saved tensors. -# During unpack calling ctx.saved_tensor triggers the parent checkpoint to -# recompute. -# -# Rule 3. We should start recomputation as if there are no checkpoints currently -# active. Checkpoints encountered during recomputation are still -# respected. -# -# When we start recomputation, we push the saved variable hook meant for -# recomputation on the stack. See examples in Rule 6 for more context. -# -# * * * * -# -# Beyond the basic semantics specific to nested checkpoint, we impose several -# more constraints that may apply to checkpointing in general. -# -# Rule 4. Lifetime of recomputed tensors -# -# Recomputed tensors are considered specific to particular invocations -# of backward and are always cleared immediately as they are unpacked -# Particularly, we require this to happen even if retain_graph=True. -# -# [ Implementation details of Rule 4 ] -# -# If we were okay with recomputed tensors staying alive after backward is run -# with retain_graph=True, we would store recomputed variables as the values of a -# WeakKeyDictionary and pack strong references to the keys, so that as we -# backward, those packed keys would be cleared as long as retain_graph=False. -# Clearing the packed key clears the corresponding entry in the WKD. -# -# If we wish recomputed variables to be immediately cleared as we unpack them in -# the retain_graph=True case, we cannot rely on the packed keys to be cleared by -# backward automatically. Instead of packing the strong reference to the key -# directly, we pack a container object, which we manually clear as we unpack. -# -# An important detail is that if a second backward happens, the second -# recomputation needs to reset the container with a newly created key. -# -# Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we -# know we need. -# -# [ Implementation details of Rule 5 ] -# -# During recomputation, raise an exception if the number of recomputed tensors -# matches the number of tensors that we expected to recompute. We wrap the -# recomputation call with a try-catch to catch this specific exception. See -# Rule #6 below for some examples. -# -# Rule 6. We support doing backward inside checkpoint context -# -# [ retain_graph is True] -# -# def fn(x): -# y = x.sin() -# z = y.cos() -# gx, = torch.autograd.grad(z, x, retains_grad=True) -# return gx, z -# -# out = checkpoint(fn)(inp) -# out.backward() -# -# Because z is saved by cos while checkpoint is enabled, it would not be -# actually saved, and so the .grad() call inside must trigger a recomputation. -# -# During recomputation the "inner pack hook" has two responsibilities: -# -# 1) As usual, populating the WeakKeyDictionary storing recomputed tensors -# 2) Pack the actual tensor (detached) so that one may perform backward on the -# recomputed graph. The tensors saved to this graph will live until the end -# of recomputation, or die earlier if someone performs backward with -# retain_graph=False. -# -# More generally performing backward on the recomputed graph occurs in the -# following cases: -# - If backward is performed inside forward, -# - During the original forward IF early-stop is disabled -# - During the original backward -# - If there are multiple .grad()/.backward() calls, we would perform backward -# on the recomputed graph even if early-stop is enabled (see the example below) -# -# [ retain_graph is False ] -# -# The example below shows what happens if during recomputation we find that some -# of the tensors we are trying to recompute have already been cleared. -# -# Spoiler: we don't do anything special, we just skip over them! -# -# def fn(x): -# y = x.sin() # (1) -# z = y.cos() # (2) -# gx, = torch.autograd.grad(z, x) # (3) -# return x.cos() * gx # (4) -# -# out = checkpoint(fn)(inp) -# out.backward() # (5) -# -# 1, 2. Don't save x and y since we are inside a checkpoint. -# 3. Trigger a recompute of fn since x and y weren't saved. -# And depending on whether early stop is enabled, either stop at (2) or -# continue running the function. -# Because we are running backward with retain_graph=False, we clear x and y's -# holders. -# 4. Don't save x since we are inside a checkpoint. -# 5. Calling backward triggers another recompute of fn. During recompute, we see -# that x and y have already been cleared in the original graph as indicated -# by holder=None. We skip over them. We still save x at (4) (since its holder -# is still alive.) - -_enable_checkpoint_early_stop: Optional[bool] = None - - -@contextlib.contextmanager -def set_checkpoint_early_stop(enable: bool): - """Context manager that sets whether checkpoint should stop recomputation early. - - By default, non-reentrant checkpoint stops recomputation as soon as it - has computed all needed Tensors. This context manager can be used to disable - that feature if it is problematic for your specific application. - - This context manager only needs to be active when forward is run. It does - not need to be active during backward. - - Example:: - - >>> # xdoctest: +SKIP(failing) - >>> message = "saved tensors default hooks are disabled" - >>> with set_checkpoint_early_stop(False): - ... # Any checkpoint under this context manager will respect this - ... # context manager, even if its backward is performed outside. - ... out = checkpoint(fn, inputs) - ... - >>> out.backward() - """ - global _enable_checkpoint_early_stop - try: - prev = _enable_checkpoint_early_stop - _enable_checkpoint_early_stop = enable - yield - finally: - _enable_checkpoint_early_stop = prev - - -class _Handle: - pass - - -class _Holder: - def __init__(self): - self.handles: Dict[int, Optional[_Handle]] = {} - - -class _NoopSaveInputs(torch.autograd.Function): - @staticmethod - # pyrefly: ignore [bad-override] - def forward(*args): - return torch.empty((0,)) - - @staticmethod - def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: - # Only tensors can be saved with ctx.save_for_backward, everything else - # is captured by get_args, which is saved directly on ctx - tensor_indices, tensors = zip( - *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)] - ) - idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)} - # args but with tensors replaced with None as placeholders - args = [None if isinstance(o, torch.Tensor) else o for o in inputs] - - def get_args(saved_tensors): - # restore the placeholders with the original tensors grabbed from - # ctx.saved_tensors (which may be saved on a parent checkpoint if - # this checkpoint is nested, and that would trigger a recursive - # unpack!) - ret = [ - saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o - for i, o in enumerate(args) - ] - # grab the tail since we also saved the dummy to avoid having to explicitly - # handle the case where there are no tensor inputs - return ret[1:] - - ctx.get_args = get_args - ctx.save_for_backward(*tensors) - - @staticmethod - def backward(ctx, *grad_outputs): - raise AssertionError("Did not expect to backward on this graph") - - -class _CheckpointFrame: - def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn): - self.recompute_fn = recompute_fn - self.input_saver = None - self.weak_holders: List[ReferenceType] = [] - # We store this as a weakkeydictionary so that in the case of a partial - # backward, the entries in the dict are cleared alongside the Holder - # which will be removed when the SavedVariable is cleared. - self.recomputed: DefaultDict[ - int, weakref.WeakKeyDictionary[_Handle, torch.Tensor] - ] = defaultdict(weakref.WeakKeyDictionary) - # We need both recomp_counter and recomputed since they can diverge - # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885 - self.recomp_counter: DefaultDict[int, int] = defaultdict(int) - self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool) - - # See Rule 5 - self.early_stop = early_stop - - # Debugging - self.metadata_fn = metadata_fn - self.unpack_error_cb = unpack_error_cb - self.x_metadatas = [] - self.forward_completed = False - self.ignore_saved_mismatch = False - - def check_recomputed_tensors_match(self, gid): - if self.ignore_saved_mismatch: - # TODO: we can probably make this check stricter by checking that - # the metadata of the first tensors still match. - return - # NOTE [ Error handling for checkpoint ] - # - # At a high level, we need to check that the tensors saved - # during original forward matches tensors saved during recompute - # This means handling 3 cases: - # - # 1. During recompute, more tensors were saved. - # - # Usually this is hidden due to the StopRecomputationError - # but if early stop is not enabled, or we would have errored - # anyway because there aren't enough weak_holders. But we - # do want to have a nice error. See the _recomputation_hook - # for details. - if not len(self.weak_holders) == self.recomp_counter[gid]: - # 2. During recompute, fewer tensors were saved - # - # We know that every time we save something do original forward - # we append to weak_holder, and every time we save a tensor - # during recompute we increment recompute_counter. - raise CheckpointError( - "torch.utils.checkpoint: A different number of tensors was saved " - "during the original forward and recomputation.\n" - f"Number of tensors saved during forward: {len(self.weak_holders)}\n" - f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n" - f"{_debug_tip_msg}" - ) - - # 3. During recompute, the same tensors were saved, but they - # have different metadata - nb_meta_different = [] - for idx, weak_holder in enumerate(self.weak_holders): - holder = weak_holder() - if holder is None: - continue - # We've seen all holders since we iterate over them in order - # For every holder that is still alive now, it must've been - # alive when we saw it during recompute, therefore, the - # gid must be set. - _internal_assert(gid in holder.handles) - # We know this is the first unpack, so it couldn't have been set - # to None yet. - _internal_assert(holder.handles[gid] is not None) - # We always set these together in the recomputation hook - _internal_assert(holder.handles[gid] in self.recomputed[gid]) - # see pack hook, x_metadata is 1:1 with weak_holders. - x_meta = self.x_metadatas[idx] - recomputed_x = self.recomputed[gid][holder.handles[gid]] - if x_meta != self.metadata_fn(recomputed_x): - nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x))) - - if len(nb_meta_different) > 0: - mismatched_tensors = "" - for idx, x_meta, recomputed_meta in nb_meta_different: - mismatched_tensors += ( - f"tensor at position {idx}:\n" - f"saved metadata: {x_meta}\n" - f"recomputed metadata: {recomputed_meta}\n" - ) - raise CheckpointError( - "torch.utils.checkpoint: Recomputed values for the following tensors " - "have different metadata than during the forward pass.\n" - f"{mismatched_tensors}.\n" - f"{_debug_tip_msg}" - ) - - -_debug_tip_msg = """ -Tip: To see a more detailed error message, either pass `debug=True` to -`torch.utils.checkpoint.checkpoint(...)` or wrap the code block -with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to -enable checkpoint‑debug mode globally. -""" - - -_checkpoint_error_template = """ \ -An error happened while unpacking tensors; dumping logs of latest computation -because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`. -Scroll all the way down for guidance on how to navigate these logs. - -+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ -| 1. Stack traces of the operators that ran in the original forward | -+------------------------------------------------------------------------------+ - -{forward_traces} -+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ -| 2. Stack traces of the operators that ran during recomputation | -+------------------------------------------------------------------------------+ - -{recompute_traces} -+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+ -| 3. Log of operators in the original forward and recomputation | -+------------------------------------------------------------------------------+ -(Scroll up to correlate stack traces with each operation listed below. This - helps identify their source in the code.) - -IMPORTANT: Differences in "detach" calls between the original forward and the - recomputation are expected. They are introduced by the checkpointing - mechanism and can be ignored. - -Operations executed during the original forward: - -{forward_ops} - -Operations executed during recomputation: - -{recompute_ops} - -+------------------------------------------------------------------------------+ - ERROR: Detected non-determinism while running activation checkpointing - - You are seeing this error because you passed `debug=True` to checkpoint and - tensors to be saved during the original forward and differ between those saved - during recomputation. This can happen if different operators were ran in the - original forward and in the recomputation. - - To identify where the mismatch may be coming from, you can do the following: - - 1) Compare the operators ran during original forward and recomputation to - see where they differ. These operators are printed above in the order they - were executed. - - 2) Review the stack trace for each operator to locate its invocation source. - Each operator's stack trace is printed in their execution order. - - Note that the logs can be quite long. Here's how they are structured: - (Tip: you can Ctrl-f for these headers) - - 1. Stack traces of the operators that ran in the original forward - 2. Stack traces of the operators that ran during recomputation - 3. Log of operators in the original forward and recomputation - 4. Error message <--- You are here --------------------------------------------------------------------------------- -""" - - -class CheckpointError(RuntimeError): - pass - - -def _get_debug_context_and_cb() -> ( - Tuple[Callable[[], Any], Callable[[CheckpointError], None]] -): - # This function returns the context_fn and error_cb to be used by the - # checkpointing mechanism. error_cb is invoked when an error is detected - # during unpack. - - # record_context_cpp is not support on non-linux non-x86_64 platforms - cpp_tb = platform.machine() == "x86_64" and platform.system() == "Linux" - - class CaptureLogs: - def __init__(self): - self.logs = None - self.tbs = None - - def get_context_manager(self): - @contextlib.contextmanager - def logging_mode(): - with LoggingTensorMode(), capture_logs( - True, python_tb=True, script_tb=True, cpp_tb=cpp_tb - ) as logs_and_tb: - # pyrefly: ignore [bad-assignment] - self.logs, self.tbs = logs_and_tb - yield logs_and_tb - - return logging_mode() - - capture_logs_fwd = CaptureLogs() - capture_logs_recompute = CaptureLogs() - - def unpack_error_cb(e: CheckpointError): - def get_str_tb(label, capture_logs): - out = "" - total_len = len(capture_logs.logs) - for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)): - out += f"{log} ({i + 1} of {total_len} in {label})\n\n" - found_torch_dispatch = False - for line in tb: - # Start printing stack trace only after __torch_dispatch__ is found - is_torch_dispatch = line["name"] == "__torch_dispatch__" - if not found_torch_dispatch and not is_torch_dispatch: - continue - elif is_torch_dispatch: - found_torch_dispatch = True - continue - out += f"{line['filename']}:{line['line']}:{line['name']}\n" - out += "\n\n" - return out - - if capture_logs_fwd.logs is None: - raise AssertionError("capture_logs_fwd.logs is None") - if capture_logs_recompute.logs is None: - raise AssertionError("capture_logs_recompute.logs is None") - raise CheckpointError( - _checkpoint_error_template.format( - forward_traces=get_str_tb("original", capture_logs_fwd), - recompute_traces=get_str_tb("recompute", capture_logs_recompute), - forward_ops="\n".join(capture_logs_fwd.logs), - recompute_ops="\n".join(capture_logs_recompute.logs), - ) - ) from e - - def context_fn(): - return ( - capture_logs_fwd.get_context_manager(), - capture_logs_recompute.get_context_manager(), - ) - - return context_fn, unpack_error_cb - - -def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]: - # These properties are fast to check, easy to understand - return {"shape": x.shape, "dtype": x.dtype, "device": x.device} - - -_allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = { - _DEFAULT_DETERMINISM_MODE: _default_meta_extractor, - "none": lambda _: None, -} - -# See Rule 5 - - -class _StopRecomputationError(Exception): - pass - - -class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks): - def __init__(self, target_frame_ref: ReferenceType, gid: int): - def pack_hook(x): - x = x.detach() if x.requires_grad else x - target_frame = target_frame_ref() - if target_frame is None: - raise AssertionError("Internal error: target_frame reference is None") - recomp_idx = target_frame.recomp_counter[gid] - target_frame.recomp_counter[gid] += 1 - - if recomp_idx >= len(target_frame.weak_holders): - if target_frame.early_stop: - raise AssertionError( - "Unexpected state: target_frame.early_stop is set" - ) - if not target_frame.forward_completed: - # We run into this case when early stop is not enabled and do - # grad within checkpoint. - # We need to set this flag, so we don't error out later when - # we check if the number of tensors saved during forward and - # recomputation match. - target_frame.ignore_saved_mismatch = True - return x - raise CheckpointError( - "torch.utils.checkpoint: trying to save more tensors during " - "recomputation than during the original forward pass.\n" - f"{_debug_tip_msg}" - ) - - holder = target_frame.weak_holders[recomp_idx]() - - # This holder may have been cleared because someone may have called - # backward within forward. If so, we don't need to save. - if holder is not None: - _internal_assert(holder.handles.get(gid, None) is None) - holder.handles[gid] = _Handle() - target_frame.recomputed[gid][holder.handles[gid]] = x - - if target_frame.early_stop and target_frame.recomp_counter[gid] == len( - target_frame.weak_holders - ): - raise _StopRecomputationError - # See Rule 6: [ retain_graph is True ] above - return x - - def unpack_hook(x): - # See Rule 6: [ retain_graph is True ] above for an example of when - # the graph created during recomputation could be backwarded. - return x - - super().__init__(pack_hook, unpack_hook) - - -# torch._disable_dynamo creates a reference cycle with decorated function -# This function is used to ensure that the decorated function does not have -# a closure, so that other objects aren't also kept alive. -# https://github.com/pytorch/pytorch/issues/154642 -# Note: does not work when fn is compiled -@torch._disable_dynamo -def _run_fn_with_dynamo_disabled(fn, *args, **kwargs): - return fn(*args, **kwargs) - - -class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks): - def __init__(self, frame): - def pack_hook(x): - # See Rule 4 above - holder = _Holder() - frame.weak_holders.append(weakref.ref(holder)) - # Save metadata to detect non-determinism - if frame.metadata_fn is not None: - with torch.no_grad(): - frame.x_metadatas.append(frame.metadata_fn(x)) - return holder - - def unpack_hook(holder): - gid = torch._C._current_graph_task_id() - if gid == -1: - # generate a temporary id if we trigger unpack outside of a backward call - gid = int(uuid.uuid4()) - - if not frame.is_recomputed[gid]: - ctx = frame.input_saver.grad_fn - args = ctx.get_args(ctx.saved_tensors) - - try: - with _recomputation_hook( - weakref.ref(frame), gid - ), torch.autograd.enable_grad(): - # See Note: [compiled autograd and checkpoint unpack hook] - _run_fn_with_dynamo_disabled(frame.recompute_fn, *args) - except _StopRecomputationError: - pass - frame.is_recomputed[gid] = True - frame.check_recomputed_tensors_match(gid) - - _internal_assert(gid in holder.handles) - - if holder.handles[gid] is None: - raise CheckpointError( - "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already " - "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do " - "so only once. Otherwise please open an issue with details on your use case." - ) - _internal_assert(holder.handles[gid] in frame.recomputed[gid]) - ret = frame.recomputed[gid][holder.handles[gid]] - holder.handles[gid] = None - return ret - - if frame.unpack_error_cb is not None: - - def unpack_hook_with_error_cb(holder): - try: - return unpack_hook(holder) - except CheckpointError as e: - frame.unpack_error_cb(e) - - super().__init__(pack_hook, unpack_hook_with_error_cb) - else: - super().__init__(pack_hook, unpack_hook) - - -def _is_compiling(func, args, kwargs): - # Check if we are under AOTAutograd tracing - # Checking that a functional mode is active should always do what we want - return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) is not None - - -class _VersionWrapper: - # Check that cached tensors are not mutated. - def __init__(self, val): - self.val: Union[torch.Tensor, Any] = val - self.version: Optional[int] = ( - val._version if isinstance(val, torch.Tensor) else None - ) - - def get_val(self, allow_cache_entry_mutation): - if self.version is not None and not allow_cache_entry_mutation: - if self.val._version != self.version: - # Can we give user a stack trace of where the mutation happened? - raise RuntimeError( - "Tensor cached during selective activation checkpoint has been mutated" - ) - return self.val - - -def _maybe_detach(x, any_ret_has_alias_info): - # We detach for two separate reasons: - # - For view ops, we need to ensure that when the tensor is returned from - # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr - # - Avoid reference cycles - # For case 1, it is not enough to check whether x has differentiable dtype - # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g. - # when the tensor is a view. - if isinstance(x, torch.Tensor) and ( - x.is_floating_point() or x.is_complex() or any_ret_has_alias_info - ): - with torch._C._SetExcludeDispatchKeyGuard( - torch._C.DispatchKey.ADInplaceOrView, False - ): - # Ensure that view performed beneath autograd properly propagates - # version counter. TODO: Use reentrant_dispatch instead of - # manually manipulating dispatch keys. Using reentrant_dispatch - # would respect inference_mode, though that is not relevant for - # this case. - x = x.detach() - return x - - -class SelectiveCheckpointContext: - """ - Context passed to policy function during selective checkpointing. - - This class is used to pass relevant metadata to the policy function during - selective checkpointing. The metadata includes whether the current invocation - of the policy function is during recomputation or not. - - Example: - >>> # xdoctest: +SKIP(stub) - >>> - >>> def policy_fn(ctx, op, *args, **kwargs): - >>> print(ctx.is_recompute) - >>> - >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) - >>> - >>> out = torch.utils.checkpoint.checkpoint( - >>> fn, x, y, - >>> use_reentrant=False, - >>> context_fn=context_fn, - >>> ) - """ - - def __init__(self, *, is_recompute): - self.is_recompute = is_recompute - - -class CheckpointPolicy(enum.Enum): - """ - Enum for specifying the policy for checkpointing during backpropagation. - - The following policies are supported: - - - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward - pass and will not be recomputed during the backward pass - - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the - forward pass and will be recomputed during the backward pass - - Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden - by other subsystems like `torch.compile`. - - .. note:: - A policy function that always returns ``PREFER_RECOMPUTE`` is - equivalent to vanilla checkpointing. - - A policy function that returns ``PREFER_SAVE`` every op is - NOT equivalent to not using checkpointing. Using such a policy would - save additional tensors not limited to ones that are actually needed for - gradient computation. - """ - - MUST_SAVE = 0 - PREFER_SAVE = 1 - MUST_RECOMPUTE = 2 - PREFER_RECOMPUTE = 3 - - -def _policy_from_bool(b): - # For backward compatibility - return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE - - -SAC_IGNORED_OPS = { - # AC inserts different number of detach during forward and recompute. - torch.ops.aten.detach.default, - # AC's determinism check invokes additional metadata ops during forward. - # With subclasses involved, these metadata ops become dispatchable, this - # can result in incorrectness if these ops are selected cached. - torch.ops.prim.device.default, -} | set( - torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns -) # type: ignore[has-type] - - -class _CachingTorchDispatchMode(TorchDispatchMode): - # Used together with _CachedTorchDispatchMode to implement SAC. - def __init__(self, policy_fn, storage): - self.policy_fn = policy_fn - self.storage = storage - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if func in SAC_IGNORED_OPS: - return func(*args, **kwargs) - - kwargs = {} if kwargs is None else kwargs - policy = self.policy_fn( - SelectiveCheckpointContext(is_recompute=False), func, *args, **kwargs - ) - if isinstance(policy, bool): - policy = _policy_from_bool(policy) - - is_compiling = _is_compiling(func, args, kwargs) - - if is_compiling: - # Overwrite each node's "recompute" tag to add in the user annotation. - fx_traceback.current_meta["recompute"] = policy - - out = func(*args, **kwargs) - - # HOPs don't support func._schema - # HOPs don't alias -> this is always true today and will be always true for a long time - # TODO HOPs don't mutate -> this is always true today but will not be true forever - if isinstance(func, torch._ops.HigherOrderOperator): - any_ret_has_alias_info = False - else: - any_ret_has_alias_info = any( - ret.alias_info is not None for ret in func._schema.returns - ) - - if ( - policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) - or is_compiling - ): - self.storage[func].append( - tree_map( - lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), - out, - ) - ) - return out - - -class _CachedTorchDispatchMode(TorchDispatchMode): - # Used together with _CachedTorchDispatchMode to implement SAC. - def __init__(self, policy_fn, storage, allow_cache_entry_mutation): - self.policy_fn = policy_fn - self.storage = storage - self.allow_cache_entry_mutation = allow_cache_entry_mutation - - def __torch_dispatch__(self, func, types, args=(), kwargs=None): - if func in SAC_IGNORED_OPS: - return func(*args, **kwargs) - - kwargs = {} if kwargs is None else kwargs - policy = self.policy_fn( - SelectiveCheckpointContext(is_recompute=True), func, *args, **kwargs - ) - if isinstance(policy, bool): - policy = _policy_from_bool(policy) - - is_compiling = _is_compiling(func, args, kwargs) - - if ( - policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) - or is_compiling - ): - storage = self.storage.get(func) - if storage is None: - raise RuntimeError( - f"{func} encountered during backward, but not found in storage" - ) - if len(storage) == 0: - raise RuntimeError( - "Trying to backward an extra time. You are only allowed to backward once " - "on any region computed under selective activation checkpoint." - ) - out = tree_map( - lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0) - ) - else: - out = func(*args, **kwargs) - return out - - -def create_selective_checkpoint_contexts( - policy_fn_or_list, allow_cache_entry_mutation=False -): - """ - Helper to avoid recomputing certain ops during activation checkpointing. - - Use this with `torch.utils.checkpoint.checkpoint` to control which - operations are recomputed during the backward pass. - - Args: - policy_fn_or_list (Callable or List): - - If a policy function is provided, it should accept a - :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and - kwargs to the op, and return a :class:`CheckpointPolicy` enum value - indicating whether the execution of the op should be recomputed or not. - - If a list of operations is provided, it is equivalent to a policy - returning `CheckpointPolicy.MUST_SAVE` for the specified - operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other - operations. - allow_cache_entry_mutation (bool, optional): By default, an error is - raised if any tensors cached by selective activation checkpoint are - mutated in order to ensure correctness. If set to `True`, this check - is disabled. - Returns: - A tuple of two context managers. - - Example: - >>> # xdoctest: +REQUIRES(LINUX) - >>> import functools - >>> - >>> x = torch.rand(10, 10, requires_grad=True) - >>> y = torch.rand(10, 10, requires_grad=True) - >>> - >>> ops_to_save = [ - >>> torch.ops.aten.mm.default, - >>> ] - >>> - >>> def policy_fn(ctx, op, *args, **kwargs): - >>> if op in ops_to_save: - >>> return CheckpointPolicy.MUST_SAVE - >>> else: - >>> return CheckpointPolicy.PREFER_RECOMPUTE - >>> - >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn) - >>> - >>> # or equivalently - >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save) - >>> - >>> def fn(x, y): - >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y - >>> - >>> out = torch.utils.checkpoint.checkpoint( - >>> fn, x, y, - >>> use_reentrant=False, - >>> context_fn=context_fn, - >>> ) - """ - # NB: If grad_mode is disabled, checkpoint would not run forward under - # context_fn anyway, so proceed as usual. - if isinstance(policy_fn_or_list, list): - for op in policy_fn_or_list: - if not isinstance( - op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) - ): - _extra_msg = ( - ( - "Please update the OpOverloadPacket to a specific OpOverload." - "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`." - ) - if isinstance(op, torch._ops.OpOverloadPacket) - else "" - ) - raise ValueError( - f"Expected op in `op_list` to be an OpOverload but got: {op} " - f"of type {type(op)}. {_extra_msg}" - ) - - def policy_fn(ctx, op, *args, **kwargs): - if op in policy_fn_or_list: - return CheckpointPolicy.MUST_SAVE - else: - return CheckpointPolicy.PREFER_RECOMPUTE - - elif callable(policy_fn_or_list): - policy_fn = policy_fn_or_list - else: - raise TypeError("policy_fn_or_list must be either a function or a list of ops.") - - storage: Dict[Any, List[Any]] = defaultdict(list) - return ( - _CachingTorchDispatchMode(policy_fn, storage), - _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation), - ) - - -# NB: this helper wraps fn before calling checkpoint_impl. kwargs and -# saving/restoring of global state is handled here. - - -def _checkpoint_without_reentrant_generator( - fn, - preserve_rng_state=True, - context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn, - determinism_check: str = _DEFAULT_DETERMINISM_MODE, - debug: bool = False, - early_stop: bool = True, - *args, - **kwargs, -): - """Checkpointing without reentrant autograd. - - Args: - fn: describes what to run in the forward pass of the model or - part of the model. It should also know how to handle the inputs - passed as the tuple. For example, in LSTM, if user passes - ``(activation, hidden)``, :attr:`function` should correctly use the - first input as ``activation`` and the second input as ``hidden`` - preserve_rng_state(bool, optional): Omit stashing and restoring - the RNG state during each checkpoint. - Default: ``True`` - context_fn(Callable, optional): A callable returning a tuple of two - context managers. The function and its recomputation will be run - under the first and second context managers respectively. - determinism_check(str, optional): A string specifying the determinism - check to perform. By default it is set to ``"default"`` which - compares the shapes, dtypes, and devices of the recomputed tensors - against those the saved tensors. To turn off this check, specify - ``"none"``. Currently these are the only two supported values. - Please open an issue if you would like to see more determinism - checks. - debug(bool, optional): If ``True``, error messages will also include - a trace of the operators ran during the original forward computation - as well as the recomputation. - early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops - recomputation as soon as it has computed all needed Tensors. Can be - overridden globally using :func:`set_checkpoint_early_stop` context - manager. Default: ``True``. - *args: Arguments to pass in to the given ``function``. - **kwargs: Keyword arguments to pass into the given ``function``. - """ - unpack_error_cb = None - - if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug: - if context_fn != noop_context_fn: - raise ValueError("debug=True is incompatible with non-default context_fn") - context_fn, unpack_error_cb = _get_debug_context_and_cb() - - if determinism_check in _allowed_determinism_checks_to_fns: - metadata_fn = _allowed_determinism_checks_to_fns[determinism_check] - else: - raise ValueError( - f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, " - f"but got {determinism_check}" - ) - - device_type = _infer_device_type(*args) - device_module = _get_device_module(device_type) - forward_context, recompute_context = context_fn() - if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn: - if not isinstance(forward_context, TorchDispatchMode) or not isinstance( - recompute_context, TorchDispatchMode - ): - raise AssertionError( - "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " - "must generate a tuple of two `TorchDispatchMode`s." - ) - # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu. - device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs( - device_type=device_type - ) - - if preserve_rng_state: - fwd_cpu_state = torch.get_rng_state() - # Don't eagerly initialize the cuda context by accident. - # (If the user intends that the context is initialized later, within their - # run_function, we SHOULD actually stash the cuda state here. Unfortunately, - # we have no way to anticipate this will happen before we run the function. - # If they do so, we raise an error.) - had_device_in_fwd = False - if getattr(device_module, "_initialized", False): - had_device_in_fwd = True - fwd_devices, fwd_device_states = get_device_states(*args) - - def recompute_fn(*inputs): - kwargs, *args = inputs - # This will be called later during recomputation. This wrapping enables - # the necessary global state to be captured. - rng_devices = [] - if preserve_rng_state and had_device_in_fwd: - rng_devices = fwd_devices - with torch.random.fork_rng( - devices=rng_devices, enabled=preserve_rng_state, device_type=device_type - ): - if preserve_rng_state: - torch.set_rng_state(fwd_cpu_state) - if had_device_in_fwd: - set_device_states( - fwd_devices, fwd_device_states, device_type=device_type - ) - - device_autocast_ctx = ( - torch.amp.autocast(device_type=device_type, **device_autocast_kwargs) - if torch.amp.is_autocast_available(device_type) - else contextlib.nullcontext() - ) - # type: ignore[attr-defined] - with device_autocast_ctx, torch.amp.autocast( - "cpu", **cpu_autocast_kwargs - ), recompute_context: - fn(*args, **kwargs) - - new_frame = _CheckpointFrame( - recompute_fn, - _enable_checkpoint_early_stop - if _enable_checkpoint_early_stop is not None - else early_stop, - unpack_error_cb, - metadata_fn, - ) - dummy = torch.empty((0,), requires_grad=True) - new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args) - - # When ambient grad_mode is False - if new_frame.input_saver.grad_fn is None: - yield - return - - with _checkpoint_hook(new_frame), forward_context: - yield - new_frame.forward_completed = True - - if ( - getattr(device_module, "_initialized", False) - and preserve_rng_state - and not had_device_in_fwd - ): # type: ignore[possibly-undefined] - # Device was not initialized before running the forward, so we didn't - # stash the device state. - raise RuntimeError( - "PyTorch's device state was initialized in the forward pass " - "of a Checkpoint, which is not allowed. Please open an issue " - "if you need this feature." - ) - - return - - -# Note: [compiled autograd and checkpoint unpack hook] -# When tracing via compiled autograd, this hook will be visible to the -# compiler if the forward of this checkpointed region ran in eager. -# If the forward had ran under compile, it would have been wrapped in a -# higher order op. See Note: [torch.compile and checkpoint]. -# -# Since we run the recomputation hook under a enable_grad context, -# AOTDispatch will trace a joint graph for this hook, and may -# save different activations than in eager. This conflicts with the -# strict activation count checks in `frame.check_recomputed_tensors_match`. -# So, we disable this hook to force it to recompute eager checkpointed regions -# in eager. This could be removed if we can disable the partitioner for this -# graph segment. diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 16fa5cc212..b2eb0cbfcc 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -96,7 +96,6 @@ "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_5_xpu", "test_scaled_dot_product_attention_3D_input_dim_2D_attn_mask_dropout_p_0_2_xpu", ), - "test_utils.py": None, "test_schema_check.py": None, "test_complex_xpu.py": None, "test_modules_xpu.py": ( diff --git a/test/xpu/test_utils.py b/test/xpu/test_utils.py deleted file mode 100644 index 420504954e..0000000000 --- a/test/xpu/test_utils.py +++ /dev/null @@ -1,1025 +0,0 @@ -# mypy: allow-untyped-defs -# Owner(s): ["module: unknown"] - -import os -import random -import shutil -import subprocess -import sys -import tempfile -import textwrap -import traceback -import unittest -import warnings -from typing import Any - -import torch -import torch.cuda -import torch.nn as nn -import torch.utils.cpp_extension -import torch.utils.data -from torch._utils import try_import -from torch._utils_internal import deprecated -from torch.testing._internal.common_cuda import TEST_MULTIGPU -from torch.testing._internal.common_device_type import ( - instantiate_device_type_tests, - onlyCPU, - ops, -) -from torch.testing._internal.common_methods_invocations import op_db -from torch.testing._internal.common_utils import ( # type: ignore[attr-defined] - IS_FBCODE, - IS_SANDCASTLE, - IS_WINDOWS, - load_tests, - run_tests, - TestCase, -) -from torch.testing._internal.inductor_utils import get_gpu_type -from torch.utils._device import set_device -from torch.utils._pytree import tree_all_only, tree_any -from torch.utils._traceback import ( - CapturedTraceback, - format_traceback_short, - report_compile_source_on_error, -) -from torch.utils.collect_env import get_pretty_env_info -from torch.utils.data import DataLoader - -from .checkpoint import ( - _infer_device_type, - checkpoint, - checkpoint_sequential, - get_device_states, -) - -# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for -# sharding on sandcastle. This line silences flake warnings -load_tests = load_tests # noqa: PLW0127 - -HAS_CUDA = torch.cuda.is_available() -HAS_XPU = torch.xpu.is_available() -HAS_GPU = HAS_CUDA or HAS_XPU -device_type = get_gpu_type() - - -# mypy: disable-error-code="name-defined" - - -class RandomDatasetMock(torch.utils.data.Dataset): - def __getitem__(self, index): - return torch.tensor([torch.rand(1).item(), random.uniform(0, 1)]) - - def __len__(self): - return 1000 - - -class TestCheckpoint(TestCase): - # This runs checkpoint_sequential on each of the nets in - # module_lists_to_compare, and compares them against the uncheckpointed model. - # To compare, it checks outputs as well as input gradients and parameter gradients - def _check_checkpoint_sequential( - self, - model, - module_lists_to_compare, - num_chunks, - input, - use_reentrant, - ): - # not checkpointed - out = model(input) - out_not_checkpointed = out.detach().clone() - model.zero_grad() - out.sum().backward() - grad_not_checkpointed = { - name: param.grad.detach().clone() - for name, param in model.named_parameters() - } - input_grad_not_checkpointed = input.grad.detach().clone() - for model_to_compare in module_lists_to_compare: - # checkpointed model by passing list of modules - detached = input.detach() - detached.requires_grad = True - - # pass list of modules to checkpoint - out = checkpoint_sequential( - model_to_compare, num_chunks, detached, use_reentrant=use_reentrant - ) - out_checkpointed = out.detach().clone() - model.zero_grad() - out.sum().backward() - grad_checkpointed = { - name: param.grad.detach().clone() - for name, param in model.named_parameters() - } - input_grad_checkpointed = detached.grad.detach().clone() - # compare outputs as well as the gradients of input and parameters - self.assertEqual(out_checkpointed, out_not_checkpointed) - self.assertEqual(input_grad_not_checkpointed, input_grad_checkpointed) - for name in grad_checkpointed: - self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name]) - - # Test whether checkpoint is being triggered or not. For this, we check - # the number of times forward pass happens - def test_checkpoint_trigger(self): - class Net(nn.Module): - def __init__(self) -> None: - super().__init__() - self.counter = 0 - - def forward(self, input_var): - self.counter += 1 - # For reentrant, need to have autograd actually - # pack a tensor to trigger recomp - ret = input_var * torch.tensor(2.0) - return ret - - # checkpointed - for use_reentrant in [True, False]: - with self.subTest(use_reentrant=use_reentrant): - modules = [Net() for _ in range(10)] - for m in modules: - self.assertEqual(m.counter, 0) - input_var = torch.randn(3, 4, requires_grad=True) - out = checkpoint_sequential( - modules, 2, input_var, use_reentrant=use_reentrant - ) - for m in modules: - self.assertEqual(m.counter, 1) - out.sum().backward() - for m in modules[: (len(modules) // 2)]: - self.assertEqual(m.counter, 2) - for m in modules[(len(modules) // 2) :]: - self.assertEqual(m.counter, 1) - - def test_checkpoint_valid(self): - model = nn.Sequential( - nn.Linear(100, 50), - nn.ReLU(), - nn.Linear(50, 20), - nn.ReLU(), - nn.Linear(20, 5), - nn.ReLU(), - ) - - input_var = torch.randn(1, 100, requires_grad=True) - - # checkpointed - chunks = 2 - modules = list(model.children()) - out = checkpoint_sequential(modules, chunks, input_var, use_reentrant=True) - with self.assertRaisesRegex( - RuntimeError, "torch.utils.checkpoint is incompatible" - ): - torch.autograd.grad( - outputs=[out], - grad_outputs=[torch.ones(1, 5)], - inputs=[input_var], - create_graph=True, - ) - # works with use_reentrant=False, and grads are the same - out = model(input_var) - grads_no_checkpoint = torch.autograd.grad( - outputs=[out], - grad_outputs=[torch.ones(1, 5)], - inputs=[input_var], - create_graph=True, - ) - out_checkpoint = checkpoint_sequential( - modules, chunks, input_var, use_reentrant=False - ) - # check outputs are the same - self.assertEqual(out_checkpoint, out) - grads_checkpoint = torch.autograd.grad( - outputs=[out_checkpoint], - grad_outputs=[torch.ones(1, 5)], - inputs=[input_var], - create_graph=True, - ) - self.assertEqual(grads_no_checkpoint, grads_checkpoint) - - def test_checkpoint(self): - for use_reentrant in [True, False]: - with self.subTest(use_reentrant=use_reentrant): - model = nn.Sequential( - nn.Linear(100, 50), - nn.ReLU(), - nn.Linear(50, 20), - nn.ReLU(), - nn.Linear(20, 5), - nn.ReLU(), - ) - - # Compare uncheckpointed model with its checkpointed counterparts - # In addition to running checkpoint_sequential on the nn.Sequential - # instance, we also run the function on the list of functions within - # the module. - self._check_checkpoint_sequential( - model, - [list(model.children()), model], - 2, - torch.randn(1, 100, requires_grad=True), - use_reentrant=use_reentrant, - ) - - def test_checkpoint_module_list(self): - class ModuleListNet(nn.Module): - def __init__(self) -> None: - super().__init__() - module_list = [ - nn.Linear(100, 50), - nn.ReLU(), - nn.Linear(50, 20), - nn.ReLU(), - nn.Linear(20, 5), - nn.ReLU(), - ] - self.module_list = nn.ModuleList(module_list) - - def forward(self, input): - for layer in self.module_list: - input = layer(input) - return input - - for use_reentrant in [True, False]: - with self.subTest(use_reentrant=use_reentrant): - model = ModuleListNet() - - # Compare uncheckpointed model with its checkpointed counterparts. - self._check_checkpoint_sequential( - model, - [list(model.module_list.children()), model.module_list], - 2, - torch.randn(1, 100, requires_grad=True), - use_reentrant=use_reentrant, - ) - - def test_checkpoint_sequential_deprecated_multiple_args(self): - class Two(nn.Module): - def forward(self, a, b): - return a, b - - model = nn.Sequential(Two()) - a = torch.randn(1, 100, requires_grad=True) - b = torch.randn(1, 100, requires_grad=True) - - for use_reentrant in [True, False]: - with self.subTest(use_reentrant=use_reentrant): - with self.assertRaises(TypeError): - checkpoint_sequential(model, 1, a, b) # type: ignore[call-arg] - - def test_checkpoint_sequential_deprecated_no_args(self): - class Noop(nn.Module): - def forward(self): - pass - - model = nn.Sequential(Noop()) - for use_reentrant in [True, False]: - with self.subTest(use_reentrant=use_reentrant): - with self.assertRaises(TypeError): - checkpoint_sequential(model, 1) # type: ignore[call-arg] - - def test_checkpoint_rng_cpu(self): - for _ in range(5): - inp = torch.randn(20000, device="cpu").requires_grad_() - phase1 = torch.nn.Dropout() - phase2 = torch.nn.Dropout() - - def run_fn(input): - return phase2(input) - - state = torch.get_rng_state() - - out = phase1(inp) - out = checkpoint(run_fn, out, use_reentrant=True) - out.sum().backward() - grad_with_checkpointing = inp.grad - - torch.set_rng_state(state) - - inp.grad = None - - out = phase1(inp) - out = run_fn(out) - out.sum().backward() - grad_no_checkpointing = inp.grad - - self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) - - @unittest.skipIf(not HAS_GPU, "No GPU") - def test_checkpoint_rng_gpu(self): - for _ in range(5): - inp = torch.randn(20000, device=device_type).requires_grad_() - phase1 = torch.nn.Dropout() - phase2 = torch.nn.Dropout() - - def run_fn(input): - return phase2(input) - - state = torch.get_device_module(device_type).get_rng_state() - - out = phase1(inp) - out = checkpoint(run_fn, out, use_reentrant=True) - out.sum().backward() - grad_with_checkpointing = inp.grad - - torch.get_device_module(device_type).set_rng_state(state) - - inp.grad = None - - out = phase1(inp) - out = run_fn(out) - out.sum().backward() - grad_no_checkpointing = inp.grad - - self.assertEqual(grad_with_checkpointing, grad_no_checkpointing) - - @unittest.skipIf(not HAS_GPU, "No GPU") - def test_checkpoint_not_preserve_rng_state_and_without_reentrant(self): - inp = torch.randn(2, device=device_type).requires_grad_() - layer = torch.nn.Dropout() - - def run_fn(input): - return layer(input) - - out = checkpoint(run_fn, inp, use_reentrant=False, preserve_rng_state=False) - out.sum().backward() - # This should run without error - - def test_checkpoint_non_tensor(self): - def run_fn(tensor1, tensor2): - if tensor2 is None: - return tensor1 - return tensor1 + tensor2 - - input_var = torch.randn(1, 100, requires_grad=True) - out = checkpoint(run_fn, input_var, None, use_reentrant=True) - out.sum().backward() - - def test_checkpoint_non_tensor_inputs_outputs(self): - def foo(t1, t2, scale, t3): - t4 = t1 + t2 * t3 - t5 = t1 * t2 + t3 - t4 *= scale - t5 *= scale - return scale, t4, None, True, t5, "bar", t1 - - t1 = torch.rand(10, requires_grad=True) - t2 = torch.rand(10, requires_grad=True) - t3 = torch.rand(10) - scale = random.randint(0, 10) - res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) - self.assertEqual(scale, res[0]) - self.assertEqual((t1 + t2 * t3) * scale, res[1]) - self.assertEqual(None, res[2]) - self.assertEqual(True, res[3]) - self.assertEqual((t1 * t2 + t3) * scale, res[4]) - self.assertEqual("bar", res[5]) - self.assertEqual(t1, res[6]) - - # Validate running backward. - res[1].sum().backward(retain_graph=True) - res[4].sum().backward(retain_graph=True) - res[6].sum().backward() - with self.assertRaisesRegex( - RuntimeError, "Trying to backward through the graph a second time" - ): - res[6].sum().backward() - t1_grad = t1.grad - t2_grad = t2.grad - - # Reset grads, run without checkpoint and validate we receive same grads. - t1.grad = None - t2.grad = None - res = foo(t1, t2, scale, t3) - torch.autograd.backward([res[1].sum(), res[4].sum(), res[6].sum()]) - self.assertEqual(t1.grad, t1_grad) - self.assertEqual(t2.grad, t2_grad) - - def test_checkpoint_no_tensors(self): - def foo(t1, t2, scale, t3): - t4 = t1 + t2 * t3 - t5 = t1 * t2 + t3 - t4 *= scale - t5 *= scale - return scale, t4, None, True, t5, "bar", t1 - - t1 = random.random() - t2 = random.random() - t3 = random.random() - scale = random.randint(0, 10) - res = checkpoint(foo, t1, t2, scale, t3, use_reentrant=True) - self.assertEqual(scale, res[0]) - self.assertEqual((t1 + t2 * t3) * scale, res[1]) - self.assertEqual(None, res[2]) - self.assertEqual(True, res[3]) - self.assertEqual((t1 * t2 + t3) * scale, res[4]) - self.assertEqual("bar", res[5]) - self.assertEqual(t1, res[6]) - - def test_checkpoint_partial_grad(self): - def run_fn(tensor1, tensor2): - # tensor 2 is used for other application logic - return tensor1, tensor2 - - input_var = torch.randn(1, 4, requires_grad=True) - input_var2 = torch.randn(1, 4, requires_grad=False) - out = checkpoint(run_fn, input_var, input_var2, use_reentrant=True) - out[0].sum().backward() - - def run_fn2(tensor1, tensor2): - return tensor1 - - input_var = torch.randn(1, 4, requires_grad=False) - input_var2 = torch.randn(1, 4, requires_grad=True) - with self.assertRaisesRegex( - RuntimeError, - r"none of output has requires_grad=True, this checkpoint\(\) is not necessary", - ): - out = checkpoint(run_fn2, input_var, input_var2, use_reentrant=True) - out.sum().backward() - - @unittest.skipIf(not HAS_GPU, "Test requires GPU") - def test_checkpointing_without_reentrant_early_free(self): - # I don't know how to check if the temporary saved variable buffer - # get de-allocated directly. So using GPU memory usage as a proxy - - def _do_test(fn, should_free): - stats: list[int] = [] - - def track(x, idx): - # Track that at each step of the backward, some Tensor were - # de-allocated (which correspond to the checkpoint storage being - # emptied at each step) - def hook(_unused): - self.assertEqual(len(stats), idx) - torch.get_device_module(device_type).synchronize() - stats.append( - torch.get_device_module(device_type).memory_allocated() - ) - if idx > 0: - if should_free: - self.assertLess(stats[idx], stats[idx - 1]) - else: - self.assertEqual(stats[idx], stats[idx - 1]) - - x.register_hook(hook) - - def test_fn(x): - # The main property of this function is that it contains multiple - # operations that save gradients in a chain. - x = x**2 - track(x, 2) - x = x**2 - track(x, 1) - x = x**2 - track(x, 0) - x = x**2 - return x.sum() - - fn(test_fn) - - return stats - - x = torch.zeros(10, device=device_type, requires_grad=True) - x.grad = torch.zeros_like(x) - - # In a regular backward, buffers get eagerly freed - non_retain_stats = _do_test(lambda fn: fn(x).backward(), True) - - # In a retain_grad backward, buffers get preserved - _unused_retain_stats = _do_test( - lambda fn: fn(x).backward(retain_graph=True), False - ) - - # In a regular backward with checkpoint, buffers get eagerly freed - checkpoint_non_retain_stats = _do_test( - lambda fn: checkpoint(fn, x, use_reentrant=False).backward(), True - ) - - # In a retain_grad backward with checkpoint, buffers get eagerly freed - checkpoint_retain_stats = _do_test( - lambda fn: checkpoint(fn, x, use_reentrant=False).backward( - retain_graph=True - ), - True, - ) - - self.assertEqual(non_retain_stats, checkpoint_non_retain_stats) - self.assertEqual(non_retain_stats, checkpoint_retain_stats) - - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_get_device_states_recursive(self): - inp = { - "foo": torch.rand(10, device=f"{device_type}:0"), - "bar": [torch.rand(10, device=f"{device_type}:1")], - } - device_ids, device_states = get_device_states(inp) - self.assertEqual(2, len(device_ids)) - self.assertEqual(2, len(device_states)) - self.assertEqual(0, device_ids[0]) - self.assertEqual(1, device_ids[1]) - self.assertTrue(isinstance(device_states[0], torch.Tensor)) - self.assertTrue(isinstance(device_states[1], torch.Tensor)) - - def test_infer_device_state_recursive_meta(self): - inp = {"foo": torch.rand(10, device="meta")} - device_type = _infer_device_type(inp) - self.assertEqual("meta", device_type) - - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_infer_device_state_recursive_multi_gpu(self): - # Check that no warning is issued for either gpu:0, gpu:1 or - # gpu:0, gpu:0 cases since they are both the same device type - global device_type - inp = { - "foo": torch.rand(10, device=f"{device_type}:0"), - "bar": [torch.rand(10, device=f"{device_type}:1")], - } - with warnings.catch_warnings(): - warnings.simplefilter("error") - _device_type = _infer_device_type(inp) - self.assertEqual(device_type, _device_type) - inp = { - "foo": torch.rand(10, device=f"{device_type}:0"), - "bar": [torch.rand(10, device=f"{device_type}:0")], - } - with warnings.catch_warnings(): - warnings.simplefilter("error") - _device_type = _infer_device_type(inp) - self.assertEqual(device_type, _device_type) - # Check that a warning is issued for gpu:0, meta and that it includes - # device type information - inp = { - "foo": torch.rand(10, device=f"{device_type}:0"), - "bar": [torch.rand(10, device="meta")], - } - with warnings.catch_warnings(record=True) as w: - _device_type = _infer_device_type(inp) - self.assertEqual(device_type, _device_type) - self.assertEqual(len(w), 1) - warning_msg = str(w[-1].message) - self.assertTrue( - "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices" - in warning_msg - ) - self.assertTrue(f"Device types: ['{device_type}', 'meta']" in warning_msg) - self.assertTrue(f"first device type: {device_type}" in warning_msg) - - -class TestDataLoaderUtils(TestCase): - MAX_TIMEOUT_IN_SECOND = 300 - - def test_random_seed(self): - def run(): - dataloader = torch.utils.data.DataLoader( - RandomDatasetMock(), - batch_size=2, - num_workers=4, - shuffle=True, - timeout=self.MAX_TIMEOUT_IN_SECOND, - ) - return next(iter(dataloader)) - - torch.manual_seed(2018) - x1 = run() - torch.manual_seed(2018) - x2 = run() - self.assertEqual(x1, x2) - - def test_single_keep(self): - # torch.rand(5, 3, 3, 2) is a Tensor here; technically not a valid input because - # not a Dataset subclass, but needs to stay working so add ignore's - # for type checking with mypy - dataloader: DataLoader = DataLoader( - torch.rand(5, 3, 3, 2), # type: ignore[arg-type] - batch_size=3, - num_workers=0, - drop_last=False, - ) - dataiter = iter(dataloader) - self.assertEqual(len(list(dataiter)), 2) - - def test_single_drop(self): - dataloader: DataLoader = DataLoader( - torch.rand(5, 3, 3, 2), # type: ignore[arg-type] - batch_size=3, - num_workers=0, - drop_last=True, - ) - dataiter = iter(dataloader) - self.assertEqual(len(list(dataiter)), 1) - - @unittest.skip( - "FIXME: Intermittent GPU out-of-memory error on Windows and time-out under ASAN" - ) - def test_multi_keep(self): - dataloader: DataLoader = DataLoader( - torch.rand(5, 3, 3, 2), # type: ignore[arg-type] - batch_size=3, - num_workers=2, - drop_last=False, - timeout=self.MAX_TIMEOUT_IN_SECOND, - ) - dataiter = iter(dataloader) - self.assertEqual(len(list(dataiter)), 2) - - def test_multi_drop(self): - dataloader: DataLoader = DataLoader( - torch.rand(5, 3, 3, 2), # type: ignore[arg-type] - batch_size=3, - num_workers=2, - drop_last=True, - timeout=self.MAX_TIMEOUT_IN_SECOND, - ) - dataiter = iter(dataloader) - self.assertEqual(len(list(dataiter)), 1) - - -test_dir = os.path.abspath(os.path.dirname(str(__file__))) - - -@unittest.skipIf(IS_FBCODE, "runs pip which is not available internally") -class TestCollectEnv(TestCase): - def test_smoke(self): - info_output = get_pretty_env_info() - self.assertTrue(info_output.count("\n") >= 17) - - -class TestHipify(TestCase): - def test_import_hipify(self): - from torch.utils.hipify import hipify_python # noqa: F401 - - -class TestHipifyTrie(TestCase): - def setUp(self): - from torch.utils.hipify import hipify_python - - self.trie = hipify_python.Trie() - - def test_add_and_search_trie(self): - self.trie.add("banana") - self.assertTrue(self.trie.search("banana")) - self.assertFalse(self.trie.search("ban")) - self.assertFalse(self.trie.search("dog")) - - def test_add_multiple_and_search_trie(self): - words_to_add = ["banana", "apple", "orange"] - for word in words_to_add: - self.trie.add(word) - - for word in words_to_add: - self.assertTrue(self.trie.search(word)) - - for word in ["ban", "dog", "okay", "app"]: - self.assertFalse(self.trie.search(word)) - - def test_quote_escape(self): - orig_chars = ["*", "[", ".", "+", "a", "z", "-"] - quoted_strs = ["\\*", "\\[", "\\.", "\\+", "a", "z", "\\-"] - for i in range(len(orig_chars)): - self.assertEqual(self.trie.quote(orig_chars[i]), quoted_strs[i]) - - @unittest.skipIf(HAS_XPU, "XPU not supported hipify") - def test_export_trie_to_regex(self): - words_to_add = [ - "__CUDACC__", - "CUDA_ERROR_CONTEXT_ALREADY_CURRENT", - "CUDA_ERROR_ARRAY_IS_MAPPED", - "CUDA_ERROR_NOT_MAPPED", - "CUDA_ERROR_INVALID_SOURCE", - ] - for word in words_to_add: - self.trie.add(word) - regex = self.trie.export_to_regex() - expected_regex = r"(?:CUDA_ERROR_(?:ARRAY_IS_MAPPED|CONTEXT_ALREADY_CURRENT|INVALID_SOURCE|NOT_MAPPED)|__CUDACC__)" - self.assertEqual(regex, expected_regex) - - def test_prefix_words_export_trie_to_regex(self): - # test case where some nodes have both children and are also leaf nodes. - words_to_add = ["apple", "app", "ban", "banana"] - for word in words_to_add: - self.trie.add(word) - regex = self.trie.export_to_regex() - expected_regex = r"(?:app(?:le)?|ban(?:ana)?)" - self.assertEqual(regex, expected_regex) - - @unittest.skipIf(HAS_XPU, "XPU not supported hipify") - def test_single_export_trie_to_regex(self): - words_to_add = ["cudaErrorInvalidMemcpyDirection"] - for word in words_to_add: - self.trie.add(word) - regex = self.trie.export_to_regex() - expected_regex = "cudaErrorInvalidMemcpyDirection" - self.assertEqual(regex, expected_regex) - - def test_char_export_trie_to_regex(self): - self.trie.add("a") - self.assertEqual(self.trie.export_to_regex(), "a") - self.trie.add("b") - self.assertEqual(self.trie.export_to_regex(), "[ab]") - - def test_special_char_export_trie_to_regex(self): - self.trie.add(r"c*") - self.assertEqual(self.trie.export_to_regex(), r"c\*") - - -class TestAssert(TestCase): - def test_assert_true(self): - # verify assertions work as expected - # bool argument - torch._assert(True, "foo") - with self.assertRaisesRegex(AssertionError, "bar"): - torch._assert(False, "bar") - # tensor argument - torch._assert(torch.tensor([True], dtype=torch.bool), "foo") - with self.assertRaisesRegex(AssertionError, "bar"): - torch._assert(torch.tensor([False], dtype=torch.bool), "bar") - - def test_assert_scriptable(self): - class M(torch.nn.Module): - def forward(self, x): - torch._assert(x.sum() > 0, "foo") - return x - - m = M() - # scriptable - ms = torch.jit.script(m) - # data can be passed without errors - x = torch.randn(4, 4).fill_(1.0) - ms(x) - with self.assertRaisesRegex(torch.jit.Error, "foo"): - ms(torch.tensor([False], dtype=torch.bool)) - - -@unittest.skipIf(IS_SANDCASTLE, "cpp_extension is OSS only") -class TestStandaloneCPPJIT(TestCase): - def test_load_standalone(self): - build_dir = tempfile.mkdtemp() - try: - src_path = os.path.join(build_dir, "main.cpp") - src = textwrap.dedent( - """\ - #include - #include - int main() { - auto x = torch::eye(3); - std::cout << x << std::endl; - } - """ - ) - with open(src_path, "w") as f: - f.write(src) - - exec_path = torch.utils.cpp_extension.load( - "standalone_load_test", - src_path, - build_directory=build_dir, - is_python_module=False, - is_standalone=True, - ) - - ext = ".exe" if IS_WINDOWS else "" - self.assertEqual( - exec_path, os.path.join(build_dir, f"standalone_load_test{ext}") - ) - - for shell in [True, False]: - r = subprocess.run( - [exec_path], - shell=shell, - stdout=subprocess.PIPE, - ) - self.assertEqual(r.returncode, 0) - self.assertEqual( - # Windows prints "\r\n" for newlines. - textwrap.dedent(r.stdout.decode("utf-8")).replace("\r\n", "\n"), - textwrap.dedent( - """\ - 1 0 0 - 0 1 0 - 0 0 1 - [ CPUFloatType{3,3} ] - """ - ), - ) - - finally: - shutil.rmtree(build_dir) - - -class TestRenderUtils(TestCase): - def test_basic(self): - self.assertExpectedInline( - torch._utils.render_call(torch.sum, [torch.randn(100)], {"dim": 0}), - """torch.sum(tensor([...], size=(100,)), dim=0)""", - ) - self.assertExpectedInline( - torch._utils.render_call(torch.sum, [torch.randn(100, 100)], {"dim": 0}), - """torch.sum(tensor([...], size=(100, 100)), dim=0)""", - ) - - -class TestDeviceUtils(TestCase): - def test_basic(self): - with torch.device("meta") as dev: - x = torch.empty(3, 3) - self.assertEqual(x.device.type, "meta") - self.assertEqual(dev, torch.device("meta")) - - def test_decorator(self): - @set_device("meta") - def f(): - return torch.empty(3, 3) - - self.assertEqual(f().device.type, "meta") - - def test_decorator_generator(self): - @set_device("meta") - def f(): - yield torch.empty(3, 3) - yield torch.empty(3, 3) - - r1, r2 = list(f()) - self.assertEqual(r1.device.type, "meta") - self.assertEqual(r2.device.type, "meta") - - def test_nn_module(self): - with torch.device("meta"): - m = nn.Linear(40, 50) - self.assertEqual(m.weight.device.type, "meta") - - def test_set_default_device(self): - try: - torch.set_default_device("meta") - r = torch.empty(2, 2) - finally: - torch.set_default_device(None) - - self.assertEqual(r.device.type, "meta") - - def test_get_default_device(self): - torch.set_default_device("meta") - self.assertEqual(torch.get_default_device().type, "meta") - torch.set_default_device(None) - - @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") - def test_get_default_device_more(self): - try: - torch.set_default_device(device_type) - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) - - torch.set_default_device(device_type) - torch.get_device_module(device_type).set_device(f"{device_type}:1") - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) - - torch.set_default_device(f"{device_type}:1") - self.assertEqual(torch.get_default_device(), torch.tensor([]).device) - torch.set_default_device(None) - - torch.set_default_device(f"{device_type}:1") - with torch.device(f"{device_type}:0"): - self.assertEqual( - torch.get_default_device(), torch.device(f"{device_type}", 0) - ) - - torch.set_default_device("cpu") - self.assertEqual(torch.get_default_device(), torch.device("cpu")) - with torch.device(f"{device_type}:0"): - self.assertEqual( - torch.get_default_device(), torch.device(f"{device_type}", 0) - ) - - self.assertEqual(torch.get_default_device(), torch.device("cpu")) - finally: - # Reset the device at the end. - torch.set_default_device(None) - - @onlyCPU - @ops(op_db) - def test_device_mode_ops(self, device, dtype, op): - func = op.get_op() - samples = op.sample_inputs(device, dtype, requires_grad=False) - for sample in samples: - # Only test samples which don't have Tensor inputs. However, - # we don't test the factory property on OpInfo as it is very, - # very incomplete - if tree_any( - lambda x: isinstance(x, torch.Tensor), - (sample.input, sample.args, sample.kwargs), - ): - continue - # Many OpInfos will explicitly pass in a device. DeviceContext - # will respect device if it is explicitly specified. To test - # DeviceContext, we have to remove the device kwarg in this case. - # NB: Can't pass None to sample_inputs, the function can't - # handle it. - kwargs = sample.kwargs.copy() - kwargs.pop("device", None) - with torch.device("meta"): - r = func(sample.input, *sample.args, **kwargs) - - def is_meta_device(x: torch.Tensor) -> bool: - return x.device.type == "meta" - - self.assertTrue(tree_all_only(torch.Tensor, is_meta_device, r)) - - -instantiate_device_type_tests(TestDeviceUtils, globals()) - - -class TestCppExtensionUtils(TestCase): - def test_cpp_compiler_is_ok(self): - self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("c++")) - - def test_cc_compiler_is_ok(self): - self.assertTrue(torch.utils.cpp_extension.check_compiler_ok_for_platform("cc")) - - -class TestTraceback(TestCase): - def test_basic(self): - source = """\ -def f(x): - def g(x): - raise RuntimeError # HEYA - - x = x * 3 - return g(x) + 1 -""" - - out: dict[str, Any] = {} - scope = {"__compile_source__": source} - exec(source, scope, out) - - try: - with report_compile_source_on_error(): - out["f"](1) - except RuntimeError as e: - self.assertIn("HEYA", "".join(traceback.format_tb(e.__traceback__))) - - def test_format_traceback_short(self): - try: - raise RuntimeError - except RuntimeError as e: - self.assertRegex( - format_traceback_short(e.__traceback__), - r".*test_utils.py:\d+ in test_format_traceback_short", - ) - - def test_captured_traceback(self): - self.assertIn( - "test_captured_traceback", "".join(CapturedTraceback.extract().format()) - ) - - def test_captured_traceback_format_all(self): - rs = CapturedTraceback.format_all( - [CapturedTraceback.extract(), CapturedTraceback.extract()] - ) - self.assertEqual(len(rs), 2) - self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) - - def test_captured_traceback_format_all_cached(self): - tb = CapturedTraceback.extract() - tb.format() # cached - rs = CapturedTraceback.format_all([tb, CapturedTraceback.extract()]) - self.assertEqual(len(rs), 2) - self.assertIn("test_captured_traceback_format_all", "".join(rs[0])) - - -class TestTryImport(TestCase): - def test_import_imported(self): - self.assertIn("os", sys.modules) - os_module = try_import("os") - self.assertIs(os_module, os) - - def test_import_existing(self): - self.assertNotIn("imaplib", sys.modules) - imaplib_module = try_import("imaplib") - self.assertIsNotNone(imaplib_module) - self.assertFalse(hasattr(imaplib_module, "not_attribute")) - self.assertTrue(hasattr(imaplib_module, "IMAP4")) - - def test_import_missing(self): - missing_module = try_import("missing_module") - self.assertIsNone(missing_module) - - -@deprecated() -def _deprecated_api(x, y=15): - return x + y - - -class TestDeprecate(TestCase): - def test_deprecated(self): - with self.assertWarnsRegex(Warning, "is DEPRECATED"): - deprecated_api(1, 2) # noqa: F821 - with self.assertWarnsRegex(Warning, "is DEPRECATED"): - deprecated_api(1, y=2) # noqa: F821 - _deprecated_api(1, 2) - _deprecated_api(1, y=2) - - -if __name__ == "__main__": - run_tests()