Skip to content

Commit

Permalink
update modeling_telechat.py
Browse files Browse the repository at this point in the history
  • Loading branch information
hannawong committed May 16, 2024
1 parent 59d9858 commit d2091c2
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 82 deletions.
35 changes: 19 additions & 16 deletions models/12B_8bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,31 +96,34 @@ def get_mscale(self, scale=1):
return 0.1 * math.log(scale) + 1.0

def get_ntk_alpha(self, true_seq_len):
context_value = math.log(true_seq_len / 4096, 2) + 1
context_value = math.log(true_seq_len / 8192, 2) + 1
# ntk_alpha = 2 ** context_value - 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
return ntk_alpha

def forward(self, x, dtype, seq_dim=0):
seq_len = x.shape[seq_dim]
self.mscale = 1.0
if not self.training:
seq_len = max(seq_len, self.config.training_seqlen)
self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
def forward(self, x, seq_dim=0, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
seq_len = max(seq_len, self.config.training_seqlen)
ntk_alpha = self.get_ntk_alpha(seq_len)
mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim ))
max_seq_len_cached = seq_len
t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
# if self.precision == torch.bfloat16:
emb = emb.float() if dtype == torch.bfloat16 else emb
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
self.cos_cached = self.mscale * emb.cos()[:, None, :].to(dtype)
self.sin_cached = self.mscale * emb.sin()[:, None, :].to(dtype)
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
cos_cached = mscale *emb.cos()[:, None, :].half()
sin_cached = mscale *emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]


# rotary pos emb helpers:
Expand Down
42 changes: 20 additions & 22 deletions models/7B/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len):
return ntk_alpha

def forward(self, x, seq_dim=0, seq_len=None):
seq_len = x.shape[seq_dim]
self.mscale = 1.0
if not self.training:
seq_len = max(seq_len, self.config.training_seqlen)
self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
if seq_len is None:
seq_len = x.shape[seq_dim]
seq_len = max(seq_len, self.config.training_seqlen)
ntk_alpha = self.get_ntk_alpha(seq_len)
if True:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
self.cos_cached = self.mscale * emb.cos()[:, None, :].half()
self.sin_cached = self.mscale * emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim ))
max_seq_len_cached = seq_len
t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = mscale *emb.cos()[:, None, :].half()
sin_cached = mscale *emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]


# rotary pos emb helpers:
Expand Down
42 changes: 20 additions & 22 deletions models/7B_4bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len):
return ntk_alpha

def forward(self, x, seq_dim=0, seq_len=None):
seq_len = x.shape[seq_dim]
self.mscale = 1.0
if not self.training:
seq_len = max(seq_len, self.config.training_seqlen)
self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
if seq_len is None:
seq_len = x.shape[seq_dim]
seq_len = max(seq_len, self.config.training_seqlen)
ntk_alpha = self.get_ntk_alpha(seq_len)
if True:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
self.cos_cached = self.mscale * emb.cos()[:, None, :].half()
self.sin_cached = self.mscale * emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim ))
max_seq_len_cached = seq_len
t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = mscale *emb.cos()[:, None, :].half()
sin_cached = mscale *emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]


# rotary pos emb helpers:
Expand Down
42 changes: 20 additions & 22 deletions models/7B_8bit/modeling_telechat.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len):
return ntk_alpha

def forward(self, x, seq_dim=0, seq_len=None):
seq_len = x.shape[seq_dim]
self.mscale = 1.0
if not self.training:
seq_len = max(seq_len, self.config.training_seqlen)
self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
if seq_len is None:
seq_len = x.shape[seq_dim]
seq_len = max(seq_len, self.config.training_seqlen)
ntk_alpha = self.get_ntk_alpha(seq_len)
if True:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim))
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
self.cos_cached = self.mscale * emb.cos()[:, None, :].half()
self.sin_cached = self.mscale * emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]
mscale = float(self.get_mscale(seq_len / self.config.training_seqlen))
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim ))
max_seq_len_cached = seq_len
t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
# [sx, 1 (b * np), hn]
cos_cached = mscale *emb.cos()[:, None, :].half()
sin_cached = mscale *emb.sin()[:, None, :].half()
if self.precision == torch.bfloat16:
cos_cached = cos_cached.bfloat16()
sin_cached = sin_cached.bfloat16()
return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...]


# rotary pos emb helpers:
Expand Down

0 comments on commit d2091c2

Please sign in to comment.