Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ee9547ba31 | ||
|
|
19a64d6291 | ||
|
|
b486885e08 | ||
|
|
0229228f3f | ||
|
|
1ed75ab30e | ||
|
|
99a1fb6027 | ||
|
|
73e04987f7 | ||
|
|
5388df784a | ||
|
|
26e0ba8f8c | ||
|
|
bc6dac4327 | ||
|
|
f18ebbd316 | ||
|
|
15564688ed | ||
|
|
c6b9c11ef6 | ||
|
|
e44d0ac7f7 | ||
|
|
56bc64f351 | ||
|
|
f7d83b72e0 | ||
|
|
80f07952d2 | ||
|
|
57f330caf9 | ||
|
|
601ff9e3db | ||
|
|
341667c4d5 | ||
|
|
1419dee915 | ||
|
|
da13b6b827 | ||
|
|
c86cd58573 | ||
|
|
b5fe39211a | ||
|
|
e946667216 | ||
|
|
d7969cb070 | ||
|
|
bddb02660c |
@@ -28,7 +28,7 @@ def pull(repo, remote_name='origin', branch='master'):
|
|||||||
|
|
||||||
if repo.index.conflicts is not None:
|
if repo.index.conflicts is not None:
|
||||||
for conflict in repo.index.conflicts:
|
for conflict in repo.index.conflicts:
|
||||||
print('Conflicts found in:', conflict[0].path)
|
print('Conflicts found in:', conflict[0].path) # noqa: T201
|
||||||
raise AssertionError('Conflicts, ahhhhh!!')
|
raise AssertionError('Conflicts, ahhhhh!!')
|
||||||
|
|
||||||
user = repo.default_signature
|
user = repo.default_signature
|
||||||
@@ -49,18 +49,18 @@ repo_path = str(sys.argv[1])
|
|||||||
repo = pygit2.Repository(repo_path)
|
repo = pygit2.Repository(repo_path)
|
||||||
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
ident = pygit2.Signature('comfyui', 'comfy@ui')
|
||||||
try:
|
try:
|
||||||
print("stashing current changes")
|
print("stashing current changes") # noqa: T201
|
||||||
repo.stash(ident)
|
repo.stash(ident)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print("nothing to stash")
|
print("nothing to stash") # noqa: T201
|
||||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||||
print("creating backup branch: {}".format(backup_branch_name))
|
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||||
try:
|
try:
|
||||||
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
repo.branches.local.create(backup_branch_name, repo.head.peel())
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
print("checking out master branch")
|
print("checking out master branch") # noqa: T201
|
||||||
branch = repo.lookup_branch('master')
|
branch = repo.lookup_branch('master')
|
||||||
if branch is None:
|
if branch is None:
|
||||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||||
@@ -72,7 +72,7 @@ else:
|
|||||||
ref = repo.lookup_reference(branch.name)
|
ref = repo.lookup_reference(branch.name)
|
||||||
repo.checkout(ref)
|
repo.checkout(ref)
|
||||||
|
|
||||||
print("pulling latest changes")
|
print("pulling latest changes") # noqa: T201
|
||||||
pull(repo)
|
pull(repo)
|
||||||
|
|
||||||
if "--stable" in sys.argv:
|
if "--stable" in sys.argv:
|
||||||
@@ -94,7 +94,7 @@ if "--stable" in sys.argv:
|
|||||||
if latest_tag is not None:
|
if latest_tag is not None:
|
||||||
repo.checkout(latest_tag)
|
repo.checkout(latest_tag)
|
||||||
|
|
||||||
print("Done!")
|
print("Done!") # noqa: T201
|
||||||
|
|
||||||
self_update = True
|
self_update = True
|
||||||
if len(sys.argv) > 2:
|
if len(sys.argv) > 2:
|
||||||
|
|||||||
60
README.md
60
README.md
@@ -38,10 +38,21 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- Image Models
|
||||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
- SD1.x, SD2.x,
|
||||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||||
|
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||||
|
- Pixart Alpha and Sigma
|
||||||
|
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||||
|
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||||
|
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||||
|
- Video Models
|
||||||
|
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||||
|
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||||
|
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||||
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||||
@@ -61,9 +72,6 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
|
|||||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
|
||||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
|
||||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
|
||||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||||
- Starts up very fast.
|
- Starts up very fast.
|
||||||
- Works fully offline: will never download anything.
|
- Works fully offline: will never download anything.
|
||||||
@@ -149,6 +157,30 @@ This is the command to install the nightly with ROCm 6.2 which might have some p
|
|||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
|
||||||
|
|
||||||
|
### Intel GPUs (Windows and Linux)
|
||||||
|
|
||||||
|
(Option 1) Intel Arc GPU users can install native PyTorch with torch.xpu support using pip (currently available in PyTorch nightly builds). More information can be found [here](https://pytorch.org/docs/main/notes/get_start_xpu.html)
|
||||||
|
|
||||||
|
1. To install PyTorch nightly, use the following command:
|
||||||
|
|
||||||
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu```
|
||||||
|
|
||||||
|
2. Launch ComfyUI by running `python main.py`
|
||||||
|
|
||||||
|
|
||||||
|
(Option 2) Alternatively, Intel GPUs supported by Intel Extension for PyTorch (IPEX) can leverage IPEX for improved performance.
|
||||||
|
|
||||||
|
1. For Intel® Arc™ A-Series Graphics utilizing IPEX, create a conda environment and use the commands below:
|
||||||
|
|
||||||
|
```
|
||||||
|
conda install libuv
|
||||||
|
pip install torch==2.3.1.post0+cxx11.abi torchvision==0.18.1.post0+cxx11.abi torchaudio==2.3.1.post0+cxx11.abi intel-extension-for-pytorch==2.3.110.post0+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/cn/
|
||||||
|
```
|
||||||
|
|
||||||
|
For other supported Intel GPUs with IPEX, visit [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) for more information.
|
||||||
|
|
||||||
|
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
||||||
|
|
||||||
### NVIDIA
|
### NVIDIA
|
||||||
|
|
||||||
Nvidia users should install stable pytorch using this command:
|
Nvidia users should install stable pytorch using this command:
|
||||||
@@ -157,7 +189,7 @@ Nvidia users should install stable pytorch using this command:
|
|||||||
|
|
||||||
This is the command to install pytorch nightly instead which might have performance improvements:
|
This is the command to install pytorch nightly instead which might have performance improvements:
|
||||||
|
|
||||||
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu124```
|
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
|
||||||
|
|
||||||
#### Troubleshooting
|
#### Troubleshooting
|
||||||
|
|
||||||
@@ -177,17 +209,6 @@ After this you should have everything installed and can proceed to running Comfy
|
|||||||
|
|
||||||
### Others:
|
### Others:
|
||||||
|
|
||||||
#### Intel GPUs
|
|
||||||
|
|
||||||
Intel GPU support is available for all Intel GPUs supported by Intel's Extension for Pytorch (IPEX) with the support requirements listed in the [Installation](https://intel.github.io/intel-extension-for-pytorch/index.html#installation?platform=gpu) page. Choose your platform and method of install and follow the instructions. The steps are as follows:
|
|
||||||
|
|
||||||
1. Start by installing the drivers or kernel listed or newer in the Installation page of IPEX linked above for Windows and Linux if needed.
|
|
||||||
1. Follow the instructions to install [Intel's oneAPI Basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit-download.html) for your platform.
|
|
||||||
1. Install the packages for IPEX using the instructions provided in the Installation page for your platform.
|
|
||||||
1. Follow the [ComfyUI manual installation](#manual-install-windows-linux) instructions for Windows and Linux and run ComfyUI normally as described above after everything is installed.
|
|
||||||
|
|
||||||
Additional discussion and help can be found [here](https://github.com/comfyanonymous/ComfyUI/discussions/476).
|
|
||||||
|
|
||||||
#### Apple Mac silicon
|
#### Apple Mac silicon
|
||||||
|
|
||||||
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
You can install ComfyUI in Apple Mac silicon (M1 or M2) with any recent macOS version.
|
||||||
@@ -308,4 +329,3 @@ This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy
|
|||||||
### Which GPU should I buy for this?
|
### Which GPU should I buy for this?
|
||||||
|
|
||||||
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
[See this page for some recommendations](https://github.com/comfyanonymous/ComfyUI/wiki/Which-GPU-should-I-buy-for-ComfyUI)
|
||||||
|
|
||||||
|
|||||||
@@ -38,8 +38,8 @@ class UserManager():
|
|||||||
if not os.path.exists(user_directory):
|
if not os.path.exists(user_directory):
|
||||||
os.makedirs(user_directory, exist_ok=True)
|
os.makedirs(user_directory, exist_ok=True)
|
||||||
if not args.multi_user:
|
if not args.multi_user:
|
||||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||||
|
|
||||||
if args.multi_user:
|
if args.multi_user:
|
||||||
if os.path.isfile(self.get_users_file()):
|
if os.path.isfile(self.get_users_file()):
|
||||||
|
|||||||
@@ -160,7 +160,6 @@ class ControlNet(nn.Module):
|
|||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
print("setting up linear c_adm embedding layer")
|
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
elif self.num_classes == "sequential":
|
elif self.num_classes == "sequential":
|
||||||
assert adm_in_channels is not None
|
assert adm_in_channels is not None
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ parser.add_argument("--force-channels-last", action="store_true", help="Force ch
|
|||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||||
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
|
|
||||||
from tqdm.auto import trange
|
from tqdm.auto import trange
|
||||||
|
|
||||||
@@ -474,7 +475,7 @@ class UniPC:
|
|||||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||||
|
|
||||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||||
ns = self.noise_schedule
|
ns = self.noise_schedule
|
||||||
assert order <= len(model_prev_list)
|
assert order <= len(model_prev_list)
|
||||||
|
|
||||||
@@ -518,7 +519,6 @@ class UniPC:
|
|||||||
A_p = C_inv_p
|
A_p = C_inv_p
|
||||||
|
|
||||||
if use_corrector:
|
if use_corrector:
|
||||||
print('using corrector')
|
|
||||||
C_inv = torch.linalg.inv(C)
|
C_inv = torch.linalg.inv(C)
|
||||||
A_c = C_inv
|
A_c = C_inv
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import itertools
|
import itertools
|
||||||
|
import logging
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
from comfy.model_patcher import ModelPatcher, PatcherInjection
|
||||||
@@ -575,7 +576,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st
|
|||||||
k1 = set(k1)
|
k1 = set(k1)
|
||||||
for x in loaded:
|
for x in loaded:
|
||||||
if (x not in k) and (x not in k1):
|
if (x not in k) and (x not in k1):
|
||||||
print(f"NOT LOADED {x}")
|
logging.warning(f"NOT LOADED {x}")
|
||||||
return (new_modelpatcher, new_clip, hook_group)
|
return (new_modelpatcher, new_clip, hook_group)
|
||||||
|
|
||||||
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]):
|
||||||
|
|||||||
@@ -381,7 +381,6 @@ class MMDiT(nn.Module):
|
|||||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||||
self.h_max, self.w_max = target_dim
|
self.h_max, self.w_max = target_dim
|
||||||
print("PE extended to", target_dim)
|
|
||||||
|
|
||||||
def pe_selection_index_based_on_dim(self, h, w):
|
def pe_selection_index_based_on_dim(self, h, w):
|
||||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ class Patchifier(ABC):
|
|||||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||||
grid = torch.stack(grid, dim=0)
|
grid = torch.stack(grid, dim=0)
|
||||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||||
|
|
||||||
|
|||||||
@@ -378,7 +378,7 @@ class Decoder(nn.Module):
|
|||||||
assert (
|
assert (
|
||||||
timestep is not None
|
timestep is not None
|
||||||
), "should pass timestep with timestep_conditioning=True"
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||||
|
|
||||||
for up_block in self.up_blocks:
|
for up_block in self.up_blocks:
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
@@ -403,7 +403,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
ada_values = self.last_scale_shift_table[
|
ada_values = self.last_scale_shift_table[
|
||||||
None, ..., None, None, None
|
None, ..., None, None, None
|
||||||
] + embedded_timestep.reshape(
|
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
2,
|
2,
|
||||||
-1,
|
-1,
|
||||||
@@ -697,7 +697,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
), "should pass timestep with timestep_conditioning=True"
|
), "should pass timestep with timestep_conditioning=True"
|
||||||
ada_values = self.scale_shift_table[
|
ada_values = self.scale_shift_table[
|
||||||
None, ..., None, None, None
|
None, ..., None, None, None
|
||||||
] + timestep.reshape(
|
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||||
batch_size,
|
batch_size,
|
||||||
4,
|
4,
|
||||||
-1,
|
-1,
|
||||||
@@ -715,7 +715,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
if self.inject_noise:
|
if self.inject_noise:
|
||||||
hidden_states = self._feed_spatial_noise(
|
hidden_states = self._feed_spatial_noise(
|
||||||
hidden_states, self.per_channel_scale1
|
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm2(hidden_states)
|
hidden_states = self.norm2(hidden_states)
|
||||||
@@ -731,7 +731,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
|
|
||||||
if self.inject_noise:
|
if self.inject_noise:
|
||||||
hidden_states = self._feed_spatial_noise(
|
hidden_states = self._feed_spatial_noise(
|
||||||
hidden_states, self.per_channel_scale2
|
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
input_tensor = self.norm3(input_tensor)
|
input_tensor = self.norm3(input_tensor)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@
|
|||||||
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -130,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
|||||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||||
steps_out = ddim_timesteps + 1
|
steps_out = ddim_timesteps + 1
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||||
return steps_out
|
return steps_out
|
||||||
|
|
||||||
|
|
||||||
@@ -142,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
|||||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||||
print(f'For the chosen value of eta, which is {eta}, '
|
logging.info(f'For the chosen value of eta, which is {eta}, '
|
||||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||||
return sigmas, alphas, alphas_prev
|
return sigmas, alphas, alphas_prev
|
||||||
|
|
||||||
|
|||||||
380
comfy/ldm/pixart/blocks.py
Normal file
380
comfy/ldm/pixart/blocks.py
Normal file
@@ -0,0 +1,380 @@
|
|||||||
|
# Based on:
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
|
||||||
|
# if model_management.xformers_enabled():
|
||||||
|
# import xformers.ops
|
||||||
|
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
|
||||||
|
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||||
|
# else:
|
||||||
|
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||||
|
|
||||||
|
def modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||||
|
|
||||||
|
def t2i_modulate(x, shift, scale):
|
||||||
|
return x * (1 + scale) + shift
|
||||||
|
|
||||||
|
class MultiHeadCrossAttention(nn.Module):
|
||||||
|
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
super(MultiHeadCrossAttention, self).__init__()
|
||||||
|
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||||
|
|
||||||
|
self.d_model = d_model
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = d_model // num_heads
|
||||||
|
|
||||||
|
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
|
||||||
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
self.proj_drop = nn.Dropout(proj_drop)
|
||||||
|
|
||||||
|
def forward(self, x, cond, mask=None):
|
||||||
|
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||||
|
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||||
|
k, v = kv.unbind(2)
|
||||||
|
|
||||||
|
assert mask is None # TODO?
|
||||||
|
# # TODO: xformers needs separate mask logic here
|
||||||
|
# if model_management.xformers_enabled():
|
||||||
|
# attn_bias = None
|
||||||
|
# if mask is not None:
|
||||||
|
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||||
|
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||||
|
# else:
|
||||||
|
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||||
|
# attn_mask = None
|
||||||
|
# mask = torch.ones(())
|
||||||
|
# if mask is not None and len(mask) > 1:
|
||||||
|
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||||
|
# # But depth doesn't matter as tensors can expand in that dimension
|
||||||
|
# attn_mask_template = torch.ones(
|
||||||
|
# [q.shape[2] // B, mask[0]],
|
||||||
|
# dtype=torch.bool,
|
||||||
|
# device=q.device
|
||||||
|
# )
|
||||||
|
# attn_mask = torch.block_diag(attn_mask_template)
|
||||||
|
#
|
||||||
|
# # create a mask on the diagonal for each mask in the batch
|
||||||
|
# for _ in range(B - 1):
|
||||||
|
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||||
|
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||||
|
|
||||||
|
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||||
|
x = self.proj(x)
|
||||||
|
x = self.proj_drop(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionKVCompress(nn.Module):
|
||||||
|
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||||
|
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
dim (int): Number of input channels.
|
||||||
|
num_heads (int): Number of attention heads.
|
||||||
|
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = dim // num_heads
|
||||||
|
self.scale = self.head_dim ** -0.5
|
||||||
|
|
||||||
|
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||||
|
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||||
|
self.sr_ratio = sr_ratio
|
||||||
|
if sr_ratio > 1 and sampling == 'conv':
|
||||||
|
# Avg Conv Init.
|
||||||
|
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
|
||||||
|
# self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||||
|
# self.sr.bias.data.zero_()
|
||||||
|
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
if qk_norm:
|
||||||
|
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||||
|
else:
|
||||||
|
self.q_norm = nn.Identity()
|
||||||
|
self.k_norm = nn.Identity()
|
||||||
|
|
||||||
|
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||||
|
if sampling is None or scale_factor == 1:
|
||||||
|
return tensor
|
||||||
|
B, N, C = tensor.shape
|
||||||
|
|
||||||
|
if sampling == 'uniform_every':
|
||||||
|
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||||
|
|
||||||
|
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||||
|
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||||
|
new_N = new_H * new_W
|
||||||
|
|
||||||
|
if sampling == 'ave':
|
||||||
|
tensor = F.interpolate(
|
||||||
|
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||||
|
).permute(0, 2, 3, 1)
|
||||||
|
elif sampling == 'uniform':
|
||||||
|
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||||
|
elif sampling == 'conv':
|
||||||
|
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||||
|
tensor = self.norm(tensor)
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||||
|
|
||||||
|
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||||
|
B, N, C = x.shape # 2 4096 1152
|
||||||
|
new_N = N
|
||||||
|
if HW is None:
|
||||||
|
H = W = int(N ** 0.5)
|
||||||
|
else:
|
||||||
|
H, W = HW
|
||||||
|
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||||
|
|
||||||
|
q, k, v = qkv.unbind(2)
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# KV compression
|
||||||
|
if self.sr_ratio > 1:
|
||||||
|
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
|
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||||
|
|
||||||
|
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||||
|
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
|
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||||
|
|
||||||
|
# This is never called at the moment
|
||||||
|
# attn_bias = None
|
||||||
|
# if mask is not None:
|
||||||
|
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||||
|
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||||
|
|
||||||
|
# attention 2
|
||||||
|
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||||
|
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||||
|
|
||||||
|
x = x.view(B, N, C)
|
||||||
|
x = self.proj(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class FinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, c):
|
||||||
|
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class T2IFinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||||
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||||
|
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFinalLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_final(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderLayer(nn.Module):
|
||||||
|
"""
|
||||||
|
The final layer of PixArt.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
self.adaLN_modulation = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
def forward(self, x, t):
|
||||||
|
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||||
|
x = modulate(self.norm_decoder(x), shift, scale)
|
||||||
|
x = self.linear(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class SizeEmbedder(TimestepEmbedder):
|
||||||
|
"""
|
||||||
|
Embeds scalar timesteps into vector representations.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||||
|
)
|
||||||
|
self.frequency_embedding_size = frequency_embedding_size
|
||||||
|
self.outdim = hidden_size
|
||||||
|
|
||||||
|
def forward(self, s, bs):
|
||||||
|
if s.ndim == 1:
|
||||||
|
s = s[:, None]
|
||||||
|
assert s.ndim == 2
|
||||||
|
if s.shape[0] != bs:
|
||||||
|
s = s.repeat(bs//s.shape[0], 1)
|
||||||
|
assert s.shape[0] == bs
|
||||||
|
b, dims = s.shape[0], s.shape[1]
|
||||||
|
s = rearrange(s, "b d -> (b d)")
|
||||||
|
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||||
|
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||||
|
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||||
|
return s_emb
|
||||||
|
|
||||||
|
|
||||||
|
class LabelEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
use_cfg_embedding = dropout_prob > 0
|
||||||
|
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.dropout_prob = dropout_prob
|
||||||
|
|
||||||
|
def token_drop(self, labels, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
def forward(self, labels, train, force_drop_ids=None):
|
||||||
|
use_dropout = self.dropout_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
labels = self.token_drop(labels, force_drop_ids)
|
||||||
|
embeddings = self.embedding_table(labels)
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedder(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.y_proj = Mlp(
|
||||||
|
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||||
|
self.uncond_prob = uncond_prob
|
||||||
|
|
||||||
|
def token_drop(self, caption, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
def forward(self, caption, train, force_drop_ids=None):
|
||||||
|
if train:
|
||||||
|
assert caption.shape[2:] == self.y_embedding.shape
|
||||||
|
use_dropout = self.uncond_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
caption = self.token_drop(caption, force_drop_ids)
|
||||||
|
caption = self.y_proj(caption)
|
||||||
|
return caption
|
||||||
|
|
||||||
|
|
||||||
|
class CaptionEmbedderDoubleBr(nn.Module):
|
||||||
|
"""
|
||||||
|
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__()
|
||||||
|
self.proj = Mlp(
|
||||||
|
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||||
|
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||||
|
self.uncond_prob = uncond_prob
|
||||||
|
|
||||||
|
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||||
|
"""
|
||||||
|
Drops labels to enable classifier-free guidance.
|
||||||
|
"""
|
||||||
|
if force_drop_ids is None:
|
||||||
|
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||||
|
else:
|
||||||
|
drop_ids = force_drop_ids == 1
|
||||||
|
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||||
|
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||||
|
return global_caption, caption
|
||||||
|
|
||||||
|
def forward(self, caption, train, force_drop_ids=None):
|
||||||
|
assert caption.shape[2: ] == self.y_embedding.shape
|
||||||
|
global_caption = caption.mean(dim=2).squeeze()
|
||||||
|
use_dropout = self.uncond_prob > 0
|
||||||
|
if (train and use_dropout) or (force_drop_ids is not None):
|
||||||
|
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||||
|
y_embed = self.proj(global_caption)
|
||||||
|
return y_embed, caption
|
||||||
256
comfy/ldm/pixart/pixartms.py
Normal file
256
comfy/ldm/pixart/pixartms.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
# Based on:
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||||
|
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .blocks import (
|
||||||
|
t2i_modulate,
|
||||||
|
CaptionEmbedder,
|
||||||
|
AttentionKVCompress,
|
||||||
|
MultiHeadCrossAttention,
|
||||||
|
T2IFinalLayer,
|
||||||
|
SizeEmbedder,
|
||||||
|
)
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||||
|
grid_h, grid_w = torch.meshgrid(
|
||||||
|
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||||
|
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||||
|
indexing='ij'
|
||||||
|
)
|
||||||
|
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||||
|
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||||
|
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
class PixArtMSBlock(nn.Module):
|
||||||
|
"""
|
||||||
|
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||||
|
"""
|
||||||
|
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||||
|
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
self.attn = AttentionKVCompress(
|
||||||
|
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||||
|
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
|
)
|
||||||
|
self.cross_attn = MultiHeadCrossAttention(
|
||||||
|
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||||
|
)
|
||||||
|
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||||
|
# to be compatible with lower version pytorch
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.mlp = Mlp(
|
||||||
|
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||||
|
|
||||||
|
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||||
|
B, N, C = x.shape
|
||||||
|
|
||||||
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||||
|
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||||
|
x = x + self.cross_attn(x, y, mask)
|
||||||
|
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
### Core PixArt Model ###
|
||||||
|
class PixArtMS(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_size=32,
|
||||||
|
patch_size=2,
|
||||||
|
in_channels=4,
|
||||||
|
hidden_size=1152,
|
||||||
|
depth=28,
|
||||||
|
num_heads=16,
|
||||||
|
mlp_ratio=4.0,
|
||||||
|
class_dropout_prob=0.1,
|
||||||
|
learn_sigma=True,
|
||||||
|
pred_sigma=True,
|
||||||
|
drop_path: float = 0.,
|
||||||
|
caption_channels=4096,
|
||||||
|
pe_interpolation=None,
|
||||||
|
pe_precision=None,
|
||||||
|
config=None,
|
||||||
|
model_max_length=120,
|
||||||
|
micro_condition=True,
|
||||||
|
qk_norm=False,
|
||||||
|
kv_compress_config=None,
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.dtype = dtype
|
||||||
|
self.pred_sigma = pred_sigma
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.pe_interpolation = pe_interpolation
|
||||||
|
self.pe_precision = pe_precision
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.depth = depth
|
||||||
|
|
||||||
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
|
self.t_block = nn.Sequential(
|
||||||
|
nn.SiLU(),
|
||||||
|
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||||
|
)
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_channels,
|
||||||
|
embed_dim=hidden_size,
|
||||||
|
bias=True,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations
|
||||||
|
)
|
||||||
|
self.t_embedder = TimestepEmbedder(
|
||||||
|
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
self.y_embedder = CaptionEmbedder(
|
||||||
|
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||||
|
act_layer=approx_gelu, token_num=model_max_length,
|
||||||
|
dtype=dtype, device=device, operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.micro_conditioning = micro_condition
|
||||||
|
if self.micro_conditioning:
|
||||||
|
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
|
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
# For fixed sin-cos embedding:
|
||||||
|
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||||
|
# self.base_size = input_size // self.patch_size
|
||||||
|
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||||
|
|
||||||
|
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||||
|
if kv_compress_config is None:
|
||||||
|
kv_compress_config = {
|
||||||
|
'sampling': None,
|
||||||
|
'scale_factor': 1,
|
||||||
|
'kv_compress_layer': [],
|
||||||
|
}
|
||||||
|
self.blocks = nn.ModuleList([
|
||||||
|
PixArtMSBlock(
|
||||||
|
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||||
|
sampling=kv_compress_config['sampling'],
|
||||||
|
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||||
|
qk_norm=qk_norm,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
for i in range(depth)
|
||||||
|
])
|
||||||
|
self.final_layer = T2IFinalLayer(
|
||||||
|
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Original forward pass of PixArt.
|
||||||
|
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||||
|
t: (N,) tensor of diffusion timesteps
|
||||||
|
y: (N, 1, 120, C) conditioning
|
||||||
|
ar: (N, 1): aspect ratio
|
||||||
|
cs: (N ,2) size conditioning for height/width
|
||||||
|
"""
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
c_res = (H + W) // 2
|
||||||
|
pe_interpolation = self.pe_interpolation
|
||||||
|
if pe_interpolation is None or self.pe_precision is not None:
|
||||||
|
# calculate pe_interpolation on-the-fly
|
||||||
|
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||||
|
|
||||||
|
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||||
|
self.hidden_size,
|
||||||
|
h=(H // self.patch_size),
|
||||||
|
w=(W // self.patch_size),
|
||||||
|
pe_interpolation=pe_interpolation,
|
||||||
|
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype,
|
||||||
|
).unsqueeze(0)
|
||||||
|
|
||||||
|
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||||
|
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||||
|
|
||||||
|
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||||
|
bs = x.shape[0]
|
||||||
|
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||||
|
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||||
|
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||||
|
|
||||||
|
t0 = self.t_block(t)
|
||||||
|
y = self.y_embedder(y, self.training) # (N, D)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.shape[0] != y.shape[0]:
|
||||||
|
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||||
|
mask = mask.squeeze(1).squeeze(1)
|
||||||
|
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||||
|
y_lens = mask.sum(dim=1).tolist()
|
||||||
|
else:
|
||||||
|
y_lens = None
|
||||||
|
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||||
|
for block in self.blocks:
|
||||||
|
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||||
|
|
||||||
|
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||||
|
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
|
||||||
|
# Fallback for missing microconds
|
||||||
|
if self.micro_conditioning:
|
||||||
|
if c_size is None:
|
||||||
|
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
|
if c_ar is None:
|
||||||
|
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||||
|
|
||||||
|
## Still accepts the input w/o that dim but returns garbage
|
||||||
|
if len(context.shape) == 3:
|
||||||
|
context = context.unsqueeze(1)
|
||||||
|
|
||||||
|
## run original forward pass
|
||||||
|
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||||
|
|
||||||
|
## only return EPS
|
||||||
|
if self.pred_sigma:
|
||||||
|
return out[:, :self.in_channels]
|
||||||
|
return out
|
||||||
|
|
||||||
|
def unpatchify(self, x, h, w):
|
||||||
|
"""
|
||||||
|
x: (N, T, patch_size**2 * C)
|
||||||
|
imgs: (N, H, W, C)
|
||||||
|
"""
|
||||||
|
c = self.out_channels
|
||||||
|
p = self.x_embedder.patch_size[0]
|
||||||
|
h = h // self.patch_size
|
||||||
|
w = w // self.patch_size
|
||||||
|
assert h * w == x.shape[1]
|
||||||
|
|
||||||
|
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||||
|
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||||
|
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||||
|
return imgs
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import optim
|
from torch import optim
|
||||||
@@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
print("Cant encode string for logging. Skipping.")
|
logging.warning("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@@ -65,7 +66,7 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow
|
||||||
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
for k in diffusers_keys:
|
for k in diffusers_keys:
|
||||||
@@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format
|
||||||
key_map[key_lora] = to
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.PixArt):
|
||||||
|
diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||||
|
for k in diffusers_keys:
|
||||||
|
if k.endswith(".weight"):
|
||||||
|
to = diffusers_keys[k]
|
||||||
|
key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
|
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script
|
||||||
|
key_map[key_lora] = to
|
||||||
|
|
||||||
if isinstance(model, comfy.model_base.HunyuanDiT):
|
if isinstance(model, comfy.model_base.HunyuanDiT):
|
||||||
for k in sdk:
|
for k in sdk:
|
||||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug
|
|||||||
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||||
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
import comfy.ldm.genmo.joint_model.asymm_models_joint
|
||||||
import comfy.ldm.aura.mmdit
|
import comfy.ldm.aura.mmdit
|
||||||
|
import comfy.ldm.pixart.pixartms
|
||||||
import comfy.ldm.hydit.models
|
import comfy.ldm.hydit.models
|
||||||
import comfy.ldm.audio.dit
|
import comfy.ldm.audio.dit
|
||||||
import comfy.ldm.audio.embedders
|
import comfy.ldm.audio.embedders
|
||||||
@@ -718,6 +719,25 @@ class HunyuanDiT(BaseModel):
|
|||||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class PixArt(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
|
||||||
|
cross_attn = kwargs.get("cross_attn", None)
|
||||||
|
if cross_attn is not None:
|
||||||
|
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||||
|
|
||||||
|
width = kwargs.get("width", None)
|
||||||
|
height = kwargs.get("height", None)
|
||||||
|
if width is not None and height is not None:
|
||||||
|
out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]]))
|
||||||
|
out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]]))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
class Flux(BaseModel):
|
class Flux(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||||
@@ -754,7 +774,6 @@ class Flux(BaseModel):
|
|||||||
mask = torch.ones_like(noise)[:, :1]
|
mask = torch.ones_like(noise)[:, :1]
|
||||||
|
|
||||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||||
print(mask.shape)
|
|
||||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||||
|
|||||||
@@ -203,11 +203,42 @@ def detect_unet_config(state_dict, key_prefix):
|
|||||||
dit_config["rope_theta"] = 10000.0
|
dit_config["rope_theta"] = 10000.0
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
|
||||||
|
# PixArt diffusers
|
||||||
|
return None
|
||||||
|
|
||||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||||
dit_config = {}
|
dit_config = {}
|
||||||
dit_config["image_model"] = "ltxv"
|
dit_config["image_model"] = "ltxv"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
|
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||||
|
patch_size = 2
|
||||||
|
dit_config = {}
|
||||||
|
dit_config["num_heads"] = 16
|
||||||
|
dit_config["patch_size"] = patch_size
|
||||||
|
dit_config["hidden_size"] = 1152
|
||||||
|
dit_config["in_channels"] = 4
|
||||||
|
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||||
|
|
||||||
|
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
||||||
|
if y_key in state_dict_keys:
|
||||||
|
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
||||||
|
|
||||||
|
pe_key = "{}pos_embed".format(key_prefix)
|
||||||
|
if pe_key in state_dict_keys:
|
||||||
|
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||||
|
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||||
|
|
||||||
|
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||||
|
if ar_key in state_dict_keys:
|
||||||
|
dit_config["image_model"] = "pixart_alpha"
|
||||||
|
dit_config["micro_condition"] = True
|
||||||
|
else:
|
||||||
|
dit_config["image_model"] = "pixart_sigma"
|
||||||
|
dit_config["micro_condition"] = False
|
||||||
|
return dit_config
|
||||||
|
|
||||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -573,6 +604,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||||
|
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||||
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||||
elif 'x_embedder.weight' in state_dict: #Flux
|
elif 'x_embedder.weight' in state_dict: #Flux
|
||||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ if args.directml is not None:
|
|||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
_ = torch.xpu.device_count()
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = torch.xpu.is_available()
|
xpu_available = xpu_available or torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
@@ -188,38 +188,44 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_amd():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.hip:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
|
if is_nvidia():
|
||||||
|
MIN_WEIGHT_MEMORY_RATIO = 0.2
|
||||||
|
|
||||||
ENABLE_PYTORCH_ATTENTION = False
|
ENABLE_PYTORCH_ATTENTION = False
|
||||||
if args.use_pytorch_cross_attention:
|
if args.use_pytorch_cross_attention:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
|
|
||||||
if args.cpu_vae:
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
|
||||||
|
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||||
|
except:
|
||||||
|
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||||
|
|
||||||
if args.lowvram:
|
if args.lowvram:
|
||||||
set_vram_to = VRAMState.LOW_VRAM
|
set_vram_to = VRAMState.LOW_VRAM
|
||||||
lowvram_available = True
|
lowvram_available = True
|
||||||
@@ -509,13 +515,14 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
model_size = loaded_model.model_memory_required(torch_dev)
|
model_size = loaded_model.model_memory_required(torch_dev)
|
||||||
loaded_memory = loaded_model.model_loaded_memory()
|
loaded_memory = loaded_model.model_loaded_memory()
|
||||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
|
||||||
|
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||||
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
|
||||||
lowvram_model_memory = 0
|
lowvram_model_memory = 0
|
||||||
|
|
||||||
if vram_set_state == VRAMState.NO_VRAM:
|
if vram_set_state == VRAMState.NO_VRAM:
|
||||||
lowvram_model_memory = 64 * 1024 * 1024
|
lowvram_model_memory = 0.1
|
||||||
|
|
||||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
current_loaded_models.insert(0, loaded_model)
|
current_loaded_models.insert(0, loaded_model)
|
||||||
@@ -743,7 +750,6 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPES
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@@ -752,12 +758,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
for d in allowed_dtypes:
|
for d in allowed_dtypes:
|
||||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
if d == torch.float16 and should_use_fp16(device):
|
||||||
return d
|
|
||||||
if d in VAE_DTYPES:
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return VAE_DTYPES[0]
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
|
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||||
|
return d
|
||||||
|
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@@ -878,14 +886,19 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = mac_version()
|
||||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if upcast:
|
if upcast:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
@@ -956,17 +969,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP16:
|
if FORCE_FP16:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@@ -1015,17 +1024,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
|
if mac_version() < (14,):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@@ -1084,7 +1091,7 @@ def unload_all_models():
|
|||||||
|
|
||||||
|
|
||||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||||
print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
#TODO: might be cleaner to put this somewhere else
|
#TODO: might be cleaner to put this somewhere else
|
||||||
|
|||||||
@@ -773,7 +773,7 @@ class ModelPatcher:
|
|||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|
||||||
def cleanup(self):
|
def cleanup(self):
|
||||||
@@ -1029,7 +1029,7 @@ class ModelPatcher:
|
|||||||
if cached_weights is not None:
|
if cached_weights is not None:
|
||||||
for key in cached_weights:
|
for key in cached_weights:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||||
else:
|
else:
|
||||||
@@ -1039,7 +1039,7 @@ class ModelPatcher:
|
|||||||
original_weights = self.get_key_patches()
|
original_weights = self.get_key_patches()
|
||||||
for key in relevant_patches:
|
for key in relevant_patches:
|
||||||
if key not in model_sd_keys:
|
if key not in model_sd_keys:
|
||||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
||||||
continue
|
continue
|
||||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||||
memory_counter=memory_counter)
|
memory_counter=memory_counter)
|
||||||
|
|||||||
18
comfy/ops.py
18
comfy/ops.py
@@ -255,9 +255,10 @@ def fp8_linear(self, input):
|
|||||||
tensor_2d = True
|
tensor_2d = True
|
||||||
input = input.unsqueeze(1)
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
input_dtype = input.dtype
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@@ -269,23 +270,24 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input.shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
59
comfy/sd.py
59
comfy/sd.py
@@ -27,6 +27,7 @@ import comfy.text_encoders.sd2_clip
|
|||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.text_encoders.sa_t5
|
import comfy.text_encoders.sa_t5
|
||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
|
import comfy.text_encoders.pixart_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.long_clipl
|
import comfy.text_encoders.long_clipl
|
||||||
@@ -110,7 +111,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -258,6 +259,9 @@ class VAE:
|
|||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
self.downscale_index_formula = None
|
||||||
|
self.upscale_index_formula = None
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
if "decoder.mid.block_1.mix_factor" in sd:
|
if "decoder.mid.block_1.mix_factor" in sd:
|
||||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
@@ -337,7 +341,9 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||||
|
self.upscale_index_formula = (6, 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||||
|
self.downscale_index_formula = (6, 8, 8)
|
||||||
self.working_dtypes = [torch.float16, torch.float32]
|
self.working_dtypes = [torch.float16, torch.float32]
|
||||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
|
||||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||||
@@ -352,14 +358,18 @@ class VAE:
|
|||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
|
self.upscale_index_formula = (8, 32, 32)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||||
|
self.downscale_index_formula = (8, 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
elif "decoder.conv_in.conv.weight" in sd:
|
elif "decoder.conv_in.conv.weight" in sd:
|
||||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||||
ddconfig["conv3d"] = True
|
ddconfig["conv3d"] = True
|
||||||
ddconfig["time_compress"] = 4
|
ddconfig["time_compress"] = 4
|
||||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||||
|
self.upscale_index_formula = (4, 8, 8)
|
||||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||||
|
self.downscale_index_formula = (4, 8, 8)
|
||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||||
@@ -392,7 +402,7 @@ class VAE:
|
|||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
@@ -425,7 +435,7 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||||
@@ -446,7 +456,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
pixel_samples = None
|
pixel_samples = None
|
||||||
@@ -478,7 +488,7 @@ class VAE:
|
|||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
|
|
||||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
|
||||||
dims = samples.ndim - 2
|
dims = samples.ndim - 2
|
||||||
@@ -496,6 +506,13 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
output = self.decode_tiled_(samples, **args)
|
output = self.decode_tiled_(samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||||
|
if tile_t is not None:
|
||||||
|
args["tile_t"] = max(2, tile_t)
|
||||||
|
|
||||||
output = self.decode_tiled_3d(samples, **args)
|
output = self.decode_tiled_3d(samples, **args)
|
||||||
return output.movedim(1, -1)
|
return output.movedim(1, -1)
|
||||||
|
|
||||||
@@ -531,7 +548,7 @@ class VAE:
|
|||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
|
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||||
dims = self.latent_dim
|
dims = self.latent_dim
|
||||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||||
@@ -555,7 +572,20 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
samples = self.encode_tiled_(pixel_samples, **args)
|
samples = self.encode_tiled_(pixel_samples, **args)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
samples = self.encode_tiled_3d(pixel_samples, **args)
|
if tile_t is not None:
|
||||||
|
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||||
|
else:
|
||||||
|
tile_t_latent = 9999
|
||||||
|
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||||
|
|
||||||
|
if overlap_t is None:
|
||||||
|
args["overlap"] = (1, overlap, overlap)
|
||||||
|
else:
|
||||||
|
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||||
|
maximum = pixel_samples.shape[2]
|
||||||
|
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||||
|
|
||||||
|
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
@@ -574,6 +604,12 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return self.downscale_ratio
|
return self.downscale_ratio
|
||||||
|
|
||||||
|
def temporal_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return round(self.upscale_ratio[0](8192) / 8192)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
@@ -604,6 +640,8 @@ class CLIPType(Enum):
|
|||||||
MOCHI = 7
|
MOCHI = 7
|
||||||
LTXV = 8
|
LTXV = 8
|
||||||
HUNYUAN_VIDEO = 9
|
HUNYUAN_VIDEO = 9
|
||||||
|
PIXART = 10
|
||||||
|
|
||||||
|
|
||||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||||
clip_data = []
|
clip_data = []
|
||||||
@@ -696,6 +734,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
|||||||
elif clip_type == CLIPType.LTXV:
|
elif clip_type == CLIPType.LTXV:
|
||||||
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
|
||||||
|
elif clip_type == CLIPType.PIXART:
|
||||||
|
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||||
|
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
|
||||||
else: #CLIPType.MOCHI
|
else: #CLIPType.MOCHI
|
||||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||||
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
|
||||||
@@ -934,11 +975,11 @@ def load_diffusion_model(unet_path, model_options={}):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
def load_unet(unet_path, dtype=None):
|
def load_unet(unet_path, dtype=None):
|
||||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def load_unet_state_dict(sd, dtype=None):
|
def load_unet_state_dict(sd, dtype=None):
|
||||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||||
|
|
||||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||||
|
|||||||
@@ -37,7 +37,10 @@ class ClipTokenWeightEncoder:
|
|||||||
|
|
||||||
sections = len(to_encode)
|
sections = len(to_encode)
|
||||||
if has_weights or sections == 0:
|
if has_weights or sections == 0:
|
||||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
if hasattr(self, "gen_empty_tokens"):
|
||||||
|
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
else:
|
||||||
|
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
|
||||||
|
|
||||||
o = self.encode(to_encode)
|
o = self.encode(to_encode)
|
||||||
out, pooled = o[:2]
|
out, pooled = o[:2]
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import comfy.text_encoders.sd2_clip
|
|||||||
import comfy.text_encoders.sd3_clip
|
import comfy.text_encoders.sd3_clip
|
||||||
import comfy.text_encoders.sa_t5
|
import comfy.text_encoders.sa_t5
|
||||||
import comfy.text_encoders.aura_t5
|
import comfy.text_encoders.aura_t5
|
||||||
|
import comfy.text_encoders.pixart_t5
|
||||||
import comfy.text_encoders.hydit
|
import comfy.text_encoders.hydit
|
||||||
import comfy.text_encoders.flux
|
import comfy.text_encoders.flux
|
||||||
import comfy.text_encoders.genmo
|
import comfy.text_encoders.genmo
|
||||||
@@ -592,6 +593,37 @@ class AuraFlow(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model)
|
||||||
|
|
||||||
|
class PixArtAlpha(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pixart_alpha",
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"beta_schedule" : "sqrt_linear",
|
||||||
|
"linear_start" : 0.0001,
|
||||||
|
"linear_end" : 0.02,
|
||||||
|
"timesteps" : 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
unet_extra_config = {}
|
||||||
|
latent_format = latent_formats.SD15
|
||||||
|
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
text_encoder_key_prefix = ["text_encoders."]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
out = model_base.PixArt(self, device=device)
|
||||||
|
return out.eval()
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL)
|
||||||
|
|
||||||
|
class PixArtSigma(PixArtAlpha):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "pixart_sigma",
|
||||||
|
}
|
||||||
|
latent_format = latent_formats.SDXL
|
||||||
|
|
||||||
class HunyuanDiT(supported_models_base.BASE):
|
class HunyuanDiT(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hydit",
|
"image_model": "hydit",
|
||||||
@@ -787,6 +819,6 @@ class HunyuanVideo(supported_models_base.BASE):
|
|||||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
42
comfy/text_encoders/pixart_t5.py
Normal file
42
comfy/text_encoders/pixart_t5.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from comfy import sd1_clip
|
||||||
|
import comfy.text_encoders.t5
|
||||||
|
import comfy.text_encoders.sd3_clip
|
||||||
|
from comfy.sd1_clip import gen_empty_tokens
|
||||||
|
|
||||||
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
|
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||||
|
# PixArt expects the negative to be all pad tokens
|
||||||
|
special_tokens = special_tokens.copy()
|
||||||
|
special_tokens.pop("end")
|
||||||
|
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||||
|
|
||||||
|
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||||
|
|
||||||
|
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||||
|
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
|
||||||
|
|
||||||
|
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
|
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||||
|
|
||||||
|
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||||
|
class PixArtTEModel_(PixArtT5XXL):
|
||||||
|
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||||
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
|
model_options = model_options.copy()
|
||||||
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
|
if dtype is None:
|
||||||
|
dtype = dtype_t5
|
||||||
|
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||||
|
return PixArtTEModel_
|
||||||
@@ -386,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
PIXART_MAP_BASIC = {
|
||||||
|
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
|
||||||
|
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
|
||||||
|
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
|
||||||
|
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
|
||||||
|
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
|
||||||
|
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
|
||||||
|
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
|
||||||
|
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
|
||||||
|
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||||
|
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||||
|
("y_embedder.y_embedding", "caption_projection.y_embedding"),
|
||||||
|
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
|
||||||
|
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
|
||||||
|
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
|
||||||
|
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
|
||||||
|
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
|
||||||
|
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
|
||||||
|
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
|
||||||
|
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
|
||||||
|
("t_block.1.weight", "adaln_single.linear.weight"),
|
||||||
|
("t_block.1.bias", "adaln_single.linear.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.scale_shift_table", "scale_shift_table"),
|
||||||
|
}
|
||||||
|
|
||||||
|
PIXART_MAP_BLOCK = {
|
||||||
|
("scale_shift_table", "scale_shift_table"),
|
||||||
|
("attn.proj.weight", "attn1.to_out.0.weight"),
|
||||||
|
("attn.proj.bias", "attn1.to_out.0.bias"),
|
||||||
|
("mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||||
|
("mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||||
|
("mlp.fc2.weight", "ff.net.2.weight"),
|
||||||
|
("mlp.fc2.bias", "ff.net.2.bias"),
|
||||||
|
("cross_attn.proj.weight" ,"attn2.to_out.0.weight"),
|
||||||
|
("cross_attn.proj.bias" ,"attn2.to_out.0.bias"),
|
||||||
|
}
|
||||||
|
|
||||||
|
def pixart_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
|
key_map = {}
|
||||||
|
|
||||||
|
depth = mmdit_config.get("depth", 0)
|
||||||
|
offset = mmdit_config.get("hidden_size", 1152)
|
||||||
|
|
||||||
|
for i in range(depth):
|
||||||
|
block_from = "transformer_blocks.{}".format(i)
|
||||||
|
block_to = "{}blocks.{}".format(output_prefix, i)
|
||||||
|
|
||||||
|
for end in ("weight", "bias"):
|
||||||
|
s = "{}.attn1.".format(block_from)
|
||||||
|
qkv = "{}.attn.qkv.{}".format(block_to, end)
|
||||||
|
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
|
||||||
|
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
|
||||||
|
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
|
||||||
|
|
||||||
|
s = "{}.attn2.".format(block_from)
|
||||||
|
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
|
||||||
|
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
|
||||||
|
|
||||||
|
key_map["{}to_q.{}".format(s, end)] = q
|
||||||
|
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
|
||||||
|
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
|
||||||
|
|
||||||
|
for k in PIXART_MAP_BLOCK:
|
||||||
|
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||||
|
|
||||||
|
for k in PIXART_MAP_BASIC:
|
||||||
|
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||||
|
|
||||||
|
return key_map
|
||||||
|
|
||||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||||
@@ -751,7 +822,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
|||||||
return rows * cols
|
return rows * cols
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
|
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||||
dims = len(tile)
|
dims = len(tile)
|
||||||
|
|
||||||
if not (isinstance(upscale_amount, (tuple, list))):
|
if not (isinstance(upscale_amount, (tuple, list))):
|
||||||
@@ -760,6 +831,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
if not (isinstance(overlap, (tuple, list))):
|
if not (isinstance(overlap, (tuple, list))):
|
||||||
overlap = [overlap] * dims
|
overlap = [overlap] * dims
|
||||||
|
|
||||||
|
if index_formulas is None:
|
||||||
|
index_formulas = upscale_amount
|
||||||
|
|
||||||
|
if not (isinstance(index_formulas, (tuple, list))):
|
||||||
|
index_formulas = [index_formulas] * dims
|
||||||
|
|
||||||
def get_upscale(dim, val):
|
def get_upscale(dim, val):
|
||||||
up = upscale_amount[dim]
|
up = upscale_amount[dim]
|
||||||
if callable(up):
|
if callable(up):
|
||||||
@@ -774,10 +851,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
else:
|
else:
|
||||||
return val / up
|
return val / up
|
||||||
|
|
||||||
|
def get_upscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return up * val
|
||||||
|
|
||||||
|
def get_downscale_pos(dim, val):
|
||||||
|
up = index_formulas[dim]
|
||||||
|
if callable(up):
|
||||||
|
return up(val)
|
||||||
|
else:
|
||||||
|
return val / up
|
||||||
|
|
||||||
if downscale:
|
if downscale:
|
||||||
get_scale = get_downscale
|
get_scale = get_downscale
|
||||||
|
get_pos = get_downscale_pos
|
||||||
else:
|
else:
|
||||||
get_scale = get_upscale
|
get_scale = get_upscale
|
||||||
|
get_pos = get_upscale_pos
|
||||||
|
|
||||||
def mult_list_upscale(a):
|
def mult_list_upscale(a):
|
||||||
out = []
|
out = []
|
||||||
@@ -810,7 +903,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||||
l = min(tile[d], s.shape[d + 2] - pos)
|
l = min(tile[d], s.shape[d + 2] - pos)
|
||||||
s_in = s_in.narrow(d + 2, pos, l)
|
s_in = s_in.narrow(d + 2, pos, l)
|
||||||
upscaled.append(round(get_scale(d, pos)))
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones_like(ps)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
|
import logging
|
||||||
from spandrel import ModelLoader
|
from spandrel import ModelLoader
|
||||||
|
|
||||||
def load_state_dict(state_dict):
|
def load_state_dict(state_dict):
|
||||||
print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
|
logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.")
|
||||||
return ModelLoader().load_from_state_dict(state_dict).eval()
|
return ModelLoader().load_from_state_dict(state_dict).eval()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class AlignYourStepsScheduler:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required":
|
return {"required":
|
||||||
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
{"model_type": (["SD1", "SDXL", "SVD"], ),
|
||||||
"steps": ("INT", {"default": 10, "min": 10, "max": 10000}),
|
"steps": ("INT", {"default": 10, "min": 1, "max": 10000}),
|
||||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING, Union
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
@@ -539,7 +540,7 @@ class CreateHookKeyframesInterpolated:
|
|||||||
is_first = False
|
is_first = False
|
||||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||||
if print_keyframes:
|
if print_keyframes:
|
||||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||||
return (prev_hook_kf,)
|
return (prev_hook_kf,)
|
||||||
|
|
||||||
class CreateHookKeyframesFromFloats:
|
class CreateHookKeyframesFromFloats:
|
||||||
@@ -588,7 +589,7 @@ class CreateHookKeyframesFromFloats:
|
|||||||
is_first = False
|
is_first = False
|
||||||
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps))
|
||||||
if print_keyframes:
|
if print_keyframes:
|
||||||
print(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}")
|
||||||
return (prev_hook_kf,)
|
return (prev_hook_kf,)
|
||||||
#------------------------------------------
|
#------------------------------------------
|
||||||
###########################################
|
###########################################
|
||||||
|
|||||||
24
comfy_extras/nodes_pixart.py
Normal file
24
comfy_extras/nodes_pixart.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
from nodes import MAX_RESOLUTION
|
||||||
|
|
||||||
|
class CLIPTextEncodePixArtAlpha:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {"required": {
|
||||||
|
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||||
|
# "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
||||||
|
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ),
|
||||||
|
}}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("CONDITIONING",)
|
||||||
|
FUNCTION = "encode"
|
||||||
|
CATEGORY = "advanced/conditioning"
|
||||||
|
DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma."
|
||||||
|
|
||||||
|
def encode(self, clip, width, height, text):
|
||||||
|
tokens = clip.tokenize(text)
|
||||||
|
return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),)
|
||||||
|
|
||||||
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
"CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha,
|
||||||
|
}
|
||||||
72
main.py
72
main.py
@@ -63,7 +63,7 @@ def execute_prestartup_script():
|
|||||||
spec.loader.exec_module(module)
|
spec.loader.exec_module(module)
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to execute startup-script: {script_path} / {e}")
|
logging.error(f"Failed to execute startup-script: {script_path} / {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if args.disable_all_custom_nodes:
|
if args.disable_all_custom_nodes:
|
||||||
@@ -85,14 +85,14 @@ def execute_prestartup_script():
|
|||||||
success = execute_script(script_path)
|
success = execute_script(script_path)
|
||||||
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
node_prestartup_times.append((time.perf_counter() - time_before, module_path, success))
|
||||||
if len(node_prestartup_times) > 0:
|
if len(node_prestartup_times) > 0:
|
||||||
print("\nPrestartup times for custom nodes:")
|
logging.info("\nPrestartup times for custom nodes:")
|
||||||
for n in sorted(node_prestartup_times):
|
for n in sorted(node_prestartup_times):
|
||||||
if n[2]:
|
if n[2]:
|
||||||
import_message = ""
|
import_message = ""
|
||||||
else:
|
else:
|
||||||
import_message = " (PRESTARTUP FAILED)"
|
import_message = " (PRESTARTUP FAILED)"
|
||||||
print("{:6.1f} seconds{}:".format(n[0], import_message), n[1])
|
logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1]))
|
||||||
print()
|
logging.info("")
|
||||||
|
|
||||||
apply_custom_paths()
|
apply_custom_paths()
|
||||||
execute_prestartup_script()
|
execute_prestartup_script()
|
||||||
@@ -114,6 +114,10 @@ if __name__ == "__main__":
|
|||||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||||
|
|
||||||
|
if args.oneapi_device_selector is not None:
|
||||||
|
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||||
|
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||||
|
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||||
@@ -146,9 +150,10 @@ def cuda_malloc_warning():
|
|||||||
if cuda_malloc_warning:
|
if cuda_malloc_warning:
|
||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
def prompt_worker(q, server):
|
|
||||||
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
e = execution.PromptExecutor(server, lru_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, lru_size=args.cache_lru)
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
@@ -163,7 +168,7 @@ def prompt_worker(q, server):
|
|||||||
item, item_id = queue_item
|
item, item_id = queue_item
|
||||||
execution_start_time = time.perf_counter()
|
execution_start_time = time.perf_counter()
|
||||||
prompt_id = item[1]
|
prompt_id = item[1]
|
||||||
server.last_prompt_id = prompt_id
|
server_instance.last_prompt_id = prompt_id
|
||||||
|
|
||||||
e.execute(item[2], prompt_id, item[3], item[4])
|
e.execute(item[2], prompt_id, item[3], item[4])
|
||||||
need_gc = True
|
need_gc = True
|
||||||
@@ -173,8 +178,8 @@ def prompt_worker(q, server):
|
|||||||
status_str='success' if e.success else 'error',
|
status_str='success' if e.success else 'error',
|
||||||
completed=e.success,
|
completed=e.success,
|
||||||
messages=e.status_messages))
|
messages=e.status_messages))
|
||||||
if server.client_id is not None:
|
if server_instance.client_id is not None:
|
||||||
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)
|
server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id)
|
||||||
|
|
||||||
current_time = time.perf_counter()
|
current_time = time.perf_counter()
|
||||||
execution_time = current_time - execution_start_time
|
execution_time = current_time - execution_start_time
|
||||||
@@ -201,21 +206,23 @@ def prompt_worker(q, server):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
|
|
||||||
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
|
|
||||||
|
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||||
addresses = []
|
addresses = []
|
||||||
for addr in address.split(","):
|
for addr in address.split(","):
|
||||||
addresses.append((addr, port))
|
addresses.append((addr, port))
|
||||||
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
|
await asyncio.gather(server_instance.start_multi_address(addresses, call_on_start), server_instance.publish_loop())
|
||||||
|
|
||||||
|
|
||||||
def hijack_progress(server):
|
def hijack_progress(server_instance):
|
||||||
def hook(value, total, preview_image):
|
def hook(value, total, preview_image):
|
||||||
comfy.model_management.throw_exception_if_processing_interrupted()
|
comfy.model_management.throw_exception_if_processing_interrupted()
|
||||||
progress = {"value": value, "max": total, "prompt_id": server.last_prompt_id, "node": server.last_node_id}
|
progress = {"value": value, "max": total, "prompt_id": server_instance.last_prompt_id, "node": server_instance.last_node_id}
|
||||||
|
|
||||||
server.send_sync("progress", progress, server.client_id)
|
server_instance.send_sync("progress", progress, server_instance.client_id)
|
||||||
if preview_image is not None:
|
if preview_image is not None:
|
||||||
server.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server.client_id)
|
server_instance.send_sync(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE, preview_image, server_instance.client_id)
|
||||||
|
|
||||||
comfy.utils.set_progress_bar_global_hook(hook)
|
comfy.utils.set_progress_bar_global_hook(hook)
|
||||||
|
|
||||||
|
|
||||||
@@ -225,7 +232,11 @@ def cleanup_temp():
|
|||||||
shutil.rmtree(temp_dir, ignore_errors=True)
|
shutil.rmtree(temp_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def start_comfyui(asyncio_loop=None):
|
||||||
|
"""
|
||||||
|
Starts the ComfyUI server using the provided asyncio event loop or creates a new one.
|
||||||
|
Returns the event loop, server instance, and a function to start the server asynchronously.
|
||||||
|
"""
|
||||||
if args.temp_directory:
|
if args.temp_directory:
|
||||||
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
temp_dir = os.path.join(os.path.abspath(args.temp_directory), "temp")
|
||||||
logging.info(f"Setting temp directory to: {temp_dir}")
|
logging.info(f"Setting temp directory to: {temp_dir}")
|
||||||
@@ -239,19 +250,20 @@ if __name__ == "__main__":
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
if not asyncio_loop:
|
||||||
asyncio.set_event_loop(loop)
|
asyncio_loop = asyncio.new_event_loop()
|
||||||
server = server.PromptServer(loop)
|
asyncio.set_event_loop(asyncio_loop)
|
||||||
q = execution.PromptQueue(server)
|
prompt_server = server.PromptServer(asyncio_loop)
|
||||||
|
q = execution.PromptQueue(prompt_server)
|
||||||
|
|
||||||
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes)
|
||||||
|
|
||||||
cuda_malloc_warning()
|
cuda_malloc_warning()
|
||||||
|
|
||||||
server.add_routes()
|
prompt_server.add_routes()
|
||||||
hijack_progress(server)
|
hijack_progress(prompt_server)
|
||||||
|
|
||||||
threading.Thread(target=prompt_worker, daemon=True, args=(q, server,)).start()
|
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
|
||||||
|
|
||||||
if args.quick_test_for_ci:
|
if args.quick_test_for_ci:
|
||||||
exit(0)
|
exit(0)
|
||||||
@@ -268,9 +280,19 @@ if __name__ == "__main__":
|
|||||||
webbrowser.open(f"{scheme}://{address}:{port}")
|
webbrowser.open(f"{scheme}://{address}:{port}")
|
||||||
call_on_start = startup_server
|
call_on_start = startup_server
|
||||||
|
|
||||||
|
async def start_all():
|
||||||
|
await prompt_server.setup()
|
||||||
|
await run(prompt_server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start)
|
||||||
|
|
||||||
|
# Returning these so that other code can integrate with the ComfyUI loop and server
|
||||||
|
return asyncio_loop, prompt_server, start_all
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Running directly, just start ComfyUI.
|
||||||
|
event_loop, _, start_all_func = start_comfyui()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(server.setup())
|
event_loop.run_until_complete(start_all_func())
|
||||||
loop.run_until_complete(run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, call_on_start=call_on_start))
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logging.info("\nStopped server")
|
logging.info("\nStopped server")
|
||||||
|
|
||||||
|
|||||||
@@ -32,4 +32,4 @@ def update_windows_updater():
|
|||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
shutil.copy(bat_path, dest_bat_path)
|
shutil.copy(bat_path, dest_bat_path)
|
||||||
print("Updated the windows standalone package updater.")
|
print("Updated the windows standalone package updater.") # noqa: T201
|
||||||
|
|||||||
29
nodes.py
29
nodes.py
@@ -293,17 +293,29 @@ class VAEDecodeTiled:
|
|||||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
||||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
|
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
|
||||||
|
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("IMAGE",)
|
RETURN_TYPES = ("IMAGE",)
|
||||||
FUNCTION = "decode"
|
FUNCTION = "decode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def decode(self, vae, samples, tile_size, overlap=64):
|
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
overlap = tile_size // 4
|
overlap = tile_size // 4
|
||||||
|
if temporal_size < temporal_overlap * 2:
|
||||||
|
temporal_overlap = temporal_overlap // 2
|
||||||
|
temporal_compression = vae.temporal_compression_decode()
|
||||||
|
if temporal_compression is not None:
|
||||||
|
temporal_size = max(2, temporal_size // temporal_compression)
|
||||||
|
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
||||||
|
else:
|
||||||
|
temporal_size = None
|
||||||
|
temporal_overlap = None
|
||||||
|
|
||||||
compression = vae.spacial_compression_decode()
|
compression = vae.spacial_compression_decode()
|
||||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
@@ -327,15 +339,17 @@ class VAEEncodeTiled:
|
|||||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
||||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||||
|
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
|
||||||
|
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "encode"
|
FUNCTION = "encode"
|
||||||
|
|
||||||
CATEGORY = "_for_testing"
|
CATEGORY = "_for_testing"
|
||||||
|
|
||||||
def encode(self, vae, pixels, tile_size, overlap):
|
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||||
return ({"samples":t}, )
|
return ({"samples": t}, )
|
||||||
|
|
||||||
class VAEEncodeForInpaint:
|
class VAEEncodeForInpaint:
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -898,7 +912,7 @@ class CLIPLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
|
||||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
|
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@@ -918,6 +932,8 @@ class CLIPLoader:
|
|||||||
clip_type = comfy.sd.CLIPType.MOCHI
|
clip_type = comfy.sd.CLIPType.MOCHI
|
||||||
elif type == "ltxv":
|
elif type == "ltxv":
|
||||||
clip_type = comfy.sd.CLIPType.LTXV
|
clip_type = comfy.sd.CLIPType.LTXV
|
||||||
|
elif type == "pixart":
|
||||||
|
clip_type = comfy.sd.CLIPType.PIXART
|
||||||
else:
|
else:
|
||||||
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
|
||||||
@@ -2164,6 +2180,7 @@ def init_builtin_extra_nodes():
|
|||||||
"nodes_stable3d.py",
|
"nodes_stable3d.py",
|
||||||
"nodes_sdupscale.py",
|
"nodes_sdupscale.py",
|
||||||
"nodes_photomaker.py",
|
"nodes_photomaker.py",
|
||||||
|
"nodes_pixart.py",
|
||||||
"nodes_cond.py",
|
"nodes_cond.py",
|
||||||
"nodes_morphology.py",
|
"nodes_morphology.py",
|
||||||
"nodes_stable_cascade.py",
|
"nodes_stable_cascade.py",
|
||||||
|
|||||||
@@ -4,7 +4,10 @@ lint.ignore = ["ALL"]
|
|||||||
# Enable specific rules
|
# Enable specific rules
|
||||||
lint.select = [
|
lint.select = [
|
||||||
"S307", # suspicious-eval-usage
|
"S307", # suspicious-eval-usage
|
||||||
|
"T201", # print-usage
|
||||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||||
"F",
|
"F",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
exclude = ["*.ipynb"]
|
||||||
|
|||||||
@@ -714,9 +714,7 @@ class PromptServer():
|
|||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
|
|
||||||
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
self.app.add_routes([
|
self.app.add_routes([web.static('/extensions/' + name, dir)])
|
||||||
web.static('/extensions/' + urllib.parse.quote(name), dir),
|
|
||||||
])
|
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root),
|
web.static('/', self.web_root),
|
||||||
|
|||||||
@@ -89,9 +89,9 @@ async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
|||||||
client = await aiohttp_client_factory()
|
client = await aiohttp_client_factory()
|
||||||
try:
|
try:
|
||||||
resp = await client.get('/files')
|
resp = await client.get('/files')
|
||||||
print(f"Response received: status {resp.status}")
|
print(f"Response received: status {resp.status}") # noqa: T201
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Exception occurred during GET request: {e}")
|
print(f"Exception occurred during GET request: {e}") # noqa: T201
|
||||||
raise
|
raise
|
||||||
|
|
||||||
assert resp.status != 404, "Route /files does not exist"
|
assert resp.status != 404, "Route /files does not exist"
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def pytest_collection_modifyitems(items):
|
|||||||
last_items = []
|
last_items = []
|
||||||
for test_name in LAST_TESTS:
|
for test_name in LAST_TESTS:
|
||||||
for item in items.copy():
|
for item in items.copy():
|
||||||
print(item.module.__name__, item)
|
print(item.module.__name__, item) # noqa: T201
|
||||||
if item.module.__name__ == test_name:
|
if item.module.__name__ == test_name:
|
||||||
last_items.append(item)
|
last_items.append(item)
|
||||||
items.remove(item)
|
items.remove(item)
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class TestExecution:
|
|||||||
use_lru, lru_size = request.param
|
use_lru, lru_size = request.param
|
||||||
if use_lru:
|
if use_lru:
|
||||||
pargs += ['--cache-lru', str(lru_size)]
|
pargs += ['--cache-lru', str(lru_size)]
|
||||||
print("Running server with args:", pargs)
|
print("Running server with args:", pargs) # noqa: T201
|
||||||
p = subprocess.Popen(pargs)
|
p = subprocess.Popen(pargs)
|
||||||
yield
|
yield
|
||||||
p.kill()
|
p.kill()
|
||||||
@@ -150,8 +150,8 @@ class TestExecution:
|
|||||||
try:
|
try:
|
||||||
comfy_client.connect(listen=listen, port=port)
|
comfy_client.connect(listen=listen, port=port)
|
||||||
except ConnectionRefusedError as e:
|
except ConnectionRefusedError as e:
|
||||||
print(e)
|
print(e) # noqa: T201
|
||||||
print(f"({i+1}/{n_tries}) Retrying...")
|
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return comfy_client
|
return comfy_client
|
||||||
|
|||||||
@@ -171,8 +171,8 @@ class TestInference:
|
|||||||
try:
|
try:
|
||||||
comfy_client.connect(listen=listen, port=port)
|
comfy_client.connect(listen=listen, port=port)
|
||||||
except ConnectionRefusedError as e:
|
except ConnectionRefusedError as e:
|
||||||
print(e)
|
print(e) # noqa: T201
|
||||||
print(f"({i+1}/{n_tries}) Retrying...")
|
print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
return comfy_client
|
return comfy_client
|
||||||
|
|||||||
Reference in New Issue
Block a user