Merge branch 'master' into worksplit-multigpu

This commit is contained in:
Jedrzej Kosinski
2025-03-26 22:26:26 -05:00
25 changed files with 1682 additions and 41 deletions

View File

@@ -51,6 +51,32 @@ cpu_state = CPUState.GPU
total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False
torch_version = ""
try:
@@ -729,11 +755,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2
fp8_dtype = None
try:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype
except:
pass
if weight_dtype in FLOAT8_TYPES:
fp8_dtype = weight_dtype
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive