Skip to content

Commit

Permalink
update infer
Browse files Browse the repository at this point in the history
  • Loading branch information
hannawong committed May 16, 2024
1 parent b970158 commit be8af21
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 15 deletions.
5 changes: 3 additions & 2 deletions models/12B/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision

def get_mscale(self, scale=1):
if scale <= 1:
Expand Down Expand Up @@ -436,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down
5 changes: 3 additions & 2 deletions models/12B_4bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision

def get_mscale(self, scale=1):
if scale <= 1:
Expand Down Expand Up @@ -436,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down
5 changes: 3 additions & 2 deletions models/12B_8bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,15 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision

def get_mscale(self, scale=1):
if scale <= 1:
Expand Down Expand Up @@ -436,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, dtype=value_layer.dtype)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down
5 changes: 2 additions & 3 deletions models/7B/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000, precision=torch.half):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda()
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
Expand Down Expand Up @@ -438,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down
5 changes: 2 additions & 3 deletions models/7B_4bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000, precision=torch.half):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda()
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
Expand Down Expand Up @@ -438,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down
5 changes: 2 additions & 3 deletions models/7B_8bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,11 @@

class RotaryEmbedding(torch.nn.Module):
# Extracted from: https://github.com/EleutherAI/gpt-neox
def __init__(self, dim, config, base=10000, precision=torch.half):
def __init__(self, dim, config, base=10000,precision=torch.half):
super().__init__()
self.config = config
self.dim = dim
self.base = base
self.inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float().half() / dim)).cuda()
self.max_seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
Expand Down Expand Up @@ -438,7 +437,7 @@ def forward(
offset = past_key.shape[0]
seq_len += offset

cos, sin = self.rotary_emb(value_layer, seq_len=seq_len)
cos, sin = self.rotary_emb(value_layer)

query_layer, key_layer = apply_rotary_fn(query_layer, key_layer, cos, sin, offset=offset)
if use_cache:
Expand Down

0 comments on commit be8af21

Please sign in to comment.