Pull in latest upscale model code from chainner.

This commit is contained in:
comfyanonymous
2023-05-23 22:26:50 -04:00
parent c00bb1a0b7
commit 7310290f17
12 changed files with 1530 additions and 2 deletions

View File

@@ -141,6 +141,19 @@ def sequential(*args):
ConvMode = Literal["CNA", "NAC", "CNAC"]
# 2x2x2 Conv Block
def conv_block_2c2(
in_nc,
out_nc,
act_type="relu",
):
return sequential(
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
act(act_type) if act_type else None,
)
def conv_block(
in_nc: int,
out_nc: int,
@@ -153,12 +166,17 @@ def conv_block(
norm_type: str | None = None,
act_type: str | None = "relu",
mode: ConvMode = "CNA",
c2x2=False,
):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
if c2x2:
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
@@ -285,6 +303,7 @@ class RRDB(nn.Module):
_convtype="Conv2D",
_spectral_norm=False,
plus=False,
c2x2=False,
):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(
@@ -298,6 +317,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB2 = ResidualDenseBlock_5C(
nf,
@@ -310,6 +330,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
self.RDB3 = ResidualDenseBlock_5C(
nf,
@@ -322,6 +343,7 @@ class RRDB(nn.Module):
act_type,
mode,
plus=plus,
c2x2=c2x2,
)
def forward(self, x):
@@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
act_type="leakyrelu",
mode: ConvMode = "CNA",
plus=False,
c2x2=False,
):
super(ResidualDenseBlock_5C, self).__init__()
@@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv2 = conv_block(
nf + gc,
@@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv3 = conv_block(
nf + 2 * gc,
@@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
self.conv4 = conv_block(
nf + 3 * gc,
@@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=act_type,
mode=mode,
c2x2=c2x2,
)
if mode == "CNA":
last_act = None
@@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
norm_type=norm_type,
act_type=last_act,
mode=mode,
c2x2=c2x2,
)
def forward(self, x):
@@ -499,6 +527,7 @@ def upconv_block(
norm_type: str | None = None,
act_type="relu",
mode="nearest",
c2x2=False,
):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
@@ -512,5 +541,6 @@ def upconv_block(
pad_type=pad_type,
norm_type=norm_type,
act_type=act_type,
c2x2=c2x2,
)
return sequential(upsample, conv)