Initial exploration of weight zipper

This commit is contained in:
Jedrzej Kosinski
2025-03-24 03:34:42 -05:00
parent 3b19fc76e3
commit c8037ab667
4 changed files with 202 additions and 13 deletions

View File

@@ -17,7 +17,7 @@
"""
from __future__ import annotations
from typing import Optional, Callable
from typing import Optional, Callable, TYPE_CHECKING
import torch
import copy
import inspect
@@ -26,6 +26,7 @@ import uuid
import collections
import math
import comfy.ops
import comfy.utils
import comfy.float
import comfy.model_management
@@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
self.model = model
self.model: BaseModel = model
if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device
@@ -568,6 +572,14 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected():
self.unpatch_hooks()
@@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0
loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = []
loading.sort(reverse=True)
for x in loading:
@@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
self.model.zipper_initialized = False
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
@@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0
@@ -804,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):