Initial exploration of weight zipper
This commit is contained in:
@@ -16,6 +16,7 @@
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import torch
|
||||
import logging
|
||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
@@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
|
||||
self.model_config = model_config
|
||||
self.manual_cast_dtype = model_config.manual_cast_dtype
|
||||
self.device = device
|
||||
self.current_patcher: 'ModelPatcher' = None
|
||||
self.current_patcher: ModelPatcher = None
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
|
||||
logging.info("model_type {}".format(model_type.name))
|
||||
logging.debug("adm {}".format(self.adm_channels))
|
||||
self.memory_usage_factor = model_config.memory_usage_factor
|
||||
self.zipper_initialized = False
|
||||
|
||||
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
|
||||
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
# handle lowvram zipper initialization, if required
|
||||
if self.model_lowvram and not self.zipper_initialized:
|
||||
if self.current_patcher:
|
||||
self.zipper_initialized = True
|
||||
with self.current_patcher.use_ejected():
|
||||
loading = self.current_patcher._load_list_lowvram_only()
|
||||
|
||||
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
|
||||
|
||||
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
|
||||
sigma = t
|
||||
xc = self.model_sampling.calculate_input(sigma, x)
|
||||
if c_concat is not None:
|
||||
|
||||
Reference in New Issue
Block a user