Pull in latest upscale model code from chainner.
This commit is contained in:
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
577
comfy_extras/chainner_models/architecture/OmniSR/OSA.py
Normal file
@@ -0,0 +1,577 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: OSA.py
|
||||
# Created Date: Tuesday April 28th 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Sunday, 23rd April 2023 3:07:42 pm
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2020 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from einops.layers.torch import Rearrange, Reduce
|
||||
from torch import einsum, nn
|
||||
|
||||
from .layernorm import LayerNorm2d
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def cast_tuple(val, length=1):
|
||||
return val if isinstance(val, tuple) else ((val,) * length)
|
||||
|
||||
|
||||
# helper classes
|
||||
|
||||
|
||||
class PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class Conv_PreNormResidual(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.norm = LayerNorm2d(dim)
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x):
|
||||
return self.fn(self.norm(x)) + x
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=2, dropout=0.0):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim, 1, 1, 0),
|
||||
nn.Dropout(dropout),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class Gated_Conv_FeedForward(nn.Module):
|
||||
def __init__(self, dim, mult=1, bias=False, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
hidden_features = int(dim * mult)
|
||||
|
||||
self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
|
||||
|
||||
self.dwconv = nn.Conv2d(
|
||||
hidden_features * 2,
|
||||
hidden_features * 2,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=hidden_features * 2,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.project_in(x)
|
||||
x1, x2 = self.dwconv(x).chunk(2, dim=1)
|
||||
x = F.gelu(x1) * x2
|
||||
x = self.project_out(x)
|
||||
return x
|
||||
|
||||
|
||||
# MBConv
|
||||
|
||||
|
||||
class SqueezeExcitation(nn.Module):
|
||||
def __init__(self, dim, shrinkage_rate=0.25):
|
||||
super().__init__()
|
||||
hidden_dim = int(dim * shrinkage_rate)
|
||||
|
||||
self.gate = nn.Sequential(
|
||||
Reduce("b c h w -> b c", "mean"),
|
||||
nn.Linear(dim, hidden_dim, bias=False),
|
||||
nn.SiLU(),
|
||||
nn.Linear(hidden_dim, dim, bias=False),
|
||||
nn.Sigmoid(),
|
||||
Rearrange("b c -> b c 1 1"),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gate(x)
|
||||
|
||||
|
||||
class MBConvResidual(nn.Module):
|
||||
def __init__(self, fn, dropout=0.0):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.dropsample = Dropsample(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.fn(x)
|
||||
out = self.dropsample(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class Dropsample(nn.Module):
|
||||
def __init__(self, prob=0):
|
||||
super().__init__()
|
||||
self.prob = prob
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
|
||||
if self.prob == 0.0 or (not self.training):
|
||||
return x
|
||||
|
||||
keep_mask = (
|
||||
torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
|
||||
> self.prob
|
||||
)
|
||||
return x * keep_mask / (1 - self.prob)
|
||||
|
||||
|
||||
def MBConv(
|
||||
dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
|
||||
):
|
||||
hidden_dim = int(expansion_rate * dim_out)
|
||||
stride = 2 if downsample else 1
|
||||
|
||||
net = nn.Sequential(
|
||||
nn.Conv2d(dim_in, hidden_dim, 1),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(
|
||||
hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
|
||||
),
|
||||
# nn.BatchNorm2d(hidden_dim),
|
||||
nn.GELU(),
|
||||
SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
|
||||
nn.Conv2d(hidden_dim, dim_out, 1),
|
||||
# nn.BatchNorm2d(dim_out)
|
||||
)
|
||||
|
||||
if dim_in == dim_out and not downsample:
|
||||
net = MBConvResidual(net, dropout=dropout)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
# attention related classes
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
# relative positional bias
|
||||
if self.with_pe:
|
||||
self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
|
||||
|
||||
pos = torch.arange(window_size)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos))
|
||||
grid = rearrange(grid, "c i j -> (i j) c")
|
||||
rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
|
||||
grid, "j ... -> 1 j ..."
|
||||
)
|
||||
rel_pos += window_size - 1
|
||||
rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
|
||||
dim=-1
|
||||
)
|
||||
|
||||
self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
batch, height, width, window_height, window_width, _, device, h = (
|
||||
*x.shape,
|
||||
x.device,
|
||||
self.heads,
|
||||
)
|
||||
|
||||
# flatten
|
||||
|
||||
x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
|
||||
|
||||
# project for queries, keys, values
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim=-1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# add positional bias
|
||||
if self.with_pe:
|
||||
bias = self.rel_pos_bias(self.rel_pos_indices)
|
||||
sim = sim + rearrange(bias, "i j h -> h i j")
|
||||
|
||||
# attention
|
||||
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
|
||||
out = rearrange(
|
||||
out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
|
||||
)
|
||||
|
||||
# combine heads out
|
||||
|
||||
out = self.to_out(out)
|
||||
return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
|
||||
|
||||
|
||||
class Block_Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=32,
|
||||
bias=False,
|
||||
dropout=0.0,
|
||||
window_size=7,
|
||||
with_pe=True,
|
||||
):
|
||||
super().__init__()
|
||||
assert (
|
||||
dim % dim_head
|
||||
) == 0, "dimension should be divisible by dimension per head"
|
||||
|
||||
self.heads = dim // dim_head
|
||||
self.ps = window_size
|
||||
self.scale = dim_head**-0.5
|
||||
self.with_pe = with_pe
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
|
||||
|
||||
self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
# project for queries, keys, values
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
|
||||
# split heads
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
|
||||
h=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# scale
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
# sim
|
||||
|
||||
sim = einsum("b h i d, b h j d -> b h i j", q, k)
|
||||
|
||||
# attention
|
||||
attn = self.attend(sim)
|
||||
|
||||
# aggregate
|
||||
|
||||
out = einsum("b h i j, b h j d -> b h i d", attn, v)
|
||||
|
||||
# merge heads
|
||||
out = rearrange(
|
||||
out,
|
||||
"(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
|
||||
x=h // self.ps,
|
||||
y=w // self.ps,
|
||||
head=self.heads,
|
||||
w1=self.ps,
|
||||
w2=self.ps,
|
||||
)
|
||||
|
||||
out = self.to_out(out)
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Channel_Attention_grid(nn.Module):
|
||||
def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
|
||||
super(Channel_Attention_grid, self).__init__()
|
||||
self.heads = heads
|
||||
|
||||
self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
|
||||
|
||||
self.ps = window_size
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
|
||||
self.qkv_dwconv = nn.Conv2d(
|
||||
dim * 3,
|
||||
dim * 3,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * 3,
|
||||
bias=bias,
|
||||
)
|
||||
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv_dwconv(self.qkv(x))
|
||||
qkv = qkv.chunk(3, dim=1)
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: rearrange(
|
||||
t,
|
||||
"b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
),
|
||||
qkv,
|
||||
)
|
||||
|
||||
q = F.normalize(q, dim=-1)
|
||||
k = F.normalize(k, dim=-1)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.temperature
|
||||
attn = attn.softmax(dim=-1)
|
||||
out = attn @ v
|
||||
|
||||
out = rearrange(
|
||||
out,
|
||||
"b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
|
||||
h=h // self.ps,
|
||||
w=w // self.ps,
|
||||
ph=self.ps,
|
||||
pw=self.ps,
|
||||
head=self.heads,
|
||||
)
|
||||
|
||||
out = self.project_out(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class OSA_Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channel_num=64,
|
||||
bias=True,
|
||||
ffn_bias=True,
|
||||
window_size=8,
|
||||
with_pe=False,
|
||||
dropout=0.0,
|
||||
):
|
||||
super(OSA_Block, self).__init__()
|
||||
|
||||
w = window_size
|
||||
|
||||
self.layer = nn.Sequential(
|
||||
MBConv(
|
||||
channel_num,
|
||||
channel_num,
|
||||
downsample=False,
|
||||
expansion_rate=1,
|
||||
shrinkage_rate=0.25,
|
||||
),
|
||||
Rearrange(
|
||||
"b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # block-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
Rearrange(
|
||||
"b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
|
||||
), # grid-like attention
|
||||
PreNormResidual(
|
||||
channel_num,
|
||||
Attention(
|
||||
dim=channel_num,
|
||||
dim_head=channel_num // 4,
|
||||
dropout=dropout,
|
||||
window_size=window_size,
|
||||
with_pe=with_pe,
|
||||
),
|
||||
),
|
||||
Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
# channel-like attention
|
||||
Conv_PreNormResidual(
|
||||
channel_num,
|
||||
Channel_Attention_grid(
|
||||
dim=channel_num, heads=4, dropout=dropout, window_size=window_size
|
||||
),
|
||||
),
|
||||
Conv_PreNormResidual(
|
||||
channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
|
||||
),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.layer(x)
|
||||
return out
|
||||
Reference in New Issue
Block a user