Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
srd13311390427 committed May 17, 2024
1 parent 6acae12 commit 9c2c2f4
Show file tree
Hide file tree
Showing 63 changed files with 3,442 additions and 9 deletions.
19 changes: 10 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
- 在训练数据方面,我们收集了覆盖书籍、百科、新闻、政务、法律、医药、专利、论文、数学、代码等诸多方面的大量中英文数据;通过优化数据清洗策略大幅提升数据的文本干净度、观点无偏性、内容有效性、格式规范性。
- 在训练方法方面,我们使用科学数据配比学习与课程学习的方法,使用小参数模型在多种数据配比的数据上拟合,得到对各个数据集难度的先验估计;训练过程中每隔一段时间自动化评估当前模型在所有数据集上的loss,以及在评测集上的生成效果,动态提升较难学习的数据集权重,保证模型在各个数据集上都有较佳的拟合效果。

- **TeleChat-12B-v2**版本使用动态数据配比和课程学习的方法,将基座模型持续训练到3.7T数据。其对话模型在通用能力评测上提升5.5%,其中数学能力提升24.6%、翻译能力提升9%、幻觉测试提升10.3%,安全拒识、知识问答、闲聊问答等方面也有不同程度的提升。
- **TeleChat-12B-V2**版本使用动态数据配比和课程学习的方法,针对基座模型进行了持续训练增强。其对话模型在通用能力评测上提升5.5%,其中数学能力提升24.6%、翻译能力提升9%、幻觉测试提升10.3%,安全拒识、知识问答、闲聊问答等方面也有不同程度的提升。



Expand Down Expand Up @@ -68,14 +68,15 @@

本次发布版本和下载链接见下表

| 模型版本 | 下载链接 |
|----------|-----------------------------------------------------------------------|
| 7B-FP16 | [TeleChat-7B-FP16](https://huggingface.co/Tele-AI/Telechat-7B) |
| 7B-int8 | [TeleChat-7B-int8](https://huggingface.co/Tele-AI/Telechat-7B-int8) |
| 7B-int4 | [TeleChat-7B-int4](https://huggingface.co/Tele-AI/Telechat-7B-int4) |
| 12B-FP16 | [TeleChat-12B-FP16](https://huggingface.co/Tele-AI/TeleChat-12B) |
| 12B-int8 | [TeleChat-12B-int8](https://huggingface.co/Tele-AI/TeleChat-12B-int8) |
| 12B-int4 | [TeleChat-12B-int4](https://huggingface.co/Tele-AI/TeleChat-12B-int4) |
| 模型版本 | 下载链接 |
|-------------|-----------------------------------------------------------------------|
| 7B-FP16 | [TeleChat-7B-FP16](https://huggingface.co/Tele-AI/Telechat-7B) |
| 7B-int8 | [TeleChat-7B-int8](https://huggingface.co/Tele-AI/Telechat-7B-int8) |
| 7B-int4 | [TeleChat-7B-int4](https://huggingface.co/Tele-AI/Telechat-7B-int4) |
| 12B-FP16 | [TeleChat-12B-FP16](https://huggingface.co/Tele-AI/TeleChat-12B) |
| 12B-int8 | [TeleChat-12B-int8](https://huggingface.co/Tele-AI/TeleChat-12B-int8) |
| 12B-int4 | [TeleChat-12B-int4](https://huggingface.co/Tele-AI/TeleChat-12B-int4) |
| 12B-V2-FP16 | [TeleChat-12B-V2-FP16](https://modelscope.cn/models/TeleAI/TeleChat-12B-v2/files) |

**镜像下载**
为了便于大家快速上手,我们提供了可运行的环境镜像,下载地址:[镜像下载](https://cloud.189.cn/web/share?code=vQFJRf7JBfmq) (访问码:ona6)
Expand Down
43 changes: 43 additions & 0 deletions models/12B-V2/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
{
"apply_residual_connection_post_layernorm": false,
"architectures": [
"TelechatForCausalLM"
],
"auto_map": {
"AutoConfig": "configuration_telechat.TelechatConfig",
"AutoModelForCausalLM": "modeling_telechat.TelechatForCausalLM"
},
"attention_dropout": 0.0,
"attention_softmax_in_fp32": true,
"bias_dropout_fusion": true,
"bos_token_id": 1,
"eos_token_id": 2,
"hidden_dropout": 0.0,
"hidden_size": 5120,
"initializer_range": 0.02,
"layer_norm_epsilon": 1e-05,
"masked_softmax_fusion": true,
"model_type": "telechat",
"n_head": 32,
"n_inner": null,
"n_layer": 38,
"offset_alibi": 100,
"pad_token_id": 3,
"pretraining_tp": 2,
"seq_length": 8192,
"skip_bias_add": true,
"skip_bias_add_qkv": false,
"slow_but_exact": false,
"transformers_version": "4.24.0",
"unk_token_id": 0,
"use_cache": true,
"vocab_size": 120000,
"ffn_hidden_size": 12288,
"flash_attn":true,
"tie_word_embeddings":false,
"training_seqlen":8192,
"logn":false,
"semi_causal":false,
"embed_layernorm":false
}

93 changes: 93 additions & 0 deletions models/12B-V2/configuration_telechat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# coding=utf-8
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

""" Telechat configuration"""

from packaging import version
from collections import OrderedDict
from transformers.utils import is_torch_available, logging
from transformers.configuration_utils import PretrainedConfig
from typing import TYPE_CHECKING, Any, List, Mapping, Optional

logger = logging.get_logger(__name__)

class TelechatConfig(PretrainedConfig):
"""
Args:
vocab_size (`int`, *optional*, defaults to 160256): Vocabulary size of the Telechat model.
hidden_size (`int`, *optional*, defaults to 4096): Dimensionality of the embeddings and hidden states.
ffn_hidden_size (`int`, *optional*, defaults to 12288): Dimensionality of the feed-forward hidden states.
n_layer (`int`, *optional*, defaults to 30): Number of hidden layers in the Transformer
n_head (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer.
layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): The epsilon to use in the layer normalization layers.
initializer_range (`float`, *optional*, defaults to 0.02): The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
apply_residual_connection_post_layernorm (`bool`, *optional*, defaults to `False`): If enabled, use the layer norm of the hidden states as the residual in the transformer blocks
hidden_dropout (`float`, *optional*, defaults to 0.0): Dropout rate of the dropout function on the bias dropout.
attention_dropout (`float`, *optional*, defaults to 0.0): Dropout rate applied to the attention probs
use_cache (`bool`, *optional*, defaults to `True`): Whether or not the model should return the last key/values attentions.
training_seqlen (`int`, *optional*, defaults to 8192): Sequence length during last finetuning.
logn (`bool`, *optional*, defaults to `True`): Whether or not to use logN during extrapolation.
embed_layernorm (`bool`, *optional*, defaults to `True`): Whether or not to use embedding layernorm.
"""

model_type = "telechat"
keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
}

def __init__(
self,
vocab_size=160256,
hidden_size=4096,
n_layer=30,
n_head=32,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
ffn_hidden_size=12288,
training_seqlen = 8192,
logn = True,
embed_layernorm = False,
**kwargs,
):
self.vocab_size = vocab_size
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.logn = logn
self.ffn_hidden_size = ffn_hidden_size
self.training_seqlen = training_seqlen
self.embed_layernorm = embed_layernorm


super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)

14 changes: 14 additions & 0 deletions models/12B-V2/generation_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"max_length": 8192,
"do_sample": false,
"use_cache": true,
"temperature": 0.3,
"top_k": 5,
"top_p": 0.85,
"repetition_penalty": 1.003,
"pad_token_id": 3,
"bos_token_id": 1,
"eos_token_id": 2,
"user_token_id": 20,
"bot_token_id": 21
}
162 changes: 162 additions & 0 deletions models/12B-V2/generation_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from typing import Optional
from collections import deque
from queue import Queue
import copy


class History:

def __init__(self, tokenizer, history):
'''
init from a list of dict
'''
# use deque to meet some special situation
self.input_history = deque()
self.tokenizer = tokenizer
if history:
self._transfer_from_list(history)

def _transfer_from_list(self, history):
for message in history:
content = message.get("content")
# the token result may not be equal to the result model gen
message.update(self.tokenizer(content))
self.input_history.append(message)

def append(self, message):
content = message.get("content")
if "input_ids" not in message or "attention_mask" not in message:
message.update(self.tokenizer(content))
self.input_history.append(message)

def append_left(self, message):
content = message.get("content")
if "input_ids" not in message or "attention_mask" not in message:
message.update(self.tokenizer(content))
self.input_history.appendleft(message)

def pop(self):
x = self.input_history.pop()
return x

def pop_left(self):
x = self.pop_left()
return x

def update(self, message):
self.input_history.pop()
self.append(message)

def __len__(self):
return self.input_history.__len__()

def __str__(self):
return self.input_history.__str__()

def __copy__(self):
new_instance = type(self)(self.tokenizer, [])
new_instance.input_history = copy.copy(self.input_history)
return new_instance

def __deepcopy__(self, memodict={}):
new_instance = type(self)(self.tokenizer, [])
new_instance.input_history = copy.deepcopy(self.input_history)
return new_instance


class TelechatIterTextStreamer:
"""
With reference to the TextIterStreamers in transformers, we have rewritten this class
"""

def __init__(
self, tokenizer, history: History = None, skip_prompt: bool = False, timeout: Optional[float] = None,
**decode_kwargs
):

self.tokenizer = tokenizer
self.history = history
self.skip_prompt = skip_prompt
self.timeout = timeout
self.decode_kwargs = decode_kwargs

self.text_queue = Queue()
self.cache_time = 0
self.text_until = ""
self.token_until = []
self.stop_signal = None
self.next_tokens_are_prompt = True

self.history.append({"role": "bot", "content": self.text_until})

def put(self, value):
"""
put printable text into queue
"""
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TextStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]

if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return

if value[-1] == self.tokenizer.eos_token_id:
return

# there may be some smart way to decode.
self.token_until.extend(value.tolist())
text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)


if self._is_printable(text) or self.cache_time >= 6:
output_text = text[len(self.text_until):]
self.text_until = text

else:
self.cache_time+=1
return

self.on_finalized_text(output_text)

def end(self):
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
text = self.tokenizer.decode(self.token_until, **self.decode_kwargs)
output_text = text[len(self.text_until):]
self.text_until = text
self.on_finalized_text(output_text, stream_end=True)
self.clear_cache()

def clear_cache(self):
self.cache_time = 0
self.token_until = []
self.text_until = ""
self.history = None
self.next_tokens_are_prompt = True

def on_finalized_text(self, text: str, stream_end: bool = False):
"""Put the text tuple in the queue."""
self.history.update({"role": "bot", "content": self.text_until, "input_ids": self.token_until,
"attention_mask": [1] * len(self.token_until)})
self.text_queue.put((text, self.history), timeout=self.timeout)
if stream_end:
self.text_queue.put((self.stop_signal, self.history), timeout=self.timeout)

@staticmethod
def _is_printable(cp):
"""Checks whether tokens can be decoded or not"""
if "�" in cp:
return False
return True

def __iter__(self):
return self

def __next__(self):
value_now, history_until = self.text_queue.get(timeout=self.timeout)
if value_now == self.stop_signal:
raise StopIteration()
else:
return value_now, history_until
Loading

0 comments on commit 9c2c2f4

Please sign in to comment.