Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx friendly merging #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/onnx_friendly_tome.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

import torch

from tomesd.merge import bipartite_soft_matching_random2d


class TestOnnxFriendlyToMeOperations(unittest.TestCase):
def test_component_correctness(self):
c = 320
w = h = 64
r = 0.2
sx = sy = 2
x = torch.rand(2, w * h, c)
m_orig, u_orig = bipartite_soft_matching_random2d(x, w, h, sx, sy, int(w * h * r), no_rand=True)
m_onnx, u_onnx = bipartite_soft_matching_random2d(x, w, h, sx, sy, int(w * h * r), no_rand=True, onnx_friendly=True)
torch.testing.assert_close(u_orig(m_orig(x)), u_onnx(m_onnx(x)))


if __name__ == '__main__':
unittest.main()
33 changes: 23 additions & 10 deletions tomesd/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Tuple, Callable


def do_nothing(x: torch.Tensor, mode:str=None):
def do_nothing(x: torch.Tensor, mode: str = None):
return x


Expand All @@ -20,7 +20,8 @@ def mps_gather_workaround(input, dim, index):
def bipartite_soft_matching_random2d(metric: torch.Tensor,
w: int, h: int, sx: int, sy: int, r: int,
no_rand: bool = False,
generator: torch.Generator = None) -> Tuple[Callable, Callable]:
generator: torch.Generator = None,
onnx_friendly: bool = False) -> Tuple[Callable, Callable]:
"""
Partitions the tokens into src and dst and merges r tokens from src to dst.
Dst tokens are partitioned by choosing one randomy in each (sx, sy) region.
Expand All @@ -34,25 +35,26 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
- r: number of tokens to remove (by merging)
- no_rand: if true, disable randomness (use top left corner only)
- rand_seed: if no_rand is false, and if not None, sets random seed.
- onnx_friendly: if onnx_friendly is True it replaces `torch.scatter_reduce` with onnx friendly operators: scatter and bincount
"""
B, N, _ = metric.shape

if r <= 0:
return do_nothing, do_nothing

gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather

with torch.no_grad():
hsy, wsx = h // sy, w // sx

# For each sy by sx kernel, randomly assign one token to be dst and the rest src
if no_rand:
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
else:
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)
rand_idx = torch.randint(sy * sx, size=(hsy, wsx, 1), device=generator.device, generator=generator).to(metric.device)

# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
idx_buffer_view = torch.zeros(hsy, wsx, sy * sx, device=metric.device, dtype=torch.int64)
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
idx_buffer_view = idx_buffer_view.view(hsy, wsx, sy, sx).transpose(1, 2).reshape(hsy * sy, wsx * sx)

Expand All @@ -71,8 +73,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,

# rand_idx is currently dst|src, so split them
num_dst = hsy * wsx
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst
a_idx = rand_idx[:, num_dst:, :] # src
b_idx = rand_idx[:, :num_dst, :] # dst

def split(x):
C = x.shape[-1]
Expand All @@ -99,10 +101,21 @@ def split(x):
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
src, dst = split(x)
n, t1, c = src.shape

unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
if not onnx_friendly:
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
else:
if mode not in ("mean", "sum"):
raise NotImplementedError(f"ONNX friendly currently supports 'mean' and 'sum' modes, got {mode}")
dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce='add')
if mode == "mean":
counts = torch.stack([
torch.bincount(dst_idx[i, :, 0], minlength=dst.size(-2))
for i in range(dst_idx.size(0))
], dim=0) + 1
dst = dst / counts[..., None]

return torch.cat([unm, dst], dim=1)

Expand Down
66 changes: 24 additions & 42 deletions tomesd/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from .utils import isinstance_str, init_generator



def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable, ...]:
original_h, original_w = tome_info["size"]
original_tokens = original_h * original_w
Expand All @@ -24,27 +23,22 @@ def compute_merge(x: torch.Tensor, tome_info: Dict[str, Any]) -> Tuple[Callable,
args["generator"] = init_generator(x.device)
elif args["generator"].device != x.device:
args["generator"] = init_generator(x.device, fallback=args["generator"])

# If the batch size is odd, then it's not possible for prompted and unprompted images to be in the same
# batch, which causes artifacts with use_rand, so force it to be off.
use_rand = False if x.shape[0] % 2 == 1 else args["use_rand"]
m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
no_rand=not use_rand, generator=args["generator"])
m, u = merge.bipartite_soft_matching_random2d(x, w, h, args["sx"], args["sy"], r,
no_rand=not use_rand, generator=args["generator"], onnx_friendly=tome_info["onnx_friendly"])
else:
m, u = (merge.do_nothing, merge.do_nothing)

m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing, merge.do_nothing)
m_a, u_a = (m, u) if args["merge_attn"] else (merge.do_nothing, merge.do_nothing)
m_c, u_c = (m, u) if args["merge_crossattn"] else (merge.do_nothing, merge.do_nothing)
m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing, merge.do_nothing)
m_m, u_m = (m, u) if args["merge_mlp"] else (merge.do_nothing, merge.do_nothing)

return m_a, m_c, m_m, u_a, u_c, u_m # Okay this is probably not very good







def make_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class on the fly so we don't have to import any specific modules.
Expand All @@ -64,32 +58,29 @@ def _forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tenso
x = u_m(self.ff(m_m(self.norm3(x)))) + x

return x

return ToMeBlock




return ToMeBlock


def make_diffusers_tome_block(block_class: Type[torch.nn.Module]) -> Type[torch.nn.Module]:
"""
Make a patched class for a diffusers model.
This patch applies ToMe to the forward function of the block.
"""

class ToMeBlock(block_class):
# Save for unpatching later
_parent = block_class

def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
cross_attention_kwargs=None,
class_labels=None,
) -> torch.Tensor:
# (1) ToMe
m_a, m_c, m_m, u_a, u_c, u_m = compute_merge(hidden_states, self._tome_info)
Expand Down Expand Up @@ -139,7 +130,7 @@ def forward(

# 3. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]

Expand All @@ -159,25 +150,16 @@ def forward(
return ToMeBlock






def hook_tome_model(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """

def hook(module, args):
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
return None

model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))








def apply_patch(
model: torch.nn.Module,
ratio: float = 0.5,
Expand All @@ -186,7 +168,8 @@ def apply_patch(
use_rand: bool = True,
merge_attn: bool = True,
merge_crossattn: bool = False,
merge_mlp: bool = False):
merge_mlp: bool = False,
onnx_friendly: bool = False):
"""
Patches a stable diffusion model with ToMe.
Apply this to the highest level stable diffusion object (i.e., it should have a .model.diffusion_model).
Expand All @@ -208,6 +191,7 @@ def apply_patch(
- merge_attn: Whether or not to merge tokens for attention (recommended).
- merge_crossattn: Whether or not to merge tokens for cross attention (not recommended).
- merge_mlp: Whether or not to merge tokens for the mlp layers (very not recommended).
- onnx_friendly: Whether or not to replace scatter_reduce with onnx friendly ops.
"""

# Make sure the module is not currently patched
Expand Down Expand Up @@ -235,8 +219,9 @@ def apply_patch(
"generator": None,
"merge_attn": merge_attn,
"merge_crossattn": merge_crossattn,
"merge_mlp": merge_mlp
}
"merge_mlp": merge_mlp,
},
"onnx_friendly": onnx_friendly
}
hook_tome_model(diffusion_model)

Expand All @@ -259,9 +244,6 @@ def apply_patch(
return model





def remove_patch(model: torch.nn.Module):
""" Removes a patch from a ToMe Diffusion module if it was already patched. """
# For diffusers
Expand All @@ -275,5 +257,5 @@ def remove_patch(model: torch.nn.Module):

if module.__class__.__name__ == "ToMeBlock":
module.__class__ = module._parent

return model