Compare commits

...

47 Commits

Author SHA1 Message Date
comfyanonymous
b8ffb2937f Memory tweaks. 2024-08-12 15:07:11 -04:00
Vladimir Semyonov
ce37c11164 add DS_Store to gitignore (#4324) 2024-08-12 12:32:34 -04:00
Alex "mcmonkey" Goodwin
b5c3906b38 Automatically link the Comfy CI page on PRs (#4326)
also use_prior_commit so it doesn't get a janked merge commit instead of the real one
2024-08-12 12:32:16 -04:00
comfyanonymous
5d43e75e5b Fix some issues with the model sometimes not getting patched. 2024-08-12 12:27:54 -04:00
comfyanonymous
517f4a94e4 Fix some lora loading slowdowns. 2024-08-12 11:50:32 -04:00
comfyanonymous
52a471c5c7 Change name of log. 2024-08-12 10:35:06 -04:00
comfyanonymous
ad76574cb8 Fix some potential issues with the previous commits. 2024-08-12 00:23:29 -04:00
comfyanonymous
9acfe4df41 Support loading directly to vram with CLIPLoader node. 2024-08-12 00:06:01 -04:00
comfyanonymous
9829b013ea Fix mistake in last commit. 2024-08-12 00:00:17 -04:00
comfyanonymous
5c69cde037 Load TE model straight to vram if certain conditions are met. 2024-08-11 23:52:43 -04:00
comfyanonymous
e9589d6d92 Add a way to set model dtype and ops from load_checkpoint_guess_config. 2024-08-11 08:50:34 -04:00
comfyanonymous
0d82a798a5 Remove the ckpt_path from load_state_dict_guess_config. 2024-08-11 08:37:35 -04:00
ljleb
925fff26fd alternative to load_checkpoint_guess_config that accepts a loaded state dict (#4249)
* make alternative fn

* add back ckpt path as 2nd argument?
2024-08-11 08:36:52 -04:00
comfyanonymous
75b9b55b22 Fix issues with #4302 and support loading diffusers format flux. 2024-08-10 21:28:24 -04:00
Jaret Burkett
1765f1c60c FLUX: Added full diffusers mapping for FLUX.1 schnell and dev. Adds full LoRA support from diffusers LoRAs. (#4302) 2024-08-10 21:26:41 -04:00
comfyanonymous
1de69fe4d5 Fix some issues with inference slowing down. 2024-08-10 16:21:25 -04:00
comfyanonymous
ae197f651b Speed up hunyuan dit inference a bit. 2024-08-10 07:36:27 -04:00
comfyanonymous
1b5b8ca81a Fix regression. 2024-08-09 21:45:21 -04:00
comfyanonymous
6678d5cf65 Fix regression. 2024-08-09 14:02:38 -04:00
TTPlanetPig
e172564eea Update controlnet.py to fix the default controlnet weight as constant (#4285) 2024-08-09 13:40:05 -04:00
comfyanonymous
a3cc326748 Better fix for lowvram issue. 2024-08-09 12:16:25 -04:00
comfyanonymous
86a97e91fc Fix controlnet regression. 2024-08-09 12:08:58 -04:00
comfyanonymous
5acdadc9f3 Fix issue with some lowvram weights. 2024-08-09 03:58:28 -04:00
comfyanonymous
55ad9d5f8c Fix regression. 2024-08-09 03:36:40 -04:00
comfyanonymous
a9f04edc58 Implement text encoder part of HunyuanDiT loras. 2024-08-09 03:21:10 -04:00
comfyanonymous
a475ec2300 Cleanup HunyuanDit controlnets.
Use the: ControlNetApply SD3 and HunyuanDiT node.
2024-08-09 02:59:34 -04:00
来新璐
06eb9fb426 feat: add support for HunYuanDit ControlNet (#4245)
* add support for HunYuanDit ControlNet

* fix hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix typo in hunyuandit controlnet

* fix code format style

* add control_weight support for HunyuanDit Controlnet

* use control_weights in HunyuanDit Controlnet

* fix typo
2024-08-09 02:59:24 -04:00
comfyanonymous
413322645e Raw torch is faster than einops? 2024-08-08 22:09:29 -04:00
comfyanonymous
11200de970 Cleaner code. 2024-08-08 20:07:09 -04:00
comfyanonymous
037c38eb0f Try to improve inference speed on some machines. 2024-08-08 17:29:27 -04:00
comfyanonymous
1e11d2d1f5 Better prints. 2024-08-08 17:29:27 -04:00
Alex "mcmonkey" Goodwin
65ea6be38f PullRequest CI Run: use pull_request_target to allow the CI Dashboard to work (#4277)
'_target' allows secrets to pass through, and we're just using the secret that allows uploading to the dashboard and are manually vetting PRs before running this workflow anyway
2024-08-08 17:20:48 -04:00
Alex "mcmonkey" Goodwin
5df6f57b5d minor fix on copypasta action name (#4276)
my bad sorry
2024-08-08 16:30:59 -04:00
Alex "mcmonkey" Goodwin
6588bfdef9 add GitHub workflow for CI tests of PRs (#4275)
When the 'Run-CI-Test' label is added to a PR, it will be tested by the CI, on a small matrix of stable versions.
2024-08-08 16:24:49 -04:00
Alex "mcmonkey" Goodwin
50ed2879ef Add full CI test matrix GitHub Workflow (#4274)
automatically runs a matrix of full GPU-enabled tests on all new commits to the ComfyUI master branch
2024-08-08 15:40:07 -04:00
comfyanonymous
66d4233210 Fix. 2024-08-08 15:16:51 -04:00
comfyanonymous
591010b7ef Support diffusers text attention flux loras. 2024-08-08 14:45:52 -04:00
comfyanonymous
08f92d55e9 Partial model shift support. 2024-08-08 14:45:06 -04:00
comfyanonymous
8115d8cce9 Add Flux fp16 support hack. 2024-08-07 15:08:39 -04:00
comfyanonymous
6969fc9ba4 Make supported_dtypes a priority list. 2024-08-07 15:00:06 -04:00
comfyanonymous
cb7c4b4be3 Workaround for lora OOM on lowvram mode. 2024-08-07 14:30:54 -04:00
comfyanonymous
1208863eca Fix "Comfy" lora keys.
They are in this format now:
diffusion_model.full.model.key.name.lora_up.weight
2024-08-07 13:49:31 -04:00
comfyanonymous
e1c528196e Fix bundled embed. 2024-08-07 13:30:45 -04:00
comfyanonymous
17030fd4c0 Support for "Comfy" lora format.
The keys are just: model.full.model.key.name.lora_up.weight

It is supported by all comfyui supported models.

Now people can just convert loras to this format instead of having to ask
for me to implement them.
2024-08-07 13:18:32 -04:00
comfyanonymous
c19dcd362f Controlnet code refactor. 2024-08-07 12:59:28 -04:00
comfyanonymous
1c08bf35b4 Support format for embeddings bundled in loras. 2024-08-07 03:45:25 -04:00
PhilWun
2a02546e20 Add type hints to folder_paths.py (#4191)
* add type hints to folder_paths.py

* replace deprecated standard collections type hints

* fix type error when using Python 3.8
2024-08-06 21:59:34 -04:00
21 changed files with 1103 additions and 178 deletions

View File

@@ -0,0 +1,53 @@
# This is the GitHub Workflow that drives full-GPU-enabled tests of pull requests to ComfyUI, when the 'Run-CI-Test' label is added
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Pull Request CI Workflow Runs
on:
pull_request_target:
types: [labeled]
jobs:
pr-test-stable:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
use_prior_commit: 'true'
comment:
if: ${{ github.event.label.name == 'Run-CI-Test' }}
runs-on: ubuntu-latest
permissions:
pull-requests: write
steps:
- uses: actions/github-script@v6
with:
script: |
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
})

95
.github/workflows/test-ci.yml vendored Normal file
View File

@@ -0,0 +1,95 @@
# This is the GitHub Workflow that drives automatic full-GPU-enabled tests of all new commits to the master branch of ComfyUI
# Results are reported as checkmarks on the commits, as well as onto https://ci.comfy.org/
name: Full Comfy CI Workflow Runs
on:
push:
branches:
- master
paths-ignore:
- 'app/**'
- 'input/**'
- 'output/**'
- 'notebooks/**'
- 'script_examples/**'
- '.github/**'
- 'web/**'
workflow_dispatch:
jobs:
test-stable:
strategy:
fail-fast: false
matrix:
os: [macos, linux, windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["stable"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-win-nightly:
strategy:
fail-fast: true
matrix:
os: [windows]
python_version: ["3.9", "3.10", "3.11", "3.12"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: windows
runner_label: [self-hosted, win]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}
test-unix-nightly:
strategy:
fail-fast: false
matrix:
os: [macos, linux]
python_version: ["3.11"]
cuda_version: ["12.1"]
torch_version: ["nightly"]
include:
- os: macos
runner_label: [self-hosted, macOS]
flags: "--use-pytorch-cross-attention"
- os: linux
runner_label: [self-hosted, Linux]
flags: ""
runs-on: ${{ matrix.runner_label }}
steps:
- name: Test Workflows
uses: comfy-org/comfy-action@main
with:
os: ${{ matrix.os }}
python_version: ${{ matrix.python_version }}
torch_version: ${{ matrix.torch_version }}
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
comfyui_flags: ${{ matrix.flags }}

1
.gitignore vendored
View File

@@ -19,3 +19,4 @@ venv/
/user/ /user/
*.log *.log
web_custom_versions/ web_custom_versions/
.DS_Store

View File

@@ -1,4 +1,24 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from enum import Enum
import math import math
import os import os
import logging import logging
@@ -13,7 +33,7 @@ import comfy.cldm.cldm
import comfy.t2i_adapter.adapter import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
@@ -33,6 +53,10 @@ def broadcast_image_to(tensor, target_batch_size, batched_number):
else: else:
return torch.cat([tensor] * batched_number, dim=0) return torch.cat([tensor] * batched_number, dim=0)
class StrengthType(Enum):
CONSTANT = 1
LINEAR_UP = 2
class ControlBase: class ControlBase:
def __init__(self, device=None): def __init__(self, device=None):
self.cond_hint_original = None self.cond_hint_original = None
@@ -51,6 +75,8 @@ class ControlBase:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
self.device = device self.device = device
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@@ -93,6 +119,8 @@ class ControlBase:
c.latent_format = self.latent_format c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy() c.extra_args = self.extra_args.copy()
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@@ -113,7 +141,10 @@ class ControlBase:
if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once if x not in applied_to: #memory saving strategy, allow shared tensors and only apply strength to shared tensors once
applied_to.add(x) applied_to.add(x)
x *= self.strength if self.strength_type == StrengthType.CONSTANT:
x *= self.strength
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype: if x.dtype != output_dtype:
x = x.to(output_dtype) x = x.to(output_dtype)
@@ -142,7 +173,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@@ -154,6 +185,8 @@ class ControlNet(ControlBase):
self.model_sampling_current = None self.model_sampling_current = None
self.manual_cast_dtype = manual_cast_dtype self.manual_cast_dtype = manual_cast_dtype
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@@ -191,13 +224,16 @@ class ControlNet(ControlBase):
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
context = cond.get('crossattn_controlnet', cond['c_crossattn']) context = cond.get('crossattn_controlnet', cond['c_crossattn'])
y = cond.get('y', None) extra = self.extra_args.copy()
if y is not None: for c in self.extra_conds:
y = y.to(dtype) temp = cond.get(c, None)
if temp is not None:
extra[c] = temp.to(dtype)
timestep = self.model_sampling_current.timestep(t) timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype) return self.control_merge(control, control_prev, output_dtype)
def copy(self): def copy(self):
@@ -286,6 +322,7 @@ class ControlLora(ControlNet):
ControlBase.__init__(self, device) ControlBase.__init__(self, device)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
super().pre_run(model, percent_to_timestep_function) super().pre_run(model, percent_to_timestep_function)
@@ -338,12 +375,8 @@ class ControlLora(ControlNet):
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return comfy.utils.calculate_parameters(self.control_weights) * comfy.model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def load_controlnet_mmdit(sd): def controlnet_config(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") model_config = comfy.model_detection.model_config_from_unet(sd, "", True)
model_config = comfy.model_detection.model_config_from_unet(new_sd, "", True)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
supported_inference_dtypes = model_config.supported_inference_dtypes supported_inference_dtypes = model_config.supported_inference_dtypes
@@ -356,23 +389,49 @@ def load_controlnet_mmdit(sd):
else: else:
operations = comfy.ops.disable_weight_init operations = comfy.ops.disable_weight_init
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **controlnet_config) return model_config, operations, load_device, unet_dtype, manual_cast_dtype
missing, unexpected = control_model.load_state_dict(new_sd, strict=False)
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
if len(missing) > 0: if len(missing) > 0:
logging.warning("missing controlnet keys: {}".format(missing)) logging.warning("missing controlnet keys: {}".format(missing))
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3() latent_format = comfy.latent_formats.SD3()
latent_format.shift_factor = 0 #SD3 controlnet weirdness latent_format.shift_factor = 0 #SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
extra_conds = ['text_embedding_mask', 'encoder_hidden_states_t5', 'text_embedding_mask_t5', 'image_meta_size', 'style', 'cos_cis_img', 'sin_cis_img']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
return control
def load_controlnet(ckpt_path, model=None): def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True) controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data)

View File

@@ -2,12 +2,12 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from einops import rearrange
from torch import Tensor, nn from torch import Tensor, nn
from .math import attention, rope from .math import attention, rope
import comfy.ops import comfy.ops
class EmbedND(nn.Module): class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list): def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__() super().__init__()
@@ -36,9 +36,7 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
""" """
t = time_factor * t t = time_factor * t
half = dim // 2 half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
t.device
)
args = t[:, None].float() * freqs[None] args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
@@ -48,7 +46,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
embedding = embedding.to(t) embedding = embedding.to(t)
return embedding return embedding
class MLPEmbedder(nn.Module): class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__() super().__init__()
@@ -94,14 +91,6 @@ class SelfAttention(nn.Module):
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass @dataclass
class ModulationOut: class ModulationOut:
@@ -163,31 +152,34 @@ class DoubleStreamBlock(nn.Module):
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated) img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention # prepare txt for attention
txt_modulated = self.txt_norm1(txt) txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated) txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention # run actual attention
q = torch.cat((txt_q, img_q), dim=2) attn = attention(torch.cat((txt_q, img_q), dim=2),
k = torch.cat((txt_k, img_k), dim=2) torch.cat((txt_k, img_k), dim=2),
v = torch.cat((txt_v, img_v), dim=2) torch.cat((txt_v, img_v), dim=2), pe=pe)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks # calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn) img += img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks # calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
return img, txt return img, txt
@@ -232,14 +224,17 @@ class SingleStreamBlock(nn.Module):
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe) attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
return x
class LastLayer(nn.Module): class LastLayer(nn.Module):

View File

@@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
def rotate_half(x): def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3) return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
@@ -78,10 +78,9 @@ def apply_rotary_emb(
xk_out = None xk_out = None
if isinstance(freqs_cis, tuple): if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device) xq_out = (xq * cos + rotate_half(xq) * sin)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None: if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) xk_out = (xk * cos + rotate_half(xk) * sin)
else: else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]

View File

@@ -0,0 +1,321 @@
from typing import Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import checkpoint
from comfy.ldm.modules.diffusionmodules.mmdit import (
Mlp,
TimestepEmbedder,
PatchEmbed,
RMSNorm,
)
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from .poolers import AttentionPool
import comfy.latent_formats
from .models import HunYuanDiTBlock, calc_rope
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
class HunYuanControlNet(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
def __init__(
self,
input_size: tuple = 128,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1408,
depth: int = 40,
num_heads: int = 16,
mlp_ratio: float = 4.3637,
text_states_dim=1024,
text_states_dim_t5=2048,
text_len=77,
text_len_t5=256,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
size_cond=False,
use_style_cond=False,
learn_sigma=True,
norm="layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
self.latent_format = comfy.latent_formats.SDXL
self.mlp_t5 = nn.Sequential(
nn.Linear(
self.text_states_dim_t5,
self.text_states_dim_t5 * 4,
bias=True,
dtype=dtype,
device=device,
),
nn.SiLU(),
nn.Linear(
self.text_states_dim_t5 * 4,
self.text_states_dim,
bias=True,
dtype=dtype,
device=device,
),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.randn(
self.text_len + self.text_len_t5,
self.text_states_dim,
dtype=dtype,
device=device,
)
)
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(
self.text_len_t5,
self.text_states_dim_t5,
num_heads=8,
output_dim=pooler_out_dim,
dtype=dtype,
device=device,
operations=operations,
)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(
1, hidden_size, dtype=dtype, device=device
)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(
input_size,
patch_size,
in_channels,
hidden_size,
dtype=dtype,
device=device,
operations=operations,
)
self.t_embedder = TimestepEmbedder(
hidden_size, dtype=dtype, device=device, operations=operations
)
self.extra_embedder = nn.Sequential(
operations.Linear(
self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device
),
nn.SiLU(),
operations.Linear(
hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device
),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunYuanDiTBlock(
hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=False,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for _ in range(19)
]
)
# Input zero linear for the first block
self.before_proj = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# Output zero linear for the every block
self.after_proj_list = nn.ModuleList(
[
operations.Linear(
self.hidden_size, self.hidden_size, dtype=dtype, device=device
)
for _ in range(len(self.blocks))
]
)
def forward(
self,
x,
hint,
timesteps,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
**kwarg,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
condition = hint
if condition.shape[0] == 1:
condition = torch.repeat_interleave(condition, x.shape[0], dim=0)
text_states = context # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:, -self.text_len :] = torch.where(
text_states_mask[:, -self.text_len :].unsqueeze(2),
text_states[:, -self.text_len :],
padding[: self.text_len],
)
text_states_t5[:, -self.text_len_t5 :] = torch.where(
text_states_t5_mask[:, -self.text_len_t5 :].unsqueeze(2),
text_states_t5[:, -self.text_len_t5 :],
padding[self.text_len :],
)
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# _, _, oh, ow = x.shape
# th, tw = oh // self.patch_size, ow // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(
x, self.patch_size, self.hidden_size // self.num_heads
) # (cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(timesteps, dtype=self.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
# if image_meta_size is not None:
# image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256]
# if image_meta_size.dtype != self.dtype:
# image_meta_size = image_meta_size.half()
# image_meta_size = image_meta_size.view(-1, 6 * 256)
# extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if style is not None:
style_embedding = self.style_embedder(style)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
# ========================= Deal with Condition =========================
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
return {"output": controls}

View File

@@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
sub_args = [start, stop, (th, tw)] sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads'] # head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args) rope = get_2d_rotary_pos_embed(head_size, *sub_args)
rope = (rope[0].to(x), rope[1].to(x))
return rope return rope
@@ -91,6 +92,8 @@ class HunYuanDiTBlock(nn.Module):
# Long Skip Connection # Long Skip Connection
if self.skip_linear is not None: if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1) cat = torch.cat([x, skip], dim=-1)
if cat.dtype != x.dtype:
cat = cat.to(x.dtype)
cat = self.skip_norm(cat) cat = self.skip_norm(cat)
x = self.skip_linear(cat) x = self.skip_linear(cat)
@@ -362,6 +365,8 @@ class HunYuanDiT(nn.Module):
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None controls = None
if control:
controls = control.get("output", None)
# ========================= Forward pass through HunYuanDiT blocks ========================= # ========================= Forward pass through HunYuanDiT blocks =========================
skips = [] skips = []
for layer, block in enumerate(self.blocks): for layer, block in enumerate(self.blocks):

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import comfy.utils import comfy.utils
import logging import logging
@@ -218,11 +236,17 @@ def model_lora_keys_clip(model, key_map={}):
lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config lora_key = "lora_prior_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #cascade lora: TODO put lora key prefix in the model config
key_map[lora_key] = k key_map[lora_key] = k
for k in sdk: #OneTrainer SD3 lora for k in sdk:
if k.startswith("t5xxl.transformer.") and k.endswith(".weight"): if k.endswith(".weight"):
l_key = k[len("t5xxl.transformer."):-len(".weight")] if k.startswith("t5xxl.transformer."):#OneTrainer SD3 lora
lora_key = "lora_te3_{}".format(l_key.replace(".", "_")) l_key = k[len("t5xxl.transformer."):-len(".weight")]
key_map[lora_key] = k lora_key = "lora_te3_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
elif k.startswith("hydit_clip.transformer.bert."): #HunyuanDiT Lora
l_key = k[len("hydit_clip.transformer.bert."):-len(".weight")]
lora_key = "lora_te1_{}".format(l_key.replace(".", "_"))
key_map[lora_key] = k
k = "clip_g.transformer.text_projection.weight" k = "clip_g.transformer.text_projection.weight"
if k in sdk: if k in sdk:
@@ -245,6 +269,7 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_") key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config) diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys: for k in diffusers_keys:

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
import logging import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module):
self.device = device self.device = device
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None: if model_config.custom_operations is None:
operations = comfy.ops.manual_cast if self.manual_cast_dtype is not None:
operations = comfy.ops.manual_cast
else:
operations = comfy.ops.disable_weight_init
else: else:
operations = comfy.ops.disable_weight_init operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
if comfy.model_management.force_channels_last(): if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)

View File

@@ -137,8 +137,8 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["hidden_size"] = 3072 dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0 dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24 dit_config["num_heads"] = 24
dit_config["depth"] = 19 dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = 38 dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
dit_config["axes_dim"] = [16, 56, 56] dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000 dit_config["theta"] = 10000
dit_config["qkv_bias"] = True dit_config["qkv_bias"] = True
@@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
def convert_diffusers_mmdit(state_dict, output_prefix=""): def convert_diffusers_mmdit(state_dict, output_prefix=""):
out_sd = {} out_sd = {}
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3 if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
hidden_size = state_dict["x_embedder.bias"].shape[0]
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64 depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix) sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
@@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
old_weight = out_sd.get(t[0], None) old_weight = out_sd.get(t[0], None)
if old_weight is None: if old_weight is None:
old_weight = torch.empty_like(weight) old_weight = torch.empty_like(weight)
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1)) if old_weight.shape[offset[0]] < offset[1] + offset[2]:
exp = list(weight.shape)
exp[offset[0]] = offset[1] + offset[2]
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
new[:old_weight.shape[0]] = old_weight
old_weight = new
w = old_weight.narrow(offset[0], offset[1], offset[2]) w = old_weight.narrow(offset[0], offset[1], offset[2])
else: else:

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import psutil import psutil
import logging import logging
from enum import Enum from enum import Enum
@@ -273,9 +291,12 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_offloaded_memory(self):
return self.model.model_size() - self.model.loaded_size()
def model_memory_required(self, device): def model_memory_required(self, device):
if device == self.model.current_loaded_device(): if device == self.model.current_loaded_device():
return 0 return self.model_offloaded_memory()
else: else:
return self.model_memory() return self.model_memory()
@@ -287,15 +308,21 @@ class LoadedModel:
load_weights = not self.weights_loaded load_weights = not self.weights_loaded
try: if self.model.loaded_size() > 0:
if lowvram_model_memory > 0 and load_weights: use_more_vram = lowvram_model_memory
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights) if use_more_vram == 0:
else: use_more_vram = 1e32
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights) self.model_use_more_vram(use_more_vram)
except Exception as e: else:
self.model.unpatch_model(self.model.offload_device) try:
self.model_unload() if lowvram_model_memory > 0 and load_weights:
raise e self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize: if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True) self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
@@ -304,19 +331,41 @@ class LoadedModel:
return self.real_model return self.real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
if force_patch_weights and self.model.lowvram_patch_counter > 0: if force_patch_weights and self.model.lowvram_patch_counter() > 0:
return True return True
return False return False
def model_unload(self, unpatch_weights=True): def model_unload(self, memory_to_free=None, unpatch_weights=True):
if memory_to_free is not None:
if memory_to_free < self.model.loaded_size():
self.model.partially_unload(self.model.offload_device, memory_to_free)
return False
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights) self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device) self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None self.real_model = None
return True
def model_use_more_vram(self, extra_memory):
return self.model.partially_load(self.device, extra_memory)
def __eq__(self, other): def __eq__(self, other):
return self.model is other.model return self.model is other.model
def use_more_memory(extra_memory, loaded_models, device):
for m in loaded_models:
if m.device == device:
extra_memory -= m.model_use_more_vram(extra_memory)
if extra_memory <= 0:
break
def offloaded_memory(loaded_models, device):
offloaded_mem = 0
for m in loaded_models:
if m.device == device:
offloaded_mem += m.model_offloaded_memory()
return offloaded_mem
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2 return (1024 * 1024 * 1024) * 1.2
@@ -363,11 +412,15 @@ def free_memory(memory_required, device, keep_loaded=[]):
for x in sorted(can_unload): for x in sorted(can_unload):
i = x[-1] i = x[-1]
memory_to_free = None
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
if get_free_memory(device) > memory_required: free_mem = get_free_memory(device)
if free_mem > memory_required:
break break
current_loaded_models[i].model_unload() memory_to_free = memory_required - free_mem
unloaded_model.append(i) logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
if current_loaded_models[i].model_unload(memory_to_free):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@@ -385,11 +438,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
if minimum_memory_required is None: if minimum_memory_required is None:
minimum_memory_required = extra_mem minimum_memory_required = extra_mem
else: else:
minimum_memory_required = max(inference_memory, minimum_memory_required) minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
models = set(models) models = set(models)
@@ -422,12 +475,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
devs = set(map(lambda a: a.device, models_already_loaded)) devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs: for d in devs:
if d != torch.device("cpu"): if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem + offloaded_memory(models_already_loaded, d), d, models_already_loaded)
free_mem = get_free_memory(d) free_mem = get_free_memory(d)
if free_mem < minimum_memory_required: if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed. logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d) models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load))) logging.info("{} models unloaded.".format(len(models_to_load)))
else:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
if len(models_to_load) == 0: if len(models_to_load) == 0:
return return
@@ -435,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
for device in total_memory_required: for loaded_model in models_already_loaded:
if device != torch.device("cpu"): total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: for loaded_model in models_to_load:
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
if weights_unloaded is not None: if weights_unloaded is not None:
loaded_model.weights_loaded = not weights_unloaded loaded_model.weights_loaded = not weights_unloaded
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
for loaded_model in models_to_load: for loaded_model in models_to_load:
model = loaded_model.model model = loaded_model.model
torch_dev = model.load_device torch_dev = model.load_device
@@ -467,6 +525,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
devs = set(map(lambda a: a.device, models_already_loaded))
for d in devs:
if d != torch.device("cpu"):
free_mem = get_free_memory(d)
if free_mem > minimum_memory_required:
use_more_memory(free_mem - minimum_memory_required, models_already_loaded, d)
return return
@@ -562,12 +628,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory: if model_params * 2 > free_model_memory:
return fp8_dtype return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True): for dt in supported_dtypes:
if torch.float16 in supported_dtypes: if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
return torch.float16 if torch.float16 in supported_dtypes:
if should_use_bf16(device, model_params=model_params, manual_cast=True): return torch.float16
if torch.bfloat16 in supported_dtypes: if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
return torch.bfloat16 if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
return torch.float32 return torch.float32
# None means no manual cast # None means no manual cast
@@ -583,13 +659,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16: if bf16_supported and weight_dtype == torch.bfloat16:
return None return None
if fp16_supported and torch.float16 in supported_dtypes: for dt in supported_dtypes:
return torch.float16 if dt == torch.float16 and fp16_supported:
return torch.float16
if dt == torch.bfloat16 and bf16_supported:
return torch.bfloat16
elif bf16_supported and torch.bfloat16 in supported_dtypes: return torch.float32
return torch.bfloat16
else:
return torch.float32
def text_encoder_offload_device(): def text_encoder_offload_device():
if args.gpu_only: if args.gpu_only:
@@ -608,6 +684,20 @@ def text_encoder_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
if is_device_mps(load_device):
return offload_device
mem_l = get_free_memory(load_device)
mem_o = get_free_memory(offload_device)
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
return load_device
else:
return offload_device
def text_encoder_dtype(device=None): def text_encoder_dtype(device=None):
if args.fp8_e4m3fn_text_enc: if args.fp8_e4m3fn_text_enc:
return torch.float8_e4m3fn return torch.float8_e4m3fn

View File

@@ -1,8 +1,27 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
import copy import copy
import inspect import inspect
import logging import logging
import uuid import uuid
import collections
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
@@ -63,12 +82,27 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
model_options["disable_cfg1_optimization"] = True model_options["disable_cfg1_optimization"] = True
return model_options return model_options
def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
self.model = model self.model = model
if not hasattr(self.model, 'device'): if not hasattr(self.model, 'device'):
logging.info("Model doesn't have a device attribute.") logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device self.model.device = offload_device
elif self.model.device is None: elif self.model.device is None:
self.model.device = offload_device self.model.device = offload_device
@@ -82,16 +116,29 @@ class ModelPatcher:
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.lowvram_patch_counter = 0
self.patches_uuid = uuid.uuid4() self.patches_uuid = uuid.uuid4()
if not hasattr(self.model, 'model_loaded_weight_memory'):
self.model.model_loaded_weight_memory = 0
if not hasattr(self.model, 'lowvram_patch_counter'):
self.model.lowvram_patch_counter = 0
if not hasattr(self.model, 'model_lowvram'):
self.model.model_lowvram = False
def model_size(self): def model_size(self):
if self.size > 0: if self.size > 0:
return self.size return self.size
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def loaded_size(self):
return self.model.model_loaded_weight_memory
def lowvram_patch_counter(self):
return self.model.lowvram_patch_counter
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
@@ -265,16 +312,16 @@ class ModelPatcher:
sd.pop(k) sd.pop(k)
return sd return sd
def patch_weight_to_device(self, key, device_to=None): def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.patches: if key not in self.patches:
return return
weight = comfy.utils.get_attr(self.model, key) weight = comfy.utils.get_attr(self.model, key)
inplace_update = self.weight_inplace_update inplace_update = self.weight_inplace_update or inplace_update
if key not in self.backup: if key not in self.backup:
self.backup[key] = weight.to(device=self.offload_device, copy=inplace_update) self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
if device_to is not None: if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True) temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
@@ -304,28 +351,24 @@ class ModelPatcher:
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = self.model_size()
return self.model return self.model
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False): def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
self.patch_model(device_to, patch_weights=False)
logging.info("loading in lowvram mode {}".format(lowvram_model_memory/(1024 * 1024)))
class LowVramPatch:
def __init__(self, key, model_patcher):
self.key = key
self.model_patcher = model_patcher
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
mem_counter = 0 mem_counter = 0
patch_counter = 0 patch_counter = 0
lowvram_counter = 0
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
lowvram_weight = False lowvram_weight = False
if hasattr(m, "comfy_cast_weights"):
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory: if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
continue
weight_key = "{}.weight".format(n) weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n) bias_key = "{}.bias".format(n)
@@ -347,16 +390,40 @@ class ModelPatcher:
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
else: else:
if hasattr(m, "comfy_cast_weights"):
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"): if hasattr(m, "weight"):
self.patch_weight_to_device(weight_key, device_to)
self.patch_weight_to_device(bias_key, device_to)
m.to(device_to)
mem_counter += comfy.model_management.module_size(m) mem_counter += comfy.model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
weight_to = None
if full_load:#TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m)) logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
self.model_lowvram = True if lowvram_counter > 0:
self.lowvram_patch_counter = patch_counter logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
self.model.model_lowvram = False
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@@ -529,31 +596,28 @@ class ModelPatcher:
def unpatch_model(self, device_to=None, unpatch_weights=True): def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights: if unpatch_weights:
if self.model_lowvram: if self.model.model_lowvram:
for m in self.model.modules(): for m in self.model.modules():
if hasattr(m, "prev_comfy_cast_weights"): wipe_lowvram_weight(m)
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
m.weight_function = None
m.bias_function = None
self.model_lowvram = False self.model.model_lowvram = False
self.lowvram_patch_counter = 0 self.model.lowvram_patch_counter = 0
keys = list(self.backup.keys()) keys = list(self.backup.keys())
if self.weight_inplace_update: for k in keys:
for k in keys: bk = self.backup[k]
comfy.utils.copy_to_param(self.model, k, self.backup[k]) if bk.inplace_update:
else: comfy.utils.copy_to_param(self.model, k, bk.weight)
for k in keys: else:
comfy.utils.set_attr_param(self.model, k, self.backup[k]) comfy.utils.set_attr_param(self.model, k, bk.weight)
self.backup.clear() self.backup.clear()
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.model.device = device_to
self.model.model_loaded_weight_memory = 0
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
@@ -561,5 +625,60 @@ class ModelPatcher:
self.object_patches_backup.clear() self.object_patches_backup.clear()
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
return memory_freed
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
full_load = False
if self.model.model_lowvram == False:
return 0
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self.model.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self.model.model_loaded_weight_memory - current_used
def current_loaded_device(self): def current_loaded_device(self):
return self.model.device return self.model.device

View File

@@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
class CLIP: class CLIP:
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}): def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
if no_init: if no_init:
return return
params = target.params.copy() params = target.params.copy()
@@ -71,20 +71,24 @@ class CLIP:
load_device = model_management.text_encoder_device() load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device() offload_device = model_management.text_encoder_offload_device()
params['device'] = offload_device
dtype = model_management.text_encoder_dtype(load_device) dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
self.cond_stage_model = clip(**(params)) self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes: for dt in self.cond_stage_model.dtypes:
if not model_management.supports_cast(load_device, dt): if not model_management.supports_cast(load_device, dt):
load_device = offload_device load_device = offload_device
if params['device'] != offload_device:
self.cond_stage_model.to(offload_device)
logging.warning("Had to shift TE back.")
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
if params['device'] == load_device:
model_management.load_model_gpu(self.patcher)
self.layer_idx = None self.layer_idx = None
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device)) logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@@ -456,7 +460,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
clip = CLIP(clip_target, embedding_directory=embedding_directory) parameters = 0
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@@ -498,15 +506,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
sd = comfy.utils.load_torch_file(ckpt_path) sd = comfy.utils.load_torch_file(ckpt_path)
sd_keys = sd.keys() out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
model = None model = None
model_patcher = None model_patcher = None
clip_target = None
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
@@ -515,13 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) return None
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None: if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype) unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype) model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("weight_dtype", None)
if unet_dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -545,7 +562,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
if clip_target is not None: if clip_target is not None:
clip_sd = model_config.process_clip_state_dict(sd) clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0: if len(clip_sd) > 0:
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd) parameters = comfy.utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
m, u = clip.load_sd(clip_sd, full_model=True) m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0: if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))

View File

@@ -313,6 +313,17 @@ def expand_directory_list(directories):
dirs.add(root) dirs.add(root)
return list(dirs) return list(dirs)
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
i = 0
out_list = []
for k in embed:
if k.startswith(prefix) and k.endswith(suffix):
out_list.append(embed[k])
if len(out_list) == 0:
return None
return torch.cat(out_list, dim=0)
def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None): def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=None):
if isinstance(embedding_directory, str): if isinstance(embedding_directory, str):
embedding_directory = [embedding_directory] embedding_directory = [embedding_directory]
@@ -379,8 +390,12 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
elif embed_key is not None and embed_key in embed: elif embed_key is not None and embed_key in embed:
embed_out = embed[embed_key] embed_out = embed[embed_key]
else: else:
values = embed.values() embed_out = bundled_embed(embed, 'bundle_emb.', '.string_to_param.*')
embed_out = next(iter(values)) if embed_out is None:
embed_out = bundled_embed(embed, 'bundle_emb.', '.{}'.format(embed_key))
if embed_out is None:
values = embed.values()
embed_out = next(iter(values))
return embed_out return embed_out
class SDTokenizer: class SDTokenizer:

View File

@@ -642,7 +642,7 @@ class Flux(supported_models_base.BASE):
memory_usage_factor = 2.8 memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]

View File

@@ -1,3 +1,21 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
from . import model_base from . import model_base
from . import utils from . import utils
@@ -30,6 +48,7 @@ class BASE:
memory_usage_factor = 2.0 memory_usage_factor = 2.0
manual_cast_dtype = None manual_cast_dtype = None
custom_operations = None
@classmethod @classmethod
def matches(s, unet_config, state_dict=None): def matches(s, unet_config, state_dict=None):

View File

@@ -1,3 +1,22 @@
"""
This file is part of ComfyUI.
Copyright (C) 2024 Comfy
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
import torch import torch
import math import math
import struct import struct
@@ -432,8 +451,33 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight", k = "{}.attn.".format(prefix_from)
"attn.to_out.0.bias": "img_attn.proj.bias", qkv = "{}.txt_attn.qkv.{}".format(prefix_to, end)
key_map["{}add_q_proj.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
@@ -449,15 +493,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size)) key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
block_map = {#TODO block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO MAP_BASIC = {
("final_layer.linear.bias", "proj_out.bias"),
("final_layer.linear.weight", "proj_out.weight"),
("img_in.bias", "x_embedder.bias"),
("img_in.weight", "x_embedder.weight"),
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
("txt_in.bias", "context_embedder.bias"),
("txt_in.weight", "context_embedder.weight"),
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
} }
for k in MAP_BASIC: for k in MAP_BASIC:

View File

@@ -19,6 +19,7 @@ class CLIPTextEncodeHunyuanDiT:
cond = output.pop("cond") cond = output.pop("cond")
return ([[cond, output]], ) return ([[cond, output]], )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
} }

View File

@@ -100,3 +100,8 @@ NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeSD3": CLIPTextEncodeSD3, "CLIPTextEncodeSD3": CLIPTextEncodeSD3,
"ControlNetApplySD3": ControlNetApplySD3, "ControlNetApplySD3": ControlNetApplySD3,
} }
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
}

View File

@@ -1,13 +1,13 @@
from __future__ import annotations
import os import os
import time import time
import logging import logging
from typing import Set, List, Dict, Tuple from collections.abc import Collection
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) supported_pt_extensions: set[str] = {'.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'}
SupportedFileExtensionsType = Set[str] folder_names_and_paths: dict[str, tuple[list[str], set[str]]] = {}
ScanPathType = List[str]
folder_names_and_paths: Dict[str, Tuple[ScanPathType, SupportedFileExtensionsType]] = {}
base_path = os.path.dirname(os.path.realpath(__file__)) base_path = os.path.dirname(os.path.realpath(__file__))
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
@@ -42,7 +42,7 @@ temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp
input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input")
user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user") user_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "user")
filename_list_cache = {} filename_list_cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try: try:
@@ -50,33 +50,33 @@ if not os.path.exists(input_directory):
except: except:
logging.error("Failed to create input directory") logging.error("Failed to create input directory")
def set_output_directory(output_dir): def set_output_directory(output_dir: str) -> None:
global output_directory global output_directory
output_directory = output_dir output_directory = output_dir
def set_temp_directory(temp_dir): def set_temp_directory(temp_dir: str) -> None:
global temp_directory global temp_directory
temp_directory = temp_dir temp_directory = temp_dir
def set_input_directory(input_dir): def set_input_directory(input_dir: str) -> None:
global input_directory global input_directory
input_directory = input_dir input_directory = input_dir
def get_output_directory(): def get_output_directory() -> str:
global output_directory global output_directory
return output_directory return output_directory
def get_temp_directory(): def get_temp_directory() -> str:
global temp_directory global temp_directory
return temp_directory return temp_directory
def get_input_directory(): def get_input_directory() -> str:
global input_directory global input_directory
return input_directory return input_directory
#NOTE: used in http server so don't put folders that should not be accessed remotely #NOTE: used in http server so don't put folders that should not be accessed remotely
def get_directory_by_type(type_name): def get_directory_by_type(type_name: str) -> str | None:
if type_name == "output": if type_name == "output":
return get_output_directory() return get_output_directory()
if type_name == "temp": if type_name == "temp":
@@ -88,7 +88,7 @@ def get_directory_by_type(type_name):
# determine base_dir rely on annotation if name is 'filename.ext [annotation]' format # determine base_dir rely on annotation if name is 'filename.ext [annotation]' format
# otherwise use default_path as base_dir # otherwise use default_path as base_dir
def annotated_filepath(name): def annotated_filepath(name: str) -> tuple[str, str | None]:
if name.endswith("[output]"): if name.endswith("[output]"):
base_dir = get_output_directory() base_dir = get_output_directory()
name = name[:-9] name = name[:-9]
@@ -104,7 +104,7 @@ def annotated_filepath(name):
return name, base_dir return name, base_dir
def get_annotated_filepath(name, default_dir=None): def get_annotated_filepath(name: str, default_dir: str | None=None) -> str:
name, base_dir = annotated_filepath(name) name, base_dir = annotated_filepath(name)
if base_dir is None: if base_dir is None:
@@ -116,7 +116,7 @@ def get_annotated_filepath(name, default_dir=None):
return os.path.join(base_dir, name) return os.path.join(base_dir, name)
def exists_annotated_filepath(name): def exists_annotated_filepath(name) -> bool:
name, base_dir = annotated_filepath(name) name, base_dir = annotated_filepath(name)
if base_dir is None: if base_dir is None:
@@ -126,17 +126,17 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath) return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path): def add_model_folder_path(folder_name: str, full_folder_path: str) -> None:
global folder_names_and_paths global folder_names_and_paths
if folder_name in folder_names_and_paths: if folder_name in folder_names_and_paths:
folder_names_and_paths[folder_name][0].append(full_folder_path) folder_names_and_paths[folder_name][0].append(full_folder_path)
else: else:
folder_names_and_paths[folder_name] = ([full_folder_path], set()) folder_names_and_paths[folder_name] = ([full_folder_path], set())
def get_folder_paths(folder_name): def get_folder_paths(folder_name: str) -> list[str]:
return folder_names_and_paths[folder_name][0][:] return folder_names_and_paths[folder_name][0][:]
def recursive_search(directory, excluded_dir_names=None): def recursive_search(directory: str, excluded_dir_names: list[str] | None=None) -> tuple[list[str], dict[str, float]]:
if not os.path.isdir(directory): if not os.path.isdir(directory):
return [], {} return [], {}
@@ -153,6 +153,10 @@ def recursive_search(directory, excluded_dir_names=None):
logging.warning(f"Warning: Unable to access {directory}. Skipping this path.") logging.warning(f"Warning: Unable to access {directory}. Skipping this path.")
logging.debug("recursive file list on directory {}".format(directory)) logging.debug("recursive file list on directory {}".format(directory))
dirpath: str
subdirs: list[str]
filenames: list[str]
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True): for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names] subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
for file_name in filenames: for file_name in filenames:
@@ -160,7 +164,7 @@ def recursive_search(directory, excluded_dir_names=None):
result.append(relative_path) result.append(relative_path)
for d in subdirs: for d in subdirs:
path = os.path.join(dirpath, d) path: str = os.path.join(dirpath, d)
try: try:
dirs[path] = os.path.getmtime(path) dirs[path] = os.path.getmtime(path)
except FileNotFoundError: except FileNotFoundError:
@@ -169,12 +173,12 @@ def recursive_search(directory, excluded_dir_names=None):
logging.debug("found {} files".format(len(result))) logging.debug("found {} files".format(len(result)))
return result, dirs return result, dirs
def filter_files_extensions(files, extensions): def filter_files_extensions(files: Collection[str], extensions: Collection[str]) -> list[str]:
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
def get_full_path(folder_name, filename): def get_full_path(folder_name: str, filename: str) -> str | None:
global folder_names_and_paths global folder_names_and_paths
if folder_name not in folder_names_and_paths: if folder_name not in folder_names_and_paths:
return None return None
@@ -189,7 +193,7 @@ def get_full_path(folder_name, filename):
return None return None
def get_filename_list_(folder_name): def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@@ -199,9 +203,9 @@ def get_filename_list_(folder_name):
output_list.update(filter_files_extensions(files, folders[1])) output_list.update(filter_files_extensions(files, folders[1]))
output_folders = {**output_folders, **folders_all} output_folders = {**output_folders, **folders_all}
return (sorted(list(output_list)), output_folders, time.perf_counter()) return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name): def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global filename_list_cache global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
if folder_name not in filename_list_cache: if folder_name not in filename_list_cache:
@@ -222,7 +226,7 @@ def cached_filename_list_(folder_name):
return out return out
def get_filename_list(folder_name): def get_filename_list(folder_name: str) -> list[str]:
out = cached_filename_list_(folder_name) out = cached_filename_list_(folder_name)
if out is None: if out is None:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)
@@ -230,17 +234,17 @@ def get_filename_list(folder_name):
filename_list_cache[folder_name] = out filename_list_cache[folder_name] = out
return list(out[0]) return list(out[0])
def get_save_image_path(filename_prefix, output_dir, image_width=0, image_height=0): def get_save_image_path(filename_prefix: str, output_dir: str, image_width=0, image_height=0) -> tuple[str, str, int, str, str]:
def map_filename(filename): def map_filename(filename: str) -> tuple[int, str]:
prefix_len = len(os.path.basename(filename_prefix)) prefix_len = len(os.path.basename(filename_prefix))
prefix = filename[:prefix_len + 1] prefix = filename[:prefix_len + 1]
try: try:
digits = int(filename[prefix_len + 1:].split('_')[0]) digits = int(filename[prefix_len + 1:].split('_')[0])
except: except:
digits = 0 digits = 0
return (digits, prefix) return digits, prefix
def compute_vars(input, image_width, image_height): def compute_vars(input: str, image_width: int, image_height: int) -> str:
input = input.replace("%width%", str(image_width)) input = input.replace("%width%", str(image_width))
input = input.replace("%height%", str(image_height)) input = input.replace("%height%", str(image_height))
return input return input