Basic Flux Schnell and Flux Dev model implementation.
This commit is contained in:
@@ -10,6 +10,8 @@ import comfy.ldm.aura.mmdit
|
||||
import comfy.ldm.hydit.models
|
||||
import comfy.ldm.audio.dit
|
||||
import comfy.ldm.audio.embedders
|
||||
import comfy.ldm.flux.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
import comfy.ops
|
||||
@@ -26,6 +28,7 @@ class ModelType(Enum):
|
||||
EDM = 5
|
||||
FLOW = 6
|
||||
V_PREDICTION_CONTINUOUS = 7
|
||||
FLUX = 8
|
||||
|
||||
|
||||
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
|
||||
@@ -53,6 +56,9 @@ def model_sampling(model_config, model_type):
|
||||
elif model_type == ModelType.V_PREDICTION_CONTINUOUS:
|
||||
c = V_PREDICTION
|
||||
s = ModelSamplingContinuousV
|
||||
elif model_type == ModelType.FLUX:
|
||||
c = comfy.model_sampling.CONST
|
||||
s = comfy.model_sampling.ModelSamplingFlux
|
||||
|
||||
class ModelSampling(s, c):
|
||||
pass
|
||||
@@ -681,3 +687,18 @@ class HunyuanDiT(BaseModel):
|
||||
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user