From c17dd1e302256627a82767611dec6db355356bee Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sat, 7 Sep 2024 05:42:28 -0700 Subject: [PATCH] rotary embedding done in full prec --- MEGABYTE_pytorch/megabyte.py | 3 +++ setup.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/MEGABYTE_pytorch/megabyte.py b/MEGABYTE_pytorch/megabyte.py index 34c16da..0d8bf28 100644 --- a/MEGABYTE_pytorch/megabyte.py +++ b/MEGABYTE_pytorch/megabyte.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from torch.nn import Module, ModuleList from torch import nn, einsum +from torch.amp import autocast from einops import rearrange, reduce, repeat, pack, unpack from einops.layers.torch import Rearrange @@ -80,6 +81,7 @@ def __init__(self, dim, theta = 10000): def device(self): return next(self.buffers()).device + @autocast('cuda', enabled = False) def forward(self, seq_len): t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) freqs = torch.einsum('i , j -> i j', t, self.inv_freq) @@ -90,6 +92,7 @@ def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) +@autocast('cuda', enabled = False) def apply_rotary_pos_emb(pos, t): return t * pos.cos() + rotate_half(t) * pos.sin() diff --git a/setup.py b/setup.py index fc4e114..353f7f7 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'MEGABYTE-pytorch', packages = find_packages(), - version = '0.3.1', + version = '0.3.2', license='MIT', description = 'MEGABYTE - Pytorch', long_description_content_type = 'text/markdown',