Add a TomePatchModel node to the _for_testing section.

Tome increases sampling speed at the expense of quality.
This commit is contained in:
comfyanonymous
2023-03-31 17:19:58 -04:00
parent 7e682784d7
commit 18a6c1db33
5 changed files with 166 additions and 11 deletions

View File

@@ -11,6 +11,7 @@ from .sub_quadratic_attention import efficient_dot_product_attention
import model_management
from . import tomesd
if model_management.xformers_enabled():
import xformers
@@ -508,8 +509,18 @@ class BasicTransformerBlock(nn.Module):
return checkpoint(self._forward, (x, context, transformer_options), self.parameters(), self.checkpoint)
def _forward(self, x, context=None, transformer_options={}):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
n = self.norm1(x)
if "tomesd" in transformer_options:
m, u = tomesd.get_functions(x, transformer_options["tomesd"]["ratio"], transformer_options["original_shape"])
n = u(self.attn1(m(n), context=context if self.disable_self_attn else None))
else:
n = self.attn1(n, context=context if self.disable_self_attn else None)
x += n
n = self.norm2(x)
n = self.attn2(n, context=context)
x += n
x = self.ff(self.norm3(x)) + x
return x