-
-
Notifications
You must be signed in to change notification settings - Fork 72
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
copy pastes of past files into this new rebase
- Loading branch information
1 parent
7798682
commit adda5fb
Showing
7 changed files
with
812 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from .phi3_v import ( | ||
LanguageModel, | ||
Model, | ||
ModelConfig, | ||
TextConfig, | ||
VisionConfig, | ||
VisionModel, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
from dataclasses import dataclass | ||
import inspect | ||
|
||
|
||
@dataclass | ||
class TextConfig: | ||
@classmethod | ||
def from_dict(cls, params): | ||
return cls( | ||
**{ | ||
k: v | ||
for k, v in params.items() | ||
if k in inspect.signature(cls).parameters | ||
} | ||
) | ||
|
||
|
||
class LanguageModel: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
import inspect | ||
import math | ||
from dataclasses import dataclass | ||
from typing import Dict, Optional, Tuple, Union | ||
from types import SimpleNamespace | ||
from typing import Optional | ||
|
||
import mlx.core as mx | ||
import mlx.nn as nn | ||
import numpy as np | ||
|
||
from .su_rope import Phi3SuScaledRotaryEmbedding | ||
from .language import TextConfig, LanguageModel | ||
from .vision import VisionConfig, VisionModel | ||
|
||
|
||
@dataclass | ||
class ModelConfig: | ||
text_config: TextConfig | ||
vision_config: VisionConfig | ||
model_type: str | ||
vocab_size: int | ||
|
||
num_hidden_layers: int | ||
intermediate_size: int | ||
num_attention_heads: int | ||
rms_norm_eps: float | ||
|
||
ignore_index: int = -100 | ||
image_token_index: int = 257152 | ||
hidden_size: int = 2048 | ||
pad_token_id: int = 0 | ||
|
||
num_key_value_heads: int = None | ||
rope_theta: float = 10000 | ||
rope_traditional: bool = False | ||
rope_scaling: Optional[Dict[str, Union[float, str]]] = None | ||
max_position_embeddings: int = 131072 | ||
original_max_position_embeddings: int = 4096 | ||
|
||
@classmethod | ||
def from_dict(cls, params): | ||
return cls( | ||
**{ | ||
k: v | ||
for k, v in params.items() | ||
if k in inspect.signature(cls).parameters | ||
} | ||
) | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
|
||
dim = args.hidden_size | ||
self.n_heads = n_heads = args.num_attention_heads | ||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads | ||
self.num_hidden_layers = args.num_hidden_layers | ||
|
||
self.head_dim = head_dim = args.hidden_size // n_heads | ||
self.scale = head_dim**-0.5 | ||
|
||
op_size = n_heads * head_dim + 2 * (n_kv_heads * head_dim) | ||
self.qkv_proj = nn.Linear(dim, op_size, bias=False) | ||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) | ||
|
||
rope_scale = 1.0 | ||
if args.rope_scaling and args.rope_scaling["type"] == "su": | ||
self.rope = Phi3SuScaledRotaryEmbedding( | ||
head_dim, | ||
traditional=False, | ||
base=args.rope_theta, | ||
scale=rope_scale, | ||
max_position_embeddings=args.max_position_embeddings, | ||
original_max_position_embeddings=args.original_max_position_embeddings, | ||
short_factor=args.rope_scaling["short_factor"], | ||
long_factor=args.rope_scaling["long_factor"], | ||
) | ||
else: | ||
if args.rope_scaling and args.rope_scaling["type"] == "linear": | ||
rope_scale = 1 / args.rope_scaling["factor"] | ||
self.rope = nn.RoPE( | ||
head_dim, | ||
traditional=args.rope_traditional, | ||
base=args.rope_theta, | ||
scale=rope_scale, | ||
) | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
B, L, D = x.shape | ||
|
||
qkv = self.qkv_proj(x) | ||
query_pos = self.n_heads * self.head_dim | ||
queries, keys, values = mx.split( | ||
qkv, [query_pos, query_pos + self.n_kv_heads * self.head_dim], axis=-1 | ||
) | ||
|
||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) | ||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) | ||
|
||
if cache is not None: | ||
offset = cache[0].shape[2] | ||
queries = self.rope(queries, offset=offset) | ||
keys = self.rope(keys, offset=offset) | ||
keys = mx.concatenate([cache[0], keys], axis=2) | ||
values = mx.concatenate([cache[1], values], axis=2) | ||
else: | ||
queries = self.rope(queries) | ||
keys = self.rope(keys) | ||
|
||
output = mx.fast.scaled_dot_product_attention( | ||
queries, keys, values, scale=self.scale, mask=mask | ||
) | ||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) | ||
return self.o_proj(output), (keys, values) | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, dim, hidden_dim): | ||
super().__init__() | ||
self.gate_up_proj = nn.Linear(dim, 2 * hidden_dim, bias=False) | ||
self.down_proj = nn.Linear(hidden_dim, dim, bias=False) | ||
|
||
def __call__(self, x) -> mx.array: | ||
x = self.gate_up_proj(x) | ||
gate, x = mx.split(x, 2, axis=-1) | ||
return self.down_proj(nn.silu(gate) * x) | ||
|
||
|
||
class TransformerBlock(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.num_attention_heads = args.num_attention_heads | ||
self.hidden_size = args.hidden_size | ||
self.self_attn = Attention(args) | ||
self.mlp = MLP(args.hidden_size, args.intermediate_size) | ||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
self.post_attention_layernorm = nn.RMSNorm( | ||
args.hidden_size, eps=args.rms_norm_eps | ||
) | ||
self.args = args | ||
|
||
def __call__( | ||
self, | ||
x: mx.array, | ||
mask: Optional[mx.array] = None, | ||
cache: Optional[Tuple[mx.array, mx.array]] = None, | ||
) -> mx.array: | ||
r, cache = self.self_attn(self.input_layernorm(x), mask, cache) | ||
h = x + r | ||
r = self.mlp(self.post_attention_layernorm(h)) | ||
out = h + r | ||
return out, cache | ||
|
||
|
||
class Phi3V(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.args = args | ||
self.vocab_size = args.vocab_size | ||
self.num_hidden_layers = args.num_hidden_layers | ||
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) | ||
self.vision_embed_tokens = VisionModel(args) | ||
self.layers = [ | ||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers) | ||
] | ||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
pixel_values=None, | ||
image_sizes=None, | ||
cache=None, | ||
): | ||
# print('inputs', inputs) # debug | ||
h = self.embed_tokens(inputs) | ||
p = np.argwhere(inputs < 0).tolist() | ||
if pixel_values is not None: | ||
h = self.vision_embed_tokens(pixel_values, h, image_sizes, p) | ||
mask = None | ||
if h.shape[1] > 1: | ||
mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) | ||
mask = mask.astype(h.dtype) | ||
if cache is None: | ||
cache = [None] * len(self.layers) | ||
for i, layer in enumerate(self.layers): | ||
h, cache[i] = layer(h, mask, cache[i]) | ||
return self.norm(h), cache | ||
|
||
|
||
class Model(nn.Module): | ||
def __init__(self, args: TextConfig): | ||
super().__init__() | ||
self.model_type = args.model_type | ||
self.model = Phi3V(args) | ||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) | ||
self.config = args | ||
|
||
def __call__( | ||
self, | ||
inputs: mx.array, | ||
pixel_values=None, | ||
mask=None, | ||
cache=None, | ||
): | ||
out, cache = self.model(inputs, pixel_values, mask, cache) | ||
return self.lm_head(out).astype(self.lm_head.weight.dtype), cache | ||
|
||
@property | ||
def layers(self): | ||
return self.model.layers | ||
|
||
@property | ||
def head_dim(self): | ||
return self.args.hidden_size // self.args.num_attention_heads | ||
|
||
@property | ||
def n_kv_heads(self): | ||
return self.args.num_key_value_heads | ||
|
||
@property | ||
def language_model(self): | ||
return self | ||
|
||
@property | ||
def vision_model(self): | ||
return self.model.vision_embed_tokens |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import math | ||
import mlx.core as mx | ||
|
||
|
||
class Phi3SuScaledRotaryEmbedding: | ||
def __init__( | ||
self, | ||
dims: int, | ||
traditional: bool = False, | ||
base: float = 10000.0, | ||
scale: float = 1.0, | ||
max_position_embeddings: int = 131072, | ||
original_max_position_embeddings: int = 4096, | ||
short_factor: list[float] | float = 1.0, | ||
long_factor: list[float] | float = 1.0, | ||
): | ||
""" | ||
Phi3Su Scaled Rotary Embedding layer for Phi-3 models. | ||
Args: | ||
dims (int): The feature dimensions to be rotated. | ||
traditional (bool, optional): Unused. Default: ``False``. | ||
base (int, optional): Base for the exponential scaling. | ||
scale (float, optional): The scale used to scale the positions. Default: 1.0. | ||
max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 131072. | ||
original_max_position_embeddings (int, optional): The maximum sequence length that this model was trained with. This is used to determine the size of the original RoPE embeddings when using long scaling. Default: 4096. | ||
short_factor (float or list of floats, optional): List of scaling factors for sequences of length lesser than original_max_position_embeddings. Default: 1.0. | ||
long_factor (float or list of floats, optional): List of scaling factors for sequences of length greater than original_max_position_embeddings. Default: 1.0. | ||
""" | ||
self.inv_freq_short = 1.0 / ( | ||
mx.array(short_factor, dtype=mx.float32) | ||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||
) | ||
self.inv_freq_long = 1.0 / ( | ||
scale | ||
* mx.array(long_factor, dtype=mx.float32) | ||
* base ** (mx.arange(0, dims, 2, dtype=mx.float32) / dims) | ||
) | ||
self.original_max_position_embeddings = original_max_position_embeddings | ||
self.scaling_factor = math.sqrt( | ||
1 | ||
+ math.log(max_position_embeddings / original_max_position_embeddings) | ||
/ math.log(original_max_position_embeddings) | ||
) | ||
|
||
def _get_cos_sin(self, offset, L): | ||
position_ids = mx.arange(offset, offset + L, dtype=mx.float32)[None] | ||
inv_freq = ( | ||
self.inv_freq_long | ||
if position_ids.max() + 1 > self.original_max_position_embeddings | ||
else self.inv_freq_short | ||
) | ||
inv_freq_expanded = mx.repeat( | ||
inv_freq[None, :, None], position_ids.shape[0], axis=0 | ||
) | ||
position_ids_expanded = position_ids[:, None, :] | ||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(0, 2, 1) | ||
emb = mx.concatenate([freqs, freqs], axis=-1) | ||
cos = mx.cos(emb) * self.scaling_factor | ||
sin = mx.sin(emb) * self.scaling_factor | ||
return mx.expand_dims(cos, axis=1), mx.expand_dims(sin, axis=1) | ||
|
||
def __call__(self, x, offset: int = 0): | ||
def _rotate_half(_x): | ||
midpoint = _x.shape[-1] // 2 | ||
x1, x2 = _x[..., :midpoint], _x[..., midpoint:] | ||
return mx.concatenate([-x2, x1], axis=-1) | ||
|
||
cos, sin = self._get_cos_sin(offset, x.shape[2]) | ||
return (x * cos) + (_rotate_half(x) * sin) |
Oops, something went wrong.