diff --git a/models/12B_8bit/modeling_telechat.py b/models/12B_8bit/modeling_telechat.py index b6c126b..81c0d57 100644 --- a/models/12B_8bit/modeling_telechat.py +++ b/models/12B_8bit/modeling_telechat.py @@ -96,31 +96,34 @@ def get_mscale(self, scale=1): return 0.1 * math.log(scale) + 1.0 def get_ntk_alpha(self, true_seq_len): - context_value = math.log(true_seq_len / 4096, 2) + 1 + context_value = math.log(true_seq_len / 8192, 2) + 1 + # ntk_alpha = 2 ** context_value - 1 ntk_alpha = 2 ** math.ceil(context_value) - 1 ntk_alpha = max(ntk_alpha, 1) return ntk_alpha - def forward(self, x, dtype, seq_dim=0): - seq_len = x.shape[seq_dim] - self.mscale = 1.0 - if not self.training: - seq_len = max(seq_len, self.config.training_seqlen) - self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + def forward(self, x, seq_dim=0, seq_len=None): + if seq_len is None: + seq_len = x.shape[seq_dim] + seq_len = max(seq_len, self.config.training_seqlen) ntk_alpha = self.get_ntk_alpha(seq_len) + mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim)) - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim )) + max_seq_len_cached = seq_len + t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - # if self.precision == torch.bfloat16: - emb = emb.float() if dtype == torch.bfloat16 else emb + if self.precision == torch.bfloat16: + emb = emb.float() # [sx, 1 (b * np), hn] - self.cos_cached = self.mscale * emb.cos()[:, None, :].to(dtype) - self.sin_cached = self.mscale * emb.sin()[:, None, :].to(dtype) - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + cos_cached = mscale *emb.cos()[:, None, :].half() + sin_cached = mscale *emb.sin()[:, None, :].half() + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...] # rotary pos emb helpers: diff --git a/models/7B/modeling_telechat.py b/models/7B/modeling_telechat.py index a615692..309e6f0 100644 --- a/models/7B/modeling_telechat.py +++ b/models/7B/modeling_telechat.py @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len): return ntk_alpha def forward(self, x, seq_dim=0, seq_len=None): - seq_len = x.shape[seq_dim] - self.mscale = 1.0 - if not self.training: - seq_len = max(seq_len, self.config.training_seqlen) - self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + if seq_len is None: + seq_len = x.shape[seq_dim] + seq_len = max(seq_len, self.config.training_seqlen) ntk_alpha = self.get_ntk_alpha(seq_len) - if True: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim)) - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - # [sx, 1 (b * np), hn] - self.cos_cached = self.mscale * emb.cos()[:, None, :].half() - self.sin_cached = self.mscale * emb.sin()[:, None, :].half() - if self.precision == torch.bfloat16: - self.cos_cached = self.cos_cached.bfloat16() - self.sin_cached = self.sin_cached.bfloat16() - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim )) + max_seq_len_cached = seq_len + t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + # [sx, 1 (b * np), hn] + cos_cached = mscale *emb.cos()[:, None, :].half() + sin_cached = mscale *emb.sin()[:, None, :].half() + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...] # rotary pos emb helpers: diff --git a/models/7B_4bit/modeling_telechat.py b/models/7B_4bit/modeling_telechat.py index a615692..309e6f0 100644 --- a/models/7B_4bit/modeling_telechat.py +++ b/models/7B_4bit/modeling_telechat.py @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len): return ntk_alpha def forward(self, x, seq_dim=0, seq_len=None): - seq_len = x.shape[seq_dim] - self.mscale = 1.0 - if not self.training: - seq_len = max(seq_len, self.config.training_seqlen) - self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + if seq_len is None: + seq_len = x.shape[seq_dim] + seq_len = max(seq_len, self.config.training_seqlen) ntk_alpha = self.get_ntk_alpha(seq_len) - if True: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim)) - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - # [sx, 1 (b * np), hn] - self.cos_cached = self.mscale * emb.cos()[:, None, :].half() - self.sin_cached = self.mscale * emb.sin()[:, None, :].half() - if self.precision == torch.bfloat16: - self.cos_cached = self.cos_cached.bfloat16() - self.sin_cached = self.sin_cached.bfloat16() - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim )) + max_seq_len_cached = seq_len + t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + # [sx, 1 (b * np), hn] + cos_cached = mscale *emb.cos()[:, None, :].half() + sin_cached = mscale *emb.sin()[:, None, :].half() + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...] # rotary pos emb helpers: diff --git a/models/7B_8bit/modeling_telechat.py b/models/7B_8bit/modeling_telechat.py index a615692..309e6f0 100644 --- a/models/7B_8bit/modeling_telechat.py +++ b/models/7B_8bit/modeling_telechat.py @@ -105,29 +105,27 @@ def get_ntk_alpha(self, true_seq_len): return ntk_alpha def forward(self, x, seq_dim=0, seq_len=None): - seq_len = x.shape[seq_dim] - self.mscale = 1.0 - if not self.training: - seq_len = max(seq_len, self.config.training_seqlen) - self.mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + if seq_len is None: + seq_len = x.shape[seq_dim] + seq_len = max(seq_len, self.config.training_seqlen) ntk_alpha = self.get_ntk_alpha(seq_len) - if True: - base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) - self.inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float() / self.dim)) - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) - freqs = torch.einsum('i,j->ij', t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - # [sx, 1 (b * np), hn] - self.cos_cached = self.mscale * emb.cos()[:, None, :].half() - self.sin_cached = self.mscale * emb.sin()[:, None, :].half() - if self.precision == torch.bfloat16: - self.cos_cached = self.cos_cached.bfloat16() - self.sin_cached = self.sin_cached.bfloat16() - return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...] + mscale = float(self.get_mscale(seq_len / self.config.training_seqlen)) + base = self.base * ntk_alpha ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, device=x.device).float( )/ self.dim )) + max_seq_len_cached = seq_len + t = torch.arange(max_seq_len_cached, device=x.device, dtype=inv_freq.dtype) + freqs = torch.einsum('i,j->ij', t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + if self.precision == torch.bfloat16: + emb = emb.float() + # [sx, 1 (b * np), hn] + cos_cached = mscale *emb.cos()[:, None, :].half() + sin_cached = mscale *emb.sin()[:, None, :].half() + if self.precision == torch.bfloat16: + cos_cached = cos_cached.bfloat16() + sin_cached = sin_cached.bfloat16() + return cos_cached[:seq_len, ...], sin_cached[:seq_len, ...] # rotary pos emb helpers: