Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
59
comfy/ldm/modules/encoders/kornia_functions.py
Normal file
@@ -0,0 +1,59 @@
|
||||
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
#from: https://github.com/kornia/kornia/blob/master/kornia/enhance/normalize.py
|
||||
|
||||
def enhance_normalize(data: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor:
|
||||
r"""Normalize an image/video tensor with mean and standard deviation.
|
||||
.. math::
|
||||
\text{input[channel] = (input[channel] - mean[channel]) / std[channel]}
|
||||
Where `mean` is :math:`(M_1, ..., M_n)` and `std` :math:`(S_1, ..., S_n)` for `n` channels,
|
||||
Args:
|
||||
data: Image tensor of size :math:`(B, C, *)`.
|
||||
mean: Mean for each channel.
|
||||
std: Standard deviations for each channel.
|
||||
Return:
|
||||
Normalised tensor with same size as input :math:`(B, C, *)`.
|
||||
Examples:
|
||||
>>> x = torch.rand(1, 4, 3, 3)
|
||||
>>> out = normalize(x, torch.tensor([0.0]), torch.tensor([255.]))
|
||||
>>> out.shape
|
||||
torch.Size([1, 4, 3, 3])
|
||||
>>> x = torch.rand(1, 4, 3, 3)
|
||||
>>> mean = torch.zeros(4)
|
||||
>>> std = 255. * torch.ones(4)
|
||||
>>> out = normalize(x, mean, std)
|
||||
>>> out.shape
|
||||
torch.Size([1, 4, 3, 3])
|
||||
"""
|
||||
shape = data.shape
|
||||
if len(mean.shape) == 0 or mean.shape[0] == 1:
|
||||
mean = mean.expand(shape[1])
|
||||
if len(std.shape) == 0 or std.shape[0] == 1:
|
||||
std = std.expand(shape[1])
|
||||
|
||||
# Allow broadcast on channel dimension
|
||||
if mean.shape and mean.shape[0] != 1:
|
||||
if mean.shape[0] != data.shape[1] and mean.shape[:2] != data.shape[:2]:
|
||||
raise ValueError(f"mean length and number of channels do not match. Got {mean.shape} and {data.shape}.")
|
||||
|
||||
# Allow broadcast on channel dimension
|
||||
if std.shape and std.shape[0] != 1:
|
||||
if std.shape[0] != data.shape[1] and std.shape[:2] != data.shape[:2]:
|
||||
raise ValueError(f"std length and number of channels do not match. Got {std.shape} and {data.shape}.")
|
||||
|
||||
mean = torch.as_tensor(mean, device=data.device, dtype=data.dtype)
|
||||
std = torch.as_tensor(std, device=data.device, dtype=data.dtype)
|
||||
|
||||
if mean.shape:
|
||||
mean = mean[..., :, None]
|
||||
if std.shape:
|
||||
std = std[..., :, None]
|
||||
|
||||
out: torch.Tensor = (data.view(shape[0], shape[1], -1) - mean) / std
|
||||
|
||||
return out.view(shape)
|
||||
Reference in New Issue
Block a user