Add a TomePatchModel node to the _for_testing section.
Tome increases sampling speed at the expense of quality.
This commit is contained in:
@@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
|
||||
|
||||
import model_management
|
||||
|
||||
from . import tomesd
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers
|
||||
@@ -508,8 +509,18 @@ class BasicTransformerBlock(nn.Module):
|
||||
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None, transformer_options={}):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
n = self.norm1(x)
|
||||
if "tomesd" in transformer_options:
|
||||
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
|
||||
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
|
||||
else:
|
||||
n = self.attn1(n, context=context if self.disable_self_attn else None)
|
||||
|
||||
x += n
|
||||
n = self.norm2(x)
|
||||
n = self.attn2(n, context=context)
|
||||
|
||||
x += n
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
Reference in New Issue
Block a user