Compare commits
52 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39fb74c5bd | ||
|
|
74e124f4d7 | ||
|
|
a562c17e8a | ||
|
|
5942c17d55 | ||
|
|
c032b11e07 | ||
|
|
b8ffb2937f | ||
|
|
ce37c11164 | ||
|
|
b5c3906b38 | ||
|
|
5d43e75e5b | ||
|
|
517f4a94e4 | ||
|
|
52a471c5c7 | ||
|
|
ad76574cb8 | ||
|
|
9acfe4df41 | ||
|
|
9829b013ea | ||
|
|
5c69cde037 | ||
|
|
e9589d6d92 | ||
|
|
0d82a798a5 | ||
|
|
925fff26fd | ||
|
|
75b9b55b22 | ||
|
|
1765f1c60c | ||
|
|
1de69fe4d5 | ||
|
|
ae197f651b | ||
|
|
1b5b8ca81a | ||
|
|
6678d5cf65 | ||
|
|
e172564eea | ||
|
|
a3cc326748 | ||
|
|
86a97e91fc | ||
|
|
5acdadc9f3 | ||
|
|
55ad9d5f8c | ||
|
|
a9f04edc58 | ||
|
|
a475ec2300 | ||
|
|
06eb9fb426 | ||
|
|
413322645e | ||
|
|
11200de970 | ||
|
|
037c38eb0f | ||
|
|
1e11d2d1f5 | ||
|
|
65ea6be38f | ||
|
|
5df6f57b5d | ||
|
|
6588bfdef9 | ||
|
|
50ed2879ef | ||
|
|
66d4233210 | ||
|
|
591010b7ef | ||
|
|
08f92d55e9 | ||
|
|
8115d8cce9 | ||
|
|
6969fc9ba4 | ||
|
|
cb7c4b4be3 | ||
|
|
1208863eca | ||
|
|
e1c528196e | ||
|
|
17030fd4c0 | ||
|
|
c19dcd362f | ||
|
|
1c08bf35b4 | ||
|
|
2a02546e20 |
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
53
.github/workflows/pullrequest-ci-run.yml
vendored
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
|
||||||
|
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||||
|
name: Pull Request CI Workflow Runs
|
||||||
|
on:
|
||||||
|
pull_request_target:
|
||||||
|
types: [labeled]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pr-test-stable:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux, windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["stable"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
use_prior_commit: 'true'
|
||||||
|
comment:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@v6
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
github.rest.issues.createComment({
|
||||||
|
issue_number: context.issue.number,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||||
|
})
|
||||||
95
.github/workflows/test-ci.yml
vendored
Normal file
95
.github/workflows/test-ci.yml
vendored
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
|
||||||
|
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
|
||||||
|
name: Full Comfy CI Workflow Runs
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- master
|
||||||
|
paths-ignore:
|
||||||
|
- 'app/**'
|
||||||
|
- 'input/**'
|
||||||
|
- 'output/**'
|
||||||
|
- 'notebooks/**'
|
||||||
|
- 'script_examples/**'
|
||||||
|
- '.github/**'
|
||||||
|
- 'web/**'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-stable:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux, windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["stable"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
|
test-win-nightly:
|
||||||
|
strategy:
|
||||||
|
fail-fast: true
|
||||||
|
matrix:
|
||||||
|
os: [windows]
|
||||||
|
python_version: ["3.9", "3.10", "3.11", "3.12"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["nightly"]
|
||||||
|
include:
|
||||||
|
- os: windows
|
||||||
|
runner_label: [self-hosted, win]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
|
||||||
|
test-unix-nightly:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
os: [macos, linux]
|
||||||
|
python_version: ["3.11"]
|
||||||
|
cuda_version: ["12.1"]
|
||||||
|
torch_version: ["nightly"]
|
||||||
|
include:
|
||||||
|
- os: macos
|
||||||
|
runner_label: [self-hosted, macOS]
|
||||||
|
flags: "--use-pytorch-cross-attention"
|
||||||
|
- os: linux
|
||||||
|
runner_label: [self-hosted, Linux]
|
||||||
|
flags: ""
|
||||||
|
runs-on: ${{ matrix.runner_label }}
|
||||||
|
steps:
|
||||||
|
- name: Test Workflows
|
||||||
|
uses: comfy-org/comfy-action@main
|
||||||
|
with:
|
||||||
|
os: ${{ matrix.os }}
|
||||||
|
python_version: ${{ matrix.python_version }}
|
||||||
|
torch_version: ${{ matrix.torch_version }}
|
||||||
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -18,4 +18,5 @@ venv/
|
|||||||
/tests-ui/data/object_info.json
|
/tests-ui/data/object_info.json
|
||||||
/user/
|
/user/
|
||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
|
.DS_Store
|
||||||
|
|||||||
@@ -1,4 +1,24 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from enum import Enum
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
@@ -13,6 +33,8 @@ import comfy.cldm.cldm
|
|||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
import comfy.ldm.cascade.controlnet
|
import comfy.ldm.cascade.controlnet
|
||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
|
import comfy.ldm.hydit.controlnet
|
||||||
|
import comfy.ldm.flux.controlnet_xlabs
|
||||||
|
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
@@ -33,6 +55,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
|
|||||||
else:
|
else:
|
||||||
return torch.cat([tensor] * batched_number, dim=0)
|
return torch.cat([tensor] * batched_number, dim=0)
|
||||||
|
|
||||||
|
class StrengthType(Enum):
|
||||||
|
CONSTANT = 1
|
||||||
|
LINEAR_UP = 2
|
||||||
|
|
||||||
class ControlBase:
|
class ControlBase:
|
||||||
def __init__(self, device=None):
|
def __init__(self, device=None):
|
||||||
self.cond_hint_original = None
|
self.cond_hint_original = None
|
||||||
@@ -51,6 +77,8 @@ class ControlBase:
|
|||||||
device = comfy.model_management.get_torch_device()
|
device = comfy.model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.extra_conds = []
|
||||||
|
self.strength_type = StrengthType.CONSTANT
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@@ -93,6 +121,8 @@ class ControlBase:
|
|||||||
c.latent_format = self.latent_format
|
c.latent_format = self.latent_format
|
||||||
c.extra_args = self.extra_args.copy()
|
c.extra_args = self.extra_args.copy()
|
||||||
c.vae = self.vae
|
c.vae = self.vae
|
||||||
|
c.extra_conds = self.extra_conds.copy()
|
||||||
|
c.strength_type = self.strength_type
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -113,7 +143,10 @@ class ControlBase:
|
|||||||
|
|
||||||
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
|
||||||
applied_to.add(x)
|
applied_to.add(x)
|
||||||
x *= self.strength
|
if self.strength_type == StrengthType.CONSTANT:
|
||||||
|
x *= self.strength
|
||||||
|
elif self.strength_type == StrengthType.LINEAR_UP:
|
||||||
|
x *= (self.strength ** float(len(control_output) - i))
|
||||||
|
|
||||||
if x.dtype != output_dtype:
|
if x.dtype != output_dtype:
|
||||||
x = x.to(output_dtype)
|
x = x.to(output_dtype)
|
||||||
@@ -142,7 +175,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@@ -154,6 +187,8 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
self.manual_cast_dtype = manual_cast_dtype
|
self.manual_cast_dtype = manual_cast_dtype
|
||||||
self.latent_format = latent_format
|
self.latent_format = latent_format
|
||||||
|
self.extra_conds += extra_conds
|
||||||
|
self.strength_type = strength_type
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -191,13 +226,16 @@ class ControlNet(ControlBase):
|
|||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
||||||
y = cond.get('y', None)
|
extra = self.extra_args.copy()
|
||||||
if y is not None:
|
for c in self.extra_conds:
|
||||||
y = y.to(dtype)
|
temp = cond.get(c, None)
|
||||||
|
if temp is not None:
|
||||||
|
extra[c] = temp.to(dtype)
|
||||||
|
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
|
||||||
return self.control_merge(control, control_prev, output_dtype)
|
return self.control_merge(control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@@ -286,6 +324,7 @@ class ControlLora(ControlNet):
|
|||||||
ControlBase.__init__(self, device)
|
ControlBase.__init__(self, device)
|
||||||
self.control_weights = control_weights
|
self.control_weights = control_weights
|
||||||
self.global_average_pooling = global_average_pooling
|
self.global_average_pooling = global_average_pooling
|
||||||
|
self.extra_conds += ["y"]
|
||||||
|
|
||||||
def pre_run(self, model, percent_to_timestep_function):
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
super().pre_run(model, percent_to_timestep_function)
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
@@ -338,12 +377,8 @@ class ControlLora(ControlNet):
|
|||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def controlnet_config(sd):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
|
||||||
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
|
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
|
||||||
for k in sd:
|
|
||||||
new_sd[k] = sd[k]
|
|
||||||
|
|
||||||
supported_inference_dtypes = model_config.supported_inference_dtypes
|
supported_inference_dtypes = model_config.supported_inference_dtypes
|
||||||
|
|
||||||
@@ -356,14 +391,27 @@ def load_controlnet_mmdit(sd):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config)
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||||
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
|
|
||||||
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logging.warning("missing controlnet keys: {}".format(missing))
|
logging.warning("missing controlnet keys: {}".format(missing))
|
||||||
|
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
|
return control_model
|
||||||
|
|
||||||
|
def load_controlnet_mmdit(sd):
|
||||||
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||||
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
latent_format.shift_factor = 0 #SD3 controlnet weirdness
|
||||||
@@ -371,8 +419,31 @@ def load_controlnet_mmdit(sd):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||||
|
|
||||||
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
|
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
|
||||||
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||||
|
return control
|
||||||
|
|
||||||
|
def load_controlnet_flux_xlabs(sd):
|
||||||
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
||||||
|
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
extra_conds = ['y', 'guidance']
|
||||||
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet(ckpt_path, model=None):
|
def load_controlnet(ckpt_path, model=None):
|
||||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||||
|
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||||
|
return load_controlnet_hunyuandit(controlnet_data)
|
||||||
|
|
||||||
if "lora_controlnet" in controlnet_data:
|
if "lora_controlnet" in controlnet_data:
|
||||||
return ControlLora(controlnet_data)
|
return ControlLora(controlnet_data)
|
||||||
|
|
||||||
@@ -430,7 +501,10 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||||
controlnet_data = new_sd
|
controlnet_data = new_sd
|
||||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||||
return load_controlnet_mmdit(controlnet_data)
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
|
return load_controlnet_flux_xlabs(controlnet_data)
|
||||||
|
else:
|
||||||
|
return load_controlnet_mmdit(controlnet_data)
|
||||||
|
|
||||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||||
pth = False
|
pth = False
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
|||||||
if text_encoder2_path is not None:
|
if text_encoder2_path is not None:
|
||||||
text_encoder_paths.append(text_encoder2_path)
|
text_encoder_paths.append(text_encoder2_path)
|
||||||
|
|
||||||
unet = comfy.sd.load_unet(unet_path)
|
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||||
|
|
||||||
clip = None
|
clip = None
|
||||||
if output_clip:
|
if output_clip:
|
||||||
|
|||||||
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor, nn
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||||
|
MLPEmbedder, SingleStreamBlock,
|
||||||
|
timestep_embedding)
|
||||||
|
|
||||||
|
from .model import Flux
|
||||||
|
import comfy.ldm.common_dit
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetFlux(Flux):
|
||||||
|
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||||
|
|
||||||
|
# add ControlNet blocks
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(self.params.depth):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
# controlnet_block = zero_module(controlnet_block)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.input_hint_block = nn.Sequential(
|
||||||
|
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(
|
||||||
|
self,
|
||||||
|
img: Tensor,
|
||||||
|
img_ids: Tensor,
|
||||||
|
controlnet_cond: Tensor,
|
||||||
|
txt: Tensor,
|
||||||
|
txt_ids: Tensor,
|
||||||
|
timesteps: Tensor,
|
||||||
|
y: Tensor,
|
||||||
|
guidance: Tensor = None,
|
||||||
|
) -> Tensor:
|
||||||
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
|
|
||||||
|
# running on sequences img
|
||||||
|
img = self.img_in(img)
|
||||||
|
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||||
|
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||||
|
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||||
|
img = img + controlnet_cond
|
||||||
|
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||||
|
if self.params.guidance_embed:
|
||||||
|
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||||
|
vec = vec + self.vector_in(y)
|
||||||
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
|
block_res_samples = ()
|
||||||
|
|
||||||
|
for block in self.double_blocks:
|
||||||
|
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
block_res_samples = block_res_samples + (img,)
|
||||||
|
|
||||||
|
controlnet_block_res_samples = ()
|
||||||
|
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||||
|
block_res_sample = controlnet_block(block_res_sample)
|
||||||
|
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||||
|
|
||||||
|
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||||
|
hint = hint * 2.0 - 1.0
|
||||||
|
|
||||||
|
bs, c, h, w = x.shape
|
||||||
|
patch_size = 2
|
||||||
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||||
|
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||||
|
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||||
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
|
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||||
@@ -2,12 +2,12 @@ import math
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
"""
|
"""
|
||||||
t = time_factor * t
|
t = time_factor * t
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
||||||
t.device
|
|
||||||
)
|
|
||||||
|
|
||||||
args = t[:, None].float() * freqs[None]
|
args = t[:, None].float() * freqs[None]
|
||||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||||
@@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|||||||
embedding = embedding.to(t)
|
embedding = embedding.to(t)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
class MLPEmbedder(nn.Module):
|
class MLPEmbedder(nn.Module):
|
||||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
|
|||||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
|
||||||
qkv = self.qkv(x)
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
|
||||||
q, k = self.norm(q, k, v)
|
|
||||||
x = attention(q, k, v, pe=pe)
|
|
||||||
x = self.proj(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModulationOut:
|
class ModulationOut:
|
||||||
@@ -163,22 +152,21 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
||||||
img_qkv = self.img_attn.qkv(img_modulated)
|
img_qkv = self.img_attn.qkv(img_modulated)
|
||||||
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
||||||
|
|
||||||
# prepare txt for attention
|
# prepare txt for attention
|
||||||
txt_modulated = self.txt_norm1(txt)
|
txt_modulated = self.txt_norm1(txt)
|
||||||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
||||||
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
||||||
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||||
|
|
||||||
# run actual attention
|
# run actual attention
|
||||||
q = torch.cat((txt_q, img_q), dim=2)
|
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||||
k = torch.cat((txt_k, img_k), dim=2)
|
torch.cat((txt_k, img_k), dim=2),
|
||||||
v = torch.cat((txt_v, img_v), dim=2)
|
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||||
|
|
||||||
attn = attention(q, k, v, pe=pe)
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
@@ -186,8 +174,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||||
|
|
||||||
# calculate the txt bloks
|
# calculate the txt bloks
|
||||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||||
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
|
if txt.dtype == torch.float16:
|
||||||
|
txt = txt.clip(-65504, 65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
|
|
||||||
@@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
|
|||||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||||
|
|
||||||
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k = self.norm(q, k, v)
|
q, k = self.norm(q, k, v)
|
||||||
|
|
||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe)
|
attn = attention(q, k, v, pe=pe)
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
return x + mod.gate * output
|
x += mod.gate * output
|
||||||
|
if x.dtype == torch.float16:
|
||||||
|
x = x.clip(-65504, 65504)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class LastLayer(nn.Module):
|
class LastLayer(nn.Module):
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ class Flux(nn.Module):
|
|||||||
Transformer model for flow matching on sequences.
|
Transformer model for flow matching on sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
params = FluxParams(**kwargs)
|
params = FluxParams(**kwargs)
|
||||||
@@ -83,7 +83,8 @@ class Flux(nn.Module):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
if final_layer:
|
||||||
|
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
@@ -94,6 +95,7 @@ class Flux(nn.Module):
|
|||||||
timesteps: Tensor,
|
timesteps: Tensor,
|
||||||
y: Tensor,
|
y: Tensor,
|
||||||
guidance: Tensor = None,
|
guidance: Tensor = None,
|
||||||
|
control=None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||||
@@ -112,8 +114,15 @@ class Flux(nn.Module):
|
|||||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
pe = self.pe_embedder(ids)
|
pe = self.pe_embedder(ids)
|
||||||
|
|
||||||
for block in self.double_blocks:
|
for i in range(len(self.double_blocks)):
|
||||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||||
|
|
||||||
|
if control is not None: #Controlnet
|
||||||
|
control_o = control.get("output")
|
||||||
|
if i < len(control_o):
|
||||||
|
add = control_o[i]
|
||||||
|
if add is not None:
|
||||||
|
img += add
|
||||||
|
|
||||||
img = torch.cat((txt, img), 1)
|
img = torch.cat((txt, img), 1)
|
||||||
for block in self.single_blocks:
|
for block in self.single_blocks:
|
||||||
@@ -123,7 +132,7 @@ class Flux(nn.Module):
|
|||||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
@@ -138,5 +147,5 @@ class Flux(nn.Module):
|
|||||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||||
|
|
||||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
|||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,10 +78,9 @@ def apply_rotary_emb(
|
|||||||
xk_out = None
|
xk_out = None
|
||||||
if isinstance(freqs_cis, tuple):
|
if isinstance(freqs_cis, tuple):
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
if xk is not None:
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||||
else:
|
else:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
|||||||
321
comfy/ldm/hydit/controlnet.py
Normal file
321
comfy/ldm/hydit/controlnet.py
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch.utils import checkpoint
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import (
|
||||||
|
Mlp,
|
||||||
|
TimestepEmbedder,
|
||||||
|
PatchEmbed,
|
||||||
|
RMSNorm,
|
||||||
|
)
|
||||||
|
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||||
|
from .poolers import AttentionPool
|
||||||
|
|
||||||
|
import comfy.latent_formats
|
||||||
|
from .models import HunYuanDiTBlock, calc_rope
|
||||||
|
|
||||||
|
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
|
||||||
|
|
||||||
|
|
||||||
|
class HunYuanControlNet(nn.Module):
|
||||||
|
"""
|
||||||
|
HunYuanDiT: Diffusion model with a Transformer backbone.
|
||||||
|
|
||||||
|
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
|
||||||
|
|
||||||
|
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
args: argparse.Namespace
|
||||||
|
The arguments parsed by argparse.
|
||||||
|
input_size: tuple
|
||||||
|
The size of the input image.
|
||||||
|
patch_size: int
|
||||||
|
The size of the patch.
|
||||||
|
in_channels: int
|
||||||
|
The number of input channels.
|
||||||
|
hidden_size: int
|
||||||
|
The hidden size of the transformer backbone.
|
||||||
|
depth: int
|
||||||
|
The number of transformer blocks.
|
||||||
|
num_heads: int
|
||||||
|
The number of attention heads.
|
||||||
|
mlp_ratio: float
|
||||||
|
The ratio of the hidden size of the MLP in the transformer block.
|
||||||
|
log_fn: callable
|
||||||
|
The logging function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size: tuple = 128,
|
||||||
|
patch_size: int = 2,
|
||||||
|
in_channels: int = 4,
|
||||||
|
hidden_size: int = 1408,
|
||||||
|
depth: int = 40,
|
||||||
|
num_heads: int = 16,
|
||||||
|
mlp_ratio: float = 4.3637,
|
||||||
|
text_states_dim=1024,
|
||||||
|
text_states_dim_t5=2048,
|
||||||
|
text_len=77,
|
||||||
|
text_len_t5=256,
|
||||||
|
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
|
||||||
|
size_cond=False,
|
||||||
|
use_style_cond=False,
|
||||||
|
learn_sigma=True,
|
||||||
|
norm="layer",
|
||||||
|
log_fn: callable = print,
|
||||||
|
attn_precision=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.log_fn = log_fn
|
||||||
|
self.depth = depth
|
||||||
|
self.learn_sigma = learn_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if learn_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.text_states_dim = text_states_dim
|
||||||
|
self.text_states_dim_t5 = text_states_dim_t5
|
||||||
|
self.text_len = text_len
|
||||||
|
self.text_len_t5 = text_len_t5
|
||||||
|
self.size_cond = size_cond
|
||||||
|
self.use_style_cond = use_style_cond
|
||||||
|
self.norm = norm
|
||||||
|
self.dtype = dtype
|
||||||
|
self.latent_format = comfy.latent_formats.SDXL
|
||||||
|
|
||||||
|
self.mlp_t5 = nn.Sequential(
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
nn.Linear(
|
||||||
|
self.text_states_dim_t5 * 4,
|
||||||
|
self.text_states_dim,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# learnable replace
|
||||||
|
self.text_embedding_padding = nn.Parameter(
|
||||||
|
torch.randn(
|
||||||
|
self.text_len + self.text_len_t5,
|
||||||
|
self.text_states_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Attention pooling
|
||||||
|
pooler_out_dim = 1024
|
||||||
|
self.pooler = AttentionPool(
|
||||||
|
self.text_len_t5,
|
||||||
|
self.text_states_dim_t5,
|
||||||
|
num_heads=8,
|
||||||
|
output_dim=pooler_out_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dimension of the extra input vectors
|
||||||
|
self.extra_in_dim = pooler_out_dim
|
||||||
|
|
||||||
|
if self.size_cond:
|
||||||
|
# Image size and crop size conditions
|
||||||
|
self.extra_in_dim += 6 * 256
|
||||||
|
|
||||||
|
if self.use_style_cond:
|
||||||
|
# Here we use a default learned embedder layer for future extension.
|
||||||
|
self.style_embedder = nn.Embedding(
|
||||||
|
1, hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.extra_in_dim += hidden_size
|
||||||
|
|
||||||
|
# Text embedding for `add`
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
input_size,
|
||||||
|
patch_size,
|
||||||
|
in_channels,
|
||||||
|
hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.extra_embedder = nn.Sequential(
|
||||||
|
operations.Linear(
|
||||||
|
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(
|
||||||
|
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Image embedding
|
||||||
|
num_patches = self.x_embedder.num_patches
|
||||||
|
|
||||||
|
# HUnYuanDiT Blocks
|
||||||
|
self.blocks = nn.ModuleList(
|
||||||
|
[
|
||||||
|
HunYuanDiTBlock(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
c_emb_size=hidden_size,
|
||||||
|
num_heads=num_heads,
|
||||||
|
mlp_ratio=mlp_ratio,
|
||||||
|
text_states_dim=self.text_states_dim,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
norm_type=self.norm,
|
||||||
|
skip=False,
|
||||||
|
attn_precision=attn_precision,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for _ in range(19)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Input zero linear for the first block
|
||||||
|
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
# Output zero linear for the every block
|
||||||
|
self.after_proj_list = nn.ModuleList(
|
||||||
|
[
|
||||||
|
|
||||||
|
operations.Linear(
|
||||||
|
self.hidden_size, self.hidden_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
for _ in range(len(self.blocks))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
hint,
|
||||||
|
timesteps,
|
||||||
|
context,#encoder_hidden_states=None,
|
||||||
|
text_embedding_mask=None,
|
||||||
|
encoder_hidden_states_t5=None,
|
||||||
|
text_embedding_mask_t5=None,
|
||||||
|
image_meta_size=None,
|
||||||
|
style=None,
|
||||||
|
return_dict=False,
|
||||||
|
**kwarg,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass of the encoder.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
x: torch.Tensor
|
||||||
|
(B, D, H, W)
|
||||||
|
t: torch.Tensor
|
||||||
|
(B)
|
||||||
|
encoder_hidden_states: torch.Tensor
|
||||||
|
CLIP text embedding, (B, L_clip, D)
|
||||||
|
text_embedding_mask: torch.Tensor
|
||||||
|
CLIP text embedding mask, (B, L_clip)
|
||||||
|
encoder_hidden_states_t5: torch.Tensor
|
||||||
|
T5 text embedding, (B, L_t5, D)
|
||||||
|
text_embedding_mask_t5: torch.Tensor
|
||||||
|
T5 text embedding mask, (B, L_t5)
|
||||||
|
image_meta_size: torch.Tensor
|
||||||
|
(B, 6)
|
||||||
|
style: torch.Tensor
|
||||||
|
(B)
|
||||||
|
cos_cis_img: torch.Tensor
|
||||||
|
sin_cis_img: torch.Tensor
|
||||||
|
return_dict: bool
|
||||||
|
Whether to return a dictionary.
|
||||||
|
"""
|
||||||
|
condition = hint
|
||||||
|
if condition.shape[0] == 1:
|
||||||
|
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
|
||||||
|
|
||||||
|
text_states = context # 2,77,1024
|
||||||
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
|
text_states_mask = text_embedding_mask.bool() # 2,77
|
||||||
|
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
||||||
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
||||||
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
||||||
|
|
||||||
|
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
|
||||||
|
|
||||||
|
text_states[:, -self.text_len :] = torch.where(
|
||||||
|
text_states_mask[:, -self.text_len :].unsqueeze(2),
|
||||||
|
text_states[:, -self.text_len :],
|
||||||
|
padding[: self.text_len],
|
||||||
|
)
|
||||||
|
text_states_t5[:, -self.text_len_t5 :] = torch.where(
|
||||||
|
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
|
||||||
|
text_states_t5[:, -self.text_len_t5 :],
|
||||||
|
padding[self.text_len :],
|
||||||
|
)
|
||||||
|
|
||||||
|
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
||||||
|
|
||||||
|
# _, _, oh, ow = x.shape
|
||||||
|
# th, tw = oh // self.patch_size, ow // self.patch_size
|
||||||
|
|
||||||
|
# Get image RoPE embedding according to `reso`lution.
|
||||||
|
freqs_cis_img = calc_rope(
|
||||||
|
x, self.patch_size, self.hidden_size // self.num_heads
|
||||||
|
) # (cos_cis_img, sin_cis_img)
|
||||||
|
|
||||||
|
# ========================= Build time and image embedding =========================
|
||||||
|
t = self.t_embedder(timesteps, dtype=self.dtype)
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
|
||||||
|
# ========================= Concatenate all extra vectors =========================
|
||||||
|
# Build text tokens with pooling
|
||||||
|
extra_vec = self.pooler(encoder_hidden_states_t5)
|
||||||
|
|
||||||
|
# Build image meta size tokens if applicable
|
||||||
|
# if image_meta_size is not None:
|
||||||
|
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
|
||||||
|
# if image_meta_size.dtype != self.dtype:
|
||||||
|
# image_meta_size = image_meta_size.half()
|
||||||
|
# image_meta_size = image_meta_size.view(-1, 6 * 256)
|
||||||
|
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
|
||||||
|
|
||||||
|
# Build style tokens
|
||||||
|
if style is not None:
|
||||||
|
style_embedding = self.style_embedder(style)
|
||||||
|
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
|
||||||
|
|
||||||
|
# Concatenate all extra vectors
|
||||||
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
# ========================= Deal with Condition =========================
|
||||||
|
condition = self.x_embedder(condition)
|
||||||
|
|
||||||
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
|
controls = []
|
||||||
|
x = x + self.before_proj(condition) # add condition
|
||||||
|
for layer, block in enumerate(self.blocks):
|
||||||
|
x = block(x, c, text_states, freqs_cis_img)
|
||||||
|
controls.append(self.after_proj_list[layer](x)) # zero linear for output
|
||||||
|
|
||||||
|
return {"output": controls}
|
||||||
@@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
|||||||
sub_args = [start, stop, (th, tw)]
|
sub_args = [start, stop, (th, tw)]
|
||||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||||
|
rope = (rope[0].to(x), rope[1].to(x))
|
||||||
return rope
|
return rope
|
||||||
|
|
||||||
|
|
||||||
@@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
|
|||||||
# Long Skip Connection
|
# Long Skip Connection
|
||||||
if self.skip_linear is not None:
|
if self.skip_linear is not None:
|
||||||
cat = torch.cat([x, skip], dim=-1)
|
cat = torch.cat([x, skip], dim=-1)
|
||||||
|
if cat.dtype != x.dtype:
|
||||||
|
cat = cat.to(x.dtype)
|
||||||
cat = self.skip_norm(cat)
|
cat = self.skip_norm(cat)
|
||||||
x = self.skip_linear(cat)
|
x = self.skip_linear(cat)
|
||||||
|
|
||||||
@@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
|
|||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
|
if control:
|
||||||
|
controls = control.get("output", None)
|
||||||
# ========================= Forward pass through HunYuanDiT blocks =========================
|
# ========================= Forward pass through HunYuanDiT blocks =========================
|
||||||
skips = []
|
skips = []
|
||||||
for layer, block in enumerate(self.blocks):
|
for layer, block in enumerate(self.blocks):
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}):
|
|||||||
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
|
||||||
key_map[lora_key] = k
|
key_map[lora_key] = k
|
||||||
|
|
||||||
for k in sdk: #OneTrainer SD3 lora
|
for k in sdk:
|
||||||
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"):
|
if k.endswith(".weight"):
|
||||||
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
|
||||||
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
l_key = k[len("t5xxl.transformer."):-len(".weight")]
|
||||||
key_map[lora_key] = k
|
lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
|
||||||
|
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
|
||||||
|
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
|
||||||
|
key_map[lora_key] = k
|
||||||
|
|
||||||
|
|
||||||
k = "clip_g.transformer.text_projection.weight"
|
k = "clip_g.transformer.text_projection.weight"
|
||||||
if k in sdk:
|
if k in sdk:
|
||||||
@@ -245,6 +269,7 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
|
||||||
key_map["lora_unet_{}".format(key_lora)] = k
|
key_map["lora_unet_{}".format(key_lora)] = k
|
||||||
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
|
||||||
|
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||||
|
|
||||||
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if self.manual_cast_dtype is not None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
if self.manual_cast_dtype is not None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
|||||||
@@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["hidden_size"] = 3072
|
dit_config["hidden_size"] = 3072
|
||||||
dit_config["mlp_ratio"] = 4.0
|
dit_config["mlp_ratio"] = 4.0
|
||||||
dit_config["num_heads"] = 24
|
dit_config["num_heads"] = 24
|
||||||
dit_config["depth"] = 19
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["depth_single_blocks"] = 38
|
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||||
dit_config["axes_dim"] = [16, 56, 56]
|
dit_config["axes_dim"] = [16, 56, 56]
|
||||||
dit_config["theta"] = 10000
|
dit_config["theta"] = 10000
|
||||||
dit_config["qkv_bias"] = True
|
dit_config["qkv_bias"] = True
|
||||||
@@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||||
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
|
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||||
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
@@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
old_weight = out_sd.get(t[0], None)
|
old_weight = out_sd.get(t[0], None)
|
||||||
if old_weight is None:
|
if old_weight is None:
|
||||||
old_weight = torch.empty_like(weight)
|
old_weight = torch.empty_like(weight)
|
||||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||||
|
exp = list(weight.shape)
|
||||||
|
exp[offset[0]] = offset[1] + offset[2]
|
||||||
|
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||||
|
new[:old_weight.shape[0]] = old_weight
|
||||||
|
old_weight = new
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import logging
|
import logging
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -273,9 +291,12 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_offloaded_memory(self):
|
||||||
|
return self.model.model_size() - self.model.loaded_size()
|
||||||
|
|
||||||
def model_memory_required(self, device):
|
def model_memory_required(self, device):
|
||||||
if device == self.model.current_loaded_device():
|
if device == self.model.current_loaded_device():
|
||||||
return 0
|
return self.model_offloaded_memory()
|
||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
@@ -287,15 +308,21 @@ class LoadedModel:
|
|||||||
|
|
||||||
load_weights = not self.weights_loaded
|
load_weights = not self.weights_loaded
|
||||||
|
|
||||||
try:
|
if self.model.loaded_size() > 0:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
use_more_vram = lowvram_model_memory
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
if use_more_vram == 0:
|
||||||
else:
|
use_more_vram = 1e32
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.model_use_more_vram(use_more_vram)
|
||||||
except Exception as e:
|
else:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
try:
|
||||||
self.model_unload()
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
raise e
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
else:
|
||||||
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
|
except Exception as e:
|
||||||
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
|
self.model_unload()
|
||||||
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||||
@@ -304,19 +331,42 @@ class LoadedModel:
|
|||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
def should_reload_model(self, force_patch_weights=False):
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
if force_patch_weights and self.model.lowvram_patch_counter() > 0:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def model_unload(self, unpatch_weights=True):
|
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||||
|
if memory_to_free is not None:
|
||||||
|
if memory_to_free < self.model.loaded_size():
|
||||||
|
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||||
|
if freed >= memory_to_free:
|
||||||
|
return False
|
||||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||||
self.model.model_patches_to(self.model.offload_device)
|
self.model.model_patches_to(self.model.offload_device)
|
||||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||||
self.real_model = None
|
self.real_model = None
|
||||||
|
return True
|
||||||
|
|
||||||
|
def model_use_more_vram(self, extra_memory):
|
||||||
|
return self.model.partially_load(self.device, extra_memory)
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
return self.model is other.model
|
return self.model is other.model
|
||||||
|
|
||||||
|
def use_more_memory(extra_memory, loaded_models, device):
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
extra_memory -= m.model_use_more_vram(extra_memory)
|
||||||
|
if extra_memory <= 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
def offloaded_memory(loaded_models, device):
|
||||||
|
offloaded_mem = 0
|
||||||
|
for m in loaded_models:
|
||||||
|
if m.device == device:
|
||||||
|
offloaded_mem += m.model_offloaded_memory()
|
||||||
|
return offloaded_mem
|
||||||
|
|
||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 1.2
|
return (1024 * 1024 * 1024) * 1.2
|
||||||
|
|
||||||
@@ -363,11 +413,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
|
|
||||||
for x in sorted(can_unload):
|
for x in sorted(can_unload):
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
|
memory_to_free = None
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
if get_free_memory(device) > memory_required:
|
free_mem = get_free_memory(device)
|
||||||
|
if free_mem > memory_required:
|
||||||
break
|
break
|
||||||
current_loaded_models[i].model_unload()
|
memory_to_free = memory_required - free_mem
|
||||||
unloaded_model.append(i)
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
if current_loaded_models[i].model_unload(memory_to_free):
|
||||||
|
unloaded_model.append(i)
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@@ -381,15 +435,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
|||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
return unloaded_models
|
return unloaded_models
|
||||||
|
|
||||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required)
|
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||||
if minimum_memory_required is None:
|
if minimum_memory_required is None:
|
||||||
minimum_memory_required = extra_mem
|
minimum_memory_required = extra_mem
|
||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required)
|
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||||
|
|
||||||
models = set(models)
|
models = set(models)
|
||||||
|
|
||||||
@@ -422,12 +476,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
devs = set(map(lambda a: a.device, models_already_loaded))
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
for d in devs:
|
for d in devs:
|
||||||
if d != torch.device("cpu"):
|
if d != torch.device("cpu"):
|
||||||
free_memory(extra_mem, d, models_already_loaded)
|
free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
|
||||||
free_mem = get_free_memory(d)
|
free_mem = get_free_memory(d)
|
||||||
if free_mem < minimum_memory_required:
|
if free_mem < minimum_memory_required:
|
||||||
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
|
||||||
models_to_load = free_memory(minimum_memory_required, d)
|
models_to_load = free_memory(minimum_memory_required, d)
|
||||||
logging.info("{} models unloaded.".format(len(models_to_load)))
|
logging.info("{} models unloaded.".format(len(models_to_load)))
|
||||||
|
else:
|
||||||
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
if len(models_to_load) == 0:
|
if len(models_to_load) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -435,18 +491,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
for device in total_memory_required:
|
for loaded_model in models_already_loaded:
|
||||||
if device != torch.device("cpu"):
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||||
if weights_unloaded is not None:
|
if weights_unloaded is not None:
|
||||||
loaded_model.weights_loaded = not weights_unloaded
|
loaded_model.weights_loaded = not weights_unloaded
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
@@ -455,7 +514,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
else:
|
else:
|
||||||
vram_set_state = vram_state
|
vram_set_state = vram_state
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
current_free_mem = get_free_memory(torch_dev)
|
current_free_mem = get_free_memory(torch_dev)
|
||||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||||
@@ -467,6 +526,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
|
|
||||||
|
|
||||||
|
devs = set(map(lambda a: a.device, models_already_loaded))
|
||||||
|
for d in devs:
|
||||||
|
if d != torch.device("cpu"):
|
||||||
|
free_mem = get_free_memory(d)
|
||||||
|
if free_mem > minimum_memory_required:
|
||||||
|
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@@ -562,12 +629,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
if model_params * 2 > free_model_memory:
|
if model_params * 2 > free_model_memory:
|
||||||
return fp8_dtype
|
return fp8_dtype
|
||||||
|
|
||||||
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
for dt in supported_dtypes:
|
||||||
if torch.float16 in supported_dtypes:
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
if torch.float16 in supported_dtypes:
|
||||||
if should_use_bf16(device, model_params=model_params, manual_cast=True):
|
return torch.float16
|
||||||
if torch.bfloat16 in supported_dtypes:
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
|
||||||
return torch.bfloat16
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
|
for dt in supported_dtypes:
|
||||||
|
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.float16 in supported_dtypes:
|
||||||
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
|
||||||
|
if torch.bfloat16 in supported_dtypes:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
# None means no manual cast
|
# None means no manual cast
|
||||||
@@ -583,13 +660,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if fp16_supported and torch.float16 in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
return torch.float16
|
if dt == torch.float16 and fp16_supported:
|
||||||
|
return torch.float16
|
||||||
|
if dt == torch.bfloat16 and bf16_supported:
|
||||||
|
return torch.bfloat16
|
||||||
|
|
||||||
elif bf16_supported and torch.bfloat16 in supported_dtypes:
|
return torch.float32
|
||||||
return torch.bfloat16
|
|
||||||
else:
|
|
||||||
return torch.float32
|
|
||||||
|
|
||||||
def text_encoder_offload_device():
|
def text_encoder_offload_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
@@ -608,6 +685,20 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
if is_device_mps(load_device):
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
mem_l = get_free_memory(load_device)
|
||||||
|
mem_o = get_free_memory(offload_device)
|
||||||
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||||
|
return load_device
|
||||||
|
else:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
def text_encoder_dtype(device=None):
|
def text_encoder_dtype(device=None):
|
||||||
if args.fp8_e4m3fn_text_enc:
|
if args.fp8_e4m3fn_text_enc:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
|
|||||||
@@ -1,8 +1,27 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import collections
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
@@ -63,12 +82,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
|
|||||||
model_options["disable_cfg1_optimization"] = True
|
model_options["disable_cfg1_optimization"] = True
|
||||||
return model_options
|
return model_options
|
||||||
|
|
||||||
|
def wipe_lowvram_weight(m):
|
||||||
|
if hasattr(m, "prev_comfy_cast_weights"):
|
||||||
|
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
||||||
|
del m.prev_comfy_cast_weights
|
||||||
|
m.weight_function = None
|
||||||
|
m.bias_function = None
|
||||||
|
|
||||||
|
class LowVramPatch:
|
||||||
|
def __init__(self, key, model_patcher):
|
||||||
|
self.key = key
|
||||||
|
self.model_patcher = model_patcher
|
||||||
|
def __call__(self, weight):
|
||||||
|
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
if not hasattr(self.model, 'device'):
|
if not hasattr(self.model, 'device'):
|
||||||
logging.info("Model doesn't have a device attribute.")
|
logging.debug("Model doesn't have a device attribute.")
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
elif self.model.device is None:
|
elif self.model.device is None:
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
@@ -82,16 +116,29 @@ class ModelPatcher:
|
|||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
self.offload_device = offload_device
|
self.offload_device = offload_device
|
||||||
self.weight_inplace_update = weight_inplace_update
|
self.weight_inplace_update = weight_inplace_update
|
||||||
self.model_lowvram = False
|
|
||||||
self.lowvram_patch_counter = 0
|
|
||||||
self.patches_uuid = uuid.uuid4()
|
self.patches_uuid = uuid.uuid4()
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'model_loaded_weight_memory'):
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'lowvram_patch_counter'):
|
||||||
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
|
if not hasattr(self.model, 'model_lowvram'):
|
||||||
|
self.model.model_lowvram = False
|
||||||
|
|
||||||
def model_size(self):
|
def model_size(self):
|
||||||
if self.size > 0:
|
if self.size > 0:
|
||||||
return self.size
|
return self.size
|
||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def loaded_size(self):
|
||||||
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
|
def lowvram_patch_counter(self):
|
||||||
|
return self.model.lowvram_patch_counter
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||||
n.patches = {}
|
n.patches = {}
|
||||||
@@ -265,16 +312,16 @@ class ModelPatcher:
|
|||||||
sd.pop(k)
|
sd.pop(k)
|
||||||
return sd
|
return sd
|
||||||
|
|
||||||
def patch_weight_to_device(self, key, device_to=None):
|
def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
|
||||||
if key not in self.patches:
|
if key not in self.patches:
|
||||||
return
|
return
|
||||||
|
|
||||||
weight = comfy.utils.get_attr(self.model, key)
|
weight = comfy.utils.get_attr(self.model, key)
|
||||||
|
|
||||||
inplace_update = self.weight_inplace_update
|
inplace_update = self.weight_inplace_update or inplace_update
|
||||||
|
|
||||||
if key not in self.backup:
|
if key not in self.backup:
|
||||||
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update)
|
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
@@ -304,28 +351,24 @@ class ModelPatcher:
|
|||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = self.model_size()
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
self.patch_model(device_to, patch_weights=False)
|
|
||||||
|
|
||||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
|
|
||||||
class LowVramPatch:
|
|
||||||
def __init__(self, key, model_patcher):
|
|
||||||
self.key = key
|
|
||||||
self.model_patcher = model_patcher
|
|
||||||
def __call__(self, weight):
|
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
|
||||||
|
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
|
lowvram_counter += 1
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
continue
|
||||||
|
|
||||||
weight_key = "{}.weight".format(n)
|
weight_key = "{}.weight".format(n)
|
||||||
bias_key = "{}.bias".format(n)
|
bias_key = "{}.bias".format(n)
|
||||||
@@ -347,16 +390,40 @@ class ModelPatcher:
|
|||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
else:
|
else:
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
if m.comfy_cast_weights:
|
||||||
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
self.patch_weight_to_device(weight_key, device_to)
|
|
||||||
self.patch_weight_to_device(bias_key, device_to)
|
|
||||||
m.to(device_to)
|
|
||||||
mem_counter += comfy.model_management.module_size(m)
|
mem_counter += comfy.model_management.module_size(m)
|
||||||
|
param = list(m.parameters())
|
||||||
|
if len(param) > 0:
|
||||||
|
weight = param[0]
|
||||||
|
if weight.device == device_to:
|
||||||
|
continue
|
||||||
|
|
||||||
|
weight_to = None
|
||||||
|
if full_load:#TODO
|
||||||
|
weight_to = device_to
|
||||||
|
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
|
||||||
|
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||||
|
m.to(device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
self.model_lowvram = True
|
if lowvram_counter > 0:
|
||||||
self.lowvram_patch_counter = patch_counter
|
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||||
|
self.model.model_lowvram = True
|
||||||
|
else:
|
||||||
|
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||||
|
self.model.model_lowvram = False
|
||||||
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = mem_counter
|
||||||
|
|
||||||
|
|
||||||
|
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||||
|
self.patch_model(device_to, patch_weights=False)
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key):
|
def calculate_weight(self, patches, weight, key):
|
||||||
@@ -529,31 +596,28 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
for m in self.model.modules():
|
for m in self.model.modules():
|
||||||
if hasattr(m, "prev_comfy_cast_weights"):
|
wipe_lowvram_weight(m)
|
||||||
m.comfy_cast_weights = m.prev_comfy_cast_weights
|
|
||||||
del m.prev_comfy_cast_weights
|
|
||||||
m.weight_function = None
|
|
||||||
m.bias_function = None
|
|
||||||
|
|
||||||
self.model_lowvram = False
|
self.model.model_lowvram = False
|
||||||
self.lowvram_patch_counter = 0
|
self.model.lowvram_patch_counter = 0
|
||||||
|
|
||||||
keys = list(self.backup.keys())
|
keys = list(self.backup.keys())
|
||||||
|
|
||||||
if self.weight_inplace_update:
|
for k in keys:
|
||||||
for k in keys:
|
bk = self.backup[k]
|
||||||
comfy.utils.copy_to_param(self.model, k, self.backup[k])
|
if bk.inplace_update:
|
||||||
else:
|
comfy.utils.copy_to_param(self.model, k, bk.weight)
|
||||||
for k in keys:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, k, self.backup[k])
|
comfy.utils.set_attr_param(self.model, k, bk.weight)
|
||||||
|
|
||||||
self.backup.clear()
|
self.backup.clear()
|
||||||
|
|
||||||
if device_to is not None:
|
if device_to is not None:
|
||||||
self.model.to(device_to)
|
self.model.to(device_to)
|
||||||
self.model.device = device_to
|
self.model.device = device_to
|
||||||
|
self.model.model_loaded_weight_memory = 0
|
||||||
|
|
||||||
keys = list(self.object_patches_backup.keys())
|
keys = list(self.object_patches_backup.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
@@ -561,5 +625,60 @@ class ModelPatcher:
|
|||||||
|
|
||||||
self.object_patches_backup.clear()
|
self.object_patches_backup.clear()
|
||||||
|
|
||||||
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
|
memory_freed = 0
|
||||||
|
patch_counter = 0
|
||||||
|
|
||||||
|
for n, m in list(self.model.named_modules())[::-1]:
|
||||||
|
if memory_to_free < memory_freed:
|
||||||
|
break
|
||||||
|
|
||||||
|
shift_lowvram = False
|
||||||
|
if hasattr(m, "comfy_cast_weights"):
|
||||||
|
module_mem = comfy.model_management.module_size(m)
|
||||||
|
weight_key = "{}.weight".format(n)
|
||||||
|
bias_key = "{}.bias".format(n)
|
||||||
|
|
||||||
|
|
||||||
|
if m.weight is not None and m.weight.device != device_to:
|
||||||
|
for key in [weight_key, bias_key]:
|
||||||
|
bk = self.backup.get(key, None)
|
||||||
|
if bk is not None:
|
||||||
|
if bk.inplace_update:
|
||||||
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
|
else:
|
||||||
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
|
self.backup.pop(key)
|
||||||
|
|
||||||
|
m.to(device_to)
|
||||||
|
if weight_key in self.patches:
|
||||||
|
m.weight_function = LowVramPatch(weight_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
if bias_key in self.patches:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
|
m.comfy_cast_weights = True
|
||||||
|
memory_freed += module_mem
|
||||||
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
|
self.model.model_lowvram = True
|
||||||
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
|
self.model.model_loaded_weight_memory -= memory_freed
|
||||||
|
return memory_freed
|
||||||
|
|
||||||
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
self.unpatch_model(unpatch_weights=False)
|
||||||
|
self.patch_model(patch_weights=False)
|
||||||
|
full_load = False
|
||||||
|
if self.model.model_lowvram == False:
|
||||||
|
return 0
|
||||||
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
|
full_load = True
|
||||||
|
current_used = self.model.model_loaded_weight_memory
|
||||||
|
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
|||||||
59
comfy/sd.py
59
comfy/sd.py
@@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@@ -71,20 +71,24 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
|
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
|
if params['device'] != offload_device:
|
||||||
|
self.cond_stage_model.to(offload_device)
|
||||||
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
if params['device'] == load_device:
|
||||||
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -456,7 +460,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
parameters = 0
|
||||||
|
for c in clip_data:
|
||||||
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -498,15 +506,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
||||||
|
if out is None:
|
||||||
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
model = None
|
model = None
|
||||||
model_patcher = None
|
model_patcher = None
|
||||||
clip_target = None
|
|
||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
@@ -515,13 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
|
unet_dtype = model_options.get("weight_dtype", None)
|
||||||
|
|
||||||
|
if unet_dtype is None:
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
@@ -545,7 +562,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
@@ -567,12 +585,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||||
if inital_load_device != torch.device("cpu"):
|
if inital_load_device != torch.device("cpu"):
|
||||||
logging.info("loaded straight to GPU")
|
logging.info("loaded straight to GPU")
|
||||||
model_management.load_model_gpu(model_patcher)
|
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||||
|
|
||||||
return (model_patcher, clip, vae, clipvision)
|
return (model_patcher, clip, vae, clipvision)
|
||||||
|
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||||
|
dtype = model_options.get("dtype", None)
|
||||||
|
|
||||||
#Allow loading unets from checkpoint files
|
#Allow loading unets from checkpoint files
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
@@ -614,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
model = model_config.get_model(new_sd, "")
|
model = model_config.get_model(new_sd, "")
|
||||||
model = model.to(offload_device)
|
model = model.to(offload_device)
|
||||||
model.load_model_weights(new_sd, "")
|
model.load_model_weights(new_sd, "")
|
||||||
@@ -622,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
|||||||
logging.info("left over keys in unet: {}".format(left_over))
|
logging.info("left over keys in unet: {}".format(left_over))
|
||||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
|
||||||
|
def load_diffusion_model(unet_path, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(unet_path)
|
sd = comfy.utils.load_torch_file(unet_path)
|
||||||
model = load_unet_state_dict(sd, dtype=dtype)
|
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||||
if model is None:
|
if model is None:
|
||||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def load_unet(unet_path, dtype=None):
|
||||||
|
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
|
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
clip_sd = None
|
clip_sd = None
|
||||||
load_models = [model]
|
load_models = [model]
|
||||||
|
|||||||
@@ -313,6 +313,17 @@ def expand_directory_list(directories):
|
|||||||
dirs.add(root)
|
dirs.add(root)
|
||||||
return list(dirs)
|
return list(dirs)
|
||||||
|
|
||||||
|
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
|
||||||
|
i = 0
|
||||||
|
out_list = []
|
||||||
|
for k in embed:
|
||||||
|
if k.startswith(prefix) and k.endswith(suffix):
|
||||||
|
out_list.append(embed[k])
|
||||||
|
if len(out_list) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return torch.cat(out_list, dim=0)
|
||||||
|
|
||||||
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
|
||||||
if isinstance(embedding_directory, str):
|
if isinstance(embedding_directory, str):
|
||||||
embedding_directory = [embedding_directory]
|
embedding_directory = [embedding_directory]
|
||||||
@@ -379,8 +390,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
|||||||
elif embed_key is not None and embed_key in embed:
|
elif embed_key is not None and embed_key in embed:
|
||||||
embed_out = embed[embed_key]
|
embed_out = embed[embed_key]
|
||||||
else:
|
else:
|
||||||
values = embed.values()
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
|
||||||
embed_out = next(iter(values))
|
if embed_out is None:
|
||||||
|
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
|
||||||
|
if embed_out is None:
|
||||||
|
values = embed.values()
|
||||||
|
embed_out = next(iter(values))
|
||||||
return embed_out
|
return embed_out
|
||||||
|
|
||||||
class SDTokenizer:
|
class SDTokenizer:
|
||||||
|
|||||||
@@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
|
|
||||||
memory_usage_factor = 2.8
|
memory_usage_factor = 2.8
|
||||||
|
|
||||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||||
|
|
||||||
vae_key_prefix = ["vae."]
|
vae_key_prefix = ["vae."]
|
||||||
text_encoder_key_prefix = ["text_encoders."]
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
@@ -30,6 +48,7 @@ class BASE:
|
|||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
custom_operations = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|||||||
@@ -1,3 +1,22 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
@@ -432,8 +451,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
k = "{}.attn.".format(prefix_from)
|
||||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
|
||||||
|
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
|
block_map = {
|
||||||
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||||
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||||
|
"norm1.linear.weight": "img_mod.lin.weight",
|
||||||
|
"norm1.linear.bias": "img_mod.lin.bias",
|
||||||
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||||
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||||
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||||
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||||
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||||
|
"ff.net.2.weight": "img_mlp.2.weight",
|
||||||
|
"ff.net.2.bias": "img_mlp.2.bias",
|
||||||
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||||
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@@ -449,15 +493,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||||
|
|
||||||
block_map = {#TODO
|
block_map = {
|
||||||
|
"norm.linear.weight": "modulation.lin.weight",
|
||||||
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
|
"proj_out.weight": "linear2.weight",
|
||||||
|
"proj_out.bias": "linear2.bias",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||||
|
|
||||||
MAP_BASIC = { #TODO
|
MAP_BASIC = {
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
("img_in.bias", "x_embedder.bias"),
|
||||||
|
("img_in.weight", "x_embedder.weight"),
|
||||||
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||||
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||||
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||||
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||||
|
("txt_in.bias", "context_embedder.bias"),
|
||||||
|
("txt_in.weight", "context_embedder.weight"),
|
||||||
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||||
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||||
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||||
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
|
|||||||
cond = output.pop("cond")
|
cond = output.pop("cond")
|
||||||
return ([[cond, output]], )
|
return ([[cond, output]], )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
"CLIPTextEncodeSD3": CLIPTextEncodeSD3,
|
||||||
"ControlNetApplySD3": ControlNetApplySD3,
|
"ControlNetApplySD3": ControlNetApplySD3,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
|
# Sampling
|
||||||
|
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
from typing import Set, List, Dict, Tuple
|
from collections.abc import Collection
|
||||||
|
|
||||||
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
|
supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
|
||||||
|
|
||||||
SupportedFileExtensionsType = Set[str]
|
folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
|
||||||
ScanPathType = List[str]
|
|
||||||
folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
|
|
||||||
|
|
||||||
base_path = os.path.dirname(os.path.realpath(__file__))
|
base_path = os.path.dirname(os.path.realpath(__file__))
|
||||||
models_dir = os.path.join(base_path, "models")
|
models_dir = os.path.join(base_path, "models")
|
||||||
@@ -42,7 +42,7 @@ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp
|
|||||||
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
|
||||||
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
|
||||||
|
|
||||||
filename_list_cache = {}
|
filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
|
||||||
|
|
||||||
if not os.path.exists(input_directory):
|
if not os.path.exists(input_directory):
|
||||||
try:
|
try:
|
||||||
@@ -50,33 +50,33 @@ if not os.path.exists(input_directory):
|
|||||||
except:
|
except:
|
||||||
logging.error("Failed to create input directory")
|
logging.error("Failed to create input directory")
|
||||||
|
|
||||||
def set_output_directory(output_dir):
|
def set_output_directory(output_dir: str) -> None:
|
||||||
global output_directory
|
global output_directory
|
||||||
output_directory = output_dir
|
output_directory = output_dir
|
||||||
|
|
||||||
def set_temp_directory(temp_dir):
|
def set_temp_directory(temp_dir: str) -> None:
|
||||||
global temp_directory
|
global temp_directory
|
||||||
temp_directory = temp_dir
|
temp_directory = temp_dir
|
||||||
|
|
||||||
def set_input_directory(input_dir):
|
def set_input_directory(input_dir: str) -> None:
|
||||||
global input_directory
|
global input_directory
|
||||||
input_directory = input_dir
|
input_directory = input_dir
|
||||||
|
|
||||||
def get_output_directory():
|
def get_output_directory() -> str:
|
||||||
global output_directory
|
global output_directory
|
||||||
return output_directory
|
return output_directory
|
||||||
|
|
||||||
def get_temp_directory():
|
def get_temp_directory() -> str:
|
||||||
global temp_directory
|
global temp_directory
|
||||||
return temp_directory
|
return temp_directory
|
||||||
|
|
||||||
def get_input_directory():
|
def get_input_directory() -> str:
|
||||||
global input_directory
|
global input_directory
|
||||||
return input_directory
|
return input_directory
|
||||||
|
|
||||||
|
|
||||||
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
#NOTE: used in http server so don't put folders that should not be accessed remotely
|
||||||
def get_directory_by_type(type_name):
|
def get_directory_by_type(type_name: str) -> str | None:
|
||||||
if type_name == "output":
|
if type_name == "output":
|
||||||
return get_output_directory()
|
return get_output_directory()
|
||||||
if type_name == "temp":
|
if type_name == "temp":
|
||||||
@@ -88,7 +88,7 @@ def get_directory_by_type(type_name):
|
|||||||
|
|
||||||
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
|
||||||
# otherwise use default_path as base_dir
|
# otherwise use default_path as base_dir
|
||||||
def annotated_filepath(name):
|
def annotated_filepath(name: str) -> tuple[str, str | None]:
|
||||||
if name.endswith("[output]"):
|
if name.endswith("[output]"):
|
||||||
base_dir = get_output_directory()
|
base_dir = get_output_directory()
|
||||||
name = name[:-9]
|
name = name[:-9]
|
||||||
@@ -104,7 +104,7 @@ def annotated_filepath(name):
|
|||||||
return name, base_dir
|
return name, base_dir
|
||||||
|
|
||||||
|
|
||||||
def get_annotated_filepath(name, default_dir=None):
|
def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
|
||||||
name, base_dir = annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
@@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None):
|
|||||||
return os.path.join(base_dir, name)
|
return os.path.join(base_dir, name)
|
||||||
|
|
||||||
|
|
||||||
def exists_annotated_filepath(name):
|
def exists_annotated_filepath(name) -> bool:
|
||||||
name, base_dir = annotated_filepath(name)
|
name, base_dir = annotated_filepath(name)
|
||||||
|
|
||||||
if base_dir is None:
|
if base_dir is None:
|
||||||
@@ -126,17 +126,17 @@ def exists_annotated_filepath(name):
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name, full_folder_path):
|
def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
if folder_name in folder_names_and_paths:
|
if folder_name in folder_names_and_paths:
|
||||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||||
else:
|
else:
|
||||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||||
|
|
||||||
def get_folder_paths(folder_name):
|
def get_folder_paths(folder_name: str) -> list[str]:
|
||||||
return folder_names_and_paths[folder_name][0][:]
|
return folder_names_and_paths[folder_name][0][:]
|
||||||
|
|
||||||
def recursive_search(directory, excluded_dir_names=None):
|
def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
|
||||||
if not os.path.isdir(directory):
|
if not os.path.isdir(directory):
|
||||||
return [], {}
|
return [], {}
|
||||||
|
|
||||||
@@ -153,6 +153,10 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
|
||||||
|
|
||||||
logging.debug("recursive file list on directory {}".format(directory))
|
logging.debug("recursive file list on directory {}".format(directory))
|
||||||
|
dirpath: str
|
||||||
|
subdirs: list[str]
|
||||||
|
filenames: list[str]
|
||||||
|
|
||||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||||
for file_name in filenames:
|
for file_name in filenames:
|
||||||
@@ -160,7 +164,7 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
result.append(relative_path)
|
result.append(relative_path)
|
||||||
|
|
||||||
for d in subdirs:
|
for d in subdirs:
|
||||||
path = os.path.join(dirpath, d)
|
path: str = os.path.join(dirpath, d)
|
||||||
try:
|
try:
|
||||||
dirs[path] = os.path.getmtime(path)
|
dirs[path] = os.path.getmtime(path)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None):
|
|||||||
logging.debug("found {} files".format(len(result)))
|
logging.debug("found {} files".format(len(result)))
|
||||||
return result, dirs
|
return result, dirs
|
||||||
|
|
||||||
def filter_files_extensions(files, extensions):
|
def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
|
||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name, filename):
|
def get_full_path(folder_name: str, filename: str) -> str | None:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
if folder_name not in folder_names_and_paths:
|
if folder_name not in folder_names_and_paths:
|
||||||
return None
|
return None
|
||||||
@@ -189,7 +193,7 @@ def get_full_path(folder_name, filename):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_filename_list_(folder_name):
|
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
output_list = set()
|
output_list = set()
|
||||||
folders = folder_names_and_paths[folder_name]
|
folders = folder_names_and_paths[folder_name]
|
||||||
@@ -199,9 +203,9 @@ def get_filename_list_(folder_name):
|
|||||||
output_list.update(filter_files_extensions(files, folders[1]))
|
output_list.update(filter_files_extensions(files, folders[1]))
|
||||||
output_folders = {**output_folders, **folders_all}
|
output_folders = {**output_folders, **folders_all}
|
||||||
|
|
||||||
return (sorted(list(output_list)), output_folders, time.perf_counter())
|
return sorted(list(output_list)), output_folders, time.perf_counter()
|
||||||
|
|
||||||
def cached_filename_list_(folder_name):
|
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
|
||||||
global filename_list_cache
|
global filename_list_cache
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
if folder_name not in filename_list_cache:
|
if folder_name not in filename_list_cache:
|
||||||
@@ -222,7 +226,7 @@ def cached_filename_list_(folder_name):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_filename_list(folder_name):
|
def get_filename_list(folder_name: str) -> list[str]:
|
||||||
out = cached_filename_list_(folder_name)
|
out = cached_filename_list_(folder_name)
|
||||||
if out is None:
|
if out is None:
|
||||||
out = get_filename_list_(folder_name)
|
out = get_filename_list_(folder_name)
|
||||||
@@ -230,17 +234,17 @@ def get_filename_list(folder_name):
|
|||||||
filename_list_cache[folder_name] = out
|
filename_list_cache[folder_name] = out
|
||||||
return list(out[0])
|
return list(out[0])
|
||||||
|
|
||||||
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0):
|
def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
|
||||||
def map_filename(filename):
|
def map_filename(filename: str) -> tuple[int, str]:
|
||||||
prefix_len = len(os.path.basename(filename_prefix))
|
prefix_len = len(os.path.basename(filename_prefix))
|
||||||
prefix = filename[:prefix_len + 1]
|
prefix = filename[:prefix_len + 1]
|
||||||
try:
|
try:
|
||||||
digits = int(filename[prefix_len + 1:].split('_')[0])
|
digits = int(filename[prefix_len + 1:].split('_')[0])
|
||||||
except:
|
except:
|
||||||
digits = 0
|
digits = 0
|
||||||
return (digits, prefix)
|
return digits, prefix
|
||||||
|
|
||||||
def compute_vars(input, image_width, image_height):
|
def compute_vars(input: str, image_width: int, image_height: int) -> str:
|
||||||
input = input.replace("%width%", str(image_width))
|
input = input.replace("%width%", str(image_width))
|
||||||
input = input.replace("%height%", str(image_height))
|
input = input.replace("%height%", str(image_height))
|
||||||
return input
|
return input
|
||||||
|
|||||||
8
nodes.py
8
nodes.py
@@ -826,14 +826,14 @@ class UNETLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name, weight_dtype):
|
def load_unet(self, unet_name, weight_dtype):
|
||||||
dtype = None
|
model_options = {}
|
||||||
if weight_dtype == "fp8_e4m3fn":
|
if weight_dtype == "fp8_e4m3fn":
|
||||||
dtype = torch.float8_e4m3fn
|
model_options["dtype"] = torch.float8_e4m3fn
|
||||||
elif weight_dtype == "fp8_e5m2":
|
elif weight_dtype == "fp8_e5m2":
|
||||||
dtype = torch.float8_e5m2
|
model_options["dtype"] = torch.float8_e5m2
|
||||||
|
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
|
|||||||
Reference in New Issue
Block a user