Pull in latest upscale model code from chainner.
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding:utf-8 -*-
|
||||
#############################################################
|
||||
# File: pixelshuffle.py
|
||||
# Created Date: Friday July 1st 2022
|
||||
# Author: Chen Xuanhong
|
||||
# Email: chenxuanhongzju@outlook.com
|
||||
# Last Modified: Friday, 1st July 2022 10:18:39 am
|
||||
# Modified By: Chen Xuanhong
|
||||
# Copyright (c) 2022 Shanghai Jiao Tong University
|
||||
#############################################################
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def pixelshuffle_block(
|
||||
in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
|
||||
):
|
||||
"""
|
||||
Upsample features according to `upscale_factor`.
|
||||
"""
|
||||
padding = kernel_size // 2
|
||||
conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels * (upscale_factor**2),
|
||||
kernel_size,
|
||||
padding=1,
|
||||
bias=bias,
|
||||
)
|
||||
pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
||||
return nn.Sequential(*[conv, pixel_shuffle])
|
||||
Reference in New Issue
Block a user