From 7dad1731478db931b30f0ba8e6b750a5c7544d05 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 3 Nov 2025 11:35:20 +0530 Subject: [PATCH] error early in auto_cpu_offload --- .../modular_pipelines/components_manager.py | 26 ++++++++++++++----- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index cb7e8fb73697..e16abb382313 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -160,7 +160,10 @@ def __call__(self, hooks, model_id, model, execution_device): if len(hooks) == 0: return [] - current_module_size = model.get_memory_footprint() + try: + current_module_size = model.get_memory_footprint() + except AttributeError: + raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.") device_type = execution_device.type device_module = getattr(torch, device_type, torch.cuda) @@ -703,7 +706,20 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, if not is_accelerate_available(): raise ImportError("Make sure to install accelerate to use auto_cpu_offload") - # TODO: add a warning if mem_get_info isn't available on `device`. + if device is None: + device = get_device() + if not isinstance(device, torch.device): + device = torch.device(device) + + device_type = device.type + device_module = getattr(torch, device_type, torch.cuda) + if not hasattr(device_module, "mem_get_info"): + raise NotImplementedError( + f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}." + ) + + if device.index is None: + device = torch.device(f"{device.type}:{0}") for name, component in self.components.items(): if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): @@ -711,11 +727,7 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = None, self.disable_auto_cpu_offload() offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) - if device is None: - device = get_device() - device = torch.device(device) - if device.index is None: - device = torch.device(f"{device.type}:{0}") + all_hooks = [] for name, component in self.components.items(): if isinstance(component, torch.nn.Module):