From 6e33ee391abfa8c118688ba4f803a47d7e851132 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 16:45:08 +0800 Subject: [PATCH 01/29] debug error --- comfy/model_management.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index d82d5b8b00ae..e0a09776123b 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -509,12 +509,29 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - if freed >= memory_to_free: - return False - self.model.detach(unpatch_weights) + logging.info(f"model_unload: {self.model.model.__class__.__name__}") + logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.info(f"unpatch_weights: {unpatch_weights}") + logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.info(f"offload_device: {self.model.offload_device}") + available_memory = get_free_memory(self.model.offload_device) + logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + try: + if memory_to_free is not None: + if memory_to_free < self.model.loaded_size(): + logging.info("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB") + if freed >= memory_to_free: + return False + logging.info("Do full unload") + self.model.detach(unpatch_weights) + logging.info("Do full unload done") + except Exception as e: + logging.error(f"Error in model_unload: {e}") + available_memory = get_free_memory(self.model.offload_device) + logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + return False self.model_finalizer.detach() self.model_finalizer = None self.real_model = None @@ -567,6 +584,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): + logging.info("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -587,7 +605,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") if current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) @@ -604,6 +622,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): + logging.info(f"start to load models") cleanup_models_gc() global vram_state @@ -625,6 +644,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: + logging.info(f"loading model: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) From fa19dd46200e5708f0e17e24622939257bfcffca Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 17:00:47 +0800 Subject: [PATCH 02/29] debug offload --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index e0a09776123b..840239a272b9 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -516,6 +516,9 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.info(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + if available_memory < memory_to_free: + logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Required: {memory_to_free/(1024*1024*1024)} GB") + return False try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From f40e00cb357754ae99a2eac59d8fbfae2f23607a Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 19:38:13 +0800 Subject: [PATCH 03/29] add detail debug --- execution.py | 26 ++++++++++++++++++++++++++ server.py | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/execution.py b/execution.py index 1dc35738b823..69bd53502966 100644 --- a/execution.py +++ b/execution.py @@ -400,7 +400,12 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] + + # Log node execution start + logging.info(f"๐Ÿ“ Node [{display_node_id}] START: {class_type}") + if caches.outputs.get(unique_id) is not None: + logging.info(f"โœ… Node [{display_node_id}] CACHED: {class_type} (using cached output)") if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) @@ -446,15 +451,20 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) + logging.info(f"๐Ÿ” Node [{display_node_id}] Getting input data for {class_type}") input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + logging.info(f"๐Ÿ“ฅ Node [{display_node_id}] Input data ready, keys: {list(input_data_all.keys())}") if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = caches.objects.get(unique_id) if obj is None: + logging.info(f"๐Ÿ—๏ธ Node [{display_node_id}] Creating new instance of {class_type}") obj = class_def() caches.objects.set(unique_id, obj) + else: + logging.info(f"โ™ป๏ธ Node [{display_node_id}] Reusing cached instance of {class_type}") if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -493,7 +503,9 @@ def execution_block_cb(block): def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) + logging.info(f"โš™๏ธ Node [{display_node_id}] Executing {class_type}") output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + logging.info(f"๐Ÿ“ค Node [{display_node_id}] Execution completed, has_subgraph: {has_subgraph}, has_pending: {has_pending_tasks}") if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -572,6 +584,7 @@ async def await_completion(): for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] + logging.error(f"โŒ Node [{display_node_id}] FAILED: {class_type}") logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) tips = "" @@ -593,6 +606,8 @@ async def await_completion(): get_progress_state().finish_progress(unique_id) executed.add(unique_id) + + logging.info(f"โœ… Node [{display_node_id}] SUCCESS: {class_type} completed") return (ExecutionResult.SUCCESS, None, None) @@ -649,6 +664,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): + logging.info(f"๐Ÿš€ Workflow execution START: prompt_id={prompt_id}, nodes_count={len(prompt)}") nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -672,6 +688,9 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) + + if len(cached_nodes) > 0: + logging.info(f"๐Ÿ’พ Workflow has {len(cached_nodes)} cached nodes: {cached_nodes}") comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", @@ -684,6 +703,8 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id) + + logging.info(f"๐Ÿ“‹ Workflow execution list prepared, executing {len(execute_outputs)} output nodes") while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() @@ -695,6 +716,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: + logging.error(f"๐Ÿ’ฅ Workflow execution FAILED at node {node_id}") self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break elif result == ExecutionResult.PENDING: @@ -703,6 +725,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break + logging.info(f"๐ŸŽ‰ Workflow execution SUCCESS: prompt_id={prompt_id}, executed_nodes={len(executed)}") self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} @@ -719,7 +742,10 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: + logging.info("๐Ÿงน Unloading all models (DISABLE_SMART_MEMORY is enabled)") comfy.model_management.unload_all_models() + + logging.info(f"โœจ Workflow execution COMPLETED: prompt_id={prompt_id}") async def validate_inputs(prompt_id, prompt, item, validated): diff --git a/server.py b/server.py index 80e9d3fa78a0..515307bf6de6 100644 --- a/server.py +++ b/server.py @@ -673,7 +673,7 @@ async def get_queue(request): @routes.post("/prompt") async def post_prompt(request): - logging.info("got prompt") + logging.info("got prompt in debug comfyui") json_data = await request.json() json_data = self.trigger_on_prompt(json_data) From 2b222962c3f9168d2333a73ea2dd525880ec215c Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 21:42:02 +0800 Subject: [PATCH 04/29] add debug log --- comfy/model_base.py | 3 +++ comfy/sd.py | 2 ++ nodes.py | 2 ++ 3 files changed, 7 insertions(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 8274c7dea192..7dead016704b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -297,8 +297,11 @@ def load_model_weights(self, sd, unet_prefix=""): if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) + logging.info(f"load model weights start, keys {keys}") to_load = self.model_config.process_unet_state_dict(to_load) + logging.info(f"load model {self.model_config} weights process end, keys {keys}") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + logging.info(f"load model {self.model_config} weights end, keys {keys}") if len(m) > 0: logging.warning("unet missing: {}".format(m)) diff --git a/comfy/sd.py b/comfy/sd.py index 28bee248dae1..16d54f08b7ac 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1347,7 +1347,9 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): sd = comfy.utils.load_torch_file(unet_path) + logging.info(f"load model start, path {unet_path}") model = load_diffusion_model_state_dict(sd, model_options=model_options) + logging.info(f"load model end, path {unet_path}") if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/nodes.py b/nodes.py index 7cfa8ca1411d..25ccc9e421bc 100644 --- a/nodes.py +++ b/nodes.py @@ -922,7 +922,9 @@ def load_unet(self, unet_name, weight_dtype): model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) + logging.info(f"load unet node start, path {unet_path}") model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) + logging.info(f"load unet node end, path {unet_path}") return (model,) class CLIPLoader: From c1eac555c011f05ff4a3393ce0c86964314ccc18 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 21:42:48 +0800 Subject: [PATCH 05/29] add debug log --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 7dead016704b..6c8ee69b47d3 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -291,6 +291,7 @@ def extra_conds(self, **kwargs): return out def load_model_weights(self, sd, unet_prefix=""): + import pdb; pdb.set_trace() to_load = {} keys = list(sd.keys()) for k in keys: From 9352987e9bc625dd5b4f1acdbf059ad5c2382172 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:25:17 +0800 Subject: [PATCH 06/29] add log --- comfy/model_base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6c8ee69b47d3..75d469221b8b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -60,6 +60,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher +from comfy.model_management import get_free_memory class ModelType(Enum): EPS = 1 @@ -291,18 +292,19 @@ def extra_conds(self, **kwargs): return out def load_model_weights(self, sd, unet_prefix=""): - import pdb; pdb.set_trace() to_load = {} keys = list(sd.keys()) for k in keys: if k.startswith(unet_prefix): to_load[k[len(unet_prefix):]] = sd.pop(k) - logging.info(f"load model weights start, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") to_load = self.model_config.process_unet_state_dict(to_load) - logging.info(f"load model {self.model_config} weights process end, keys {keys}") + logging.info(f"load model {self.model_config} weights process end") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) - logging.info(f"load model {self.model_config} weights end, keys {keys}") + free_cpu_memory = get_free_memory(torch.device("cpu")) + logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) From a207301c25e7fd83723152fc343a5ac49f983f4d Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:28:06 +0800 Subject: [PATCH 07/29] rm useless log --- execution.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/execution.py b/execution.py index 69bd53502966..c3a4cc5faff1 100644 --- a/execution.py +++ b/execution.py @@ -401,11 +401,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - # Log node execution start - logging.info(f"๐Ÿ“ Node [{display_node_id}] START: {class_type}") if caches.outputs.get(unique_id) is not None: - logging.info(f"โœ… Node [{display_node_id}] CACHED: {class_type} (using cached output)") if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id) @@ -451,20 +448,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - logging.info(f"๐Ÿ” Node [{display_node_id}] Getting input data for {class_type}") input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) - logging.info(f"๐Ÿ“ฅ Node [{display_node_id}] Input data ready, keys: {list(input_data_all.keys())}") if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) obj = caches.objects.get(unique_id) if obj is None: - logging.info(f"๐Ÿ—๏ธ Node [{display_node_id}] Creating new instance of {class_type}") obj = class_def() caches.objects.set(unique_id, obj) - else: - logging.info(f"โ™ป๏ธ Node [{display_node_id}] Reusing cached instance of {class_type}") if issubclass(class_def, _ComfyNodeInternal): lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None @@ -503,9 +495,7 @@ def execution_block_cb(block): def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - logging.info(f"โš™๏ธ Node [{display_node_id}] Executing {class_type}") output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) - logging.info(f"๐Ÿ“ค Node [{display_node_id}] Execution completed, has_subgraph: {has_subgraph}, has_pending: {has_pending_tasks}") if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -584,7 +574,6 @@ async def await_completion(): for name, inputs in input_data_all.items(): input_data_formatted[name] = [format_value(x) for x in inputs] - logging.error(f"โŒ Node [{display_node_id}] FAILED: {class_type}") logging.error(f"!!! Exception during processing !!! {ex}") logging.error(traceback.format_exc()) tips = "" @@ -607,7 +596,6 @@ async def await_completion(): get_progress_state().finish_progress(unique_id) executed.add(unique_id) - logging.info(f"โœ… Node [{display_node_id}] SUCCESS: {class_type} completed") return (ExecutionResult.SUCCESS, None, None) From 71b23d12e45e39fb2e94da510b823831e9a7b151 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:34:55 +0800 Subject: [PATCH 08/29] rm useless log --- execution.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/execution.py b/execution.py index c3a4cc5faff1..53f2953572c5 100644 --- a/execution.py +++ b/execution.py @@ -652,7 +652,6 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): - logging.info(f"๐Ÿš€ Workflow execution START: prompt_id={prompt_id}, nodes_count={len(prompt)}") nodes.interrupt_processing(False) if "client_id" in extra_data: @@ -676,9 +675,6 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= for node_id in prompt: if self.caches.outputs.get(node_id) is not None: cached_nodes.append(node_id) - - if len(cached_nodes) > 0: - logging.info(f"๐Ÿ’พ Workflow has {len(cached_nodes)} cached nodes: {cached_nodes}") comfy.model_management.cleanup_models_gc() self.add_message("execution_cached", @@ -691,8 +687,6 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= current_outputs = self.caches.outputs.all_node_ids() for node_id in list(execute_outputs): execution_list.add_node(node_id) - - logging.info(f"๐Ÿ“‹ Workflow execution list prepared, executing {len(execute_outputs)} output nodes") while not execution_list.is_empty(): node_id, error, ex = await execution_list.stage_node_execution() @@ -704,7 +698,6 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: - logging.error(f"๐Ÿ’ฅ Workflow execution FAILED at node {node_id}") self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break elif result == ExecutionResult.PENDING: @@ -713,7 +706,6 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= execution_list.complete_node_execution() else: # Only execute when the while-loop ends without break - logging.info(f"๐ŸŽ‰ Workflow execution SUCCESS: prompt_id={prompt_id}, executed_nodes={len(executed)}") self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False) ui_outputs = {} @@ -730,10 +722,7 @@ async def execute_async(self, prompt, prompt_id, extra_data={}, execute_outputs= } self.server.last_node_id = None if comfy.model_management.DISABLE_SMART_MEMORY: - logging.info("๐Ÿงน Unloading all models (DISABLE_SMART_MEMORY is enabled)") comfy.model_management.unload_all_models() - - logging.info(f"โœจ Workflow execution COMPLETED: prompt_id={prompt_id}") async def validate_inputs(prompt_id, prompt, item, validated): From e5ff6a1b53211ce3130cc0de071ce137714e03a4 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 16 Oct 2025 22:47:03 +0800 Subject: [PATCH 09/29] refine log --- comfy/model_base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_base.py b/comfy/model_base.py index 75d469221b8b..b0bb0cfb047f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -300,6 +300,7 @@ def load_model_weights(self, sd, unet_prefix=""): free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") m, u = self.diffusion_model.load_state_dict(to_load, strict=False) From 5c3c6c02b237b3728348f90567b7236cfc45b8b7 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 16:33:14 +0800 Subject: [PATCH 10/29] add debug log of cpu load --- .../ldm/modules/diffusionmodules/openaimodel.py | 12 ++++++++++++ comfy/model_base.py | 17 +++++++++++++++++ comfy/model_patcher.py | 5 +++++ comfy/sd.py | 1 + 4 files changed, 35 insertions(+) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4c8d53cac9c2..ff6e96a3cdc2 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,3 +911,15 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf return self.id_predictor(h) else: return self.out(h) + + + def load_state_dict(self, state_dict, strict=True): + """Override load_state_dict() to add logging""" + logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/model_base.py b/comfy/model_base.py index b0bb0cfb047f..7d474a76a1fc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,6 +303,8 @@ def load_model_weights(self, sd, unet_prefix=""): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") + # TODO(sf): to mmap + # diffusion_model is UNetModel m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -384,6 +386,21 @@ def memory_required(self, input_shape, cond_shapes={}): #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) + + def to(self, *args, **kwargs): + """Override to() to add custom device management logic""" + old_device = self.device if hasattr(self, 'device') else None + + result = super().to(*args, **kwargs) + + if len(args) > 0: + if isinstance(args[0], (torch.device, str)): + new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] + if 'device' in kwargs: + new_device = kwargs['device'] + + logging.info(f"BaseModel moved from {old_device} to {new_device}") + return result def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index c0b68fb8cff7..ea91bd2c5613 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -486,6 +486,7 @@ def get_model_object(self, name: str) -> torch.nn.Module: return comfy.utils.get_attr(self.model, name) def model_patches_to(self, device): + # TODO(sf): to mmap to = self.model_options["transformer_options"] if "patches" in to: patches = to["patches"] @@ -783,6 +784,8 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.backup.clear() if device_to is not None: + # TODO(sf): to mmap + # self.model is what module? self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 @@ -837,6 +840,8 @@ def partially_unload(self, device_to, memory_to_free=0): bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights + # TODO(sf): to mmap + # m is what module? m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: diff --git a/comfy/sd.py b/comfy/sd.py index 16d54f08b7ac..89a1f30b89a0 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,6 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() + logging.info(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None From 6583cc0142466473922a59d2e646881693cff011 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 18:28:25 +0800 Subject: [PATCH 11/29] debug load mem --- comfy/ldm/flux/model.py | 13 +++++++++++++ comfy/ldm/modules/diffusionmodules/openaimodel.py | 1 + comfy/model_base.py | 2 ++ comfy/sd.py | 4 ++++ comfy/utils.py | 6 ++++++ 5 files changed, 26 insertions(+) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 14f90cea55e6..263cdae26054 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,6 +7,7 @@ from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension +import logging from .layers import ( DoubleStreamBlock, @@ -278,3 +279,15 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] + + def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() + """Override load_state_dict() to add logging""" + logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") + + # Call parent's load_state_dict method + result = super().load_state_dict(state_dict, strict=strict) + + logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") + + return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index ff6e96a3cdc2..e847700c6b13 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -914,6 +914,7 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf def load_state_dict(self, state_dict, strict=True): + import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/model_base.py b/comfy/model_base.py index 7d474a76a1fc..34dd160372e5 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,6 +305,8 @@ def load_model_weights(self, sd, unet_prefix=""): logging.info(f"load model {self.model_config} weights process end") # TODO(sf): to mmap # diffusion_model is UNetModel + import pdb; pdb.set_trace() + # TODO(sf): here needs to avoid load mmap into cpu mem m, u = self.diffusion_model.load_state_dict(to_load, strict=False) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") diff --git a/comfy/sd.py b/comfy/sd.py index 89a1f30b89a0..7005a1b53887 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,6 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") + import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1347,10 +1348,13 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): + # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) logging.info(f"load model start, path {unet_path}") + import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) logging.info(f"load model end, path {unet_path}") + import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index 0fd03f165b7c..a664024515a5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -55,11 +55,15 @@ def scalar(*args, **kwargs): logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): + # TODO(sf): here load file into mmap + logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}") if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: + if not DISABLE_MMAP: + logging.info(f"load_torch_file safetensors mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -80,6 +84,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: + logging.info(f"load_torch_file mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -97,6 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd + import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): From 49597bfa3e36c78635db4234106611846fbc4117 Mon Sep 17 00:00:00 2001 From: strint Date: Fri, 17 Oct 2025 21:43:49 +0800 Subject: [PATCH 12/29] load remains mmap --- comfy/ldm/flux/model.py | 6 +++--- comfy/ldm/modules/diffusionmodules/openaimodel.py | 6 +++--- comfy/model_base.py | 4 ++-- comfy/sd.py | 6 +++--- comfy/utils.py | 2 +- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 263cdae26054..da46ed2ed389 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -280,13 +280,13 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] - def load_state_dict(self, state_dict, strict=True): - import pdb; pdb.set_trace() + def load_state_dict(self, state_dict, strict=True, assign=False): + # import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict) + result = super().load_state_dict(state_dict, strict=strict, assign=assign) logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index e847700c6b13..2cdf711d4c7b 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -913,13 +913,13 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf return self.out(h) - def load_state_dict(self, state_dict, strict=True): - import pdb; pdb.set_trace() + def load_state_dict(self, state_dict, strict=True, assign=False): + # import pdb; pdb.set_trace() """Override load_state_dict() to add logging""" logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict) + result = super().load_state_dict(state_dict, strict=strict, assign=assign) logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") diff --git a/comfy/model_base.py b/comfy/model_base.py index 34dd160372e5..409e7fb87969 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -305,9 +305,9 @@ def load_model_weights(self, sd, unet_prefix=""): logging.info(f"load model {self.model_config} weights process end") # TODO(sf): to mmap # diffusion_model is UNetModel - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() # TODO(sf): here needs to avoid load mmap into cpu mem - m, u = self.diffusion_model.load_state_dict(to_load, strict=False) + m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: diff --git a/comfy/sd.py b/comfy/sd.py index 7005a1b53887..a956884fb7a8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1338,7 +1338,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1351,10 +1351,10 @@ def load_diffusion_model(unet_path, model_options={}): # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) logging.info(f"load model start, path {unet_path}") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) logging.info(f"load model end, path {unet_path}") - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index a664024515a5..4c22f684c44e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -102,7 +102,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd - import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): From 21ebcada1da1466d2a3fe91c9e517156ed5172cf Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 20 Oct 2025 16:22:50 +0800 Subject: [PATCH 13/29] debug free mem --- comfy/model_management.py | 1 + 1 file changed, 1 insertion(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 840239a272b9..79f0434199be 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -588,6 +588,7 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): logging.info("start to free mem") + import pdb; pdb.set_trace() cleanup_models_gc() unloaded_model = [] can_unload = [] From 4ac827d56454838e051fb05b0047fea06359bcc7 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 20 Oct 2025 18:27:38 +0800 Subject: [PATCH 14/29] unload partial --- comfy/model_management.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 79f0434199be..30a509670eaa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -516,9 +516,17 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.info(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - if available_memory < memory_to_free: - logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Required: {memory_to_free/(1024*1024*1024)} GB") + reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + if available_memory < reserved_memory: + logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") return False + else: + offload_memory = available_memory - reserved_memory + + if offload_memory < memory_to_free: + memory_to_free = offload_memory + logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") + logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From e9e1d2f0e82af07b701a72e20c171625cdc1f402 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 00:40:14 +0800 Subject: [PATCH 15/29] add mmap tensor --- comfy/model_patcher.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index ea91bd2c5613..e4d8507d00aa 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,6 +27,7 @@ from typing import Callable, Optional import torch +import tensordict import comfy.float import comfy.hooks @@ -37,6 +38,9 @@ from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: + return tensordict.MemoryMappedTensor.from_tensor(t) + def string_to_seed(data): crc = 0xFFFFFFFF @@ -784,9 +788,37 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.backup.clear() if device_to is not None: - # TODO(sf): to mmap - # self.model is what module? - self.model.to(device_to) + # Temporarily register to_mmap method to the model + # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + def _to_mmap_method(self): + """Convert all parameters and buffers to memory-mapped tensors + + This method mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead. + """ + import pdb; pdb.set_trace() + logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method") + def convert_fn(t): + if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + return to_mmap(t) + elif isinstance(t, torch.nn.Parameter): + # For parameters, convert the data and wrap back in Parameter + param_mmap = to_mmap(t.data) + return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) + return t + + return self._apply(convert_fn) + + # Bind the method to the model instance + import types + self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) + + # Call the to_mmap method + self.model.to_mmap() + + # Optionally clean up the temporary method + # delattr(self.model, 'to_mmap') + self.model.device = device_to self.model.model_loaded_weight_memory = 0 From 49561788cfccecd872808515a3975df772155a75 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:03:38 +0800 Subject: [PATCH 16/29] fix log --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index e4d8507d00aa..10ac1e7dea01 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -797,7 +797,7 @@ def _to_mmap_method(self): tensors to memory-mapped format instead. """ import pdb; pdb.set_trace() - logging.info(f"model {self.model.__class__.__name__} is calling to_mmap method") + logging.info(f"model {self.__class__.__name__} is calling to_mmap method") def convert_fn(t): if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): return to_mmap(t) From 8aeebbf7ef0e6b54e41473661fb0ea216d380e29 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:27:40 +0800 Subject: [PATCH 17/29] fix to --- comfy/model_patcher.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 10ac1e7dea01..4c7cd5e3e121 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -794,17 +794,28 @@ def _to_mmap_method(self): """Convert all parameters and buffers to memory-mapped tensors This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead. + tensors to memory-mapped format instead, using _apply() method. + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. """ - import pdb; pdb.set_trace() logging.info(f"model {self.__class__.__name__} is calling to_mmap method") + def convert_fn(t): - if isinstance(t, torch.Tensor) and not isinstance(t, torch.nn.Parameter): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor return to_mmap(t) - elif isinstance(t, torch.nn.Parameter): - # For parameters, convert the data and wrap back in Parameter - param_mmap = to_mmap(t.data) - return torch.nn.Parameter(param_mmap, requires_grad=t.requires_grad) return t return self._apply(convert_fn) From 05c2518c6dea831cc15031dc8833afddbbb5a33e Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 02:59:51 +0800 Subject: [PATCH 18/29] refact mmap --- comfy/model_patcher.py | 91 ++++++++++++++++++++++-------------------- 1 file changed, 47 insertions(+), 44 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4c7cd5e3e121..d2e3a296ad15 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -37,9 +37,52 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP +from comfy.model_management import get_free_memory def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: return tensordict.MemoryMappedTensor.from_tensor(t) + +def model_to_mmap(model: torch.nn.Module): + """Convert all parameters and buffers to memory-mapped tensors + + This function mimics PyTorch's Module.to() behavior but converts + tensors to memory-mapped format instead, using _apply() method. + + Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 + + Note: For Parameters, we modify .data in-place because + MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. + For buffers, _apply() will automatically update the reference. + + Args: + model: PyTorch module to convert + + Returns: + The same model with all tensors converted to memory-mapped format + """ + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + + def convert_fn(t): + """Convert function for _apply() + + - For Parameters: modify .data and return the Parameter object + - For buffers (plain Tensors): return new MemoryMappedTensor + """ + if isinstance(t, torch.nn.Parameter): + # For parameters, modify data in-place and return the parameter + if isinstance(t.data, torch.Tensor): + t.data = to_mmap(t.data) + return t + elif isinstance(t, torch.Tensor): + # For buffers (plain tensors), return the converted tensor + return to_mmap(t) + return t + + new_model = model._apply(convert_fn) + free_cpu_mem = get_free_memory(torch.device("cpu")) + logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + return new_model def string_to_seed(data): @@ -787,50 +830,9 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.model.current_weight_patches_uuid = None self.backup.clear() - if device_to is not None: - # Temporarily register to_mmap method to the model - # Reference: https://github.com/pytorch/pytorch/blob/0fabc3ba44823f257e70ce397d989c8de5e362c1/torch/nn/modules/module.py#L1244 - def _to_mmap_method(self): - """Convert all parameters and buffers to memory-mapped tensors - - This method mimics PyTorch's Module.to() behavior but converts - tensors to memory-mapped format instead, using _apply() method. - - Note: For Parameters, we modify .data in-place because - MemoryMappedTensor cannot be wrapped in torch.nn.Parameter. - For buffers, _apply() will automatically update the reference. - """ - logging.info(f"model {self.__class__.__name__} is calling to_mmap method") - - def convert_fn(t): - """Convert function for _apply() - - - For Parameters: modify .data and return the Parameter object - - For buffers (plain Tensors): return new MemoryMappedTensor - """ - if isinstance(t, torch.nn.Parameter): - # For parameters, modify data in-place and return the parameter - if isinstance(t.data, torch.Tensor): - t.data = to_mmap(t.data) - return t - elif isinstance(t, torch.Tensor): - # For buffers (plain tensors), return the converted tensor - return to_mmap(t) - return t - - return self._apply(convert_fn) - - # Bind the method to the model instance - import types - self.model.to_mmap = types.MethodType(_to_mmap_method, self.model) - # Call the to_mmap method - self.model.to_mmap() - - # Optionally clean up the temporary method - # delattr(self.model, 'to_mmap') - - self.model.device = device_to + model_to_mmap(self.model) + self.model.device = device_to self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,7 +887,8 @@ def partially_unload(self, device_to, memory_to_free=0): cast_weight = self.force_cast_weights # TODO(sf): to mmap # m is what module? - m.to(device_to) + # m.to(device_to) + model_to_mmap(m) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: From 2f0d56656eea7da5a9dda2e5b0061b31bc5aefbd Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:38:17 +0800 Subject: [PATCH 19/29] refine code --- comfy/ldm/flux/model.py | 12 ----------- .../modules/diffusionmodules/openaimodel.py | 14 +------------ comfy/model_base.py | 20 +------------------ comfy/model_management.py | 1 - comfy/model_patcher.py | 10 ++++++---- 5 files changed, 8 insertions(+), 49 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index da46ed2ed389..a07c3ca95af1 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -279,15 +279,3 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig] - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"Flux load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"Flux load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 2cdf711d4c7b..cd89977167f4 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -911,16 +911,4 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf return self.id_predictor(h) else: return self.out(h) - - - def load_state_dict(self, state_dict, strict=True, assign=False): - # import pdb; pdb.set_trace() - """Override load_state_dict() to add logging""" - logging.info(f"UNetModel load_state_dict start, strict={strict}, state_dict keys count={len(state_dict)}") - - # Call parent's load_state_dict method - result = super().load_state_dict(state_dict, strict=strict, assign=assign) - - logging.info(f"UNetModel load_state_dict end, strict={strict}, state_dict keys count={len(state_dict)}") - - return result + \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index 409e7fb87969..d2d4aa93d682 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -303,10 +303,7 @@ def load_model_weights(self, sd, unet_prefix=""): logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) logging.info(f"load model {self.model_config} weights process end") - # TODO(sf): to mmap - # diffusion_model is UNetModel - # import pdb; pdb.set_trace() - # TODO(sf): here needs to avoid load mmap into cpu mem + # replace tensor with mmap tensor by assign m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") @@ -389,21 +386,6 @@ def memory_required(self, input_shape, cond_shapes={}): area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) - def to(self, *args, **kwargs): - """Override to() to add custom device management logic""" - old_device = self.device if hasattr(self, 'device') else None - - result = super().to(*args, **kwargs) - - if len(args) > 0: - if isinstance(args[0], (torch.device, str)): - new_device = torch.device(args[0]) if isinstance(args[0], str) else args[0] - if 'device' in kwargs: - new_device = kwargs['device'] - - logging.info(f"BaseModel moved from {old_device} to {new_device}") - return result - def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 30a509670eaa..4c29b07e1dd4 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -596,7 +596,6 @@ def minimum_inference_memory(): def free_memory(memory_required, device, keep_loaded=[]): logging.info("start to free mem") - import pdb; pdb.set_trace() cleanup_models_gc() unloaded_model = [] can_unload = [] diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index d2e3a296ad15..1c725663a7c7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -831,8 +831,11 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): self.backup.clear() - model_to_mmap(self.model) - self.model.device = device_to + if device_to is not None: + # offload to mmap + model_to_mmap(self.model) + self.model.device = device_to + self.model.model_loaded_weight_memory = 0 for m in self.model.modules(): @@ -885,8 +888,7 @@ def partially_unload(self, device_to, memory_to_free=0): bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - # TODO(sf): to mmap - # m is what module? + # offload to mmap # m.to(device_to) model_to_mmap(m) module_mem += move_weight_functions(m, device_to) From 2d010f545c9df6c3e07b7560ba7887432261947f Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:54:56 +0800 Subject: [PATCH 20/29] refine code --- comfy/ldm/flux/model.py | 1 - .../modules/diffusionmodules/openaimodel.py | 3 +- comfy/model_base.py | 10 +++---- comfy/model_management.py | 30 +++++++++---------- comfy/model_patcher.py | 4 +-- comfy/sd.py | 8 +---- comfy/utils.py | 7 ++--- execution.py | 3 -- nodes.py | 2 -- server.py | 2 +- 10 files changed, 27 insertions(+), 43 deletions(-) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index a07c3ca95af1..14f90cea55e6 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -7,7 +7,6 @@ from einops import rearrange, repeat import comfy.ldm.common_dit import comfy.patcher_extension -import logging from .layers import ( DoubleStreamBlock, diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index cd89977167f4..4963811a8bea 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -910,5 +910,4 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf if self.predict_codebook_ids: return self.id_predictor(h) else: - return self.out(h) - \ No newline at end of file + return self.out(h) \ No newline at end of file diff --git a/comfy/model_base.py b/comfy/model_base.py index d2d4aa93d682..d6ef644dd42b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -299,14 +299,14 @@ def load_model_weights(self, sd, unet_prefix=""): to_load[k[len(unet_prefix):]] = sd.pop(k) free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.info(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") - logging.info(f"model destination device {next(self.diffusion_model.parameters()).device}") + logging.debug(f"load model weights start, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"model destination device {next(self.diffusion_model.parameters()).device}") to_load = self.model_config.process_unet_state_dict(to_load) - logging.info(f"load model {self.model_config} weights process end") + logging.debug(f"load model {self.model_config} weights process end") # replace tensor with mmap tensor by assign m, u = self.diffusion_model.load_state_dict(to_load, strict=False, assign=True) free_cpu_memory = get_free_memory(torch.device("cpu")) - logging.info(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") + logging.debug(f"load model {self.model_config} weights end, free cpu memory size {free_cpu_memory/(1024*1024*1024)} GB") if len(m) > 0: logging.warning("unet missing: {}".format(m)) @@ -385,7 +385,7 @@ def memory_required(self, input_shape, cond_shapes={}): #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) - + def extra_conds_shapes(self, **kwargs): return {} diff --git a/comfy/model_management.py b/comfy/model_management.py index 4c29b07e1dd4..70a5039efcdc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -509,16 +509,16 @@ def should_reload_model(self, force_patch_weights=False): return False def model_unload(self, memory_to_free=None, unpatch_weights=True): - logging.info(f"model_unload: {self.model.model.__class__.__name__}") - logging.info(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") - logging.info(f"unpatch_weights: {unpatch_weights}") - logging.info(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") - logging.info(f"offload_device: {self.model.offload_device}") + logging.debug(f"model_unload: {self.model.model.__class__.__name__}") + logging.debug(f"memory_to_free: {memory_to_free/(1024*1024*1024)} GB") + logging.debug(f"unpatch_weights: {unpatch_weights}") + logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") + logging.debug(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) - logging.info(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage if available_memory < reserved_memory: - logging.error(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") + logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") return False else: offload_memory = available_memory - reserved_memory @@ -530,14 +530,14 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): - logging.info("Do partially unload") + logging.debug("Do partially unload") freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - logging.info(f"partially_unload freed: {freed/(1024*1024*1024)} GB") + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") if freed >= memory_to_free: return False - logging.info("Do full unload") + logging.debug("Do full unload") self.model.detach(unpatch_weights) - logging.info("Do full unload done") + logging.debug("Do full unload done") except Exception as e: logging.error(f"Error in model_unload: {e}") available_memory = get_free_memory(self.model.offload_device) @@ -595,7 +595,7 @@ def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() def free_memory(memory_required, device, keep_loaded=[]): - logging.info("start to free mem") + logging.debug("start to free mem") cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -616,7 +616,7 @@ def free_memory(memory_required, device, keep_loaded=[]): if free_mem > memory_required: break memory_to_free = memory_required - free_mem - logging.info(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") + logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") if current_loaded_models[i].model_unload(memory_to_free): unloaded_model.append(i) @@ -633,7 +633,7 @@ def free_memory(memory_required, device, keep_loaded=[]): return unloaded_models def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False): - logging.info(f"start to load models") + logging.debug(f"start to load models") cleanup_models_gc() global vram_state @@ -655,7 +655,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu models_to_load = [] for x in models: - logging.info(f"loading model: {x.model.__class__.__name__}") + logging.debug(f"start loading model to vram: {x.model.__class__.__name__}") loaded_model = LoadedModel(x) try: loaded_model_index = current_loaded_models.index(loaded_model) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 1c725663a7c7..63bae24d36be 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -61,7 +61,7 @@ def model_to_mmap(model: torch.nn.Module): The same model with all tensors converted to memory-mapped format """ free_cpu_mem = get_free_memory(torch.device("cpu")) - logging.info(f"Converting model {model.__class__.__name__} to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + logging.debug(f"Converting model {model.__class__.__name__} to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") def convert_fn(t): """Convert function for _apply() @@ -81,7 +81,7 @@ def convert_fn(t): new_model = model._apply(convert_fn) free_cpu_mem = get_free_memory(torch.device("cpu")) - logging.info(f"Model {model.__class__.__name__} converted to mmap, cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") + logging.debug(f"Model {model.__class__.__name__} converted to mmap, current free cpu memory: {free_cpu_mem/(1024*1024*1024)} GB") return new_model diff --git a/comfy/sd.py b/comfy/sd.py index a956884fb7a8..3651da5e7552 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1321,7 +1321,7 @@ def load_diffusion_model_state_dict(sd, model_options={}): logging.warning("{} {}".format(diffusers_keys[k], k)) offload_device = model_management.unet_offload_device() - logging.info(f"loader load model to offload device: {offload_device}") + logging.debug(f"loader load model to offload device: {offload_device}") unet_weight_dtype = list(model_config.supported_inference_dtypes) if model_config.scaled_fp8 is not None: weight_dtype = None @@ -1338,7 +1338,6 @@ def load_diffusion_model_state_dict(sd, model_options={}): model_config.optimizations["fp8"] = True model = model_config.get_model(new_sd, "") - # import pdb; pdb.set_trace() model = model.to(offload_device) model.load_model_weights(new_sd, "") left_over = sd.keys() @@ -1348,13 +1347,8 @@ def load_diffusion_model_state_dict(sd, model_options={}): def load_diffusion_model(unet_path, model_options={}): - # TODO(sf): here load file into mem sd = comfy.utils.load_torch_file(unet_path) - logging.info(f"load model start, path {unet_path}") - # import pdb; pdb.set_trace() model = load_diffusion_model_state_dict(sd, model_options=model_options) - logging.info(f"load model end, path {unet_path}") - # import pdb; pdb.set_trace() if model is None: logging.error("ERROR UNSUPPORTED DIFFUSION MODEL {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}\n{}".format(unet_path, model_detection_error_hint(unet_path, sd))) diff --git a/comfy/utils.py b/comfy/utils.py index 4c22f684c44e..be6ab759655e 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -55,15 +55,13 @@ def scalar(*args, **kwargs): logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.") def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): - # TODO(sf): here load file into mmap - logging.info(f"load_torch_file start, ckpt={ckpt}, safe_load={safe_load}, device={device}, return_metadata={return_metadata}") if device is None: device = torch.device("cpu") metadata = None if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): try: if not DISABLE_MMAP: - logging.info(f"load_torch_file safetensors mmap True") + logging.debug(f"load_torch_file of safetensors into mmap True") with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: sd = {} for k in f.keys(): @@ -84,7 +82,7 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): else: torch_args = {} if MMAP_TORCH_FILES: - logging.info(f"load_torch_file mmap True") + logging.debug(f"load_torch_file of torch state dict into mmap True") torch_args["mmap"] = True if safe_load or ALWAYS_SAFE_LOAD: @@ -102,7 +100,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False): sd = pl_sd else: sd = pl_sd - # import pdb; pdb.set_trace() return (sd, metadata) if return_metadata else sd def save_torch_file(sd, ckpt, metadata=None): diff --git a/execution.py b/execution.py index 53f2953572c5..1dc35738b823 100644 --- a/execution.py +++ b/execution.py @@ -400,8 +400,6 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = nodes.NODE_CLASS_MAPPINGS[class_type] - - if caches.outputs.get(unique_id) is not None: if server.client_id is not None: cached_output = caches.ui.get(unique_id) or {} @@ -595,7 +593,6 @@ async def await_completion(): get_progress_state().finish_progress(unique_id) executed.add(unique_id) - return (ExecutionResult.SUCCESS, None, None) diff --git a/nodes.py b/nodes.py index 25ccc9e421bc..7cfa8ca1411d 100644 --- a/nodes.py +++ b/nodes.py @@ -922,9 +922,7 @@ def load_unet(self, unet_name, weight_dtype): model_options["dtype"] = torch.float8_e5m2 unet_path = folder_paths.get_full_path_or_raise("diffusion_models", unet_name) - logging.info(f"load unet node start, path {unet_path}") model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options) - logging.info(f"load unet node end, path {unet_path}") return (model,) class CLIPLoader: diff --git a/server.py b/server.py index 515307bf6de6..80e9d3fa78a0 100644 --- a/server.py +++ b/server.py @@ -673,7 +673,7 @@ async def get_queue(request): @routes.post("/prompt") async def post_prompt(request): - logging.info("got prompt in debug comfyui") + logging.info("got prompt") json_data = await request.json() json_data = self.trigger_on_prompt(json_data) From fff56de63cfe9ad7057d8403c13c6428d57593c5 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 11:59:59 +0800 Subject: [PATCH 21/29] fix format --- comfy/ldm/modules/diffusionmodules/openaimodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/ldm/modules/diffusionmodules/openaimodel.py b/comfy/ldm/modules/diffusionmodules/openaimodel.py index 4963811a8bea..4c8d53cac9c2 100644 --- a/comfy/ldm/modules/diffusionmodules/openaimodel.py +++ b/comfy/ldm/modules/diffusionmodules/openaimodel.py @@ -910,4 +910,4 @@ def _forward(self, x, timesteps=None, context=None, y=None, control=None, transf if self.predict_codebook_ids: return self.id_predictor(h) else: - return self.out(h) \ No newline at end of file + return self.out(h) From 08e094ed81b66e23876a5cc8be1bb9f40f213061 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 17:00:56 +0800 Subject: [PATCH 22/29] use native mmap --- comfy/model_patcher.py | 78 +++++++- tests/execution/test_model_mmap.py | 280 +++++++++++++++++++++++++++++ 2 files changed, 355 insertions(+), 3 deletions(-) create mode 100644 tests/execution/test_model_mmap.py diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 63bae24d36be..0f4445d33a0c 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -27,7 +27,10 @@ from typing import Callable, Optional import torch -import tensordict +import os +import tempfile +import weakref +import gc import comfy.float import comfy.hooks @@ -39,8 +42,77 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory -def to_mmap(t: torch.Tensor) -> tensordict.MemoryMappedTensor: - return tensordict.MemoryMappedTensor.from_tensor(t) + +def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: + """ + Convert a tensor to a memory-mapped CPU tensor using PyTorch's native mmap support. + """ + # Move to CPU if needed + if t.is_cuda: + cpu_tensor = t.cpu() + else: + cpu_tensor = t + + # Create temporary file + if filename is None: + temp_file = tempfile.mktemp(suffix='.pt', prefix='comfy_mmap_') + else: + temp_file = filename + + # Save tensor to file + torch.save(cpu_tensor, temp_file) + + # If we created a CPU copy from CUDA, delete it to free memory + if t.is_cuda: + del cpu_tensor + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Load with mmap - this doesn't load all data into RAM + mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) + + # Register cleanup callback + def _cleanup(): + try: + if os.path.exists(temp_file): + os.remove(temp_file) + logging.debug(f"Cleaned up mmap file: {temp_file}") + except Exception: + pass + + weakref.finalize(mmap_tensor, _cleanup) + + # Save original 'to' method + original_to = mmap_tensor.to + + # Create custom 'to' method that cleans up file when moving to CUDA + def custom_to(*args, **kwargs): + # Determine target device + target_device = None + if len(args) > 0: + if isinstance(args[0], torch.device): + target_device = args[0] + elif isinstance(args[0], str): + target_device = torch.device(args[0]) + if 'device' in kwargs: + target_device = kwargs['device'] + if isinstance(target_device, str): + target_device = torch.device(target_device) + + # Call original 'to' method first to move data + result = original_to(*args, **kwargs) + + # If moved to CUDA, cleanup the mmap file after the move + if target_device is not None and target_device.type == 'cuda': + _cleanup() + + return result + + # Replace the 'to' method + mmap_tensor.to = custom_to + + return mmap_tensor def model_to_mmap(model: torch.nn.Module): """Convert all parameters and buffers to memory-mapped tensors diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py new file mode 100644 index 000000000000..65dbe01bd691 --- /dev/null +++ b/tests/execution/test_model_mmap.py @@ -0,0 +1,280 @@ +import pytest +import torch +import torch.nn as nn +import psutil +import os +import gc +import tempfile +from comfy.model_patcher import model_to_mmap, to_mmap + + +class LargeModel(nn.Module): + """A simple model with large parameters for testing memory mapping""" + + def __init__(self, size_gb=10): + super().__init__() + # Calculate number of float32 elements needed for target size + # 1 GB = 1024^3 bytes, float32 = 4 bytes + bytes_per_gb = 1024 * 1024 * 1024 + elements_per_gb = bytes_per_gb // 4 # float32 is 4 bytes + total_elements = int(size_gb * elements_per_gb) + + # Create a large linear layer + # Split into multiple layers to avoid single tensor size limits + self.layers = nn.ModuleList() + elements_per_layer = 500 * 1024 * 1024 # 500M elements per layer (~2GB) + num_layers = (total_elements + elements_per_layer - 1) // elements_per_layer + + for i in range(num_layers): + if i == num_layers - 1: + # Last layer gets the remaining elements + remaining = total_elements - (i * elements_per_layer) + in_features = int(remaining ** 0.5) + out_features = (remaining + in_features - 1) // in_features + else: + in_features = int(elements_per_layer ** 0.5) + out_features = (elements_per_layer + in_features - 1) // in_features + + # Create layer without bias to control size precisely + self.layers.append(nn.Linear(in_features, out_features, bias=False)) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return x + + +def get_process_memory_gb(): + """Get current process memory usage in GB""" + process = psutil.Process(os.getpid()) + mem_info = process.memory_info() + return mem_info.rss / (1024 ** 3) # Convert to GB + + +def get_model_size_gb(model): + """Calculate model size in GB""" + total_size = 0 + for param in model.parameters(): + total_size += param.nelement() * param.element_size() + for buffer in model.buffers(): + total_size += buffer.nelement() * buffer.element_size() + return total_size / (1024 ** 3) + + +def test_model_to_mmap_memory_efficiency(): + """Test that model_to_mmap reduces memory usage for a 10GB model to less than 1GB + + The typical use case is: + 1. Load a large model on CUDA + 2. Convert to mmap to offload from GPU to disk-backed memory + 3. This frees GPU memory and reduces CPU RAM usage + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection before starting + gc.collect() + torch.cuda.empty_cache() + + # Record initial memory + initial_cpu_memory = get_process_memory_gb() + initial_gpu_memory = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"\nInitial CPU memory: {initial_cpu_memory:.2f} GB") + print(f"Initial GPU memory: {initial_gpu_memory:.2f} GB") + + # Create a 10GB model + print("Creating 10GB model...") + model = LargeModel(size_gb=10) + + # Verify model size + model_size = get_model_size_gb(model) + print(f"Model size: {model_size:.2f} GB") + assert model_size >= 9.5, f"Model size {model_size:.2f} GB is less than expected 10 GB" + + # Move model to CUDA + print("Moving model to CUDA...") + model = model.cuda() + torch.cuda.synchronize() + + # Memory after moving to CUDA + cpu_after_cuda = get_process_memory_gb() + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after moving to CUDA: {cpu_after_cuda:.2f} GB") + print(f"GPU memory after moving to CUDA: {gpu_after_cuda:.2f} GB") + + # Convert to mmap (this should move model from GPU to disk-backed memory) + # Note: model_to_mmap modifies the model in-place via _apply() + # so model and model_mmap will be the same object + print("Converting model to mmap...") + model_mmap = model_to_mmap(model) + + # Verify that model and model_mmap are the same object (in-place modification) + assert model is model_mmap, "model_to_mmap should modify the model in-place" + + # Force garbage collection and clear CUDA cache + # The original CUDA tensors should be automatically freed when replaced + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Memory after mmap conversion + cpu_after_mmap = get_process_memory_gb() + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f"CPU memory after mmap: {cpu_after_mmap:.2f} GB") + print(f"GPU memory after mmap: {gpu_after_mmap:.2f} GB") + + # Calculate memory changes from CUDA state (the baseline we're converting from) + cpu_increase = cpu_after_mmap - cpu_after_cuda + gpu_decrease = gpu_after_cuda - gpu_after_mmap # Should be positive (freed) + print(f"\nCPU memory increase from CUDA: {cpu_increase:.2f} GB") + print(f"GPU memory freed: {gpu_decrease:.2f} GB") + + # Verify that CPU memory usage increase is less than 1GB + # The mmap should use disk-backed storage, keeping CPU RAM usage low + # We use 1.5 GB threshold to account for overhead + assert cpu_increase < 1.5, ( + f"CPU memory increase after mmap ({cpu_increase:.2f} GB) should be less than 1.5 GB. " + f"CUDA state: {cpu_after_cuda:.2f} GB, After mmap: {cpu_after_mmap:.2f} GB" + ) + + # Verify that GPU memory has been freed + # We expect at least 9 GB to be freed (original 10GB model with some tolerance) + assert gpu_decrease > 9.0, ( + f"GPU memory should be freed after mmap. " + f"Freed: {gpu_decrease:.2f} GB (from {gpu_after_cuda:.2f} to {gpu_after_mmap:.2f} GB), expected > 9 GB" + ) + + # Verify the model is still functional (basic sanity check) + assert model_mmap is not None + assert len(list(model_mmap.parameters())) > 0 + + print(f"\nโœ“ Test passed!") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 1.5 GB") + print(f" GPU memory freed: {gpu_decrease:.2f} GB > 9.0 GB") + print(f" Model successfully offloaded from GPU to disk-backed memory") + + # Cleanup (model and model_mmap are the same object) + del model, model_mmap + gc.collect() + torch.cuda.empty_cache() + + +def test_to_mmap_cuda_cycle(): + """Test CUDA -> mmap -> CUDA cycle + + This test verifies: + 1. CUDA tensor can be converted to mmap tensor + 2. CPU memory increase is minimal when using mmap (< 0.1 GB) + 3. GPU memory is freed when converting to mmap + 4. mmap tensor can be moved back to CUDA + 5. Data remains consistent throughout the cycle + 6. mmap file is automatically cleaned up when moved to CUDA + """ + + # Check if CUDA is available + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available, skipping test") + + # Force garbage collection + gc.collect() + torch.cuda.empty_cache() + + print("\nTest: CUDA -> mmap -> CUDA cycle") + + # Record initial CPU memory + initial_cpu_memory = get_process_memory_gb() + print(f"Initial CPU memory: {initial_cpu_memory:.2f} GB") + + # Step 1: Create a CUDA tensor + print("\n1. Creating CUDA tensor...") + original_data = torch.randn(5000, 5000).cuda() + original_sum = original_data.sum().item() + print(f" Shape: {original_data.shape}") + print(f" Device: {original_data.device}") + print(f" Sum: {original_sum:.2f}") + + # Record GPU and CPU memory after CUDA allocation + cpu_after_cuda = get_process_memory_gb() + gpu_before_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_before_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_cuda:.2f} GB") + + # Step 2: Convert to mmap tensor + print("\n2. Converting to mmap tensor...") + mmap_tensor = to_mmap(original_data) + del original_data + gc.collect() + torch.cuda.empty_cache() + + print(f" Device: {mmap_tensor.device}") + print(f" Sum: {mmap_tensor.sum().item():.2f}") + + # Verify GPU memory is freed + gpu_after_mmap = torch.cuda.memory_allocated() / (1024 ** 3) + cpu_after_mmap = get_process_memory_gb() + print(f" GPU memory freed: {gpu_before_mmap - gpu_after_mmap:.2f} GB") + print(f" CPU memory: {cpu_after_mmap:.2f} GB") + + # Verify GPU memory is freed + assert gpu_after_mmap < 0.1, f"GPU memory should be freed, but {gpu_after_mmap:.2f} GB still allocated" + + # Verify CPU memory increase is minimal (should be close to 0 due to mmap) + cpu_increase = cpu_after_mmap - cpu_after_cuda + print(f" CPU memory increase: {cpu_increase:.2f} GB") + assert cpu_increase < 0.1, f"CPU memory should increase minimally, but increased by {cpu_increase:.2f} GB" + + # Get the temp file path (we'll check if it gets cleaned up) + # The file should exist at this point + temp_files_before = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files exist: {temp_files_before}") + + # Step 3: Move back to CUDA + print("\n3. Moving back to CUDA...") + cuda_tensor = mmap_tensor.to('cuda') + torch.cuda.synchronize() + + print(f" Device: {cuda_tensor.device}") + final_sum = cuda_tensor.sum().item() + print(f" Sum: {final_sum:.2f}") + + # Verify GPU memory is used again + gpu_after_cuda = torch.cuda.memory_allocated() / (1024 ** 3) + print(f" GPU memory: {gpu_after_cuda:.2f} GB") + + # Step 4: Verify data consistency + print("\n4. Verifying data consistency...") + sum_diff = abs(original_sum - final_sum) + print(f" Original sum: {original_sum:.2f}") + print(f" Final sum: {final_sum:.2f}") + print(f" Difference: {sum_diff:.6f}") + assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" + + # Step 5: Verify file cleanup + print("\n5. Verifying file cleanup...") + gc.collect() + import time + time.sleep(0.1) # Give OS time to clean up + temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) + print(f" Temp mmap files after: {temp_files_after}") + # File should be cleaned up when moved to CUDA + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" + + print("\nโœ“ Test passed!") + print(" CUDA -> mmap -> CUDA cycle works correctly") + print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") + print(" Data consistency maintained") + print(" File cleanup successful") + + # Cleanup + del mmap_tensor, cuda_tensor + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == "__main__": + # Run the tests directly + test_model_to_mmap_memory_efficiency() + test_to_mmap_cuda_cycle() + From 80383932ec63056261d584771eba8c8c1eb51ebf Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 18:00:31 +0800 Subject: [PATCH 23/29] lazy rm file --- comfy/model_patcher.py | 55 +++++++++++++++--------------- tests/execution/test_model_mmap.py | 16 +++++---- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0f4445d33a0c..4b0c5b9c5b75 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -72,7 +72,7 @@ def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: # Load with mmap - this doesn't load all data into RAM mmap_tensor = torch.load(temp_file, map_location='cpu', mmap=True, weights_only=False) - # Register cleanup callback + # Register cleanup callback - will be called when tensor is garbage collected def _cleanup(): try: if os.path.exists(temp_file): @@ -83,34 +83,35 @@ def _cleanup(): weakref.finalize(mmap_tensor, _cleanup) - # Save original 'to' method - original_to = mmap_tensor.to + # # Save original 'to' method + # original_to = mmap_tensor.to - # Create custom 'to' method that cleans up file when moving to CUDA - def custom_to(*args, **kwargs): - # Determine target device - target_device = None - if len(args) > 0: - if isinstance(args[0], torch.device): - target_device = args[0] - elif isinstance(args[0], str): - target_device = torch.device(args[0]) - if 'device' in kwargs: - target_device = kwargs['device'] - if isinstance(target_device, str): - target_device = torch.device(target_device) - - # Call original 'to' method first to move data - result = original_to(*args, **kwargs) - - # If moved to CUDA, cleanup the mmap file after the move - if target_device is not None and target_device.type == 'cuda': - _cleanup() - - return result + # # Create custom 'to' method that cleans up file when moving to CUDA + # def custom_to(*args, **kwargs): + # # Determine target device + # target_device = None + # if len(args) > 0: + # if isinstance(args[0], torch.device): + # target_device = args[0] + # elif isinstance(args[0], str): + # target_device = torch.device(args[0]) + # if 'device' in kwargs: + # target_device = kwargs['device'] + # if isinstance(target_device, str): + # target_device = torch.device(target_device) + # + # # Call original 'to' method first to move data + # result = original_to(*args, **kwargs) + # + # # NOTE: Cleanup disabled to avoid blocking model load performance + # # If moved to CUDA, cleanup the mmap file after the move + # if target_device is not None and target_device.type == 'cuda': + # _cleanup() + # + # return result - # Replace the 'to' method - mmap_tensor.to = custom_to + # # Replace the 'to' method + # mmap_tensor.to = custom_to return mmap_tensor diff --git a/tests/execution/test_model_mmap.py b/tests/execution/test_model_mmap.py index 65dbe01bd691..7a608c9316b1 100644 --- a/tests/execution/test_model_mmap.py +++ b/tests/execution/test_model_mmap.py @@ -170,7 +170,7 @@ def test_to_mmap_cuda_cycle(): 3. GPU memory is freed when converting to mmap 4. mmap tensor can be moved back to CUDA 5. Data remains consistent throughout the cycle - 6. mmap file is automatically cleaned up when moved to CUDA + 6. mmap file is automatically cleaned up via garbage collection """ # Check if CUDA is available @@ -251,24 +251,26 @@ def test_to_mmap_cuda_cycle(): print(f" Difference: {sum_diff:.6f}") assert sum_diff < 0.01, f"Data should be consistent, but difference is {sum_diff:.6f}" - # Step 5: Verify file cleanup + # Step 5: Verify file cleanup (delayed until garbage collection) print("\n5. Verifying file cleanup...") + # Delete the mmap tensor reference to trigger garbage collection + del mmap_tensor gc.collect() import time time.sleep(0.1) # Give OS time to clean up temp_files_after = len([f for f in os.listdir(tempfile.gettempdir()) if f.startswith('comfy_mmap_')]) - print(f" Temp mmap files after: {temp_files_after}") - # File should be cleaned up when moved to CUDA - assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after moving to CUDA" + print(f" Temp mmap files after GC: {temp_files_after}") + # File should be cleaned up after garbage collection + assert temp_files_after <= temp_files_before, "mmap file should be cleaned up after garbage collection" print("\nโœ“ Test passed!") print(" CUDA -> mmap -> CUDA cycle works correctly") print(f" CPU memory increase: {cpu_increase:.2f} GB < 0.1 GB (mmap efficiency)") print(" Data consistency maintained") - print(" File cleanup successful") + print(" File cleanup successful (via garbage collection)") # Cleanup - del mmap_tensor, cuda_tensor + del cuda_tensor # mmap_tensor already deleted in Step 5 gc.collect() torch.cuda.empty_cache() From 98ba3115110164d3c81ef01dc7b9790c67539328 Mon Sep 17 00:00:00 2001 From: strint Date: Tue, 21 Oct 2025 19:06:34 +0800 Subject: [PATCH 24/29] add env --- comfy/model_patcher.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 4b0c5b9c5b75..f379c230b9e2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -42,6 +42,13 @@ from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP from comfy.model_management import get_free_memory +def need_mmap() -> bool: + free_cpu_mem = get_free_memory(torch.device("cpu")) + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "1024")) + if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: + logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") + return True + return False def to_mmap(t: torch.Tensor, filename: Optional[str] = None) -> torch.Tensor: """ @@ -905,8 +912,11 @@ def unpatch_model(self, device_to=None, unpatch_weights=True): if device_to is not None: - # offload to mmap - model_to_mmap(self.model) + if need_mmap(): + # offload to mmap + model_to_mmap(self.model) + else: + self.model.to(device_to) self.model.device = device_to self.model.model_loaded_weight_memory = 0 @@ -961,9 +971,11 @@ def partially_unload(self, device_to, memory_to_free=0): bias_key = "{}.bias".format(n) if move_weight: cast_weight = self.force_cast_weights - # offload to mmap - # m.to(device_to) - model_to_mmap(m) + if need_mmap(): + # offload to mmap + model_to_mmap(m) + else: + m.to(device_to) module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: From aab0e244f7b221a00b9049f8dfaa0706185b22bd Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 14:44:51 +0800 Subject: [PATCH 25/29] fix MMAP_MEM_THRESHOLD_GB default --- comfy/model_patcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f379c230b9e2..115e401b3ab0 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -44,7 +44,7 @@ def need_mmap() -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "1024")) + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True From 58d28edade40ecd45ac7b20272f51a978a3045d2 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 15:50:57 +0800 Subject: [PATCH 26/29] no limit for offload size --- comfy/model_management.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 0dc471fb8cc6..8bf4e68fbd57 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -526,17 +526,17 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.debug(f"offload_device: {self.model.offload_device}") available_memory = get_free_memory(self.model.offload_device) logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage - if available_memory < reserved_memory: - logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") - return False - else: - offload_memory = available_memory - reserved_memory - - if offload_memory < memory_to_free: - memory_to_free = offload_memory - logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") - logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") + # reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + # if available_memory < reserved_memory: + # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") + # return False + # else: + # offload_memory = available_memory - reserved_memory + # + # if offload_memory < memory_to_free: + # memory_to_free = offload_memory + # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") + # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") try: if memory_to_free is not None: if memory_to_free < self.model.loaded_size(): From c312733b8cd28010e3370716f1311d3b30067b13 Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 15:53:35 +0800 Subject: [PATCH 27/29] refine log --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8bf4e68fbd57..f4ed13899cfa 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -553,6 +553,10 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): available_memory = get_free_memory(self.model.offload_device) logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") return False + finally: + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + self.model_finalizer.detach() self.model_finalizer = None self.real_model = None From dc7c77e78cb219f149c448cb961ae5122be7ce6b Mon Sep 17 00:00:00 2001 From: strint Date: Thu, 23 Oct 2025 18:09:47 +0800 Subject: [PATCH 28/29] better partial unload --- comfy/model_management.py | 62 +++++++++++++++++++++++++-------------- comfy/model_patcher.py | 7 +++-- 2 files changed, 45 insertions(+), 24 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index f4ed13899cfa..f2e23c446e59 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -26,6 +26,14 @@ import platform import weakref import gc +import os + +def get_mmap_mem_threshold_gb(): + mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + return mmap_mem_threshold_gb + +def get_free_disk(): + return psutil.disk_usage("/").free class VRAMState(Enum): DISABLED = 0 #No vram present: no need to move models to vram @@ -524,9 +532,7 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): logging.debug(f"unpatch_weights: {unpatch_weights}") logging.debug(f"loaded_size: {self.model.loaded_size()/(1024*1024*1024)} GB") logging.debug(f"offload_device: {self.model.offload_device}") - available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") - # reserved_memory = 1024*1024*1024 # 1GB reserved memory for other usage + # if available_memory < reserved_memory: # logging.warning(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB") # return False @@ -537,30 +543,42 @@ def model_unload(self, memory_to_free=None, unpatch_weights=True): # memory_to_free = offload_memory # logging.info(f"Not enough cpu memory to unload. Available: {available_memory/(1024*1024*1024)} GB, Reserved: {reserved_memory/(1024*1024*1024)} GB, Offload: {offload_memory/(1024*1024*1024)} GB") # logging.info(f"Set memory_to_free to {memory_to_free/(1024*1024*1024)} GB") - try: - if memory_to_free is not None: - if memory_to_free < self.model.loaded_size(): - logging.debug("Do partially unload") - freed = self.model.partially_unload(self.model.offload_device, memory_to_free) - logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") - if freed >= memory_to_free: - return False + + if memory_to_free is None: + # free the full model + memory_to_free = self.model.loaded_size() + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"before unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + mmap_mem_threshold = get_mmap_mem_threshold_gb() * 1024 * 1024 * 1024 # this is reserved memory for other system usage + if memory_to_free > available_memory - mmap_mem_threshold or memory_to_free < self.model.loaded_size(): + partially_unload = True + else: + partially_unload = False + + if partially_unload: + logging.debug("Do partially unload") + freed = self.model.partially_unload(self.model.offload_device, memory_to_free) + logging.debug(f"partially_unload freed vram: {freed/(1024*1024*1024)} GB") + if freed < memory_to_free: + logging.warning(f"Partially unload not enough memory, freed {freed/(1024*1024*1024)} GB, memory_to_free {memory_to_free/(1024*1024*1024)} GB") + else: logging.debug("Do full unload") self.model.detach(unpatch_weights) logging.debug("Do full unload done") - except Exception as e: - logging.error(f"Error in model_unload: {e}") - available_memory = get_free_memory(self.model.offload_device) - logging.info(f"after error, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + self.model_finalizer.detach() + self.model_finalizer = None + self.real_model = None + + available_memory = get_free_memory(self.model.offload_device) + logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + + if partially_unload: return False - finally: - available_memory = get_free_memory(self.model.offload_device) - logging.debug(f"after unload, available_memory of offload device {self.model.offload_device}: {available_memory/(1024*1024*1024)} GB") + else: + return True - self.model_finalizer.detach() - self.model_finalizer = None - self.real_model = None - return True def model_use_more_vram(self, extra_memory, force_patch_weights=False): return self.model.partially_load(self.device, extra_memory, force_patch_weights=force_patch_weights) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 115e401b3ab0..361f15e5b9c7 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,11 +40,11 @@ import comfy.utils from comfy.comfy_types import UnetWrapperFunction from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP -from comfy.model_management import get_free_memory +from comfy.model_management import get_free_memory, get_mmap_mem_threshold_gb, get_free_disk def need_mmap() -> bool: free_cpu_mem = get_free_memory(torch.device("cpu")) - mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + mmap_mem_threshold_gb = get_mmap_mem_threshold_gb() if free_cpu_mem < mmap_mem_threshold_gb * 1024 * 1024 * 1024: logging.debug(f"Enabling mmap, current free cpu memory {free_cpu_mem/(1024*1024*1024)} GB < {mmap_mem_threshold_gb} GB") return True @@ -972,6 +972,9 @@ def partially_unload(self, device_to, memory_to_free=0): if move_weight: cast_weight = self.force_cast_weights if need_mmap(): + if get_free_disk() < module_mem: + logging.warning(f"Not enough disk space to offload {n} to mmap, current free disk space {get_free_disk()/(1024*1024*1024)} GB < {module_mem/(1024*1024*1024)} GB") + break # offload to mmap model_to_mmap(m) else: From 5c5fbddbbe71c986e43214002069d0edd1260445 Mon Sep 17 00:00:00 2001 From: strint Date: Mon, 17 Nov 2025 15:34:50 +0800 Subject: [PATCH 29/29] debug mmap --- comfy/model_management.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index f2e23c446e59..a2ad5db2aac7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -28,8 +28,12 @@ import gc import os +from functools import lru_cache + +@lru_cache(maxsize=1) def get_mmap_mem_threshold_gb(): mmap_mem_threshold_gb = int(os.environ.get("MMAP_MEM_THRESHOLD_GB", "0")) + logging.debug(f"MMAP_MEM_THRESHOLD_GB: {mmap_mem_threshold_gb}") return mmap_mem_threshold_gb def get_free_disk():