From e56e423e440cfd0f62fdc2f0de06549e170d0dcd Mon Sep 17 00:00:00 2001 From: hben35096 <139383150+hben35096@users.noreply.github.com> Date: Fri, 9 May 2025 01:51:30 +0800 Subject: [PATCH 1/6] Update clip_vision.py --- comfy/clip_vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164e5e..8f592652e798 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -29,7 +29,7 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear") h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] From d22feacb052028ea074a3d753903ed4d1b291a1b Mon Sep 17 00:00:00 2001 From: hben35096 <139383150+hben35096@users.noreply.github.com> Date: Fri, 9 May 2025 02:26:43 +0800 Subject: [PATCH 2/6] Added support for Moore Threads GPUs --- comfy/model_management.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 44aff37625c0..f92820375da2 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -128,6 +128,15 @@ def get_supported_float8_types(): except: mlu_available = False +try: + import torch_musa + _ = torch_musa.device_count() + musa_available = torch_musa.is_available() + if musa_available: + logging.info("MUSA device detected: {}".format(torch_musa.get_device_name(0))) +except: + musa_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -151,6 +160,12 @@ def is_mlu(): return True return False +def is_musa(): + global musa_available + if musa_available: + return True + return False + def get_torch_device(): global directml_enabled global cpu_state @@ -168,6 +183,8 @@ def get_torch_device(): return torch.device("npu", torch.npu.current_device()) elif is_mlu(): return torch.device("mlu", torch.mlu.current_device()) + elif is_musa(): + return torch.device('musa', torch.musa.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -200,6 +217,12 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_mlu + elif is_musa(): + stats = torch.musa.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total = torch.musa.mem_get_info(dev) + mem_total_torch = mem_reserved + else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -1099,6 +1122,14 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_mlu, _ = torch.mlu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_mlu + mem_free_torch + elif is_musa(): + stats = torch.musa.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_musa, _ = torch.musa.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_musa + mem_free_torch + else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1171,6 +1202,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_musa(): + return True + if torch.version.hip: return True @@ -1231,6 +1265,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return True + if is_musa(): + return True + if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 From b06b4f67cdd3e8e2bd9eff331623ab9c3674136f Mon Sep 17 00:00:00 2001 From: hben35096 <139383150+hben35096@users.noreply.github.com> Date: Fri, 9 May 2025 23:27:23 +0800 Subject: [PATCH 3/6] Solve the error "RuntimeError: BinaryCall MUDNN failed in: Run PowTensorOut" of the flux KSampler. --- comfy/ldm/flux/math.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e09781768a7..c1bfa32c8a6a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -27,9 +27,14 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: device = torch.device("cpu") else: device = pos.device - - scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) - omega = 1.0 / (theta**scale) + if device.type == "musa": + scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float32, device=device) + if not isinstance(theta, torch.Tensor): + theta = torch.tensor(theta, dtype=torch.float32, device=device) + omega = torch.exp(-scale * torch.log(theta + 1e-6)) + else: + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) From 31914c9beb5d1b2d0e4062d1fa6a6304458ec9d9 Mon Sep 17 00:00:00 2001 From: hben35096 <139383150+hben35096@users.noreply.github.com> Date: Sat, 10 May 2025 13:47:35 +0800 Subject: [PATCH 4/6] Use the CPU for interpolation. --- comfy/clip_vision.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 8f592652e798..92cd6b179342 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -29,7 +29,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bilinear") + if image.device.type == 'musa': + image = image.cpu() + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = image.to('musa') + else: + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] From 7e1d23356a31172a195c7181c6ffedf9ff055ed1 Mon Sep 17 00:00:00 2001 From: mccxadmin Date: Thu, 31 Jul 2025 10:09:22 +0800 Subject: [PATCH 5/6] changes to run svd_photo_to_video.json --- comfy/ldm/modules/attention.py | 2 ++ comfy/ldm/modules/sub_quadratic_attention.py | 8 +++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 35d2270ee98e..8851c0e0cf0c 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -632,6 +632,8 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) + if context.is_contiguous() is False: + context = context.contiguous() k = self.to_k(context) if value is not None: v = self.to_v(value) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index fab145f1c208..15a509136c97 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -16,6 +16,12 @@ import math import logging +try: + import torch_musa + musa_available = torch_musa.is_available() +except: + musa_available = False + try: from typing import Optional, NamedTuple, List, Protocol except ImportError: @@ -145,7 +151,7 @@ def _get_attention_scores_no_kv_chunking( mask, ) -> Tensor: if upcast_attention: - with torch.autocast(enabled=False, device_type = 'cuda'): + with torch.autocast(enabled=False, device_type = 'musa'): query = query.float() key_t = key_t.float() attn_scores = torch.baddbmm( From c0c5d3d1fb08aae32227df09682727eae6266ce8 Mon Sep 17 00:00:00 2001 From: Ldpe2G Date: Fri, 8 Aug 2025 12:14:34 +0800 Subject: [PATCH 6/6] refine --- comfy/ldm/modules/attention.py | 2 -- comfy/ldm/modules/sub_quadratic_attention.py | 8 +------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 8851c0e0cf0c..35d2270ee98e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -632,8 +632,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) context = default(context, x) - if context.is_contiguous() is False: - context = context.contiguous() k = self.to_k(context) if value is not None: v = self.to_v(value) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 15a509136c97..fab145f1c208 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -16,12 +16,6 @@ import math import logging -try: - import torch_musa - musa_available = torch_musa.is_available() -except: - musa_available = False - try: from typing import Optional, NamedTuple, List, Protocol except ImportError: @@ -151,7 +145,7 @@ def _get_attention_scores_no_kv_chunking( mask, ) -> Tensor: if upcast_attention: - with torch.autocast(enabled=False, device_type = 'musa'): + with torch.autocast(enabled=False, device_type = 'cuda'): query = query.float() key_t = key_t.float() attn_scores = torch.baddbmm(