Initial exploration of weight zipper
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, Callable, TYPE_CHECKING
|
||||
import torch
|
||||
import copy
|
||||
import inspect
|
||||
@@ -26,6 +26,7 @@ import uuid
|
||||
import collections
|
||||
import math
|
||||
|
||||
import comfy.ops
|
||||
import comfy.utils
|
||||
import comfy.float
|
||||
import comfy.model_management
|
||||
@@ -34,6 +35,9 @@ import comfy.hooks
|
||||
import comfy.patcher_extension
|
||||
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
if TYPE_CHECKING:
|
||||
from comfy.model_base import BaseModel
|
||||
|
||||
|
||||
def string_to_seed(data):
|
||||
crc = 0xFFFFFFFF
|
||||
@@ -201,7 +205,7 @@ class MemoryCounter:
|
||||
class ModelPatcher:
|
||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||
self.size = size
|
||||
self.model = model
|
||||
self.model: BaseModel = model
|
||||
if not hasattr(self.model, 'device'):
|
||||
logging.debug("Model doesn't have a device attribute.")
|
||||
self.model.device = offload_device
|
||||
@@ -568,6 +572,14 @@ class ModelPatcher:
|
||||
else:
|
||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||
|
||||
def _zipper_dict_lowvram_only(self):
|
||||
loading = self._load_list_lowvram_only()
|
||||
|
||||
|
||||
def _load_list_lowvram_only(self):
|
||||
loading = self._load_list()
|
||||
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||
|
||||
def _load_list(self):
|
||||
loading = []
|
||||
for n, m in self.model.named_modules():
|
||||
@@ -583,6 +595,35 @@ class ModelPatcher:
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
return loading
|
||||
|
||||
def prepare_teeth(self):
|
||||
ordered_list = self._load_list_lowvram_only()
|
||||
prev_i = None
|
||||
next_i = None
|
||||
# first, create teeth on modules in list
|
||||
for l in ordered_list:
|
||||
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||
m.init_tooth(self.load_device, self.offload_device, l[1])
|
||||
# create teeth linked list
|
||||
for i in range(len(ordered_list)):
|
||||
if i+1 == len(ordered_list):
|
||||
next_i = None
|
||||
else:
|
||||
next_i = i+1
|
||||
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
|
||||
if prev_i is not None:
|
||||
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
|
||||
else:
|
||||
m.zipper_tooth.start = True
|
||||
if next_i is not None:
|
||||
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
|
||||
prev_i = i
|
||||
|
||||
def clean_teeth(self):
|
||||
ordered_list = self._load_list_lowvram_only()
|
||||
for l in ordered_list:
|
||||
m: comfy.ops.CastWeightBiasOp = l[2]
|
||||
m.clean_tooth()
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
with self.use_ejected():
|
||||
self.unpatch_hooks()
|
||||
@@ -591,6 +632,8 @@ class ModelPatcher:
|
||||
lowvram_counter = 0
|
||||
loading = self._load_list()
|
||||
|
||||
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
|
||||
|
||||
load_completely = []
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
@@ -672,6 +715,7 @@ class ModelPatcher:
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
self.model.zipper_initialized = False
|
||||
else:
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self.model.model_lowvram = False
|
||||
@@ -684,6 +728,9 @@ class ModelPatcher:
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
if self.model.model_lowvram:
|
||||
self.prepare_teeth()
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
|
||||
|
||||
@@ -715,6 +762,7 @@ class ModelPatcher:
|
||||
move_weight_functions(m, device_to)
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
self.clean_teeth()
|
||||
self.model.model_lowvram = False
|
||||
self.model.lowvram_patch_counter = 0
|
||||
|
||||
@@ -804,8 +852,10 @@ class ModelPatcher:
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.zipper_initialized = False
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
self.prepare_teeth()
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
|
||||
Reference in New Issue
Block a user