Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xPos embeddings #370

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion finetune_t0_non_causal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def get_batch_pipe(data):
segment_ids=segment_ids.long(),
)

if args.position_embedding_type not in [PositionEmbeddingType.alibi, PositionEmbeddingType.rotary]:
if args.position_embedding_type not in [
PositionEmbeddingType.alibi,
PositionEmbeddingType.rotary,
PositionEmbeddingType.xpos,
]:
raise NotImplementedError("absolute positional embeddings require us to reset position_ids accordingly.")

return (tokens, position_ids, attention_mask), (labels, loss_mask)
Expand Down
2 changes: 1 addition & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _add_network_size_args(parser):
group.add_argument('--position-embedding-type', type=lambda x: PositionEmbeddingType[x],
choices=list(PositionEmbeddingType),
default=PositionEmbeddingType.absolute,
help='Define position embedding type ("absolute" | "rotary" | "alibi"). "absolute" by default.'
help='Define position embedding type ("absolute" | "rotary" | "alibi" | "xpos"). "absolute" by default.'
)
group.add_argument('--glu-activation', type=str,
choices=megatron.model.glu_activations.GLU_ACTIVATIONS.keys(),
Expand Down
1 change: 1 addition & 0 deletions megatron/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
alibi = 3
xpos = 4
109 changes: 108 additions & 1 deletion megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,111 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):

def apply_rotary_pos_emb_torch(q, k, cos, sin, offset: int = 0): # jitting fails with bf16
cos, sin = cos[offset:q.shape[0] + offset, ...], sin[offset:q.shape[0] + offset, ...]
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)


# Original implementation adjusted from https://github.com/sunyt32/torchscale

def fixed_pos_embedding(x, base):
seq_len, dim = x.shape
inv_freq = 1.0 / (base ** (torch.arange(0, dim) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x)
)
return torch.cos(sinusoid_inp), torch.sin(sinusoid_inp)


class XPosEmbedding(torch.nn.Module):
"""
xPos positional embeddings from https://arxiv.org/abs/2212.10554.
"""

def __init__(self, head_dim, freq_base=10000, scale_base=512, gamma=0.4, precision=torch.half):
super().__init__()
self.scale_base = scale_base
self.register_buffer(
"scale",
(
(torch.arange(0, head_dim, 2) + gamma * head_dim)
/ ((1.0 + gamma) * head_dim)
),
)
self.max_seq_len_cached = None
self.precision = precision
self.freq_base = freq_base

def forward(self, x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
scale = (
self.scale
** (
torch.arange(0, seq_len, 1) - seq_len // 2
).to(self.scale).div(self.scale_base)[:, None]
)

if (
self.max_seq_len_cached is None
or (seq_len > self.max_seq_len_cached)
):
self.max_seq_len_cached = seq_len
cos, sin = fixed_pos_embedding(scale, self.freq_base)
self.cos_cached = cos
self.sin_cached = sin
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return (
self.cos_cached[:seq_len],
self.sin_cached[:seq_len],
scale,
)


def rotate_every_two(x):
x1 = x[:, :, ::2]
x2 = x[:, :, 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\


def duplicate_interleave(m):
"""
A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy.
"""
dim0 = m.shape[0]
m = m.view(-1, 1) # flatten the matrix
m = m.repeat(1, 2) # repeat all elements into the 2nd dimension
m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy
return m.unsqueeze(1)


def _apply_xpos_emb(x, cos, sin, scale):
# x is assumed to be (seq_len, batch_size, dim) here.
cos = duplicate_interleave(cos * scale)
sin = duplicate_interleave(sin * scale)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
return (x * cos) + (rotate_every_two(x) * sin)


@torch.jit.script
def apply_xpos_emb(q, k, cos, sin, scale, offset: int = 0):
# q/k are assumed to be (seq_len, batch_size, dim) here.
cos = cos[offset:q.shape[0] + offset]
sin = sin[offset:q.shape[0] + offset]
scale = scale[offset:q.shape[0] + offset]
return (
_apply_xpos_emb(q, cos, sin, scale),
_apply_xpos_emb(k, cos, sin, 1.0 / scale),
)


def apply_xpos_emb_torch(q, k, cos, sin, scale, offset: int = 0):
# q/k are assumed to be (seq_len, batch_size, dim) here.
cos = cos[offset:q.shape[0] + offset]
sin = sin[offset:q.shape[0] + offset]
scale = scale[offset:q.shape[0] + offset]
return (
_apply_xpos_emb(q, cos, sin, scale),
_apply_xpos_emb(k, cos, sin, 1.0 / scale),
)
25 changes: 21 additions & 4 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
import deepspeed

from .glu_activations import GLU_ACTIVATIONS
from .positional_embeddings import RotaryEmbedding, apply_rotary_pos_emb_torch, apply_rotary_pos_emb
from .positional_embeddings import (
apply_rotary_pos_emb,
apply_rotary_pos_emb_torch,
apply_xpos_emb,
apply_xpos_emb_torch,
RotaryEmbedding,
XPosEmbedding,
)

# flags required to enable jit fusion kernels
torch._C._jit_set_profiling_mode(False)
Expand Down Expand Up @@ -204,6 +211,8 @@ def __init__(self, init_method,

if self.position_embedding_type == PositionEmbeddingType.rotary:
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)
elif self.position_embedding_type == PositionEmbeddingType.xpos:
self.xpos_emb = XPosEmbedding(self.hidden_size_per_attention_head, precision=args.params_dtype)

def forward(self, hidden_states, attention_mask, layer_past=None,
get_key_value=False, encoder_output=None, alibi=None):
Expand Down Expand Up @@ -291,16 +300,24 @@ def forward(self, hidden_states, attention_mask, layer_past=None,
matmul_result = alibi[:output_size[0]*output_size[1], :, :output_size[3]]

# Rotary embeddings
if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb

if self.position_embedding_type in [
PositionEmbeddingType.rotary, PositionEmbeddingType.xpos]:
seq_len = key_layer.shape[0]
offset = 0
if layer_past is not None and layer_past.numel() > 0:
offset = layer_past[0].shape[0]
seq_len += offset

if self.position_embedding_type == PositionEmbeddingType.rotary:
apply_rotary_fn = apply_rotary_pos_emb_torch if self.bf16 else apply_rotary_pos_emb
cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
elif self.position_embedding_type == PositionEmbeddingType.xpos:
apply_xpos_fn = apply_xpos_emb_torch if self.bf16 else apply_xpos_emb
cos, sin, scale = self.xpos_emb(value_layer, seq_len=seq_len)
query_layer, key_layer = apply_xpos_fn(
query_layer, key_layer, cos, sin, scale, offset=offset)


# Raw attention scores. [b * np, sq, sk]
if alibi is None:
Expand Down