Compare commits

..

1 Commits

Author SHA1 Message Date
Yoland Y
1d47ec38d8 Set torch version to be 2.3.1 for v0.0.3 2024-07-26 18:54:29 -07:00
70 changed files with 231 additions and 49551 deletions

View File

@@ -62,15 +62,8 @@ except:
print("checking out master branch") print("checking out master branch")
branch = repo.lookup_branch('master') branch = repo.lookup_branch('master')
if branch is None: ref = repo.lookup_reference(branch.name)
ref = repo.lookup_reference('refs/remotes/origin/master') repo.checkout(ref)
repo.checkout(ref)
branch = repo.lookup_branch('master')
if branch is None:
repo.create_branch('master', repo.get(ref.target))
else:
ref = repo.lookup_reference(branch.name)
repo.checkout(ref)
print("pulling latest changes") print("pulling latest changes")
pull(repo) pull(repo)

View File

@@ -2,28 +2,9 @@
name: "Release Stable Version" name: "Release Stable Version"
on: on:
workflow_dispatch: push:
inputs: tags:
git_tag: - 'v*'
description: 'Git tag'
required: true
type: string
cu:
description: 'CUDA version'
required: true
type: string
default: "121"
python_minor:
description: 'Python minor version'
required: true
type: string
default: "11"
python_patch:
description: 'Python patch version'
required: true
type: string
default: "9"
jobs: jobs:
package_comfy_windows: package_comfy_windows:
@@ -32,44 +13,69 @@ jobs:
packages: "write" packages: "write"
pull-requests: "read" pull-requests: "read"
runs-on: windows-latest runs-on: windows-latest
strategy:
matrix:
python_version: [3.11.8]
cuda_version: [121]
steps: steps:
- name: Calculate Minor Version
shell: bash
run: |
# Extract the minor version from the Python version
MINOR_VERSION=$(echo "${{ matrix.python_version }}" | cut -d'.' -f2)
echo "MINOR_VERSION=$MINOR_VERSION" >> $GITHUB_ENV
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: ${{ inputs.git_tag }}
fetch-depth: 0 fetch-depth: 0
persist-credentials: false persist-credentials: false
- uses: actions/cache/restore@v4
id: cache
with:
path: |
cu${{ inputs.cu }}_python_deps.tar
update_comfyui_and_python_dependencies.bat
key: ${{ runner.os }}-build-cu${{ inputs.cu }}-${{ inputs.python_minor }}
- shell: bash - shell: bash
run: | run: |
mv cu${{ inputs.cu }}_python_deps.tar ../ echo "@echo off
call update_comfyui.bat nopause
echo -
echo This will try to update pytorch and all python dependencies.
echo -
echo If you just want to update normally, close this and run update_comfyui.bat instead.
echo -
pause
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu${{ matrix.cuda_version }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic
ls -lah temp_wheel_dir
mv temp_wheel_dir cu${{ matrix.cuda_version }}_python_deps
mv cu${{ matrix.cuda_version }}_python_deps ../
mv update_comfyui_and_python_dependencies.bat ../ mv update_comfyui_and_python_dependencies.bat ../
cd .. cd ..
tar xf cu${{ inputs.cu }}_python_deps.tar
pwd pwd
ls ls
- shell: bash
run: |
cd ..
cp -r ComfyUI ComfyUI_copy cp -r ComfyUI ComfyUI_copy
curl https://www.python.org/ftp/python/3.${{ inputs.python_minor }}.${{ inputs.python_patch }}/python-3.${{ inputs.python_minor }}.${{ inputs.python_patch }}-embed-amd64.zip -o python_embeded.zip curl https://www.python.org/ftp/python/${{ matrix.python_version }}/python-${{ matrix.python_version }}-embed-amd64.zip -o python_embeded.zip
unzip python_embeded.zip -d python_embeded unzip python_embeded.zip -d python_embeded
cd python_embeded cd python_embeded
echo ${{ env.MINOR_VERSION }} echo ${{ env.MINOR_VERSION }}
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth echo 'import site' >> ./python3${{ env.MINOR_VERSION }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
./python.exe -s -m pip install ../cu${{ inputs.cu }}_python_deps/* ./python.exe --version
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth echo "Pip version:"
cd .. ./python.exe -m pip --version
git clone --depth 1 https://github.com/comfyanonymous/taesd set PATH=$PWD/Scripts:$PATH
echo $PATH
./python.exe -s -m pip install ../cu${{ matrix.cuda_version }}_python_deps/*
sed -i '1i../ComfyUI' ./python3${{ env.MINOR_VERSION }}._pth
cd ..
git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable mkdir ComfyUI_windows_portable
@@ -98,7 +104,6 @@ jobs:
with: with:
repo_token: ${{ secrets.GITHUB_TOKEN }} repo_token: ${{ secrets.GITHUB_TOKEN }}
file: ComfyUI_windows_portable_nvidia.7z file: ComfyUI_windows_portable_nvidia.7z
tag: ${{ inputs.git_tag }} tag: ${{ github.ref }}
overwrite: true overwrite: true
prerelease: true
make_latest: false

View File

@@ -32,7 +32,7 @@ jobs:
node-version: lts/* node-version: lts/*
- uses: actions/setup-python@v4 - uses: actions/setup-python@v4
with: with:
python-version: '3.8' python-version: '3.10'
- name: Install requirements - name: Install requirements
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip

View File

@@ -8,16 +8,11 @@ on:
required: false required: false
type: string type: string
default: "" default: ""
extra_dependencies:
description: 'extra dependencies'
required: false
type: string
default: "\"numpy<2\""
cu: cu:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "124" default: "121"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -29,7 +24,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "9" default: "8"
# push: # push:
# branches: # branches:
# - master # - master
@@ -56,7 +51,7 @@ jobs:
..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 ..\python_embeded\python.exe -s -m pip install --upgrade torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
pause" > update_comfyui_and_python_dependencies.bat pause" > update_comfyui_and_python_dependencies.bat
python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} ${{ inputs.extra_dependencies }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir python -m pip wheel --no-cache-dir torch torchvision torchaudio ${{ inputs.xformers }} --extra-index-url https://download.pytorch.org/whl/cu${{ inputs.cu }} -r requirements.txt pygit2 -w ./temp_wheel_dir
python -m pip install --no-cache-dir ./temp_wheel_dir/* python -m pip install --no-cache-dir ./temp_wheel_dir/*
echo installed basic echo installed basic
ls -lah temp_wheel_dir ls -lah temp_wheel_dir

View File

@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "4" default: "3"
# push: # push:
# branches: # branches:
# - master # - master
@@ -49,13 +49,13 @@ jobs:
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
./python.exe get-pip.py ./python.exe get-pip.py
python -m pip wheel torch torchvision torchaudio --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir python -m pip wheel torch torchvision torchaudio mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
ls ../temp_wheel_dir ls ../temp_wheel_dir
./python.exe -s -m pip install --pre ../temp_wheel_dir/* ./python.exe -s -m pip install --pre ../temp_wheel_dir/*
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd .. cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable_nightly_pytorch mkdir ComfyUI_windows_portable_nightly_pytorch

View File

@@ -7,7 +7,7 @@ on:
description: 'cuda version' description: 'cuda version'
required: true required: true
type: string type: string
default: "124" default: "121"
python_minor: python_minor:
description: 'python minor version' description: 'python minor version'
@@ -19,7 +19,7 @@ on:
description: 'python patch version' description: 'python patch version'
required: true required: true
type: string type: string
default: "9" default: "8"
# push: # push:
# branches: # branches:
# - master # - master
@@ -66,7 +66,7 @@ jobs:
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
cd .. cd ..
git clone --depth 1 https://github.com/comfyanonymous/taesd git clone https://github.com/comfyanonymous/taesd
cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/ cp taesd/*.pth ./ComfyUI_copy/models/vae_approx/
mkdir ComfyUI_windows_portable mkdir ComfyUI_windows_portable

View File

@@ -12,7 +12,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram. - Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -34,7 +33,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/) - [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/) - [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/) - [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- Latent previews with [TAESD](#how-to-show-high-quality-previews) - Latent previews with [TAESD](#how-to-show-high-quality-previews)
- Starts up very fast. - Starts up very fast.
- Works fully offline: will never download anything. - Works fully offline: will never download anything.
@@ -79,7 +77,7 @@ Ctrl can also be replaced with Cmd instead for macOS users
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases). There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_nvidia.7z) ### [Direct link to download](https://github.com/comfyanonymous/ComfyUI/releases/download/latest/ComfyUI_windows_portable_nvidia_cu121_or_cpu.7z)
Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you put your Stable Diffusion checkpoints/models (the huge ckpt/safetensors files) in: ComfyUI\models\checkpoints
@@ -165,6 +163,20 @@ You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS ve
```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml``` ```pip install torch-directml``` Then you can launch ComfyUI with: ```python main.py --directml```
### I already have another UI for Stable Diffusion installed do I really have to install all of these dependencies?
You don't. If you have another UI installed and working with its own python venv you can use that venv to run ComfyUI. You can open up your favorite terminal and activate it:
```source path_to_other_sd_gui/venv/bin/activate```
or on Windows:
With Powershell: ```"path_to_other_sd_gui\venv\Scripts\Activate.ps1"```
With cmd.exe: ```"path_to_other_sd_gui\venv\Scripts\activate.bat"```
And then you can use that terminal to run ComfyUI without installing any dependencies. Note that the venv folder might be called something else depending on the SD UI.
# Running # Running
```python main.py``` ```python main.py```

View File

@@ -5,7 +5,7 @@
"attention_dropout": 0.0, "attention_dropout": 0.0,
"bos_token_id": 0, "bos_token_id": 0,
"dropout": 0.0, "dropout": 0.0,
"eos_token_id": 49407, "eos_token_id": 2,
"hidden_act": "gelu", "hidden_act": "gelu",
"hidden_size": 1280, "hidden_size": 1280,
"initializer_factor": 1.0, "initializer_factor": 1.0,

View File

@@ -1,6 +1,5 @@
import torch import torch
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class CLIPAttention(torch.nn.Module): class CLIPAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations): def __init__(self, embed_dim, heads, dtype, device, operations):
@@ -72,13 +71,13 @@ class CLIPEncoder(torch.nn.Module):
return x, intermediate return x, intermediate
class CLIPEmbeddings(torch.nn.Module): class CLIPEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None, operations=None): def __init__(self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None):
super().__init__() super().__init__()
self.token_embedding = operations.Embedding(vocab_size, embed_dim, dtype=dtype, device=device) self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, dtype=torch.float32): def forward(self, input_tokens):
return self.token_embedding(input_tokens, out_dtype=dtype) + comfy.ops.cast_to(self.position_embedding.weight, dtype=dtype, device=input_tokens.device) return self.token_embedding(input_tokens) + self.position_embedding.weight
class CLIPTextModel_(torch.nn.Module): class CLIPTextModel_(torch.nn.Module):
@@ -88,15 +87,14 @@ class CLIPTextModel_(torch.nn.Module):
heads = config_dict["num_attention_heads"] heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__() super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device, operations=operations) self.embeddings = CLIPEmbeddings(embed_dim, dtype=torch.float32, device=device)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device) self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32): def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
x = self.embeddings(input_tokens, dtype=dtype) x = self.embeddings(input_tokens)
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -113,7 +111,7 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate: if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i) i = self.final_layer_norm(i)
pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),] pooled_output = x[torch.arange(x.shape[0], device=x.device), input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1),]
return x, i, pooled_output return x, i, pooled_output
class CLIPTextModel(torch.nn.Module): class CLIPTextModel(torch.nn.Module):
@@ -155,11 +153,11 @@ class CLIPVisionEmbeddings(torch.nn.Module):
num_patches = (image_size // patch_size) ** 2 num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1 num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device) self.position_embedding = torch.nn.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values): def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2) embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds) return torch.cat([self.class_embedding.to(embeds.device).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + self.position_embedding.weight.to(embeds.device)
class CLIPVision(torch.nn.Module): class CLIPVision(torch.nn.Module):
@@ -171,7 +169,7 @@ class CLIPVision(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"] intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"] intermediate_activation = config_dict["hidden_act"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations) self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=torch.float32, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim) self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations) self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim) self.post_layernorm = operations.LayerNorm(embed_dim)

View File

@@ -139,32 +139,3 @@ class SD3(LatentFormat):
class StableAudio1(LatentFormat): class StableAudio1(LatentFormat):
latent_channels = 64 latent_channels = 64
class Flux(SD3):
def __init__(self):
self.scale_factor = 0.3611
self.shift_factor = 0.1159
self.latent_rgb_factors =[
[-0.0404, 0.0159, 0.0609],
[ 0.0043, 0.0298, 0.0850],
[ 0.0328, -0.0749, -0.0503],
[-0.0245, 0.0085, 0.0549],
[ 0.0966, 0.0894, 0.0530],
[ 0.0035, 0.0399, 0.0123],
[ 0.0583, 0.1184, 0.1262],
[-0.0191, -0.0206, -0.0306],
[-0.0324, 0.0055, 0.1001],
[ 0.0955, 0.0659, -0.0545],
[-0.0504, 0.0231, -0.0013],
[ 0.0500, -0.0008, -0.0088],
[ 0.0982, 0.0941, 0.0976],
[-0.1233, -0.0280, -0.0897],
[-0.0005, -0.0530, -0.0020],
[-0.1273, -0.0932, -0.0680]
]
def process_in(self, latent):
return (latent - self.shift_factor) * self.scale_factor
def process_out(self, latent):
return (latent / self.scale_factor) + self.shift_factor

View File

@@ -9,7 +9,6 @@ from einops import rearrange
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import math import math
import comfy.ops
class FourierFeatures(nn.Module): class FourierFeatures(nn.Module):
def __init__(self, in_features, out_features, std=1., dtype=None, device=None): def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
@@ -19,7 +18,7 @@ class FourierFeatures(nn.Module):
[out_features // 2, in_features], dtype=dtype, device=device)) [out_features // 2, in_features], dtype=dtype, device=device))
def forward(self, input): def forward(self, input):
f = 2 * math.pi * input @ comfy.ops.cast_to_input(self.weight.T, input) f = 2 * math.pi * input @ self.weight.T.to(dtype=input.dtype, device=input.device)
return torch.cat([f.cos(), f.sin()], dim=-1) return torch.cat([f.cos(), f.sin()], dim=-1)
# norms # norms
@@ -39,9 +38,9 @@ class LayerNorm(nn.Module):
def forward(self, x): def forward(self, x):
beta = self.beta beta = self.beta
if beta is not None: if self.beta is not None:
beta = comfy.ops.cast_to_input(beta, x) beta = beta.to(dtype=x.dtype, device=x.device)
return F.layer_norm(x, x.shape[-1:], weight=comfy.ops.cast_to_input(self.gamma, x), bias=beta) return F.layer_norm(x, x.shape[-1:], weight=self.gamma.to(dtype=x.dtype, device=x.device), bias=beta)
class GLU(nn.Module): class GLU(nn.Module):
def __init__( def __init__(
@@ -124,9 +123,7 @@ class RotaryEmbedding(nn.Module):
scale_base = 512, scale_base = 512,
interpolation_factor = 1., interpolation_factor = 1.,
base = 10000, base = 10000,
base_rescale_factor = 1., base_rescale_factor = 1.
dtype=None,
device=None,
): ):
super().__init__() super().__init__()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
@@ -134,8 +131,8 @@ class RotaryEmbedding(nn.Module):
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
base *= base_rescale_factor ** (dim / (dim - 2)) base *= base_rescale_factor ** (dim / (dim - 2))
# inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', torch.empty((dim // 2,), device=device, dtype=dtype)) self.register_buffer('inv_freq', inv_freq)
assert interpolation_factor >= 1. assert interpolation_factor >= 1.
self.interpolation_factor = interpolation_factor self.interpolation_factor = interpolation_factor
@@ -164,14 +161,14 @@ class RotaryEmbedding(nn.Module):
t = t / self.interpolation_factor t = t / self.interpolation_factor
freqs = torch.einsum('i , j -> i j', t, comfy.ops.cast_to_input(self.inv_freq, t)) freqs = torch.einsum('i , j -> i j', t, self.inv_freq.to(dtype=dtype, device=device))
freqs = torch.cat((freqs, freqs), dim = -1) freqs = torch.cat((freqs, freqs), dim = -1)
if self.scale is None: if self.scale is None:
return freqs, 1. return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1') scale = self.scale.to(dtype=dtype, device=device) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1) scale = torch.cat((scale, scale), dim = -1)
return freqs, scale return freqs, scale
@@ -571,7 +568,7 @@ class ContinuousTransformer(nn.Module):
self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity() self.project_out = operations.Linear(dim, dim_out, bias=False, dtype=dtype, device=device) if dim_out is not None else nn.Identity()
if rotary_pos_emb: if rotary_pos_emb:
self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32), device=device, dtype=dtype) self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
else: else:
self.rotary_pos_emb = None self.rotary_pos_emb = None

View File

@@ -8,8 +8,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
import comfy.ldm.common_dit
def modulate(x, shift, scale): def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
@@ -408,7 +406,10 @@ class MMDiT(nn.Module):
def patchify(self, x): def patchify(self, x):
B, C, H, W = x.size() B, C, H, W = x.size()
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) pad_h = (self.patch_size - H % self.patch_size) % self.patch_size
pad_w = (self.patch_size - W % self.patch_size) % self.patch_size
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = x.view( x = x.view(
B, B,
C, C,
@@ -426,7 +427,7 @@ class MMDiT(nn.Module):
max_dim = max(h, w) max_dim = max(h, w)
cur_dim = self.h_max cur_dim = self.h_max
pos_encoding = comfy.ops.cast_to_input(self.positional_encoding.reshape(1, cur_dim, cur_dim, -1), x) pos_encoding = self.positional_encoding.reshape(1, cur_dim, cur_dim, -1).to(device=x.device, dtype=x.dtype)
if max_dim > cur_dim: if max_dim > cur_dim:
pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1) pos_encoding = F.interpolate(pos_encoding.movedim(-1, 1), (max_dim, max_dim), mode="bilinear").movedim(1, -1)
@@ -454,7 +455,7 @@ class MMDiT(nn.Module):
t = timestep t = timestep
c = self.cond_seq_linear(c_seq) # B, T_c, D c = self.cond_seq_linear(c_seq) # B, T_c, D
c = torch.cat([comfy.ops.cast_to_input(self.register_tokens, c).repeat(c.size(0), 1, 1), c], dim=1) c = torch.cat([self.register_tokens.to(device=c.device, dtype=c.dtype).repeat(c.size(0), 1, 1), c], dim=1)
global_cond = self.t_embedder(t, x.dtype) # B, D global_cond = self.t_embedder(t, x.dtype) # B, D

View File

@@ -19,7 +19,14 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class Linear(torch.nn.Linear):
def reset_parameters(self):
return None
class Conv2d(torch.nn.Conv2d):
def reset_parameters(self):
return None
class OptimizedAttention(nn.Module): class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
@@ -71,13 +78,13 @@ class GlobalResponseNorm(nn.Module):
"from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" "from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"
def __init__(self, dim, dtype=None, device=None): def __init__(self, dim, dtype=None, device=None):
super().__init__() super().__init__()
self.gamma = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device)) self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
self.beta = nn.Parameter(torch.empty(1, 1, 1, dim, dtype=dtype, device=device)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim, dtype=dtype, device=device))
def forward(self, x): def forward(self, x):
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return comfy.ops.cast_to_input(self.gamma, x) * (x * Nx) + comfy.ops.cast_to_input(self.beta, x) + x return self.gamma.to(device=x.device, dtype=x.dtype) * (x * Nx) + self.beta.to(device=x.device, dtype=x.dtype) + x
class ResBlock(nn.Module): class ResBlock(nn.Module):

View File

@@ -1,8 +0,0 @@
import torch
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
padding_mode = "reflect"
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)

View File

@@ -1,256 +0,0 @@
import math
from dataclasses import dataclass
import torch
from einops import rearrange
from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def forward(self, ids: Tensor) -> Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
t = time_factor * t
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
t.device
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
if torch.is_floating_point(t):
embedding = embedding.to(t)
return embedding
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.silu = nn.SiLU()
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
def forward(self, x: Tensor) -> Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
x_dtype = x.dtype
x = x.float()
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms).to(dtype=x_dtype) * comfy.ops.cast_to(self.scale, dtype=x_dtype, device=x.device)
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
k = self.key_norm(k)
return q.to(v), k.to(v)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
qkv = self.qkv(x)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
x = attention(q, k, v, pe=pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: Tensor
scale: Tensor
gate: Tensor
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
ModulationOut(*out[3:]) if self.is_double else None,
)
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
# run actual attention
q = torch.cat((txt_q, img_q), dim=2)
k = torch.cat((txt_k, img_k), dim=2)
v = torch.cat((txt_v, img_v), dim=2)
attn = attention(q, k, v, pe=pe)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
return img, txt
class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
q, k = self.norm(q, k, v)
# compute attention
attn = attention(q, k, v, pe=pe)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + mod.gate * output
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

View File

@@ -1,35 +0,0 @@
import torch
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True)
return x
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
device = torch.device("cpu")
else:
device = pos.device
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -1,142 +0,0 @@
#Original code can be found on: https://github.com/black-forest-labs/flux
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from einops import rearrange, repeat
import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
)
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = torch.cat((txt, img), 1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
def forward(self, x, timestep, context, y, guidance, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]

View File

@@ -1,219 +0,0 @@
import torch
import torch.nn as nn
from typing import Tuple, Union, Optional
from comfy.ldm.modules.attention import optimized_attention
def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}'
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
class CrossAttention(nn.Module):
"""
Use QK Normalization.
"""
def __init__(self,
qdim,
kdim,
num_heads,
qkv_bias=True,
qk_norm=False,
attn_drop=0.0,
proj_drop=0.0,
attn_precision=None,
device=None,
dtype=None,
operations=None,
):
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.attn_precision = attn_precision
self.qdim = qdim
self.kdim = kdim
self.num_heads = num_heads
assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads"
self.head_dim = self.qdim // num_heads
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
self.q_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.kv_proj = operations.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, y, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s1, c = x.shape # [b, s1, D]
_, s2, c = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q)
k = self.k_norm(k)
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
q = qq
q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C
k = k.transpose(-2, -3).contiguous() # k -> B, L2, H, C - B, H, C, L2
v = v.transpose(-2, -3).contiguous()
context = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
out = self.out_proj(context) # context.reshape - B, L1, -1
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
class Attention(nn.Module):
"""
We rename some layer names to align with flash attention
"""
def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim should be divisible by num_heads'
self.head_dim = self.dim // num_heads
# This assertion is aligned with flash attention
assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8"
self.scale = self.head_dim ** -0.5
# qkv --> Wqkv
self.Wqkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
# TODO: eps should be 1 / 65530 if using fp16
self.q_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.k_norm = operations.LayerNorm(self.head_dim, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = operations.Linear(dim, dim, dtype=dtype, device=device)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, freqs_cis_img=None):
B, N, C = x.shape
qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d]
q, k, v = qkv.unbind(0) # [b, h, s, d]
q = self.q_norm(q) # [b, h, s, d]
k = self.k_norm(k) # [b, h, s, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True)
assert qq.shape == q.shape and kk.shape == k.shape, \
f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk
x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True, attn_precision=self.attn_precision)
x = self.out_proj(x)
x = self.proj_drop(x)
out_tuple = (x,)
return out_tuple

View File

@@ -1,405 +0,0 @@
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
from .attn_layers import Attention, CrossAttention
from .poolers import AttentionPool
from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop
def calc_rope(x, patch_size, head_size):
th = (x.shape[2] + (patch_size // 2)) // patch_size
tw = (x.shape[3] + (patch_size // 2)) // patch_size
base_size = 512 // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
return rope
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class HunYuanDiTBlock(nn.Module):
"""
A HunYuanDiT block with `add` conditioning.
"""
def __init__(self,
hidden_size,
c_emb_size,
num_heads,
mlp_ratio=4.0,
text_states_dim=1024,
qk_norm=False,
norm_type="layer",
skip=False,
attn_precision=None,
dtype=None,
device=None,
operations=None,
):
super().__init__()
use_ele_affine = True
if norm_type == "layer":
norm_layer = operations.LayerNorm
elif norm_type == "rms":
norm_layer = RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# ========================= Self-Attention =========================
self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
# ========================= FFN =========================
self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6, dtype=dtype, device=device)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0, dtype=dtype, device=device, operations=operations)
# ========================= Add =========================
# Simply use add like SDXL.
self.default_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, hidden_size, bias=True, dtype=dtype, device=device)
)
# ========================= Cross-Attention =========================
self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True,
qk_norm=qk_norm, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
# ========================= Skip Connection =========================
if skip:
self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.skip_linear = operations.Linear(2 * hidden_size, hidden_size, dtype=dtype, device=device)
else:
self.skip_linear = None
self.gradient_checkpointing = False
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x
def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
if self.gradient_checkpointing and self.training:
return checkpoint.checkpoint(self._forward, x, c, text_states, freq_cis_img, skip)
return self._forward(x, c, text_states, freq_cis_img, skip)
class FinalLayer(nn.Module):
"""
The final layer of HunYuanDiT.
"""
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class HunYuanDiT(nn.Module):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Inherit PeftAdapterMixin to be compatible with the PEFT training pipeline.
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
input_size: tuple
The size of the input image.
patch_size: int
The size of the patch.
in_channels: int
The number of input channels.
hidden_size: int
The hidden size of the transformer backbone.
depth: int
The number of transformer blocks.
num_heads: int
The number of attention heads.
mlp_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
log_fn: callable
The logging function.
"""
#@register_to_config
def __init__(self,
input_size: tuple = 32,
patch_size: int = 2,
in_channels: int = 4,
hidden_size: int = 1152,
depth: int = 28,
num_heads: int = 16,
mlp_ratio: float = 4.0,
text_states_dim = 1024,
text_states_dim_t5 = 2048,
text_len = 77,
text_len_t5 = 256,
qk_norm = True,# See http://arxiv.org/abs/2302.05442 for details.
size_cond = False,
use_style_cond = False,
learn_sigma = True,
norm = "layer",
log_fn: callable = print,
attn_precision=None,
dtype=None,
device=None,
operations=None,
**kwargs,
):
super().__init__()
self.log_fn = log_fn
self.depth = depth
self.learn_sigma = learn_sigma
self.in_channels = in_channels
self.out_channels = in_channels * 2 if learn_sigma else in_channels
self.patch_size = patch_size
self.num_heads = num_heads
self.hidden_size = hidden_size
self.text_states_dim = text_states_dim
self.text_states_dim_t5 = text_states_dim_t5
self.text_len = text_len
self.text_len_t5 = text_len_t5
self.size_cond = size_cond
self.use_style_cond = use_style_cond
self.norm = norm
self.dtype = dtype
#import pdb
#pdb.set_trace()
self.mlp_t5 = nn.Sequential(
operations.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True, dtype=dtype, device=device),
)
# learnable replace
self.text_embedding_padding = nn.Parameter(
torch.empty(self.text_len + self.text_len_t5, self.text_states_dim, dtype=dtype, device=device))
# Attention pooling
pooler_out_dim = 1024
self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=pooler_out_dim, dtype=dtype, device=device, operations=operations)
# Dimension of the extra input vectors
self.extra_in_dim = pooler_out_dim
if self.size_cond:
# Image size and crop size conditions
self.extra_in_dim += 6 * 256
if self.use_style_cond:
# Here we use a default learned embedder layer for future extension.
self.style_embedder = operations.Embedding(1, hidden_size, dtype=dtype, device=device)
self.extra_in_dim += hidden_size
# Text embedding for `add`
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, dtype=dtype, device=device, operations=operations)
self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device, operations=operations)
self.extra_embedder = nn.Sequential(
operations.Linear(self.extra_in_dim, hidden_size * 4, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
)
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks
self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size,
c_emb_size=hidden_size,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
text_states_dim=self.text_states_dim,
qk_norm=qk_norm,
norm_type=self.norm,
skip=layer > depth // 2,
attn_precision=attn_precision,
dtype=dtype,
device=device,
operations=operations,
)
for layer in range(depth)
])
self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
self.unpatchify_channels = self.out_channels
def forward(self,
x,
t,
context,#encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
return_dict=False,
control=None,
transformer_options=None,
):
"""
Forward pass of the encoder.
Parameters
----------
x: torch.Tensor
(B, D, H, W)
t: torch.Tensor
(B)
encoder_hidden_states: torch.Tensor
CLIP text embedding, (B, L_clip, D)
text_embedding_mask: torch.Tensor
CLIP text embedding mask, (B, L_clip)
encoder_hidden_states_t5: torch.Tensor
T5 text embedding, (B, L_t5, D)
text_embedding_mask_t5: torch.Tensor
T5 text embedding mask, (B, L_t5)
image_meta_size: torch.Tensor
(B, 6)
style: torch.Tensor
(B)
cos_cis_img: torch.Tensor
sin_cis_img: torch.Tensor
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
text_states_mask = text_embedding_mask.bool() # 2,77
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
b_t5, l_t5, c_t5 = text_states_t5.shape
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
padding = comfy.ops.cast_to_input(self.text_embedding_padding, text_states)
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,2051024
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
_, _, oh, ow = x.shape
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
# Get image RoPE embedding according to `reso`lution.
freqs_cis_img = calc_rope(x, self.patch_size, self.hidden_size // self.num_heads) #(cos_cis_img, sin_cis_img)
# ========================= Build time and image embedding =========================
t = self.t_embedder(t, dtype=x.dtype)
x = self.x_embedder(x)
# ========================= Concatenate all extra vectors =========================
# Build text tokens with pooling
extra_vec = self.pooler(encoder_hidden_states_t5)
# Build image meta size tokens if applicable
if self.size_cond:
image_meta_size = timestep_embedding(image_meta_size.view(-1), 256).to(x.dtype) # [B * 6, 256]
image_meta_size = image_meta_size.view(-1, 6 * 256)
extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256]
# Build style tokens
if self.use_style_cond:
if style is None:
style = torch.zeros((extra_vec.shape[0],), device=x.device, dtype=torch.int)
style_embedding = self.style_embedder(style, out_dtype=x.dtype)
extra_vec = torch.cat([extra_vec, style_embedding], dim=1)
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
controls = None
# ========================= Forward pass through HunYuanDiT blocks =========================
skips = []
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
if controls is not None and len(controls) != 0:
raise ValueError("The number of controls is not equal to the number of skip connections.")
# ========================= Final layer =========================
x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels)
x = self.unpatchify(x, th, tw) # (N, out_channels, H, W)
if return_dict:
return {'x': x}
if self.learn_sigma:
return x[:,:self.out_channels // 2,:oh,:ow]
return x[:,:,:oh,:ow]
def unpatchify(self, x, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
p = self.x_embedder.patch_size[0]
# h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
return imgs

View File

@@ -1,37 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.ops
class AttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None, dtype=None, device=None, operations=None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.empty(spacial_dim + 1, embed_dim, dtype=dtype, device=device))
self.k_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.q_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.v_proj = operations.Linear(embed_dim, embed_dim, dtype=dtype, device=device)
self.c_proj = operations.Linear(embed_dim, output_dim or embed_dim, dtype=dtype, device=device)
self.num_heads = num_heads
self.embed_dim = embed_dim
def forward(self, x):
x = x[:,:self.positional_embedding.shape[0] - 1]
x = x.permute(1, 0, 2) # NLC -> LNC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC
x = x + comfy.ops.cast_to_input(self.positional_embedding[:, None, :], x) # (L+1)NC
q = self.q_proj(x[:1])
k = self.k_proj(x)
v = self.v_proj(x)
batch_size = q.shape[1]
head_dim = self.embed_dim // self.num_heads
q = q.view(1, batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
k = k.view(k.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
v = v.view(v.shape[0], batch_size * self.num_heads, head_dim).transpose(0, 1).view(batch_size, self.num_heads, -1, head_dim)
attn_output = optimized_attention(q, k, v, self.num_heads, skip_reshape=True).transpose(0, 1)
attn_output = self.c_proj(attn_output)
return attn_output.squeeze(0)

View File

@@ -1,224 +0,0 @@
import torch
import numpy as np
from typing import Union
def _to_tuple(x):
if isinstance(x, int):
return x, x
else:
return x
def get_fill_resize_and_crop(src, tgt):
th, tw = _to_tuple(tgt)
h, w = _to_tuple(src)
tr = th / tw # base resolution
r = h / w # target resolution
# resize
if r > tr:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h)) # resize the target resolution down based on the base resolution
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def get_meshgrid(start, *args):
if len(args) == 0:
# start is grid_size
num = _to_tuple(start)
start = (0, 0)
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = (stop[0] - start[0], stop[1] - start[1])
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start)
stop = _to_tuple(args[0])
num = _to_tuple(args[1])
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0) # [2, W, H]
return grid
#################################################################################
# Sine/Cosine Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid = get_meshgrid(start, *args) # [2, H, w]
# grid_h = np.arange(grid_size, dtype=np.float32)
# grid_w = np.arange(grid_size, dtype=np.float32)
# grid = np.meshgrid(grid_w, grid_h) # here w goes first
# grid = np.stack(grid, axis=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (W,H)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443
def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True):
"""
This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure.
Parameters
----------
embed_dim: int
embedding dimension size
start: int or tuple of int
If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1;
If len(args) == 2, start is start, args[0] is stop, args[1] is num.
use_real: bool
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Returns
-------
pos_embed: torch.Tensor
[HW, D/2]
"""
grid = get_meshgrid(start, *args) # [2, H, w]
grid = grid.reshape([2, 1, *grid.shape[1:]]) # Returns a sampling matrix with the same resolution as the target resolution
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
assert embed_dim % 4 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4)
emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4)
if use_real:
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2)
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2)
return cos, sin
else:
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
return emb
def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False):
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
if isinstance(pos, int):
pos = np.arange(pos)
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def calc_sizes(rope_img, patch_size, th, tw):
if rope_img == 'extend':
# Expansion mode
sub_args = [(th, tw)]
elif rope_img.startswith('base'):
# Based on the specified dimensions, other dimensions are obtained through interpolation.
base_size = int(rope_img[4:]) // 8 // patch_size
start, stop = get_fill_resize_and_crop((th, tw), base_size)
sub_args = [start, stop, (th, tw)]
else:
raise ValueError(f"Unknown rope_img: {rope_img}")
return sub_args
def init_image_posemb(rope_img,
resolutions,
patch_size,
hidden_size,
num_heads,
log_fn,
rope_real=True,
):
freqs_cis_img = {}
for reso in resolutions:
th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size
sub_args = calc_sizes(rope_img, patch_size, th, tw)
freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real)
log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) "
f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}")
return freqs_cis_img

View File

@@ -8,8 +8,6 @@ import torch.nn as nn
from .. import attention from .. import attention
from einops import rearrange, repeat from einops import rearrange, repeat
from .util import timestep_embedding from .util import timestep_embedding
import comfy.ops
import comfy.ldm.common_dit
def default(x, y): def default(x, y):
if x is not None: if x is not None:
@@ -71,14 +69,12 @@ class PatchEmbed(nn.Module):
bias: bool = True, bias: bool = True,
strict_img_size: bool = True, strict_img_size: bool = True,
dynamic_img_pad: bool = True, dynamic_img_pad: bool = True,
padding_mode='circular',
dtype=None, dtype=None,
device=None, device=None,
operations=None, operations=None,
): ):
super().__init__() super().__init__()
self.patch_size = (patch_size, patch_size) self.patch_size = (patch_size, patch_size)
self.padding_mode = padding_mode
if img_size is not None: if img_size is not None:
self.img_size = (img_size, img_size) self.img_size = (img_size, img_size)
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)]) self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
@@ -112,7 +108,9 @@ class PatchEmbed(nn.Module):
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})." # f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
# ) # )
if self.dynamic_img_pad: if self.dynamic_img_pad:
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode) pad_h = (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]
pad_w = (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]
x = torch.nn.functional.pad(x, (0, pad_w, 0, pad_h), mode='reflect')
x = self.proj(x) x = self.proj(x)
if self.flatten: if self.flatten:
x = x.flatten(2).transpose(1, 2) # NCHW -> NLC x = x.flatten(2).transpose(1, 2) # NCHW -> NLC
@@ -926,7 +924,7 @@ class MMDiT(nn.Module):
context = self.context_processor(context) context = self.context_processor(context)
hw = x.shape[-2:] hw = x.shape[-2:]
x = self.x_embedder(x) + comfy.ops.cast_to_input(self.cropped_pos_embed(hw, device=x.device), x) x = self.x_embedder(x) + self.cropped_pos_embed(hw, device=x.device).to(dtype=x.dtype, device=x.device)
c = self.t_embedder(t, dtype=x.dtype) # (N, D) c = self.t_embedder(t, dtype=x.dtype) # (N, D)
if y is not None and self.y_embedder is not None: if y is not None and self.y_embedder is not None:
y = self.y_embedder(y) # (N, D) y = self.y_embedder(y) # (N, D)

View File

@@ -809,7 +809,7 @@ class UNetModel(nn.Module):
self.out = nn.Sequential( self.out = nn.Sequential(
operations.GroupNorm(32, ch, dtype=self.dtype, device=device), operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
nn.SiLU(), nn.SiLU(),
operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device), zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
) )
if self.predict_codebook_ids: if self.predict_codebook_ids:
self.id_predictor = nn.Sequential( self.id_predictor = nn.Sequential(

View File

@@ -282,18 +282,4 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
key_map[key_lora] = to key_map[key_lora] = to
if isinstance(model, comfy.model_base.HunyuanDiT):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["base_model.model.{}".format(key_lora)] = k #official hunyuan lora format
if isinstance(model, comfy.model_base.Flux): #Diffusers lora Flux
diffusers_keys = comfy.utils.flux_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
return key_map return key_map

View File

@@ -7,11 +7,8 @@ from comfy.ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugme
from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
import comfy.ldm.aura.mmdit import comfy.ldm.aura.mmdit
import comfy.ldm.hydit.models
import comfy.ldm.audio.dit import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.model_management import comfy.model_management
import comfy.conds import comfy.conds
import comfy.ops import comfy.ops
@@ -28,7 +25,6 @@ class ModelType(Enum):
EDM = 5 EDM = 5
FLOW = 6 FLOW = 6
V_PREDICTION_CONTINUOUS = 7 V_PREDICTION_CONTINUOUS = 7
FLUX = 8
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@@ -56,9 +52,6 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.V_PREDICTION_CONTINUOUS: elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
c = V_PREDICTION c = V_PREDICTION
s = ModelSamplingContinuousV s = ModelSamplingContinuousV
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
class ModelSampling(s, c): class ModelSampling(s, c):
pass pass
@@ -74,7 +67,6 @@ class BaseModel(torch.nn.Module):
self.latent_format = model_config.latent_format self.latent_format = model_config.latent_format
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if self.manual_cast_dtype is not None: if self.manual_cast_dtype is not None:
@@ -85,7 +77,6 @@ class BaseModel(torch.nn.Module):
if comfy.model_management.force_channels_last(): if comfy.model_management.force_channels_last():
self.diffusion_model.to(memory_format=torch.channels_last) self.diffusion_model.to(memory_format=torch.channels_last)
logging.debug("using channels last mode for diffusion model") logging.debug("using channels last mode for diffusion model")
logging.info("model weight dtype {}, manual cast: {}".format(self.get_dtype(), self.manual_cast_dtype))
self.model_type = model_type self.model_type = model_type
self.model_sampling = model_sampling(model_config, model_type) self.model_sampling = model_sampling(model_config, model_type)
@@ -96,7 +87,6 @@ class BaseModel(torch.nn.Module):
self.concat_keys = () self.concat_keys = ()
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
@@ -255,11 +245,11 @@ class BaseModel(torch.nn.Module):
dtype = self.manual_cast_dtype dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:]) area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) return (area * comfy.model_management.dtype_size(dtype) / 50) * (1024 * 1024)
else: else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory. #TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:]) area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024) return (((area * 0.6) / 0.9) + 1024) * (1024 * 1024)
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None): def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
@@ -357,7 +347,6 @@ class SDXL(BaseModel):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1) flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1) return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
class SVD_img2vid(BaseModel): class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None): def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
@@ -598,6 +587,17 @@ class SD3(BaseModel):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out return out
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
class AuraFlow(BaseModel): class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@@ -648,50 +648,3 @@ class StableAudio1(BaseModel):
for l in s: for l in s:
sd["{}{}".format(k, l)] = s[l] sd["{}{}".format(k, l)] = s[l]
return sd return sd
class HunyuanDiT(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hydit.models.HunYuanDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['text_embedding_mask'] = comfy.conds.CONDRegular(attention_mask)
conditioning_mt5xl = kwargs.get("conditioning_mt5xl", None)
if conditioning_mt5xl is not None:
out['encoder_hidden_states_t5'] = comfy.conds.CONDRegular(conditioning_mt5xl)
attention_mask_mt5xl = kwargs.get("attention_mask_mt5xl", None)
if attention_mask_mt5xl is not None:
out['text_embedding_mask_t5'] = comfy.conds.CONDRegular(attention_mask_mt5xl)
width = kwargs.get("width", 768)
height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height)
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out

View File

@@ -115,36 +115,6 @@ def detect_unet_config(state_dict, key_prefix):
unet_config["n_layers"] = double_layers + single_layers unet_config["n_layers"] = double_layers + single_layers
return unet_config return unet_config
if '{}mlp_t5.0.weight'.format(key_prefix) in state_dict_keys: #Hunyuan DiT
unet_config = {}
unet_config["image_model"] = "hydit"
unet_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
unet_config["hidden_size"] = state_dict['{}x_embedder.proj.weight'.format(key_prefix)].shape[0]
if unet_config["hidden_size"] == 1408 and unet_config["depth"] == 40: #DiT-g/2
unet_config["mlp_ratio"] = 4.3637
if state_dict['{}extra_embedder.0.weight'.format(key_prefix)].shape[1] == 3968:
unet_config["size_cond"] = True
unet_config["use_style_cond"] = True
unet_config["image_model"] = "hydit1"
return unet_config
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
dit_config["num_heads"] = 24
dit_config["depth"] = 19
dit_config["depth_single_blocks"] = 38
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None return None

View File

@@ -274,7 +274,7 @@ class LoadedModel:
return self.model.model_size() return self.model.model_size()
def model_memory_required(self, device): def model_memory_required(self, device):
if device == self.model.current_loaded_device(): if device == self.model.current_device:
return 0 return 0
else: else:
return self.model_memory() return self.model_memory()
@@ -318,7 +318,7 @@ class LoadedModel:
return self.model is other.model return self.model is other.model
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2 return (1024 * 1024 * 1024)
def unload_model_clones(model, unload_weights_only=True, force_unload=True): def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [] to_unload = []
@@ -352,7 +352,6 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
def free_memory(memory_required, device, keep_loaded=[]): def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
unloaded_models = []
for i in range(len(current_loaded_models) -1, -1, -1): for i in range(len(current_loaded_models) -1, -1, -1):
shift_model = current_loaded_models[i] shift_model = current_loaded_models[i]
@@ -370,7 +369,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
unloaded_model.append(i) unloaded_model.append(i)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) current_loaded_models.pop(i)
if len(unloaded_model) > 0: if len(unloaded_model) > 0:
soft_empty_cache() soft_empty_cache()
@@ -379,17 +378,12 @@ def free_memory(memory_required, device, keep_loaded=[]):
mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True) mem_free_total, mem_free_torch = get_free_memory(device, torch_free_too=True)
if mem_free_torch > mem_free_total * 0.25: if mem_free_torch > mem_free_total * 0.25:
soft_empty_cache() soft_empty_cache()
return unloaded_models
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None): def load_models_gpu(models, memory_required=0, force_patch_weights=False):
global vram_state global vram_state
inference_memory = minimum_inference_memory() inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required) extra_mem = max(inference_memory, memory_required)
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required)
models = set(models) models = set(models)
@@ -423,13 +417,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
for d in devs: for d in devs:
if d != torch.device("cpu"): if d != torch.device("cpu"):
free_memory(extra_mem, d, models_already_loaded) free_memory(extra_mem, d, models_already_loaded)
free_mem = get_free_memory(d) return
if free_mem < minimum_memory_required:
logging.info("Unloading models for lowram load.") #TODO: partial model unloading when this case happens, also handle the opposite case where models can be unlowvramed.
models_to_load = free_memory(minimum_memory_required, d)
logging.info("{} models unloaded.".format(len(models_to_load)))
if len(models_to_load) == 0:
return
logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") logging.info(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}")
@@ -458,8 +446,8 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM): if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
model_size = loaded_model.model_memory_required(torch_dev) model_size = loaded_model.model_memory_required(torch_dev)
current_free_mem = get_free_memory(torch_dev) current_free_mem = get_free_memory(torch_dev)
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory())) lowvram_model_memory = int(max(64 * (1024 * 1024), (current_free_mem - 1024 * (1024 * 1024)) / 1.3 ))
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary if model_size <= (current_free_mem - inference_memory): #only switch to lowvram if really necessary
lowvram_model_memory = 0 lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
@@ -535,9 +523,6 @@ def unet_inital_load_device(parameters, dtype):
else: else:
return cpu_dev return cpu_dev
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
@@ -547,21 +532,6 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn return torch.float8_e4m3fn
if args.fp8_e5m2_unet: if args.fp8_e5m2_unet:
return torch.float8_e5m2 return torch.float8_e5m2
fp8_dtype = None
try:
for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
if dtype in supported_dtypes:
fp8_dtype = dtype
break
except:
pass
if fp8_dtype is not None:
free_model_memory = maximum_vram_for_weights(device)
if model_params * 2 > free_model_memory:
return fp8_dtype
if should_use_fp16(device=device, model_params=model_params, manual_cast=True): if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes: if torch.float16 in supported_dtypes:
return torch.float16 return torch.float16
@@ -679,29 +649,18 @@ def supports_cast(device, dtype): #TODO
return True return True
if dtype == torch.float16: if dtype == torch.float16:
return True return True
if is_device_mps(device):
return False
if directml_enabled: #TODO: test this if directml_enabled: #TODO: test this
return False return False
if dtype == torch.bfloat16: if dtype == torch.bfloat16:
return True return True
if is_device_mps(device):
return False
if dtype == torch.float8_e4m3fn: if dtype == torch.float8_e4m3fn:
return True return True
if dtype == torch.float8_e5m2: if dtype == torch.float8_e5m2:
return True return True
return False return False
def pick_weight_dtype(dtype, fallback_dtype, device=None):
if dtype is None:
dtype = fallback_dtype
elif dtype_size(dtype) > dtype_size(fallback_dtype):
dtype = fallback_dtype
if not supports_cast(device, dtype):
dtype = fallback_dtype
return dtype
def device_supports_non_blocking(device): def device_supports_non_blocking(device):
if is_device_mps(device): if is_device_mps(device):
return False #pytorch bug? mps doesn't support non blocking return False #pytorch bug? mps doesn't support non blocking
@@ -897,7 +856,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
fp16_works = True fp16_works = True
if fp16_works or manual_cast: if fp16_works or manual_cast:
free_model_memory = maximum_vram_for_weights(device) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True
@@ -917,9 +876,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False return False
if device is not None: if device is not None: #TODO not sure about mps bf16 support
if is_device_mps(device): if is_device_mps(device):
return True return False
if FORCE_FP32: if FORCE_FP32:
return False return False
@@ -927,23 +886,23 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if directml_enabled: if directml_enabled:
return False return False
if mps_mode(): if cpu_mode() or mps_mode():
return True
if cpu_mode():
return False return False
if is_intel_xpu(): if is_intel_xpu():
return True return True
props = torch.cuda.get_device_properties("cuda") if device is None:
device = torch.device("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8: if props.major >= 8:
return True return True
bf16_works = torch.cuda.is_bf16_supported() bf16_works = torch.cuda.is_bf16_supported()
if bf16_works or manual_cast: if bf16_works or manual_cast:
free_model_memory = maximum_vram_for_weights(device) free_model_memory = (get_free_memory() * 0.9 - minimum_inference_memory())
if (not prioritize_performance) or model_params * 4 > free_model_memory: if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True return True

View File

@@ -64,15 +64,9 @@ def set_model_options_pre_cfg_function(model_options, pre_cfg_function, disable_
return model_options return model_options
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, current_device=None, weight_inplace_update=False):
self.size = size self.size = size
self.model = model self.model = model
if not hasattr(self.model, 'device'):
logging.info("Model doesn't have a device attribute.")
self.model.device = offload_device
elif self.model.device is None:
self.model.device = offload_device
self.patches = {} self.patches = {}
self.backup = {} self.backup = {}
self.object_patches = {} self.object_patches = {}
@@ -81,6 +75,11 @@ class ModelPatcher:
self.model_size() self.model_size()
self.load_device = load_device self.load_device = load_device
self.offload_device = offload_device self.offload_device = offload_device
if current_device is None:
self.current_device = self.offload_device
else:
self.current_device = current_device
self.weight_inplace_update = weight_inplace_update self.weight_inplace_update = weight_inplace_update
self.model_lowvram = False self.model_lowvram = False
self.lowvram_patch_counter = 0 self.lowvram_patch_counter = 0
@@ -93,7 +92,7 @@ class ModelPatcher:
return self.size return self.size
def clone(self): def clone(self):
n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) n = ModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
n.patches = {} n.patches = {}
for k in self.patches: for k in self.patches:
n.patches[k] = self.patches[k][:] n.patches[k] = self.patches[k][:]
@@ -303,7 +302,7 @@ class ModelPatcher:
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.current_device = device_to
return self.model return self.model
@@ -356,7 +355,6 @@ class ModelPatcher:
self.model_lowvram = True self.model_lowvram = True
self.lowvram_patch_counter = patch_counter self.lowvram_patch_counter = patch_counter
self.model.device = device_to
return self.model return self.model
def calculate_weight(self, patches, weight, key): def calculate_weight(self, patches, weight, key):
@@ -553,13 +551,10 @@ class ModelPatcher:
if device_to is not None: if device_to is not None:
self.model.to(device_to) self.model.to(device_to)
self.model.device = device_to self.current_device = device_to
keys = list(self.object_patches_backup.keys()) keys = list(self.object_patches_backup.keys())
for k in keys: for k in keys:
comfy.utils.set_attr(self.model, k, self.object_patches_backup[k]) comfy.utils.set_attr(self.model, k, self.object_patches_backup[k])
self.object_patches_backup.clear() self.object_patches_backup.clear()
def current_loaded_device(self):
return self.model.device

View File

@@ -272,43 +272,3 @@ class StableCascadeSampling(ModelSamplingDiscrete):
percent = 1.0 - percent percent = 1.0 - percent
return self.sigma(torch.tensor(percent)) return self.sigma(torch.tensor(percent))
def flux_time_shift(mu: float, sigma: float, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
class ModelSamplingFlux(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.15))
def set_parameters(self, shift=1.15, timesteps=10000):
self.shift = shift
ts = self.sigma((torch.arange(1, timesteps + 1, 1) / timesteps))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma
def sigma(self, timestep):
return flux_time_shift(self.shift, 1.0, timestep)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent

View File

@@ -19,27 +19,14 @@
import torch import torch
import comfy.model_management import comfy.model_management
def cast_bias_weight(s, input):
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_bias_weight(s, input=None, dtype=None, device=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if device is None:
device = input.device
bias = None bias = None
non_blocking = comfy.model_management.device_should_use_non_blocking(device) non_blocking = comfy.model_management.device_should_use_non_blocking(input.device)
if s.bias is not None: if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking) bias = s.bias.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.bias_function is not None: if s.bias_function is not None:
bias = s.bias_function(bias) bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking) weight = s.weight.to(device=input.device, dtype=input.dtype, non_blocking=non_blocking)
if s.weight_function is not None: if s.weight_function is not None:
weight = s.weight_function(weight) weight = s.weight_function(weight)
return weight, bias return weight, bias
@@ -181,26 +168,6 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
def forward_comfy_cast_weights(self, input, out_dtype=None):
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
if self.comfy_cast_weights:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
kwargs.pop("out_dtype")
return super().forward(*args, **kwargs)
@classmethod @classmethod
def conv_nd(s, dims, *args, **kwargs): def conv_nd(s, dims, *args, **kwargs):
if dims == 2: if dims == 2:
@@ -235,6 +202,3 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d): class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True comfy_cast_weights = True
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True

View File

@@ -61,9 +61,7 @@ def prepare_sampling(model, noise_shape, conds):
device = model.load_device device = model.load_device
real_model = None real_model = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory comfy.model_management.load_models_gpu([model] + models, model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory)
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
real_model = model.model real_model = model.model
return real_model, conds, models return real_model, conds, models

View File

@@ -171,7 +171,7 @@ def calc_cond_batch(model, conds, x_in, timestep, model_options):
for i in range(1, len(to_batch_temp) + 1): for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i] batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:] input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory: if model.memory_required(input_shape) < free_memory:
to_batch = batch_amount to_batch = batch_amount
break break

View File

@@ -17,13 +17,11 @@ from . import diffusers_convert
from . import model_detection from . import model_detection
from . import sd1_clip from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@@ -387,8 +385,6 @@ class CLIPType(Enum):
STABLE_CASCADE = 2 STABLE_CASCADE = 2
SD3 = 3 SD3 = 3
STABLE_AUDIO = 4 STABLE_AUDIO = 4
HUNYUAN_DIT = 5
FLUX = 6
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = [] clip_data = []
@@ -416,8 +412,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sdxl_clip.SDXLRefinerClipModel clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]: elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = comfy.text_encoders.sd2_clip.SD2ClipModel clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = comfy.text_encoders.sd2_clip.SD2Tokenizer clip_target.tokenizer = sd2_clip.SD2Tokenizer
elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]: elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in clip_data[0]:
weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"] weight = clip_data[0]["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
dtype_t5 = weight.dtype dtype_t5 = weight.dtype
@@ -437,18 +433,6 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.HUNYUAN_DIT:
clip_target.clip = comfy.text_encoders.hydit.HyditModel
clip_target.tokenizer = comfy.text_encoders.hydit.HyditTokenizer
elif clip_type == CLIPType.FLUX:
weight_name = "encoder.block.23.layer.1.DenseReluDense.wi_1.weight"
weight = clip_data[0].get(weight_name, clip_data[1].get(weight_name, None))
dtype_t5 = None
if weight is not None:
dtype_t5 = weight.dtype
clip_target.clip = comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5)
clip_target.tokenizer = comfy.text_encoders.flux.FluxTokenizer
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -510,18 +494,13 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix) parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
if model_config is None: if model_config is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
if weight_dtype is not None:
unet_weight_dtype.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -564,7 +543,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
logging.debug("left over keys: {}".format(left_over)) logging.debug("left over keys: {}".format(left_over))
if output_model: if output_model:
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device()) model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device(), current_device=inital_load_device)
if inital_load_device != torch.device("cpu"): if inital_load_device != torch.device("cpu"):
logging.info("loaded straight to GPU") logging.info("loaded straight to GPU")
model_management.load_model_gpu(model_patcher) model_management.load_model_gpu(model_patcher)
@@ -572,7 +551,7 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
return (model_patcher, clip, vae, clipvision) return (model_patcher, clip, vae, clipvision)
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format def load_unet_state_dict(sd): #load unet in diffusers or regular format
#Allow loading unets from checkpoint files #Allow loading unets from checkpoint files
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd) diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
@@ -581,6 +560,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
sd = temp_sd sd = temp_sd
parameters = comfy.utils.calculate_parameters(sd) parameters = comfy.utils.calculate_parameters(sd)
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
model_config = model_detection.model_config_from_unet(sd, "") model_config = model_detection.model_config_from_unet(sd, "")
@@ -607,11 +587,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.warning("{} {}".format(diffusers_keys[k], k)) logging.warning("{} {}".format(diffusers_keys[k], k))
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
if dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=model_config.supported_inference_dtypes)
else:
unet_dtype = dtype
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
@@ -622,9 +598,9 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
logging.info("left over keys in unet: {}".format(left_over)) logging.info("left over keys in unet: {}".format(left_over))
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
def load_unet(unet_path, dtype=None): def load_unet(unet_path):
sd = comfy.utils.load_torch_file(unet_path) sd = comfy.utils.load_torch_file(unet_path)
model = load_unet_state_dict(sd, dtype=dtype) model = load_unet_state_dict(sd)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))

View File

@@ -94,8 +94,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f: with open(textmodel_json_config) as f:
config = json.load(f) config = json.load(f)
self.operations = comfy.ops.manual_cast self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
self.max_length = max_length self.max_length = max_length
@@ -141,13 +140,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_up_textual_embeddings(self, tokens, current_embeds): def set_up_textual_embeddings(self, tokens, current_embeds):
out_tokens = [] out_tokens = []
next_new_token = token_dict_size = current_embeds.weight.shape[0] next_new_token = token_dict_size = current_embeds.weight.shape[0] - 1
embedding_weights = [] embedding_weights = []
for x in tokens: for x in tokens:
tokens_temp = [] tokens_temp = []
for y in x: for y in x:
if isinstance(y, numbers.Integral): if isinstance(y, numbers.Integral):
if y == token_dict_size: #EOS token
y = -1
tokens_temp += [int(y)] tokens_temp += [int(y)]
else: else:
if y.shape[0] == current_embeds.weight.shape[1]: if y.shape[0] == current_embeds.weight.shape[1]:
@@ -162,11 +163,12 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
n = token_dict_size n = token_dict_size
if len(embedding_weights) > 0: if len(embedding_weights) > 0:
new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype) new_embedding = torch.nn.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
new_embedding.weight[:token_dict_size] = current_embeds.weight new_embedding.weight[:token_dict_size] = current_embeds.weight[:-1]
for x in embedding_weights: for x in embedding_weights:
new_embedding.weight[n] = x new_embedding.weight[n] = x
n += 1 n += 1
new_embedding.weight[n] = current_embeds.weight[-1] #EOS embedding
self.transformer.set_input_embeddings(new_embedding) self.transformer.set_input_embeddings(new_embedding)
processed_tokens = [] processed_tokens = []
@@ -195,7 +197,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if self.enable_attention_masks: if self.enable_attention_masks:
attention_mask_model = attention_mask attention_mask_model = attention_mask
outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32) outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds) self.transformer.set_input_embeddings(backup_embeds)
if self.layer == "last": if self.layer == "last":

View File

@@ -6,7 +6,7 @@
"attention_dropout": 0.0, "attention_dropout": 0.0,
"bos_token_id": 0, "bos_token_id": 0,
"dropout": 0.0, "dropout": 0.0,
"eos_token_id": 49407, "eos_token_id": 2,
"hidden_act": "quick_gelu", "hidden_act": "quick_gelu",
"hidden_size": 768, "hidden_size": 768,
"initializer_factor": 1.0, "initializer_factor": 1.0,

View File

@@ -5,7 +5,7 @@
"attention_dropout": 0.0, "attention_dropout": 0.0,
"bos_token_id": 0, "bos_token_id": 0,
"dropout": 0.0, "dropout": 0.0,
"eos_token_id": 49407, "eos_token_id": 2,
"hidden_act": "gelu", "hidden_act": "gelu",
"hidden_size": 1024, "hidden_size": 1024,
"initializer_factor": 1.0, "initializer_factor": 1.0,

View File

@@ -3,13 +3,11 @@ from . import model_base
from . import utils from . import utils
from . import sd1_clip from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip from . import sdxl_clip
import comfy.text_encoders.sd2_clip
import comfy.text_encoders.sd3_clip import comfy.text_encoders.sd3_clip
import comfy.text_encoders.sa_t5 import comfy.text_encoders.sa_t5
import comfy.text_encoders.aura_t5 import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
from . import supported_models_base from . import supported_models_base
from . import latent_formats from . import latent_formats
@@ -31,7 +29,6 @@ class SD15(supported_models_base.BASE):
} }
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def process_clip_state_dict(self, state_dict): def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys()) k = list(state_dict.keys())
@@ -78,7 +75,6 @@ class SD20(supported_models_base.BASE):
} }
latent_format = latent_formats.SD15 latent_format = latent_formats.SD15
memory_usage_factor = 1.0
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
@@ -104,7 +100,7 @@ class SD20(supported_models_base.BASE):
return state_dict return state_dict
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.sd2_clip.SD2Tokenizer, comfy.text_encoders.sd2_clip.SD2ClipModel) return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
class SD21UnclipL(SD20): class SD21UnclipL(SD20):
unet_config = { unet_config = {
@@ -142,7 +138,6 @@ class SDXLRefiner(supported_models_base.BASE):
} }
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
memory_usage_factor = 1.0
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device) return model_base.SDXLRefiner(self, device=device)
@@ -181,8 +176,6 @@ class SDXL(supported_models_base.BASE):
latent_format = latent_formats.SDXL latent_format = latent_formats.SDXL
memory_usage_factor = 0.7
def model_type(self, state_dict, prefix=""): def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5 if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5() self.latent_format = latent_formats.SDXL_Playground_2_5()
@@ -510,9 +503,6 @@ class SD3(supported_models_base.BASE):
unet_extra_config = {} unet_extra_config = {}
latent_format = latent_formats.SD3 latent_format = latent_formats.SD3
memory_usage_factor = 1.2
text_encoder_key_prefix = ["text_encoders."] text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None): def get_model(self, state_dict, prefix="", device=None):
@@ -590,90 +580,6 @@ class AuraFlow(supported_models_base.BASE):
def clip_target(self, state_dict={}): def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
class HunyuanDiT(supported_models_base.BASE): models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow]
unet_config = {
"image_model": "hydit",
}
unet_extra_config = {
"attn_precision": torch.float32,
}
sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.018,
}
latent_format = latent_formats.SDXL
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.HunyuanDiT(self, device=device)
return out
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.hydit.HyditTokenizer, comfy.text_encoders.hydit.HyditModel)
class HunyuanDiT1(HunyuanDiT):
unet_config = {
"image_model": "hydit1",
}
unet_extra_config = {}
sampling_settings = {
"linear_start" : 0.00085,
"linear_end" : 0.03,
}
class Flux(supported_models_base.BASE):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
}
sampling_settings = {
}
unet_extra_config = {}
latent_format = latent_formats.Flux
memory_usage_factor = 2.8
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": False,
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
return out
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@@ -27,8 +27,6 @@ class BASE:
text_encoder_key_prefix = ["cond_stage_model."] text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
memory_usage_factor = 2.0
manual_cast_dtype = None manual_cast_dtype = None
@classmethod @classmethod

View File

@@ -1,140 +0,0 @@
import torch
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class BertAttention(torch.nn.Module):
def __init__(self, embed_dim, heads, dtype, device, operations):
super().__init__()
self.heads = heads
self.query = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.key = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
self.value = operations.Linear(embed_dim, embed_dim, bias=True, dtype=dtype, device=device)
def forward(self, x, mask=None, optimized_attention=None):
q = self.query(x)
k = self.key(x)
v = self.value(x)
out = optimized_attention(q, k, v, self.heads, mask)
return out
class BertOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(output_dim, eps=layer_norm_eps, dtype=dtype, device=device)
# self.dropout = nn.Dropout(0.0)
def forward(self, x, y):
x = self.dense(x)
# hidden_states = self.dropout(hidden_states)
x = self.LayerNorm(x + y)
return x
class BertAttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.self = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = BertOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
y = self.self(x, mask, optimized_attention)
return self.output(y, x)
class BertIntermediate(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(embed_dim, intermediate_dim, dtype=dtype, device=device)
def forward(self, x):
x = self.dense(x)
return torch.nn.functional.gelu(x)
class BertBlock(torch.nn.Module):
def __init__(self, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttentionBlock(embed_dim, heads, layer_norm_eps, dtype, device, operations)
self.intermediate = BertIntermediate(embed_dim, intermediate_dim, dtype, device, operations)
self.output = BertOutput(intermediate_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
x = self.attention(x, mask, optimized_attention)
y = self.intermediate(x)
return self.output(y, x)
class BertEncoder(torch.nn.Module):
def __init__(self, num_layers, embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([BertBlock(embed_dim, intermediate_dim, heads, layer_norm_eps, dtype, device, operations) for i in range(num_layers)])
def forward(self, x, mask=None, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, mask, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class BertEmbeddings(torch.nn.Module):
def __init__(self, vocab_size, max_position_embeddings, type_vocab_size, pad_token_id, embed_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.word_embeddings = operations.Embedding(vocab_size, embed_dim, padding_idx=pad_token_id, dtype=dtype, device=device)
self.position_embeddings = operations.Embedding(max_position_embeddings, embed_dim, dtype=dtype, device=device)
self.token_type_embeddings = operations.Embedding(type_vocab_size, embed_dim, dtype=dtype, device=device)
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, input_tokens, token_type_ids=None, dtype=None):
x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
else:
x += comfy.ops.cast_to_input(self.token_type_embeddings.weight[0], x)
x = self.LayerNorm(x)
return x
class BertModel_(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
embed_dim = config_dict["hidden_size"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
x = self.embeddings(input_tokens, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
x, i = self.encoder(x, mask, intermediate_output)
return x, i
class BertModel(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.bert = BertModel_(config_dict, dtype, device, operations)
self.num_layers = config_dict["num_hidden_layers"]
def get_input_embeddings(self):
return self.bert.embeddings.word_embeddings
def set_input_embeddings(self, embeddings):
self.bert.embeddings.word_embeddings = embeddings
def forward(self, *args, **kwargs):
return self.bert(*args, **kwargs)

View File

@@ -1,71 +0,0 @@
from comfy import sd1_clip
import comfy.text_encoders.t5
import comfy.model_management
from transformers import T5TokenizerFast
import torch
import os
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.clip_l.untokenize(token_weight_pair)
def state_dict(self):
return {}
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
self.t5xxl.set_clip_options(options)
def reset_clip_options(self):
self.clip_l.reset_clip_options()
self.t5xxl.reset_clip_options()
def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return t5_out, l_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
return self.clip_l.load_sd(sd)
else:
return self.t5xxl.load_sd(sd)
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
return FluxClipModel_

View File

@@ -1,79 +0,0 @@
from comfy import sd1_clip
from transformers import BertTokenizer
from .spiece_tokenizer import SPieceTokenizer
from .bert import BertModel
import comfy.text_encoders.t5
import os
import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
def tokenize_with_weights(self, text:str, return_word_ids=False):
out = {}
out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
return out
def untokenize(self, token_weight_pair):
return self.hydit_clip.untokenize(token_weight_pair)
def state_dict(self):
return {"mt5xl.spiece_model": self.mt5xl.state_dict()["spiece_model"]}
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
self.dtypes = set()
if dtype is not None:
self.dtypes.add(dtype)
def encode_token_weights(self, token_weight_pairs):
hydit_out = self.hydit_clip.encode_token_weights(token_weight_pairs["hydit_clip"])
mt5_out = self.mt5xl.encode_token_weights(token_weight_pairs["mt5xl"])
return hydit_out[0], hydit_out[1], {"attention_mask": hydit_out[2]["attention_mask"], "conditioning_mt5xl": mt5_out[0], "attention_mask_mt5xl": mt5_out[2]["attention_mask"]}
def load_sd(self, sd):
if "bert.encoder.layer.0.attention.self.query.weight" in sd:
return self.hydit_clip.load_sd(sd)
else:
return self.mt5xl.load_sd(sd)
def set_clip_options(self, options):
self.hydit_clip.set_clip_options(options)
self.mt5xl.set_clip_options(options)
def reset_clip_options(self):
self.hydit_clip.reset_clip_options()
self.mt5xl.reset_clip_options()

View File

@@ -1,35 +0,0 @@
{
"_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
"architectures": [
"BertModel"
],
"attention_probs_dropout_prob": 0.1,
"bos_token_id": 0,
"classifier_dropout": null,
"directionality": "bidi",
"eos_token_id": 2,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 1024,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 16,
"num_hidden_layers": 24,
"output_past": true,
"pad_token_id": 0,
"pooler_fc_size": 768,
"pooler_num_attention_heads": 12,
"pooler_num_fc_layers": 3,
"pooler_size_per_head": 128,
"pooler_type": "first_token_transform",
"position_embedding_type": "absolute",
"torch_dtype": "float32",
"transformers_version": "4.22.1",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 47020
}

View File

@@ -1,7 +0,0 @@
{
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"unk_token": "[UNK]"
}

View File

@@ -1,16 +0,0 @@
{
"cls_token": "[CLS]",
"do_basic_tokenize": true,
"do_lower_case": true,
"mask_token": "[MASK]",
"name_or_path": "hfl/chinese-roberta-wwm-ext",
"never_split": null,
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"special_tokens_map_file": "/home/chenweifeng/.cache/huggingface/hub/models--hfl--chinese-roberta-wwm-ext/snapshots/5c58d0b8ec1d9014354d691c538661bf00bfdb44/special_tokens_map.json",
"strip_accents": null,
"tokenize_chinese_chars": true,
"tokenizer_class": "BertTokenizer",
"unk_token": "[UNK]",
"model_max_length": 77
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,22 +0,0 @@
{
"d_ff": 5120,
"d_kv": 64,
"d_model": 2048,
"decoder_start_token_id": 0,
"dropout_rate": 0.1,
"eos_token_id": 1,
"dense_act_fn": "gelu_pytorch_tanh",
"initializer_factor": 1.0,
"is_encoder_decoder": true,
"is_gated_act": true,
"layer_norm_epsilon": 1e-06,
"model_type": "mt5",
"num_decoder_layers": 24,
"num_heads": 32,
"num_layers": 24,
"output_past": true,
"pad_token_id": 0,
"relative_attention_num_buckets": 32,
"tie_word_embeddings": false,
"vocab_size": 250112
}

View File

@@ -54,7 +54,14 @@ class SD3ClipModel(torch.nn.Module):
self.clip_g = None self.clip_g = None
if t5: if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device) if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5) self.dtypes.add(dtype_t5)
else: else:
@@ -81,7 +88,7 @@ class SD3ClipModel(torch.nn.Module):
def encode_token_weights(self, token_weight_pairs): def encode_token_weights(self, token_weight_pairs):
token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_l = token_weight_pairs["l"]
token_weight_pairs_g = token_weight_pairs["g"] token_weight_pairs_g = token_weight_pairs["g"]
token_weight_pairs_t5 = token_weight_pairs["t5xxl"] token_weight_pars_t5 = token_weight_pairs["t5xxl"]
lg_out = None lg_out = None
pooled = None pooled = None
out = None out = None
@@ -108,7 +115,7 @@ class SD3ClipModel(torch.nn.Module):
pooled = torch.cat((l_pooled, g_pooled), dim=-1) pooled = torch.cat((l_pooled, g_pooled), dim=-1)
if self.t5xxl is not None: if self.t5xxl is not None:
t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pairs_t5) t5_out, t5_pooled = self.t5xxl.encode_token_weights(token_weight_pars_t5)
if lg_out is not None: if lg_out is not None:
out = torch.cat([lg_out, t5_out], dim=-2) out = torch.cat([lg_out, t5_out], dim=-2)
else: else:

View File

@@ -27,6 +27,3 @@ class SPieceTokenizer:
def __call__(self, string): def __call__(self, string):
out = self.tokenizer.encode(string) out = self.tokenizer.encode(string)
return {"input_ids": out} return {"input_ids": out}
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@@ -1,7 +1,6 @@
import torch import torch
import math import math
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.ops
class T5LayerNorm(torch.nn.Module): class T5LayerNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None): def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None, operations=None):
@@ -12,7 +11,7 @@ class T5LayerNorm(torch.nn.Module):
def forward(self, x): def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True) variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon) x = x * torch.rsqrt(variance + self.variance_epsilon)
return comfy.ops.cast_to_input(self.weight, x) * x return self.weight.to(device=x.device, dtype=x.dtype) * x
activations = { activations = {
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"), "gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
@@ -83,7 +82,7 @@ class T5Attention(torch.nn.Module):
if relative_attention_bias: if relative_attention_bias:
self.relative_attention_num_buckets = 32 self.relative_attention_num_buckets = 32
self.relative_attention_max_distance = 128 self.relative_attention_max_distance = 128
self.relative_attention_bias = operations.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device, dtype=dtype) self.relative_attention_bias = torch.nn.Embedding(self.relative_attention_num_buckets, self.num_heads, device=device)
@staticmethod @staticmethod
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
@@ -133,7 +132,7 @@ class T5Attention(torch.nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device, dtype): def compute_bias(self, query_length, key_length, device):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
@@ -144,7 +143,7 @@ class T5Attention(torch.nn.Module):
num_buckets=self.relative_attention_num_buckets, num_buckets=self.relative_attention_num_buckets,
max_distance=self.relative_attention_max_distance, max_distance=self.relative_attention_max_distance,
) )
values = self.relative_attention_bias(relative_position_bucket, out_dtype=dtype) # shape (query_length, key_length, num_heads) values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values return values
@@ -153,7 +152,7 @@ class T5Attention(torch.nn.Module):
k = self.k(x) k = self.k(x)
v = self.v(x) v = self.v(x)
if self.relative_attention_bias is not None: if self.relative_attention_bias is not None:
past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device, x.dtype) past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device)
if past_bias is not None: if past_bias is not None:
if mask is not None: if mask is not None:
@@ -200,7 +199,7 @@ class T5Stack(torch.nn.Module):
self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations) self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device, operations=operations)
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None): def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True):
mask = None mask = None
if attention_mask is not None: if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
@@ -226,7 +225,7 @@ class T5(torch.nn.Module):
self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations) self.encoder = T5Stack(self.num_layers, model_dim, model_dim, config_dict["d_ff"], config_dict["dense_act_fn"], config_dict["is_gated_act"], config_dict["num_heads"], config_dict["model_type"] != "umt5", dtype, device, operations)
self.dtype = dtype self.dtype = dtype
self.shared = operations.Embedding(config_dict["vocab_size"], model_dim, device=device, dtype=dtype) self.shared = torch.nn.Embedding(config_dict["vocab_size"], model_dim, device=device)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
@@ -235,7 +234,5 @@ class T5(torch.nn.Module):
self.shared = embeddings self.shared = embeddings
def forward(self, input_ids, *args, **kwargs): def forward(self, input_ids, *args, **kwargs):
x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32)) x = self.shared(input_ids)
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
return self.encoder(x, *args, **kwargs) return self.encoder(x, *args, **kwargs)

View File

@@ -11,7 +11,7 @@ import itertools
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
device = torch.device("cpu") device = torch.device("cpu")
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): if ckpt.lower().endswith(".safetensors"):
sd = safetensors.torch.load_file(ckpt, device=device.type) sd = safetensors.torch.load_file(ckpt, device=device.type)
else: else:
if safe_load: if safe_load:
@@ -40,22 +40,9 @@ def calculate_parameters(sd, prefix=""):
params = 0 params = 0
for k in sd.keys(): for k in sd.keys():
if k.startswith(prefix): if k.startswith(prefix):
w = sd[k] params += sd[k].nelement()
params += w.nelement()
return params return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
if k.startswith(prefix):
w = sd[k]
dtypes[w.dtype] = dtypes.get(w.dtype, 0) + 1
if len(dtypes) == 0:
return None
return max(dtypes, key=dtypes.get)
def state_dict_key_replace(state_dict, keys_to_replace): def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace: for x in keys_to_replace:
if x in state_dict: if x in state_dict:
@@ -415,59 +402,6 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
hidden_size = mmdit_config.get("hidden_size", 0)
key_map = {}
for index in range(n_double_layers):
prefix_from = "transformer_blocks.{}".format(index)
prefix_to = "{}double_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.img_attn.qkv.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
for index in range(n_single_layers):
prefix_from = "single_transformer_blocks.{}".format(index)
prefix_to = "{}single_blocks.{}".format(output_prefix, index)
for end in ("weight", "bias"):
k = "{}.attn.".format(prefix_from)
qkv = "{}.linear1.{}".format(prefix_to, end)
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
block_map = {#TODO
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
MAP_BASIC = { #TODO
}
for k in MAP_BASIC:
if len(k) > 2:
key_map[k[1]] = ("{}{}".format(output_prefix, k[0]), None, k[2])
else:
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
return key_map
def repeat_to_batch_size(tensor, batch_size, dim=0): def repeat_to_batch_size(tensor, batch_size, dim=0):
if tensor.shape[dim] > batch_size: if tensor.shape[dim] > batch_size:
return tensor.narrow(dim, 0, batch_size) return tensor.narrow(dim, 0, batch_size)

View File

@@ -7,7 +7,6 @@ import io
import json import json
import struct import struct
import random import random
import hashlib
from comfy.cli_args import args from comfy.cli_args import args
class EmptyLatentAudio: class EmptyLatentAudio:

View File

@@ -295,23 +295,6 @@ class SamplerDPMPP_SDE:
sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, ) return (sampler, )
class SamplerDPMPP_2S_Ancestral:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
"s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta, s_noise):
sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise})
return (sampler, )
class SamplerEulerAncestral: class SamplerEulerAncestral:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -683,7 +666,6 @@ NODE_CLASS_MAPPINGS = {
"SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral,
"SamplerDPMAdaptative": SamplerDPMAdaptative, "SamplerDPMAdaptative": SamplerDPMAdaptative,
"SplitSigmas": SplitSigmas, "SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise, "SplitSigmasDenoise": SplitSigmasDenoise,
@@ -700,4 +682,4 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
"SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++",
} }

View File

@@ -1,47 +0,0 @@
import node_helpers
class CLIPTextEncodeFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning/flux"
def encode(self, clip, clip_l, t5xxl, guidance):
tokens = clip.tokenize(clip_l)
tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
output["guidance"] = guidance
return ([[cond, output]], )
class FluxGuidance:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"conditioning": ("CONDITIONING", ),
"guidance": ("FLOAT", {"default": 3.5, "min": 0.0, "max": 100.0, "step": 0.1}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "append"
CATEGORY = "advanced/conditioning/flux"
def append(self, conditioning, guidance):
c = node_helpers.conditioning_set_values(conditioning, {"guidance": guidance})
return (c, )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
}

View File

@@ -1,24 +0,0 @@
class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"mt5xl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode"
CATEGORY = "advanced/conditioning"
def encode(self, clip, bert, mt5xl):
tokens = clip.tokenize(bert)
tokens["mt5xl"] = clip.tokenize(mt5xl)["mt5xl"]
output = clip.encode_from_tokens(tokens, return_pooled=True, return_dict=True)
cond = output.pop("cond")
return ([[cond, output]], )
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
}

View File

@@ -2,7 +2,6 @@ import folder_paths
import comfy.sd import comfy.sd
import comfy.model_sampling import comfy.model_sampling
import comfy.latent_formats import comfy.latent_formats
import nodes
import torch import torch
class LCM(comfy.model_sampling.EPS): class LCM(comfy.model_sampling.EPS):
@@ -171,42 +170,6 @@ class ModelSamplingAuraFlow(ModelSamplingSD3):
def patch_aura(self, model, shift): def patch_aura(self, model, shift):
return self.patch(model, shift, multiplier=1.0) return self.patch(model, shift, multiplier=1.0)
class ModelSamplingFlux:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 1.15, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01}),
"width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, width, height):
m = model.clone()
x1 = 256
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (width * height / (8 * 8 * 2 * 2)) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class ModelSamplingContinuousEDM: class ModelSamplingContinuousEDM:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@@ -321,6 +284,5 @@ NODE_CLASS_MAPPINGS = {
"ModelSamplingStableCascade": ModelSamplingStableCascade, "ModelSamplingStableCascade": ModelSamplingStableCascade,
"ModelSamplingSD3": ModelSamplingSD3, "ModelSamplingSD3": ModelSamplingSD3,
"ModelSamplingAuraFlow": ModelSamplingAuraFlow, "ModelSamplingAuraFlow": ModelSamplingAuraFlow,
"ModelSamplingFlux": ModelSamplingFlux,
"RescaleCFG": RescaleCFG, "RescaleCFG": RescaleCFG,
} }

View File

@@ -264,7 +264,6 @@ class CLIPSave:
metadata = {} metadata = {}
if not args.disable_metadata: if not args.disable_metadata:
metadata["format"] = "pt"
metadata["prompt"] = prompt_info metadata["prompt"] = prompt_info
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:

View File

@@ -75,36 +75,9 @@ class ModelMergeSD3_2B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict} return {"required": arg_dict}
class ModelMergeFlux1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["img_in."] = argument
arg_dict["time_in."] = argument
arg_dict["guidance_in"] = argument
arg_dict["vector_in."] = argument
arg_dict["txt_in."] = argument
for i in range(19):
arg_dict["double_blocks.{}.".format(i)] = argument
for i in range(38):
arg_dict["single_blocks.{}.".format(i)] = argument
arg_dict["final_layer."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1, "ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
"ModelMergeSDXL": ModelMergeSDXL, "ModelMergeSDXL": ModelMergeSDXL,
"ModelMergeSD3_2B": ModelMergeSD3_2B, "ModelMergeSD3_2B": ModelMergeSD3_2B,
"ModelMergeFlux1": ModelMergeFlux1,
} }

View File

@@ -12,7 +12,7 @@ class PerturbedAttentionGuidance:
return { return {
"required": { "required": {
"model": ("MODEL",), "model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": 0.01}), "scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
} }
} }

View File

@@ -96,7 +96,7 @@ class SelfAttentionGuidance:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",), return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.01}), "scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}), "blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)

View File

@@ -27,8 +27,8 @@ class EmptySD3LatentImage:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), return {"required": { "width": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}), "height": ("INT", {"default": 1024, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"

View File

@@ -188,11 +188,6 @@ def recursive_execute(server, prompt, outputs, current_item, extra_data, execute
"current_inputs": input_data_formatted, "current_inputs": input_data_formatted,
"current_outputs": output_data_formatted "current_outputs": output_data_formatted
} }
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
return (False, error_details, ex) return (False, error_details, ex)
executed.add(unique_id) executed.add(unique_id)

View File

@@ -3,7 +3,7 @@ import time
import logging import logging
from typing import Set, List, Dict, Tuple from typing import Set, List, Dict, Tuple
supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl', '.sft']) supported_pt_extensions: Set[str] = set(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
SupportedFileExtensionsType = Set[str] SupportedFileExtensionsType = Set[str]
ScanPathType = List[str] ScanPathType = List[str]

View File

@@ -25,7 +25,7 @@ def pillow(fn, arg):
finally: finally:
if prev_value is not None: if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x return x
def hasher(): def hasher():
hashfuncs = { hashfuncs = {

View File

@@ -818,22 +818,15 @@ class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("unet"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name, weight_dtype): def load_unet(self, unet_name):
dtype = None
if weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
elif weight_dtype == "fp8_e5m2":
dtype = torch.float8_e5m2
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = folder_paths.get_full_path("unet", unet_name)
model = comfy.sd.load_unet(unet_path, dtype=dtype) model = comfy.sd.load_unet(unet_path)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:
@@ -866,7 +859,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ), return {"required": { "clip_name1": (folder_paths.get_filename_list("clip"), ),
"clip_name2": (folder_paths.get_filename_list("clip"), ), "clip_name2": (folder_paths.get_filename_list("clip"), ),
"type": (["sdxl", "sd3", "flux"], ), "type": (["sdxl", "sd3"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@@ -880,8 +873,6 @@ class DualCLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
elif type == "sd3": elif type == "sd3":
clip_type = comfy.sd.CLIPType.SD3 clip_type = comfy.sd.CLIPType.SD3
elif type == "flux":
clip_type = comfy.sd.CLIPType.FLUX
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)
@@ -1852,7 +1843,6 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"StyleModelLoader": "Load Style Model", "StyleModelLoader": "Load Style Model",
"CLIPVisionLoader": "Load CLIP Vision", "CLIPVisionLoader": "Load CLIP Vision",
"UpscaleModelLoader": "Load Upscale Model", "UpscaleModelLoader": "Load Upscale Model",
"UNETLoader": "Load Diffusion Model",
# Conditioning # Conditioning
"CLIPVisionEncode": "CLIP Vision Encode", "CLIPVisionEncode": "CLIP Vision Encode",
"StyleModelApply": "Apply Style Model", "StyleModelApply": "Apply Style Model",
@@ -2047,8 +2037,6 @@ def init_builtin_extra_nodes():
"nodes_sd3.py", "nodes_sd3.py",
"nodes_gits.py", "nodes_gits.py",
"nodes_controlnet.py", "nodes_controlnet.py",
"nodes_hunyuan.py",
"nodes_flux.py",
] ]
import_failed = [] import_failed = []

View File

@@ -1,7 +1,7 @@
torch torch==2.3.1
torchsde torchsde
torchvision torchvision==0.18.1
torchaudio torchaudio==2.3.1
einops einops
transformers>=4.28.1 transformers>=4.28.1
tokenizers>=0.13.3 tokenizers>=0.13.3
@@ -13,6 +13,7 @@ Pillow
scipy scipy
tqdm tqdm
psutil psutil
numpy<2.0.0
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@@ -127,11 +127,7 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html")) return web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache'
response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0"
return response
@routes.get("/embeddings") @routes.get("/embeddings")
def get_embeddings(self): def get_embeddings(self):

View File

@@ -105,16 +105,15 @@ export class ChangeTracker {
window.addEventListener( window.addEventListener(
"keydown", "keydown",
(e) => { (e) => {
const activeEl = document.activeElement;
requestAnimationFrame(async () => { requestAnimationFrame(async () => {
let bindInputEl; let activeEl;
// If we are auto queue in change mode then we do want to trigger on inputs // If we are auto queue in change mode then we do want to trigger on inputs
if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") { if (!app.ui.autoQueueEnabled || app.ui.autoQueueMode === "instant") {
activeEl = document.activeElement;
if (activeEl?.tagName === "INPUT" || activeEl?.["type"] === "textarea") { if (activeEl?.tagName === "INPUT" || activeEl?.["type"] === "textarea") {
// Ignore events on inputs, they have their native history // Ignore events on inputs, they have their native history
return; return;
} }
bindInputEl = activeEl;
} }
keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta"; keyIgnored = e.key === "Control" || e.key === "Shift" || e.key === "Alt" || e.key === "Meta";
@@ -124,7 +123,7 @@ export class ChangeTracker {
if (await changeTracker().undoRedo(e)) return; if (await changeTracker().undoRedo(e)) return;
// If our active element is some type of input then handle changes after they're done // If our active element is some type of input then handle changes after they're done
if (ChangeTracker.bindInput(bindInputEl)) return; if (ChangeTracker.bindInput(activeEl)) return;
changeTracker().checkState(); changeTracker().checkState();
}); });
}, },

View File

@@ -190,10 +190,9 @@ function parseVorbisComment(dataView) {
const comment = getString(dataView, offset, commentLength); const comment = getString(dataView, offset, commentLength);
offset += commentLength; offset += commentLength;
const ind = comment.indexOf('=') const [key, value] = comment.split('=');
const key = comment.substring(0, ind);
comments[key] = comment.substring(ind+1); comments[key] = value;
} }
return comments; return comments;

View File

@@ -330,7 +330,6 @@
.comfyui-workflows-open .active { .comfyui-workflows-open .active {
font-weight: bold; font-weight: bold;
color: var(--primary-fg);
} }
.comfyui-workflows-favorites:empty { .comfyui-workflows-favorites:empty {
@@ -418,10 +417,6 @@
padding: 2px 4px; padding: 2px 4px;
} }
.comfyui-workflows-tree-file.active .comfyui-workflows-file-action {
color: var(--primary-fg);
}
.lg ~ .comfyui-workflows-popup .comfyui-workflows-tree-file:not(:hover) .comfyui-workflows-file-action { .lg ~ .comfyui-workflows-popup .comfyui-workflows-tree-file:not(:hover) .comfyui-workflows-file-action {
opacity: 0; opacity: 0;
} }