Initial exploration of weight zipper
This commit is contained in:
@@ -6,6 +6,7 @@ if TYPE_CHECKING:
|
||||
from comfy.model_patcher import ModelPatcher
|
||||
from comfy.model_base import BaseModel
|
||||
from comfy.controlnet import ControlBase
|
||||
from comfy.ops import CastWeightBiasOp
|
||||
import torch
|
||||
from functools import partial
|
||||
import collections
|
||||
@@ -18,6 +19,7 @@ import comfy.patcher_extension
|
||||
import comfy.hooks
|
||||
import scipy.stats
|
||||
import numpy
|
||||
import comfy.ops
|
||||
|
||||
|
||||
def add_area_dims(area, num_dims):
|
||||
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
||||
|
||||
#The main sampling function shared by all the samplers
|
||||
#Returns denoised
|
||||
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
|
||||
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
||||
uncond_ = None
|
||||
else:
|
||||
uncond_ = uncond
|
||||
|
||||
do_cleanup = False
|
||||
if "weight_zipper" not in model_options:
|
||||
do_cleanup = True
|
||||
#zipper_dict = {}
|
||||
model_options["weight_zipper"] = True
|
||||
loaded_modules = model.current_patcher._load_list_lowvram_only()
|
||||
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
|
||||
sum_m = sum([x[0] for x in low_m])
|
||||
for l in loaded_modules:
|
||||
m: CastWeightBiasOp = l[2]
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
m.zipper_tooth = comfy.ops.ZipperTooth
|
||||
#m.zipper_dict = zipper_dict
|
||||
m.zipper_key = l[1]
|
||||
|
||||
conds = [cond, uncond_]
|
||||
out = calc_cond_batch(model, conds, x, timestep, model_options)
|
||||
|
||||
if do_cleanup:
|
||||
zzz = 20
|
||||
for l in loaded_modules:
|
||||
m: CastWeightBiasOp = l[2]
|
||||
if hasattr(l[2], "comfy_cast_weights"):
|
||||
#m.zipper_dict = None
|
||||
m.zipper_key = None
|
||||
|
||||
for fn in model_options.get("sampler_pre_cfg_function", []):
|
||||
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
|
||||
"input": x, "sigma": timestep, "model": model, "model_options": model_options}
|
||||
|
||||
Reference in New Issue
Block a user