Compare commits

..

1 Commits

Author SHA1 Message Date
Yoland Y
1d47ec38d8 Set torch version to be 2.3.1 for v0.0.3 2024-07-26 18:54:29 -07:00
70 changed files with 231 additions and 49551 deletions

View File

@@ -62,15 +62,8 @@ except:
print("checking out master branch")
branch = repo.lookup_branch('master')
if branch is None:
ref = repo.lookup_reference('refs/remotes/origin/master')
repo.checkout(ref)
branch = repo.lookup_branch('master')
if branch is None:
repo.create_branch('master', repo.get(ref.target))
else:
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes")
pull(repo)

View File

@@ -2,28 +2,9 @@
name: "Release Stable Version"
on:
workflow_dispatch:
inputs:
git_tag:
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
required: true
type: string
default: "121"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "11"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "9"
push:
tags:
- 'v*'
jobs:
package_comfy_windows:
@@ -32,44 +13,69 @@ jobs:
packages: "write"
pull-requests: "read"
runs-on: windows-latest
strategy:
matrix:
python_version: [3.11.8]
cuda_version: [121]
steps:
- name: Calculate Minor Version
shell: bash
run: |
# Extract the minor version from the Python version
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- uses: actions/checkout@v4
with:
ref: ${{ inputs.git_tag }}
fetch-depth: 0
persist-credentials: false
- uses: actions/cache/restore@v4
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
- shell: bash
run: |
mv cu${{ inputs.cu }}_python_deps.tar ../
echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
mv cu${{ matrix.cuda_version }}_python_deps ../
mv update_comfyui_and_python_dependencies.bat ../
cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
pwd
ls
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip
curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded
cd python_embeded
echo ${{ env.MINOR_VERSION }}
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
./python.exe --version
echo "Pip version:"
./python.exe -m pip --version
git clone --depth 1 https://github.com/comfyanonymous/taesd
set PATH=$PWD/Scripts:$PATH
echo $PATH
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable
@@ -98,7 +104,6 @@ jobs:
with:
repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }}
tag: ${{ github.ref }}
overwrite: true
prerelease: true
make_latest: false

View File

@@ -32,7 +32,7 @@ jobs:
node-version: lts/*
- uses: actions/setup-python@v4
with:
python-version: '3.8'
python-version: '3.10'
- name: Install requirements
run: |
python -m pip install --upgrade pip

View File

@@ -8,16 +8,11 @@ on:
required: false
type: string
default: ""
extra_dependencies:
description: 'extra dependencies'
required: false
type: string
default: "\"numpy<2\""
cu:
description: 'cuda version'
required: true
type: string
default: "124"
default: "121"
python_minor:
description: 'python minor version'
@@ -29,7 +24,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "8"
# push:
# branches:
# - master
@@ -56,7 +51,7 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir

View File

@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "4"
default: "3"
# push:
# branches:
# - master
@@ -49,13 +49,13 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable_nightly_pytorch

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "124"
default: "121"
python_minor:
description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "9"
default: "8"
# push:
# branches:
# - master
@@ -66,7 +66,7 @@ jobs:
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable

View File

@@ -12,7 +12,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -34,7 +33,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast.
- Works fully offline: will never download anything.
@@ -79,7 +77,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z)
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
@@ -165,6 +163,20 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
```source path_to_other_sd_gui/venv/bin/activate```
or on Windows:
With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
# Running
```python main.py```

View File

@@ -5,7 +5,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1280,
"initializer_factor": 1.0,

View File

@@ -1,6 +1,5 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
@@ -72,13 +71,13 @@ class CLIPEncoder(torch.nn.Module):
return x, intermediate
class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__()
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, dtype=torch.float32):
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device)
def forward(self, input_tokens):
return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module):
@@ -88,15 +87,14 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
x = self.embeddings(input_tokens, dtype=dtype)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -113,7 +111,7 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
@@ -155,11 +153,11 @@ class CLIPVisionEmbeddings(torch.nn.Module):
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module):
@@ -171,7 +169,7 @@ class CLIPVision(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)

View File

@@ -139,32 +139,3 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat):
latent_channels = 64
class Flux(SD3):
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
]
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor

View File

@@ -9,7 +9,6 @@ from einops import rearrange
from torch import nn
from torch.nn import functional as F
import math
import comfy.ops
class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
@@ -19,7 +18,7 @@ class FourierFeatures(nn.Module):
[out_features // 2, in_features], dtype=dtype, device=device))
def forward(self, input):
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input)
f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
return torch.cat([f.cos(), f.sin()], dim=-1)
# norms
@@ -39,9 +38,9 @@ class LayerNorm(nn.Module):
def forward(self, x):
beta = self.beta
if beta is not None:
beta = comfy.ops.cast_to_input(beta, x)
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta)
if self.beta is not None:
beta = beta.to(dtype=x.dtype, device=x.device)
return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
class GLU(nn.Module):
def __init__(
@@ -124,9 +123,7 @@ class RotaryEmbedding(nn.Module):
scale_base = 512,
interpolation_factor = 1.,
base = 10000,
base_rescale_factor = 1.,
dtype=None,
device=None,
base_rescale_factor = 1.
):
super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
@@ -134,8 +131,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2))
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype))
inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor
@@ -164,14 +161,14 @@ class RotaryEmbedding(nn.Module):
t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t))
freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None:
return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1)
return freqs, scale
@@ -571,7 +568,7 @@ class ContinuousTransformer(nn.Module):
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype)
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else:
self.rotary_pos_emb = None

View File

@@ -8,8 +8,6 @@ import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.ldm.common_dit
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -408,7 +406,10 @@ class MMDiT(nn.Module):
def patchify(self, x):
B, C, H, W = x.size()
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = x.view(
B,
C,
@@ -426,7 +427,7 @@ class MMDiT(nn.Module):
max_dim = max(h, w)
cur_dim = self.h_max
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x)
pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
@@ -454,7 +455,7 @@ class MMDiT(nn.Module):
t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1)
c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D

View File

@@ -19,7 +19,14 @@
import torch
import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
@@ -71,13 +78,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device))
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x
return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
class ResBlock(nn.Module):

View File

@@ -1,8 +0,0 @@
import torch
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)

View File

@@ -1,256 +0,0 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@@ -1,35 +0,0 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -1,142 +0,0 @@
#Original code can be found on: https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -1,219 +0,0 @@
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional
from comfy.ldm.modules.attention import optimized_attention
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
class CrossAttention(nn.Module):
"""
Use QK Normalization.
"""
def __init__(self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
attn_precision=None,
device=None,
dtype=None,
operations=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.attn_precision = attn_precision
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, y, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s1, c = x.shape # [b, s1, D]
_, s2, c = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q)
k = self.k_norm(k)
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
q = qq
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
self.head_dim = self.dim // num_heads
# This assertion is aligned with flash attention
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
# qkv --> Wqkv
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, freqs_cis_img=None):
B, N, C = x.shape
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
q, k, v = qkv.unbind(0) # [b, h, s, d]
q = self.q_norm(q) # [b, h, s, d]
k = self.k_norm(k) # [b, h, s, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
assert qq.shape == q.shape and kk.shape == k.shape, \
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
x = self.out_proj(x)
x = self.proj_drop(x)
out_tuple = (x,)
return out_tuple

View File

@@ -1,405 +0,0 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
from .attn_layers import Attention, CrossAttention
from .poolers import AttentionPool
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class HunYuanDiTBlock(nn.Module):
"""
A HunYuanDiT block with `add` conditioning.
"""
def __init__(self,
hidden_size,
c_emb_size,
num_heads,
mlp_ratio=4.0,
text_states_dim=1024,
qk_norm=False,
norm_type="layer",
skip=False,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
use_ele_affine = True
if norm_type == "layer":
norm_layer = operations.LayerNorm
elif norm_type == "rms":
norm_layer = RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
# ========================= Add =========================
# Simply use add like SDXL.
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
)
# ========================= Cross-Attention =========================
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
# ========================= Skip Connection =========================
if skip:
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
else:
self.skip_linear = None
self.gradient_checkpointing = False
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training:
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
return self._forward(x, c, text_states, freq_cis_img, skip)
class FinalLayer(nn.Module):
"""
The final layer of HunYuanDiT.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class HunYuanDiT(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
#@register_to_config
def __init__(self,
input_size: tuple = 32,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1152,
depth: int = 28,
num_heads: int = 16,
mlp_ratio: float = 4.0,
text_states_dim = 1024,
text_states_dim_t5 = 2048,
text_len = 77,
text_len_t5 = 256,
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
size_cond = False,
use_style_cond = False,
learn_sigma = True,
norm = "layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
#import pdb
#pdb.set_trace()
self.mlp_t5 = nn.Sequential(
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
self.extra_embedder = nn.Sequential(
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for layer in range(depth)
])
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.unpatchify_channels = self.out_channels
def forward(self,
x,
t,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
control=None,
transformer_options=None,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
_, _, oh, ow = x.shape
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=x.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
if self.size_cond:
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if self.use_style_cond:
if style is None:
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
if controls is not None and len(controls) != 0:
raise ValueError("The number of controls is not equal to the number of skip connections.")
# ========================= Final layer =========================
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
if return_dict:
return {'x': x}
if self.learn_sigma:
return x[:,:self.out_channels // 2,:oh,:ow]
return x[:,:,:oh,:ow]
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
p = self.x_embedder.patch_size[0]
# h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -1,37 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
self.num_heads = num_heads
self.embed_dim = embed_dim
def forward(self, x):
x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
q = self.q_proj(x[:1])
k = self.k_proj(x)
v = self.v_proj(x)
batch_size = q.shape[1]
head_dim = self.embed_dim // self.num_heads
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
attn_output = self.c_proj(attn_output)
return attn_output.squeeze(0)

View File

@@ -1,224 +0,0 @@
import torch
import numpy as np
from typing import Union
def _to_tuple(x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(src, tgt):
th, tw = _to_tuple(tgt)
h, w = _to_tuple(src)
tr = th / tw # base resolution
r = h / w # target resolution
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(start, *args):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = _to_tuple(args[1])
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = get_meshgrid(start, *args) # [2, H, w]
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (W,H)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
Parameters
----------
embed_dim: int
embedding dimension size
start: int or tuple of int
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
use_real: bool
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns
-------
pos_embed: torch.Tensor
[HW, D/2]
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_sizes(rope_img, patch_size, th, tw):
if rope_img == 'extend':
# Expansion mode
sub_args = [(th, tw)]
elif rope_img.startswith('base'):
# Based on the specified dimensions, other dimensions are obtained through interpolation.
base_size = int(rope_img[4:]) // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
else:
raise ValueError(f"Unknown rope_img: {rope_img}")
return sub_args
def init_image_posemb(rope_img,
resolutions,
patch_size,
hidden_size,
num_heads,
log_fn,
rope_real=True,
):
freqs_cis_img = {}
for reso in resolutions:
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
sub_args = calc_sizes(rope_img, patch_size, th, tw)
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
return freqs_cis_img

View File

@@ -8,8 +8,6 @@ import torch.nn as nn
from .. import attention
from einops import rearrange, repeat
from .util import timestep_embedding
import comfy.ops
import comfy.ldm.common_dit
def default(x, y):
if x is not None:
@@ -71,14 +69,12 @@ class PatchEmbed(nn.Module):
bias: bool = True,
strict_img_size: bool = True,
dynamic_img_pad: bool = True,
padding_mode='circular',
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode
if img_size is not None:
self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
@@ -112,7 +108,9 @@ class PatchEmbed(nn.Module):
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# )
if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
@@ -926,7 +924,7 @@ class MMDiT(nn.Module):
context = self.context_processor(context)
hw = x.shape[-2:]
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x)
x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D)

View File

@@ -809,7 +809,7 @@ class UNetModel(nn.Module):
self.out = nn.Sequential(
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(),
operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device),
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(

View File

@@ -282,18 +282,4 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
return key_map

View File

@@ -7,11 +7,8 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.aura.mmdit
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.model_management
import comfy.conds
import comfy.ops
@@ -28,7 +25,6 @@ class ModelType(Enum):
EDM = 5
FLOW = 6
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@@ -56,9 +52,6 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION
s = ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
class ModelSampling(s, c):
pass
@@ -74,7 +67,6 @@ class BaseModel(torch.nn.Module):
self.latent_format = model_config.latent_format
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None:
@@ -85,7 +77,6 @@ class BaseModel(torch.nn.Module):
if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type)
@@ -96,7 +87,6 @@ class BaseModel(torch.nn.Module):
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
@@ -255,11 +245,11 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
@@ -357,7 +347,6 @@ class SDXL(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
@@ -598,6 +587,17 @@ class SD3(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@@ -648,50 +648,3 @@ class StableAudio1(BaseModel):
for l in s:
sd["{}{}".format(k, l)] = s[l]
return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
if conditioning_mt5xl is not None:
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
if attention_mask_mt5xl is not None:
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out

View File

@@ -115,36 +115,6 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["n_layers"] = double_layers + single_layers
return unet_config
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
return unet_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = 19
dit_config["depth_single_blocks"] = 38
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -274,7 +274,7 @@ class LoadedModel:
return self.model.model_size()
def model_memory_required(self, device):
if device == self.model.current_loaded_device():
if device == self.model.current_device:
return 0
else:
return self.model_memory()
@@ -318,7 +318,7 @@ class LoadedModel:
return self.model is other.model
def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = []
@@ -352,7 +352,6 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = []
can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i]
@@ -370,7 +369,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
current_loaded_models.pop(i)
if len(unloaded_model) > 0:
soft_empty_cache()
@@ -379,17 +378,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
models = set(models)
@@ -423,13 +417,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for d in devs:
if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded)
free_mem = get_free_memory(d)
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
if len(models_to_load) == 0:
return
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
@@ -458,8 +446,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
@@ -535,9 +523,6 @@ def unet_inital_load_device(parameters, dtype):
else:
return cpu_dev
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet:
return torch.bfloat16
@@ -547,21 +532,6 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
except:
pass
if fp8_dtype is not None:
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
@@ -679,29 +649,18 @@ def supports_cast(device, dtype): #TODO
return True
if dtype == torch.float16:
return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this
return False
if dtype == torch.bfloat16:
return True
if is_device_mps(device):
return False
if dtype == torch.float8_e4m3fn:
return True
if dtype == torch.float8_e5m2:
return True
return False
def pick_weight_dtype(dtype, fallback_dtype, device=None):
if dtype is None:
dtype = fallback_dtype
elif dtype_size(dtype) > dtype_size(fallback_dtype):
dtype = fallback_dtype
if not supports_cast(device, dtype):
dtype = fallback_dtype
return dtype
def device_supports_non_blocking(device):
if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking
@@ -897,7 +856,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
fp16_works = True
if fp16_works or manual_cast:
free_model_memory = maximum_vram_for_weights(device)
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -917,9 +876,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False
if device is not None:
if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device):
return True
return False
if FORCE_FP32:
return False
@@ -927,23 +886,23 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if directml_enabled:
return False
if mps_mode():
return True
if cpu_mode():
if cpu_mode() or mps_mode():
return False
if is_intel_xpu():
return True
props = torch.cuda.get_device_properties("cuda")
if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast:
free_model_memory = maximum_vram_for_weights(device)
free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True

View File

@@ -64,15 +64,9 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
return model_options
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size
self.model = model
if not hasattr(self.model, 'device'):
logging.info("Model doesn't have a device attribute.")
self.model.device = offload_device
elif self.model.device is None:
self.model.device = offload_device
self.patches = {}
self.backup = {}
self.object_patches = {}
@@ -81,6 +75,11 @@ class ModelPatcher:
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False
self.lowvram_patch_counter = 0
@@ -93,7 +92,7 @@ class ModelPatcher:
return self.size
def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {}
for k in self.patches:
n.patches[k] = self.patches[k][:]
@@ -303,7 +302,7 @@ class ModelPatcher:
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.current_device = device_to
return self.model
@@ -356,7 +355,6 @@ class ModelPatcher:
self.model_lowvram = True
self.lowvram_patch_counter = patch_counter
self.model.device = device_to
return self.model
def calculate_weight(self, patches, weight, key):
@@ -553,13 +551,10 @@ class ModelPatcher:
if device_to is not None:
self.model.to(device_to)
self.model.device = device_to
self.current_device = device_to
keys = list(self.object_patches_backup.keys())
for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear()
def current_loaded_device(self):
return self.model.device

View File

@@ -272,43 +272,3 @@ class StableCascadeSampling(ModelSamplingDiscrete):
percent = 1.0 - percent
return self.sigma(torch.tensor(percent))
def flux_time_shift(mu: float, sigma: float, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma
def sigma(self, timestep):
return flux_time_shift(self.shift, 1.0, timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent

View File

@@ -19,27 +19,14 @@
import torch
import comfy.model_management
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_bias_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
def cast_bias_weight(s, input):
bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(device)
non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None:
bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None:
weight = s.weight_function(weight)
return weight, bias
@@ -181,26 +168,6 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
kwargs.pop("out_dtype")
return super().forward(*args, **kwargs)
@classmethod
def conv_nd(s, dims, *args, **kwargs):
if dims == 2:
@@ -235,6 +202,3 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True

View File

@@ -61,9 +61,7 @@ def prepare_sampling(model, noise_shape, conds):
device = model.load_device
real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
real_model = model.model
return real_model, conds, models

View File

@@ -171,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount
break

View File

@@ -17,13 +17,11 @@ from . import diffusers_convert
from . import model_detection
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.model_patcher
import comfy.lora
@@ -387,8 +385,6 @@ class CLIPType(Enum):
STABLE_CASCADE = 2
SD3 = 3
STABLE_AUDIO = 4
HUNYUAN_DIT = 5
FLUX = 6
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
@@ -416,8 +412,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype
@@ -437,18 +433,6 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -510,18 +494,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -564,7 +543,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over))
if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher)
@@ -572,7 +551,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
def load_unet_state_dict(sd): #load unet in diffusers or regular format
#Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@@ -581,6 +560,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
sd = temp_sd
parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "")
@@ -607,11 +587,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device()
if dtype is None:
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
else:
unet_dtype = dtype
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "")
@@ -622,9 +598,9 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None):
def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype)
model = load_unet_state_dict(sd)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))

View File

@@ -94,8 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
self.operations = comfy.ops.manual_cast
self.transformer = model_class(config, dtype, device, self.operations)
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers
self.max_length = max_length
@@ -141,13 +140,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0]
next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = []
for x in tokens:
tokens_temp = []
for y in x:
if isinstance(y, numbers.Integral):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [int(y)]
else:
if y.shape[0] == current_embeds.weight.shape[1]:
@@ -162,11 +163,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
n = token_dict_size
if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight
new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights:
new_embedding.weight[n] = x
n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding)
processed_tokens = []
@@ -195,7 +197,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks:
attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last":

View File

@@ -6,7 +6,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"eos_token_id": 2,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,

View File

@@ -5,7 +5,7 @@
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_size": 1024,
"initializer_factor": 1.0,

View File

@@ -3,13 +3,11 @@ from . import model_base
from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
from . import supported_models_base
from . import latent_formats
@@ -31,7 +29,6 @@ class SD15(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
@@ -78,7 +75,6 @@ class SD20(supported_models_base.BASE):
}
latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
@@ -104,7 +100,7 @@ class SD20(supported_models_base.BASE):
return state_dict
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel)
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20):
unet_config = {
@@ -142,7 +138,6 @@ class SDXLRefiner(supported_models_base.BASE):
}
latent_format = latent_formats.SDXL
memory_usage_factor = 1.0
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
@@ -181,8 +176,6 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
@@ -510,9 +503,6 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.SD3
memory_usage_factor = 1.2
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
@@ -590,90 +580,6 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class HunyuanDiT(supported_models_base.BASE):
unet_config = {
"image_model": "hydit",
}
unet_extra_config = {
"attn_precision": torch.float32,
}
sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.018,
}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanDiT(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
class HunyuanDiT1(HunyuanDiT):
unet_config = {
"image_model": "hydit1",
}
unet_extra_config = {}
sampling_settings = {
"linear_start" : 0.00085,
"linear_end" : 0.03,
}
class Flux(supported_models_base.BASE):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
}
sampling_settings = {
}
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": False,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
models += [SVD_img2vid]

View File

@@ -27,8 +27,6 @@ class BASE:
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
memory_usage_factor = 2.0
manual_cast_dtype = None
@classmethod

View File

@@ -1,140 +0,0 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class BertAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)
out = optimized_attention(q, k, v, self.heads, mask)
return out
class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
# self.dropout = nn.Dropout(0.0)
def forward(self, x, y):
x = self.dense(x)
# hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y)
return x
class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
y = self.self(x, mask, optimized_attention)
return self.output(y, x)
class BertIntermediate(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
def forward(self, x):
x = self.dense(x)
return torch.nn.functional.gelu(x)
class BertBlock(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
x = self.attention(x, mask, optimized_attention)
y = self.intermediate(x)
return self.output(y, x)
class BertEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class BertEmbeddings(torch.nn.Module):
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None, dtype=None):
x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
else:
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
x = self.LayerNorm(x)
return x
class BertModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
embed_dim = config_dict["hidden_size"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
x, i = self.encoder(x, mask, intermediate_output)
return x, i
class BertModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.bert = BertModel_(config_dict, dtype, device, operations)
self.num_layers = config_dict["num_hidden_layers"]
def get_input_embeddings(self):
return self.bert.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.bert.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.bert(*args, **kwargs)

View File

@@ -1,71 +0,0 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
return FluxClipModel_

View File

@@ -1,79 +0,0 @@
from comfy import sd1_clip
from transformers import BertTokenizer
from .spiece_tokenizer import SPieceTokenizer
from .bert import BertModel
import comfy.text_encoders.t5
import os
import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.hydit_clip.untokenize(token_weight_pair)
def state_dict(self):
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def encode_token_weights(self, token_weight_pairs):
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
def load_sd(self, sd):
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
return self.hydit_clip.load_sd(sd)
else:
return self.mt5xl.load_sd(sd)
def set_clip_options(self, options):
self.hydit_clip.set_clip_options(options)
self.mt5xl.set_clip_options(options)
def reset_clip_options(self):
self.hydit_clip.reset_clip_options()
self.mt5xl.reset_clip_options()

View File

@@ -1,35 +0,0 @@
{
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"directionality": "bidi",
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"output_past": true,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"torch_dtype": "float32",
"transformers_version": "4.22.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 47020
}

View File

@@ -1,7 +0,0 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

View File

@@ -1,16 +0,0 @@
{
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"mask_token": "[MASK]",
"name_or_path": "hfl/chinese-roberta-wwm-ext",
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]",
"model_max_length": 77
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +0,0 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "mt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 250112
}

View File

@@ -54,7 +54,14 @@ class SD3ClipModel(torch.nn.Module):
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5)
else:
@@ -81,7 +88,7 @@ class SD3ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None
pooled = None
out = None
@@ -108,7 +115,7 @@ class SD3ClipModel(torch.nn.Module):
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2)
else:

View File

@@ -27,6 +27,3 @@ class SPieceTokenizer:
def __call__(self, string):
out = self.tokenizer.encode(string)
return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@@ -1,7 +1,6 @@
import torch
import math
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
@@ -12,7 +11,7 @@ class T5LayerNorm(torch.nn.Module):
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
return comfy.ops.cast_to_input(self.weight, x) * x
return self.weight.to(device=x.device, dtype=x.dtype) * x
activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
@@ -83,7 +82,7 @@ class T5Attention(torch.nn.Module):
if relative_attention_bias:
self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype)
self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
@staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@@ -133,7 +132,7 @@ class T5Attention(torch.nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets
def compute_bias(self, query_length, key_length, device, dtype):
def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
@@ -144,7 +143,7 @@ class T5Attention(torch.nn.Module):
num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance,
)
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads)
values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values
@@ -153,7 +152,7 @@ class T5Attention(torch.nn.Module):
k = self.k(x)
v = self.v(x)
if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype)
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None:
if mask is not None:
@@ -200,7 +199,7 @@ class T5Stack(torch.nn.Module):
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -226,7 +225,7 @@ class T5(torch.nn.Module):
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype)
self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
def get_input_embeddings(self):
return self.shared
@@ -235,7 +234,5 @@ class T5(torch.nn.Module):
self.shared = embeddings
def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
x = self.shared(input_ids)
return self.encoder(x, *args, **kwargs)

View File

@@ -11,7 +11,7 @@ import itertools
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type)
else:
if safe_load:
@@ -40,22 +40,9 @@ def calculate_parameters(sd, prefix=""):
params = 0
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
params += w.nelement()
params += sd[k].nelement()
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
if len(dtypes) == 0:
return None
return max(dtypes, key=dtypes.get)
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
@@ -415,59 +402,6 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
hidden_size = mmdit_config.get("hidden_size", 0)
key_map = {}
for index in range(n_double_layers):
prefix_from = "transformer_blocks.{}".format(index)
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
for index in range(n_single_layers):
prefix_from = "single_transformer_blocks.{}".format(index)
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.linear1.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
block_map = {#TODO
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size)

View File

@@ -7,7 +7,6 @@ import io
import json
import struct
import random
import hashlib
from comfy.cli_args import args
class EmptyLatentAudio:

View File

@@ -295,23 +295,6 @@ class SamplerDPMPP_SDE:
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerDPMPP_2S_Ancestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerEulerAncestral:
@classmethod
def INPUT_TYPES(s):
@@ -683,7 +666,6 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
@@ -700,4 +682,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
}
}

View File

@@ -1,47 +0,0 @@
import node_helpers
class CLIPTextEncodeFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning/flux"
def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]], )
class FluxGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
}

View File

@@ -1,24 +0,0 @@
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
}

View File

@@ -2,7 +2,6 @@ import folder_paths
import comfy.sd
import comfy.model_sampling
import comfy.latent_formats
import nodes
import torch
class LCM(comfy.model_sampling.EPS):
@@ -171,42 +170,6 @@ class ModelSamplingAuraFlow(ModelSamplingSD3):
def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0)
class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM:
@classmethod
def INPUT_TYPES(s):
@@ -321,6 +284,5 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG,
}

View File

@@ -264,7 +264,6 @@ class CLIPSave:
metadata = {}
if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info
if extra_pnginfo is not None:
for x in extra_pnginfo:

View File

@@ -75,36 +75,9 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
}

View File

@@ -12,7 +12,7 @@ class PerturbedAttentionGuidance:
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
}
}

View File

@@ -96,7 +96,7 @@ class SelfAttentionGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}}
RETURN_TYPES = ("MODEL",)

View File

@@ -27,8 +27,8 @@ class EmptySD3LatentImage:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"

View File

@@ -188,11 +188,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
"current_inputs": input_data_formatted,
"current_outputs": output_data_formatted
}
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (False, error_details, ex)
executed.add(unique_id)

View File

@@ -3,7 +3,7 @@ import time
import logging
from typing import Set, List, Dict, Tuple
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft'])
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
SupportedFileExtensionsType = Set[str]
ScanPathType = List[str]

View File

@@ -25,7 +25,7 @@ def pillow(fn, arg):
finally:
if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x
return x
def hasher():
hashfuncs = {

View File

@@ -818,22 +818,15 @@ class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype):
dtype = None
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=dtype)
model = comfy.sd.load_unet(unet_path)
return (model,)
class CLIPLoader:
@@ -866,7 +859,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ),
"type": (["sdxl", "sd3", "flux"], ),
"type": (["sdxl", "sd3"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -880,8 +873,6 @@ class DualCLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)
@@ -1852,7 +1843,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model",
# Conditioning
"CLIPVisionEncode": "CLIP Vision Encode",
"StyleModelApply": "Apply Style Model",
@@ -2047,8 +2037,6 @@ def init_builtin_extra_nodes():
"nodes_sd3.py",
"nodes_gits.py",
"nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_flux.py",
]
import_failed = []

View File

@@ -1,7 +1,7 @@
torch
torch==2.3.1
torchsde
torchvision
torchaudio
torchvision==0.18.1
torchaudio==2.3.1
einops
transformers>=4.28.1
tokenizers>=0.13.3
@@ -13,6 +13,7 @@ Pillow
scipy
tqdm
psutil
numpy<2.0.0
#non essential dependencies:
kornia>=0.7.1

View File

@@ -127,11 +127,7 @@ class PromptServer():
@routes.get("/")
async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
return web.FileResponse(os.path.join(self.web_root, "index.html"))
@routes.get("/embeddings")
def get_embeddings(self):

View File

@@ -105,16 +105,15 @@ export class ChangeTracker {
window.addEventListener(
"keydown",
(e) => {
const activeEl = document.activeElement;
requestAnimationFrame(async () => {
let bindInputEl;
let activeEl;
// If we are auto queue in change mode then we do want to trigger on inputs
if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.["type"] === "textarea") {
// Ignore events on inputs, they have their native history
return;
}
bindInputEl = activeEl;
}
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
@@ -124,7 +123,7 @@ export class ChangeTracker {
if (await changeTracker().undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done
if (ChangeTracker.bindInput(bindInputEl)) return;
if (ChangeTracker.bindInput(activeEl)) return;
changeTracker().checkState();
});
},

View File

@@ -190,10 +190,9 @@ function parseVorbisComment(dataView) {
const comment = getString(dataView, offset, commentLength);
offset += commentLength;
const ind = comment.indexOf('=')
const key = comment.substring(0, ind);
const [key, value] = comment.split('=');
comments[key] = comment.substring(ind+1);
comments[key] = value;
}
return comments;

View File

@@ -330,7 +330,6 @@
.comfyui-workflows-open .active {
font-weight: bold;
color: var(--primary-fg);
}
.comfyui-workflows-favorites:empty {
@@ -418,10 +417,6 @@
padding: 2px 4px;
}
.comfyui-workflows-tree-file.active .comfyui-workflows-file-action {
color: var(--primary-fg);
}
.lg ~ .comfyui-workflows-popup .comfyui-workflows-tree-file:not(:hover) .comfyui-workflows-file-action {
opacity: 0;
}