Fix sub quadratic attention for SD2 and make it the default optimization.
This commit is contained in:
@@ -175,13 +175,11 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1)
|
||||
|
||||
dtype = query.dtype
|
||||
# TODO: do we still need to do *everything* in float32, given how we delay the division?
|
||||
# TODO: do we need to support upcast_softmax too? SD 2.1 seems to work without it
|
||||
# if self.upcast_attention:
|
||||
# query = query.float()
|
||||
# key_t = key_t.float()
|
||||
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
bytes_per_token = torch.finfo(query.dtype).bits//8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key_t.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
@@ -198,7 +196,7 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
|
||||
query_chunk_size_x = 1024 * 4
|
||||
kv_chunk_size_min_x = None
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 1.2) // 1024) * 1024
|
||||
kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024
|
||||
if kv_chunk_size_x < 1024:
|
||||
kv_chunk_size_x = None
|
||||
|
||||
@@ -220,6 +218,7 @@ class CrossAttentionBirchSan(nn.Module):
|
||||
kv_chunk_size=kv_chunk_size,
|
||||
kv_chunk_size_min=kv_chunk_size_min,
|
||||
use_checkpoint=self.training,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
@@ -383,8 +382,15 @@ class OriginalCrossAttention(nn.Module):
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
class CrossAttention(CrossAttentionDoggettx):
|
||||
pass
|
||||
import sys
|
||||
if "--use-split-cross-attention" in sys.argv:
|
||||
print("Using split optimization for cross attention")
|
||||
class CrossAttention(CrossAttentionDoggettx):
|
||||
pass
|
||||
else:
|
||||
print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
class CrossAttention(CrossAttentionBirchSan):
|
||||
pass
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
|
||||
Reference in New Issue
Block a user