Pull in latest upscale model code from chainner.
This commit is contained in:
@@ -141,6 +141,19 @@ def sequential(*args):
|
||||
ConvMode = Literal["CNA", "NAC", "CNAC"]
|
||||
|
||||
|
||||
# 2x2x2 Conv Block
|
||||
def conv_block_2c2(
|
||||
in_nc,
|
||||
out_nc,
|
||||
act_type="relu",
|
||||
):
|
||||
return sequential(
|
||||
nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
|
||||
nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
|
||||
act(act_type) if act_type else None,
|
||||
)
|
||||
|
||||
|
||||
def conv_block(
|
||||
in_nc: int,
|
||||
out_nc: int,
|
||||
@@ -153,12 +166,17 @@ def conv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type: str | None = "relu",
|
||||
mode: ConvMode = "CNA",
|
||||
c2x2=False,
|
||||
):
|
||||
"""
|
||||
Conv layer with padding, normalization, activation
|
||||
mode: CNA --> Conv -> Norm -> Act
|
||||
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
|
||||
"""
|
||||
|
||||
if c2x2:
|
||||
return conv_block_2c2(in_nc, out_nc, act_type=act_type)
|
||||
|
||||
assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
|
||||
padding = get_valid_padding(kernel_size, dilation)
|
||||
p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
|
||||
@@ -285,6 +303,7 @@ class RRDB(nn.Module):
|
||||
_convtype="Conv2D",
|
||||
_spectral_norm=False,
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(RRDB, self).__init__()
|
||||
self.RDB1 = ResidualDenseBlock_5C(
|
||||
@@ -298,6 +317,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB2 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@@ -310,6 +330,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.RDB3 = ResidualDenseBlock_5C(
|
||||
nf,
|
||||
@@ -322,6 +343,7 @@ class RRDB(nn.Module):
|
||||
act_type,
|
||||
mode,
|
||||
plus=plus,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -365,6 +387,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
act_type="leakyrelu",
|
||||
mode: ConvMode = "CNA",
|
||||
plus=False,
|
||||
c2x2=False,
|
||||
):
|
||||
super(ResidualDenseBlock_5C, self).__init__()
|
||||
|
||||
@@ -382,6 +405,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv2 = conv_block(
|
||||
nf + gc,
|
||||
@@ -393,6 +417,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv3 = conv_block(
|
||||
nf + 2 * gc,
|
||||
@@ -404,6 +429,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
self.conv4 = conv_block(
|
||||
nf + 3 * gc,
|
||||
@@ -415,6 +441,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
if mode == "CNA":
|
||||
last_act = None
|
||||
@@ -430,6 +457,7 @@ class ResidualDenseBlock_5C(nn.Module):
|
||||
norm_type=norm_type,
|
||||
act_type=last_act,
|
||||
mode=mode,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -499,6 +527,7 @@ def upconv_block(
|
||||
norm_type: str | None = None,
|
||||
act_type="relu",
|
||||
mode="nearest",
|
||||
c2x2=False,
|
||||
):
|
||||
# Up conv
|
||||
# described in https://distill.pub/2016/deconv-checkerboard/
|
||||
@@ -512,5 +541,6 @@ def upconv_block(
|
||||
pad_type=pad_type,
|
||||
norm_type=norm_type,
|
||||
act_type=act_type,
|
||||
c2x2=c2x2,
|
||||
)
|
||||
return sequential(upsample, conv)
|
||||
|
||||
Reference in New Issue
Block a user