diff --git a/nodes.py b/nodes.py index ecd931d6..815631f5 100644 --- a/nodes.py +++ b/nodes.py @@ -1076,7 +1076,7 @@ class ImageToMask: def image_to_mask(self, image, channel): channels = ["red", "green", "blue"] - mask = torch.select(image[0], 2, channels.index(channel)) + mask = image[0, :, :, channels.index(channel)] return (mask,) class MaskToImage: