diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 00aab9164e5e..92cd6b179342 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -29,7 +29,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s else: scale_size = (size, size) - image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + if image.device.type == 'musa': + image = image.cpu() + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) + image = image.to('musa') + else: + image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True) h = (image.shape[2] - size)//2 w = (image.shape[3] - size)//2 image = image[:,:,h:h+size,w:w+size] diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 3e09781768a7..c1bfa32c8a6a 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -27,9 +27,14 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor: device = torch.device("cpu") else: device = pos.device - - scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) - omega = 1.0 / (theta**scale) + if device.type == "musa": + scale = torch.linspace(0, (dim - 2) / dim, steps=dim // 2, dtype=torch.float32, device=device) + if not isinstance(theta, torch.Tensor): + theta = torch.tensor(theta, dtype=torch.float32, device=device) + omega = torch.exp(-scale * torch.log(theta + 1e-6)) + else: + scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) + omega = 1.0 / (theta**scale) out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) diff --git a/comfy/model_management.py b/comfy/model_management.py index d08aee1fe10a..8478e50d89eb 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -131,6 +131,15 @@ def get_supported_float8_types(): except: mlu_available = False +try: + import torch_musa + _ = torch_musa.device_count() + musa_available = torch_musa.is_available() + if musa_available: + logging.info("MUSA device detected: {}".format(torch_musa.get_device_name(0))) +except: + musa_available = False + try: ixuca_available = hasattr(torch, "corex") except: @@ -159,6 +168,12 @@ def is_mlu(): return True return False +def is_musa(): + global musa_available + if musa_available: + return True + return False + def is_ixuca(): global ixuca_available if ixuca_available: @@ -182,6 +197,8 @@ def get_torch_device(): return torch.device("npu", torch.npu.current_device()) elif is_mlu(): return torch.device("mlu", torch.mlu.current_device()) + elif is_musa(): + return torch.device('musa', torch.musa.current_device()) else: return torch.device(torch.cuda.current_device()) @@ -215,6 +232,12 @@ def get_total_memory(dev=None, torch_total_too=False): _, mem_total_mlu = torch.mlu.mem_get_info(dev) mem_total_torch = mem_reserved mem_total = mem_total_mlu + elif is_musa(): + stats = torch.musa.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total = torch.musa.mem_get_info(dev) + mem_total_torch = mem_reserved + else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -1157,6 +1180,14 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_mlu, _ = torch.mlu.mem_get_info(dev) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_mlu + mem_free_torch + elif is_musa(): + stats = torch.musa.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_musa, _ = torch.musa.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_musa + mem_free_torch + else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1235,6 +1266,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_mlu(): return True + if is_musa(): + return True + if is_ixuca(): return True @@ -1301,6 +1335,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_ascend_npu(): return True + if is_musa(): + return True + if is_ixuca(): return True