Pull in latest upscale model code from chainner.

This commit is contained in:
comfyanonymous
2023-05-23 22:26:50 -04:00
parent c00bb1a0b7
commit 7310290f17
12 changed files with 1530 additions and 2 deletions

View File

@@ -79,6 +79,12 @@ class RRDBNet(nn.Module):
self.scale: int = self.get_scale()
self.num_filters: int = self.state[self.key_arr[0]].shape[0]
c2x2 = False
if self.state["model.0.weight"].shape[-2] == 2:
c2x2 = True
self.scale = round(math.sqrt(self.scale / 4))
self.model_arch = "ESRGAN-2c2"
self.supports_fp16 = True
self.supports_bfp16 = True
self.min_size_restriction = None
@@ -105,11 +111,15 @@ class RRDBNet(nn.Module):
out_nc=self.num_filters,
upscale_factor=3,
act_type=self.act,
c2x2=c2x2,
)
else:
upsample_blocks = [
upsample_block(
in_nc=self.num_filters, out_nc=self.num_filters, act_type=self.act
in_nc=self.num_filters,
out_nc=self.num_filters,
act_type=self.act,
c2x2=c2x2,
)
for _ in range(int(math.log(self.scale, 2)))
]
@@ -122,6 +132,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
B.ShortcutBlock(
B.sequential(
@@ -138,6 +149,7 @@ class RRDBNet(nn.Module):
act_type=self.act,
mode="CNA",
plus=self.plus,
c2x2=c2x2,
)
for _ in range(self.num_blocks)
],
@@ -149,6 +161,7 @@ class RRDBNet(nn.Module):
norm_type=self.norm,
act_type=None,
mode=self.mode,
c2x2=c2x2,
),
)
),
@@ -160,6 +173,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=self.act,
c2x2=c2x2,
),
# hr_conv1
B.conv_block(
@@ -168,6 +182,7 @@ class RRDBNet(nn.Module):
kernel_size=3,
norm_type=None,
act_type=None,
c2x2=c2x2,
),
)