From 8ba46657954858eb1268328e713d6b9d70658648 Mon Sep 17 00:00:00 2001 From: neavo Date: Sun, 25 Aug 2024 19:40:34 +0800 Subject: [PATCH] =?UTF-8?q?=E8=B0=83=E6=95=B4=20-=20=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=A2=84=E5=A4=84=E7=90=86=E6=B5=81=E7=A8=8B=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 01.py | 125 +++++++++++++++++++++++++++++++--------------------------- 1 file changed, 67 insertions(+), 58 deletions(-) diff --git a/01.py b/01.py index 69db6ac..0a80dcb 100644 --- a/01.py +++ b/01.py @@ -24,26 +24,26 @@ from model.PreTrainerCallback import PreTrainerCallback # 参数设置 -MODEL_NAME = "microsoft_mdeberta_v3_base" +MODEL_NAME = "facebookai_xlm_roberta_base" MODEL_PATH = f"assets/{MODEL_NAME}" OUTPUT_PATH = f"output/{MODEL_NAME}_pretrain" EPOCHS = 2 -MAX_LENGTH = 512 -BATCH_SIZE = 6 -GRADIENT_ACCUMULATION_SIZE = 64 +LENGTH_THRESHOLD = 256 +BATCH_SIZE = 11 +GRADIENT_ACCUMULATION_SIZE = 128 DO_LOWER_CASE = False LEARNING_RATE = 2 * 1e-5 INTERVAL_STEPS = 200 AUTO_RESUME_FROM_CHECKPOINT = True DATASET_PATH = [ - ("dataset/pretrain/en", 10 * 10000), - ("dataset/pretrain/en_r18_visual_novels", 10 * 10000), - ("dataset/pretrain/zh", 10 * 10000), - ("dataset/pretrain/zh_r18_pixiv", 10 * 10000), - ("dataset/pretrain/jp", 8.5 * 10000), - ("dataset/pretrain/jp_r18", 8.5 * 10000), - ("dataset/pretrain/jp_r18_rpgmaker", 3 * 10000), + ("dataset/pretrain/en", 20 * 10000), + ("dataset/pretrain/en_r18_visual_novels", 20 * 10000), + ("dataset/pretrain/zh", 20 * 10000), + ("dataset/pretrain/zh_r18_pixiv", 20 * 10000), + ("dataset/pretrain/jp", 17 * 10000), + ("dataset/pretrain/jp_r18", 17 * 10000), + ("dataset/pretrain/jp_r18_rpgmaker", 6 * 10000), ] # 加载分词器 @@ -54,6 +54,10 @@ def load_tokenizer(): local_files_only = True, ) +# 分割数组 +def split(datas, size): + return [datas[i:(i + size)] for i in range(0, len(datas), size)] + # 清理文本 def cleanup(line): # 【\N[123]】 这种形式是代指角色名字的变量 @@ -87,12 +91,6 @@ def cleanup(line): return line -# 获取长度 -def get_length(tokenizer, line): - return len( - tokenizer(line).input_ids - ) - # 生成数据 def generate_datas(tokenizer, lines): lines = [cleanup(line) for line in lines] @@ -100,15 +98,15 @@ def generate_datas(tokenizer, lines): datas = [] tokens = tokenizer( lines, - padding = False, + padding = False, truncation = True, - max_length = MAX_LENGTH, + max_length = LENGTH_THRESHOLD, ) - for line, input_id in zip(lines, tokens.input_ids): + for line, input_ids in zip(lines, tokens.input_ids): datas.append({ "line": line, - "length": len(input_id), + "length": len(input_ids), }) return datas @@ -118,7 +116,6 @@ def generate_chunks(tokenizer, lines): chunks = [] datas = generate_datas(tokenizer, lines) - longest = 0 chunk = "" chunk_length = 0 for data in datas: @@ -129,24 +126,38 @@ def generate_chunks(tokenizer, lines): if "�" in line: continue - # 单句长度就超过最大值,则跳过 - if length > MAX_LENGTH - 3: - continue + if chunk_length + length >= LENGTH_THRESHOLD - 3: + chunk = re.sub(r" +", " ", chunk + " " + line) + chunks.append(chunk) - if chunk_length + length > MAX_LENGTH - 3: - longest = max(longest, get_length(tokenizer, re.sub(r" +", " ", chunk.strip()))) - chunks.append(re.sub(r" +", " ", chunk.strip())) chunk = "" chunk_length = 0 - - chunk = chunk + " " + line - chunk_length = chunk_length + length + 1 + else: + chunk = chunk + " " + line + chunk_length = chunk_length + 1 + length - 2 # 空格算不算 Token 都有可能,保险起见 +1,再减去首尾的两个特殊 Token if chunk.strip() != "": - longest = max(longest, get_length(tokenizer, re.sub(r" +", " ", chunk.strip()))) - chunks.append(re.sub(r" +", " ", chunk.strip())) + chunk = re.sub(r" +", " ", chunk) + chunks.append(chunk) + + return chunks + +# 映射函数 +def map_function(tokenizer, samples): + encodings = tokenizer( + samples["line"], + padding = "max_length", + truncation = True, + max_length = LENGTH_THRESHOLD, + return_attention_mask = True, + return_offsets_mapping = True, + return_special_tokens_mask = True, + ) - return chunks, longest + # 计算有效的 Token 数量 + encodings["input_length"] = [sum(item) for item in encodings.attention_mask] + + return encodings # 加载数据集 def load_dataset(tokenizer): @@ -154,16 +165,17 @@ def load_dataset(tokenizer): print(f"正在加载数据集 ...") print(f"") - longest = 0 - count = 0 datas = [] + count = 0 for path, num in DATASET_PATH: datas_by_type = [] + dir_path, dir_name = os.path.split(path) - if os.path.exists(f"{path}_{MODEL_NAME}.txt"): + if os.path.exists(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt"): count = count + 1 - with open(f"{path}_{MODEL_NAME}.txt", "r", encoding = "utf-8") as file: - datas_by_type = [line.strip() for line in file] + with open(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt", "r", encoding = "utf-8") as file: + datas_by_type = [line.strip() for line in tqdm(file, desc = path, total = num)] + random.shuffle(datas_by_type) else: total = len([entry for entry in os.listdir(path) if os.path.isfile(os.path.join(path, entry))]) @@ -174,45 +186,42 @@ def load_dataset(tokenizer): count = count + 1 lines.extend([line.strip() for line in file if line.strip() != ""]) - lines = [lines[i:(i + 64 * 1024)] for i in range(0, len(lines), 64 * 1024)] + lines = split(lines, 32 * 1024) results = Parallel(n_jobs = -1, prefer = "processes", return_as = "generator_unordered")( delayed(generate_chunks)(tokenizer, v) for v in lines ) for v in tqdm(results, desc = path, total = len(lines)): - datas_by_type.extend(v[0]) - longest = max(longest, v[1]) + datas_by_type.extend(v) datas_by_type = random.sample(datas_by_type, min(int(num), len(datas_by_type))) - with open(f"{path}_{MODEL_NAME}.txt", "w", encoding = "utf-8") as file: + with open(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt", "w", encoding = "utf-8") as file: file.writelines([f"{line}\n" for line in datas_by_type]) datas.extend(datas_by_type) - print(f"") - print(f"最大长度为 {longest} ...") if longest > 0 else None - print(f"找到数据文件 {count} 个,共 {len(datas)} 条数据 ...") # 生成数据集 - print(f"") os.makedirs("cache", exist_ok = True) dataset_train = Dataset.from_dict({"line": datas}) dataset_train_tokenized = dataset_train.map( - lambda samples: tokenizer( - samples["line"], - padding = "max_length", - truncation = True, - max_length = MAX_LENGTH, - return_attention_mask = True, - return_offsets_mapping = True, - return_special_tokens_mask = True, - ), + lambda samples: map_function(tokenizer, samples), + num_proc = 1, batched = True, - batch_size = 1024, - writer_batch_size = 4096, + batch_size = 512, + writer_batch_size = 8 * 1024, remove_columns = ["line"], cache_file_name = f"cache/{MODEL_NAME}.cache", load_from_cache_file = True, ) + + # 计算有效的 Token 数量 + total_length = sum(dataset_train_tokenized["input_length"]) + + print(f"") + print( + f"找到数据文件 {count} 个,数据条目 {len(datas)} 个," + + f"有效 Token {(total_length / 1000 / 1000):.2f} M,平均每个条目 {(total_length / len(datas)):.2f} Token ..." + ) print(f"") return dataset_train_tokenized