From 2f9a43420bf050d9b56b5aedc889b0f41654860f Mon Sep 17 00:00:00 2001 From: erfanzar Date: Wed, 17 Apr 2024 18:50:50 +0330 Subject: [PATCH] Update `Gemma Model` --- .../modules/gemma/modelling_gemma_flax.py | 177 ++++++++++++++---- 1 file changed, 138 insertions(+), 39 deletions(-) diff --git a/lib/python/EasyDel/modules/gemma/modelling_gemma_flax.py b/lib/python/EasyDel/modules/gemma/modelling_gemma_flax.py index 8408aa23e..4f72d7be4 100644 --- a/lib/python/EasyDel/modules/gemma/modelling_gemma_flax.py +++ b/lib/python/EasyDel/modules/gemma/modelling_gemma_flax.py @@ -1,4 +1,5 @@ import math +import warnings from typing import Optional, Tuple, Union import chex @@ -25,12 +26,58 @@ get_dot_general_by_bits, block_wise_ffn, precompute_freq_cis, - apply_rotary_pos_emb, + apply_rotary_pos_emb ) from ..easydel_modelling_utils import EasyDelFlaxPretrainedModel from .gemma_configuration import GemmaConfig +def add_positional_embedding( + input_embedding: jax.Array, + position: int, + theta: int = 10_000, +) -> jax.Array: + """Adds positional embeddings to input embeddings. From DeepMind Gemma""" + embed_dim = input_embedding.shape[-1] + num_timescales = embed_dim // 2 + log_timescale_increment = jnp.log(float(theta)) / jnp.maximum( + jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 + ) + inv_timescales = jnp.exp( + jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment + ) + scaled_time = position * inv_timescales + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)]) + signal = jnp.pad(signal, [[0, jnp.mod(embed_dim, 2)]]) + position_embedding = signal.astype(jnp.float32) + + return input_embedding + position_embedding + + +def apply_rope( + inputs: jax.Array, # [B, L] + positions: jax.Array, # [B, L] + head_dim: int, + theta: int = 10_000, +) -> jax.Array: + """Applies RoPE. From DeepMind Gemma""" + fraction = 2 * jnp.arange(0, head_dim // 2) / head_dim + timescale = theta ** fraction + + sinusoid_inp = ( + positions[..., jnp.newaxis] / timescale[jnp.newaxis, jnp.newaxis, :] + ) + sinusoid_inp = sinusoid_inp[..., jnp.newaxis, :] + sin = jnp.sin(sinusoid_inp) + cos = jnp.cos(sinusoid_inp) + + first_half, second_half = jnp.split(inputs, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + out = jnp.concatenate([first_part, second_part], axis=-1) + return out.astype(inputs.dtype) + + class FlaxGemmaRMSNorm(nn.Module): config: GemmaConfig dtype: jnp.dtype = jnp.float32 @@ -40,13 +87,14 @@ def setup(self): self.weight_kernel = self.param("kernel", lambda _, shape: jnp.ones(shape), self.config.hidden_size) def __call__(self, hidden_states): - hidden_states = hidden_states * ( - 1 / jnp.sqrt(jnp.power( - jnp.asarray(hidden_states, dtype=jnp.float32), 2 - ).mean(-1, keepdims=True) + self.epsilon) + variance = jnp.asarray(hidden_states, dtype=jnp.float32) + variance = jnp.power(variance, 2) + variance = variance.mean(-1, keepdims=True) + hidden_states = hidden_states / jnp.sqrt(variance + self.epsilon) + + return (1 + nn.linen.control_quantization(self.weight_kernel, self.dtype)) * jnp.asarray( + hidden_states, dtype=self.dtype ) - kernel = nn.linen.control_quantization(self.weight_kernel, self.dtype) - return (1 + kernel) * jnp.asarray(hidden_states, dtype=self.dtype) class FlaxGemmaRotaryEmbedding(nn.Module): @@ -54,23 +102,15 @@ class FlaxGemmaRotaryEmbedding(nn.Module): dtype: jnp.dtype = jnp.float32 def __call__(self, freq_cis, key_states, query_states, position_ids): - b, s, h, d = key_states.shape - sin_pos, cos_pos = freq_cis - key_states = apply_rotary_pos_emb( - key_states, - sin_pos[None, :s, None, :], - cos_pos[None, :s, None, :] - ) - query_states = apply_rotary_pos_emb( - query_states, - sin_pos[None, :s, None, :], - cos_pos[None, :s, None, :] - ) + sin, cos = freq_cis + + sin = sin[position_ids][:, None, :, :] + cos = cos[position_ids][:, None, :, :] - key_states = jnp.asarray(key_states, dtype=self.dtype) - query_states = jnp.asarray(query_states, dtype=self.dtype) + key = apply_rotary_pos_emb(key_states, sin, cos) + query = apply_rotary_pos_emb(query_states, sin, cos) - return key_states, query_states + return query.astype(self.dtype), key.astype(self.dtype) class FlaxGemmaAttention(BaseJAXAttentionModule): @@ -159,7 +199,7 @@ def setup(self): value_partition_spec=self.config.value_partition_spec, scan_ring_attention=self.config.scan_ring_attention, mesh=self.config.jax_mesh(), - sm_scale=1 / math.sqrt(self.head_dim), + sm_scale=self.head_dim ** -0.5, ) self.rotary_emb = FlaxGemmaRotaryEmbedding(config, dtype=self.dtype) @@ -170,6 +210,61 @@ def _merge_heads(self, hidden_states): def _split_heads(self, hidden_states, num_heads): return hidden_states.reshape(hidden_states.shape[:2] + (num_heads, self.head_dim)) + @staticmethod + def _transpose_sequence_head(query, key, value): + """ + The _transpose_sequence_head function transposes the query, key and value matrices. + + :param query: Get the attention weights for each of the heads + :param key: Determine the number of heads + :param value: Store the values of the input + :return: The transpose of the query, key and value matrices + + """ + return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3)) + + def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids): + """ + The apply_rotary function is a modified version of the apply_attention function in the BertModel class. + The main difference is that it takes in an additional argument, freq_cis, which are used to calculate + the rotary attention weights. The other differences are minor and mostly related to reshaping tensors. + + :param self: Access variables that belong to the class + :param batch_size: Reshape the query, key and value tensors + :param sequence_length: Reshape the query, key and value tensors + :param query: Calculate the attention weights + :param key: Calculate the attention + :param value: Compute the attention weights + :param freq_cis: Calculate the frequency of each word in the vocabulary + :param position_ids: Identify the position of each token in the sequence + :return: A tuple of 3 tensors: query, key and value + + """ + query = query.reshape( + batch_size, + sequence_length, + self.config.num_attention_heads, + self.head_dim + ) + key = key.reshape( + batch_size, + sequence_length, + self.config.num_key_value_heads, + self.head_dim + ) + value = value.reshape( + batch_size, + sequence_length, + self.config.num_key_value_heads, + self.head_dim + ) + + query, key, value = self._transpose_sequence_head(query, key, value) + query, key = self.rotary_emb( + position_ids=position_ids, query_states=query, key_states=key, freq_cis=freq_cis + ) + return self._transpose_sequence_head(query, key, value) + def __call__( self, hidden_states: chex.Array, @@ -188,11 +283,15 @@ def __call__( value_states ) = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states) - query_states = self._split_heads(query_states, self.num_heads) - key_states = self._split_heads(key_states, self.num_key_value_heads) - value_states = self._split_heads(value_states, self.num_key_value_heads) - - key_states, query_states = self.rotary_emb(freq_cis, key_states, query_states, position_ids) + query_states, key_states, value_states = self.apply_rotary( + query_states.shape[0], + query_states.shape[1], + query_states, + key_states, + value_states, + freq_cis, + position_ids + ) query_length, key_length = query_states.shape[1], key_states.shape[1] @@ -205,16 +304,6 @@ def __call__( else: causal_mask = causal_mask[:, :, :query_length, :key_length] - # if self.config.use_sharding_constraint: - # query_states = with_sharding_constraint( - # query_states, PartitionSpec(("dp", "fsdp"), "sp" if query_states.shape[1] != 1 else None, "tp", None) - # ) - # key_states = with_sharding_constraint( - # key_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) - # value_states = with_sharding_constraint( - # value_states, PartitionSpec(("dp", "fsdp"), "sp", "tp", None) - # ) batch_size = hidden_states.shape[0] causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) @@ -286,7 +375,17 @@ def setup(self): inner_dim = self.config.intermediate_size if self.config.intermediate_size is not None else 4 * embed_dim kernel_init = jax.nn.initializers.normal(self.config.initializer_range) - self.act = ACT2FN[self.config.hidden_act if self.config.hidden_act != "gelu" else "gelu_new"] + if self.config.hidden_activation is None: + warnings.warn( + "Gemma's activation function should be approximate GeLU and not exact GeLU. " + "Changing the activation function to `gelu_pytorch_tanh`." + f"if you want to use the legacy `{self.config.hidden_act}`, " + f"edit the `model.config` to set `hidden_activation={self.config.hidden_act}` " + ) + hidden_activation = "gelu_pytorch_tanh" + else: + hidden_activation = self.config.hidden_activation + self.act = ACT2FN[hidden_activation] self.gate_proj = Linear( inner_dim,