diff --git a/models/12B/modeling_telechat.py b/models/12B/modeling_telechat.py index b337969..81ce405 100644 --- a/models/12B/modeling_telechat.py +++ b/models/12B/modeling_telechat.py @@ -81,7 +81,7 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim @@ -89,6 +89,7 @@ def __init__(self, dim, config, base=10000): self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None + self.precision = precision def get_mscale(self, scale=1): if scale <= 1: @@ -436,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: diff --git a/models/12B_4bit/modeling_telechat.py b/models/12B_4bit/modeling_telechat.py index 81c0d57..9fa51b5 100644 --- a/models/12B_4bit/modeling_telechat.py +++ b/models/12B_4bit/modeling_telechat.py @@ -81,7 +81,7 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim @@ -89,6 +89,7 @@ def __init__(self, dim, config, base=10000): self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None + self.precision = precision def get_mscale(self, scale=1): if scale <= 1: @@ -436,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: diff --git a/models/12B_8bit/modeling_telechat.py b/models/12B_8bit/modeling_telechat.py index 81c0d57..9fa51b5 100644 --- a/models/12B_8bit/modeling_telechat.py +++ b/models/12B_8bit/modeling_telechat.py @@ -81,7 +81,7 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim @@ -89,6 +89,7 @@ def __init__(self, dim, config, base=10000): self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None + self.precision = precision def get_mscale(self, scale=1): if scale <= 1: @@ -436,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: diff --git a/models/7B/modeling_telechat.py b/models/7B/modeling_telechat.py index 309e6f0..c699b95 100644 --- a/models/7B/modeling_telechat.py +++ b/models/7B/modeling_telechat.py @@ -81,12 +81,11 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000, precision=torch.half): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim self.base = base - self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda() self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None @@ -438,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: diff --git a/models/7B_4bit/modeling_telechat.py b/models/7B_4bit/modeling_telechat.py index 309e6f0..c699b95 100644 --- a/models/7B_4bit/modeling_telechat.py +++ b/models/7B_4bit/modeling_telechat.py @@ -81,12 +81,11 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000, precision=torch.half): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim self.base = base - self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda() self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None @@ -438,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: diff --git a/models/7B_8bit/modeling_telechat.py b/models/7B_8bit/modeling_telechat.py index 309e6f0..c699b95 100644 --- a/models/7B_8bit/modeling_telechat.py +++ b/models/7B_8bit/modeling_telechat.py @@ -81,12 +81,11 @@ class RotaryEmbedding(torch.nn.Module): # Extracted from: https://github.com/EleutherAI/gpt-neox - def __init__(self, dim, config, base=10000, precision=torch.half): + def __init__(self, dim, config, base=10000,precision=torch.half): super().__init__() self.config = config self.dim = dim self.base = base - self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda() self.max_seq_len_cached = None self.cos_cached = None self.sin_cached = None @@ -438,7 +437,7 @@ def forward( offset = past_key.shape[0] seq_len += offset - cos, sin = self.rotary_emb(value_layer, seq_len=seq_len) + cos, sin = self.rotary_emb(value_layer) query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset) if use_cache: