Skip to content
72 changes: 70 additions & 2 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import os
import importlib
import platform
import weakref
Expand Down Expand Up @@ -81,6 +82,56 @@ def get_supported_float8_types():

FLOAT8_TYPES = get_supported_float8_types()

def get_docker_memory_limit():
"""
Get Docker container memory limit from cgroup.
Returns None if not in a container or no limit is set.
"""
try:
# Try cgroup v2 first
cgroup_v2_path = '/sys/fs/cgroup/memory.max'
if os.path.exists(cgroup_v2_path):
with open(cgroup_v2_path, 'r') as f:
limit = f.read().strip()
if limit != 'max': # 'max' means no limit
return int(limit)

# Try cgroup v1
cgroup_v1_path = '/sys/fs/cgroup/memory/memory.limit_in_bytes'
if os.path.exists(cgroup_v1_path):
with open(cgroup_v1_path, 'r') as f:
limit = int(f.read().strip())
# Check if it's the default "no limit" value (very large number)
# Typical values: 9223372036854771712 or similar
if limit < (1 << 60): # Less than ~1 exabyte, likely a real limit
return limit
except Exception as e:
logging.debug(f"Could not read cgroup memory limit: {e}")

return None

def get_docker_memory_usage():
"""
Get current memory usage from cgroup.
Returns None if not available.
"""
try:
# Try cgroup v2
cgroup_v2_path = '/sys/fs/cgroup/memory.current'
if os.path.exists(cgroup_v2_path):
with open(cgroup_v2_path, 'r') as f:
return int(f.read().strip())

# Try cgroup v1
cgroup_v1_path = '/sys/fs/cgroup/memory/memory.usage_in_bytes'
if os.path.exists(cgroup_v1_path):
with open(cgroup_v1_path, 'r') as f:
return int(f.read().strip())
except Exception as e:
logging.debug(f"Could not read cgroup memory usage: {e}")

return None

xpu_available = False
torch_version = ""
try:
Expand Down Expand Up @@ -200,7 +251,13 @@ def get_total_memory(dev=None, torch_total_too=False):
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_total = psutil.virtual_memory().total
# Check if running in Docker with memory limit
docker_limit = get_docker_memory_limit()
if docker_limit is not None:
mem_total = docker_limit
logging.debug(f"Docker total memory limit: {docker_limit/(1024**3):.2f}GB")
else:
mem_total = psutil.virtual_memory().total
mem_total_torch = mem_total
else:
if directml_enabled:
Expand Down Expand Up @@ -1214,7 +1271,18 @@ def get_free_memory(dev=None, torch_free_too=False):
dev = get_torch_device()

if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
mem_free_total = psutil.virtual_memory().available
# Check if running in Docker with memory limit
docker_limit = get_docker_memory_limit()
docker_usage = get_docker_memory_usage()

if docker_limit is not None and docker_usage is not None:
# Running in Docker with memory limit
mem_free_total = docker_limit - docker_usage
logging.debug(f"Docker memory: limit={docker_limit/(1024**3):.2f}GB, usage={docker_usage/(1024**3):.2f}GB, free={mem_free_total/(1024**3):.2f}GB")
else:
# Not in Docker or no limit set, use system memory
mem_free_total = psutil.virtual_memory().available

mem_free_torch = mem_free_total
else:
if directml_enabled:
Expand Down
Loading