Skip to content

Commit

Permalink
Update Gemma Model
Browse files Browse the repository at this point in the history
  • Loading branch information
erfanzar committed Apr 17, 2024
1 parent 4ee1119 commit 2f9a434
Showing 1 changed file with 138 additions and 39 deletions.
177 changes: 138 additions & 39 deletions lib/python/EasyDel/modules/gemma/modelling_gemma_flax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import warnings
from typing import Optional, Tuple, Union

import chex
Expand All @@ -25,12 +26,58 @@
get_dot_general_by_bits,
block_wise_ffn,
precompute_freq_cis,
apply_rotary_pos_emb,
apply_rotary_pos_emb
)
from ..easydel_modelling_utils import EasyDelFlaxPretrainedModel
from .gemma_configuration import GemmaConfig


def add_positional_embedding(
input_embedding: jax.Array,
position: int,
theta: int = 10_000,
) -> jax.Array:
"""Adds positional embeddings to input embeddings. From DeepMind Gemma"""
embed_dim = input_embedding.shape[-1]
num_timescales = embed_dim // 2
log_timescale_increment = jnp.log(float(theta)) / jnp.maximum(
jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1
)
inv_timescales = jnp.exp(
jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
)
scaled_time = position * inv_timescales
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)])
signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]])
position_embedding = signal.astype(jnp.float32)

return input_embedding + position_embedding


def apply_rope(
inputs: jax.Array, # [B, L]
positions: jax.Array, # [B, L]
head_dim: int,
theta: int = 10_000,
) -> jax.Array:
"""Applies RoPE. From DeepMind Gemma"""
fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim
timescale = theta ** fraction

sinusoid_inp = (
positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :]
)
sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :]
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)

first_half, second_half = jnp.split(inputs, 2, axis=-1)
first_part = first_half * cos - second_half * sin
second_part = second_half * cos + first_half * sin
out = jnp.concatenate([first_part, second_part], axis=-1)
return out.astype(inputs.dtype)


class FlaxGemmaRMSNorm(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32
Expand All @@ -40,37 +87,30 @@ def setup(self):
self.weight_kernel = self.param("kernel", lambda _, shape: jnp.ones(shape), self.config.hidden_size)

def __call__(self, hidden_states):
hidden_states = hidden_states * (
1 / jnp.sqrt(jnp.power(
jnp.asarray(hidden_states, dtype=jnp.float32), 2
).mean(-1, keepdims=True) + self.epsilon)
variance = jnp.asarray(hidden_states, dtype=jnp.float32)
variance = jnp.power(variance, 2)
variance = variance.mean(-1, keepdims=True)
hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon)

return (1 + nn.linen.control_quantization(self.weight_kernel, self.dtype)) * jnp.asarray(
hidden_states, dtype=self.dtype
)
kernel = nn.linen.control_quantization(self.weight_kernel, self.dtype)
return (1 + kernel) * jnp.asarray(hidden_states, dtype=self.dtype)


class FlaxGemmaRotaryEmbedding(nn.Module):
config: GemmaConfig
dtype: jnp.dtype = jnp.float32

def __call__(self, freq_cis, key_states, query_states, position_ids):
b, s, h, d = key_states.shape
sin_pos, cos_pos = freq_cis
key_states = apply_rotary_pos_emb(
key_states,
sin_pos[None, :s, None, :],
cos_pos[None, :s, None, :]
)
query_states = apply_rotary_pos_emb(
query_states,
sin_pos[None, :s, None, :],
cos_pos[None, :s, None, :]
)
sin, cos = freq_cis

sin = sin[position_ids][:, None, :, :]
cos = cos[position_ids][:, None, :, :]

key_states = jnp.asarray(key_states, dtype=self.dtype)
query_states = jnp.asarray(query_states, dtype=self.dtype)
key = apply_rotary_pos_emb(key_states, sin, cos)
query = apply_rotary_pos_emb(query_states, sin, cos)

return key_states, query_states
return query.astype(self.dtype), key.astype(self.dtype)


class FlaxGemmaAttention(BaseJAXAttentionModule):
Expand Down Expand Up @@ -159,7 +199,7 @@ def setup(self):
value_partition_spec=self.config.value_partition_spec,
scan_ring_attention=self.config.scan_ring_attention,
mesh=self.config.jax_mesh(),
sm_scale=1 / math.sqrt(self.head_dim),
sm_scale=self.head_dim ** -0.5,
)

self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype)
Expand All @@ -170,6 +210,61 @@ def _merge_heads(self, hidden_states):
def _split_heads(self, hidden_states, num_heads):
return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim))

@staticmethod
def _transpose_sequence_head(query, key, value):
"""
The _transpose_sequence_head function transposes the query, key and value matrices.
:param query: Get the attention weights for each of the heads
:param key: Determine the number of heads
:param value: Store the values of the input
:return: The transpose of the query, key and value matrices
"""
return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))

def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
"""
The apply_rotary function is a modified version of the apply_attention function in the BertModel class.
The main difference is that it takes in an additional argument, freq_cis, which are used to calculate
the rotary attention weights. The other differences are minor and mostly related to reshaping tensors.
:param self: Access variables that belong to the class
:param batch_size: Reshape the query, key and value tensors
:param sequence_length: Reshape the query, key and value tensors
:param query: Calculate the attention weights
:param key: Calculate the attention
:param value: Compute the attention weights
:param freq_cis: Calculate the frequency of each word in the vocabulary
:param position_ids: Identify the position of each token in the sequence
:return: A tuple of 3 tensors: query, key and value
"""
query = query.reshape(
batch_size,
sequence_length,
self.config.num_attention_heads,
self.head_dim
)
key = key.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim
)
value = value.reshape(
batch_size,
sequence_length,
self.config.num_key_value_heads,
self.head_dim
)

query, key, value = self._transpose_sequence_head(query, key, value)
query, key = self.rotary_emb(
position_ids=position_ids, query_states=query, key_states=key, freq_cis=freq_cis
)
return self._transpose_sequence_head(query, key, value)

def __call__(
self,
hidden_states: chex.Array,
Expand All @@ -188,11 +283,15 @@ def __call__(
value_states
) = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)

query_states = self._split_heads(query_states, self.num_heads)
key_states = self._split_heads(key_states, self.num_key_value_heads)
value_states = self._split_heads(value_states, self.num_key_value_heads)

key_states, query_states = self.rotary_emb(freq_cis, key_states, query_states, position_ids)
query_states, key_states, value_states = self.apply_rotary(
query_states.shape[0],
query_states.shape[1],
query_states,
key_states,
value_states,
freq_cis,
position_ids
)

query_length, key_length = query_states.shape[1], key_states.shape[1]

Expand All @@ -205,16 +304,6 @@ def __call__(
else:
causal_mask = causal_mask[:, :, :query_length, :key_length]

# if self.config.use_sharding_constraint:
# query_states = with_sharding_constraint(
# query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None)
# )
# key_states = with_sharding_constraint(
# key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
# )
# value_states = with_sharding_constraint(
# value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None)
# )
batch_size = hidden_states.shape[0]
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])

Expand Down Expand Up @@ -286,7 +375,17 @@ def setup(self):
inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim

kernel_init = jax.nn.initializers.normal(self.config.initializer_range)
self.act = ACT2FN[self.config.hidden_act if self.config.hidden_act != "gelu" else "gelu_new"]
if self.config.hidden_activation is None:
warnings.warn(
"Gemma's activation function should be approximate GeLU and not exact GeLU. "
"Changing the activation function to `gelu_pytorch_tanh`."
f"if you want to use the legacy `{self.config.hidden_act}`, "
f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` "
)
hidden_activation = "gelu_pytorch_tanh"
else:
hidden_activation = self.config.hidden_activation
self.act = ACT2FN[hidden_activation]

self.gate_proj = Linear(
inner_dim,
Expand Down

0 comments on commit 2f9a434

Please sign in to comment.