Load the SD3 T5xxl model in the same dtype stored in the checkpoint.

This commit is contained in:
comfyanonymous
2024-06-11 17:03:26 -04:00
parent 5889b7ca0a
commit 0e49211a11
6 changed files with 49 additions and 6 deletions

View File

@@ -44,24 +44,36 @@ class SD3Tokenizer:
return self.clip_g.untokenize(token_weight_pair)
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
self.t5xxl = T5XXLModel(device=device, dtype=dtype)
if dtype_t5 is None:
dtype_t5 = dtype
elif comfy.model_management.dtype_size(dtype_t5) > comfy.model_management.dtype_size(dtype):
dtype_t5 = dtype
if not comfy.model_management.supports_cast(device, dtype_t5):
dtype_t5 = dtype
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}".format(clip_l, clip_g, t5))
logging.debug("Created SD3 text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}".format(clip_l, clip_g, t5, dtype_t5))
def set_clip_options(self, options):
if self.clip_l is not None: