Support SDXS 0.9

This commit is contained in:
comfyanonymous
2024-03-27 23:51:17 -04:00
parent 8ae1e4d125
commit 327ca1313d
2 changed files with 9 additions and 3 deletions

View File

@@ -70,8 +70,8 @@ class SD20(supported_models_base.BASE):
def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
out = state_dict[k]
if torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
out = state_dict.get(k, None)
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
return model_base.ModelType.V_PREDICTION
return model_base.ModelType.EPS