Skip to content

Commit

Permalink
调整 - 数据预处理流程优化
Browse files Browse the repository at this point in the history
  • Loading branch information
neavo committed Aug 25, 2024
1 parent 9f66343 commit 8ba4665
Showing 1 changed file with 67 additions and 58 deletions.
125 changes: 67 additions & 58 deletions 01.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,26 @@
from model.PreTrainerCallback import PreTrainerCallback

# 参数设置
MODEL_NAME = "microsoft_mdeberta_v3_base"
MODEL_NAME = "facebookai_xlm_roberta_base"
MODEL_PATH = f"assets/{MODEL_NAME}"
OUTPUT_PATH = f"output/{MODEL_NAME}_pretrain"
EPOCHS = 2
MAX_LENGTH = 512
BATCH_SIZE = 6
GRADIENT_ACCUMULATION_SIZE = 64
LENGTH_THRESHOLD = 256
BATCH_SIZE = 11
GRADIENT_ACCUMULATION_SIZE = 128
DO_LOWER_CASE = False
LEARNING_RATE = 2 * 1e-5
INTERVAL_STEPS = 200
AUTO_RESUME_FROM_CHECKPOINT = True

DATASET_PATH = [
("dataset/pretrain/en", 10 * 10000),
("dataset/pretrain/en_r18_visual_novels", 10 * 10000),
("dataset/pretrain/zh", 10 * 10000),
("dataset/pretrain/zh_r18_pixiv", 10 * 10000),
("dataset/pretrain/jp", 8.5 * 10000),
("dataset/pretrain/jp_r18", 8.5 * 10000),
("dataset/pretrain/jp_r18_rpgmaker", 3 * 10000),
("dataset/pretrain/en", 20 * 10000),
("dataset/pretrain/en_r18_visual_novels", 20 * 10000),
("dataset/pretrain/zh", 20 * 10000),
("dataset/pretrain/zh_r18_pixiv", 20 * 10000),
("dataset/pretrain/jp", 17 * 10000),
("dataset/pretrain/jp_r18", 17 * 10000),
("dataset/pretrain/jp_r18_rpgmaker", 6 * 10000),
]

# 加载分词器
Expand All @@ -54,6 +54,10 @@ def load_tokenizer():
local_files_only = True,
)

# 分割数组
def split(datas, size):
return [datas[i:(i + size)] for i in range(0, len(datas), size)]

# 清理文本
def cleanup(line):
# 【\N[123]】 这种形式是代指角色名字的变量
Expand Down Expand Up @@ -87,28 +91,22 @@ def cleanup(line):

return line

# 获取长度
def get_length(tokenizer, line):
return len(
tokenizer(line).input_ids
)

# 生成数据
def generate_datas(tokenizer, lines):
lines = [cleanup(line) for line in lines]

datas = []
tokens = tokenizer(
lines,
padding = False,
padding = False,
truncation = True,
max_length = MAX_LENGTH,
max_length = LENGTH_THRESHOLD,
)

for line, input_id in zip(lines, tokens.input_ids):
for line, input_ids in zip(lines, tokens.input_ids):
datas.append({
"line": line,
"length": len(input_id),
"length": len(input_ids),
})

return datas
Expand All @@ -118,7 +116,6 @@ def generate_chunks(tokenizer, lines):
chunks = []
datas = generate_datas(tokenizer, lines)

longest = 0
chunk = ""
chunk_length = 0
for data in datas:
Expand All @@ -129,41 +126,56 @@ def generate_chunks(tokenizer, lines):
if "�" in line:
continue

# 单句长度就超过最大值,则跳过
if length > MAX_LENGTH - 3:
continue
if chunk_length + length >= LENGTH_THRESHOLD - 3:
chunk = re.sub(r" +", " ", chunk + " " + line)
chunks.append(chunk)

if chunk_length + length > MAX_LENGTH - 3:
longest = max(longest, get_length(tokenizer, re.sub(r" +", " ", chunk.strip())))
chunks.append(re.sub(r" +", " ", chunk.strip()))
chunk = ""
chunk_length = 0
chunk = chunk + " " + line
chunk_length = chunk_length + length + 1
else:
chunk = chunk + " " + line
chunk_length = chunk_length + 1 + length - 2 # 空格算不算 Token 都有可能,保险起见 +1,再减去首尾的两个特殊 Token

if chunk.strip() != "":
longest = max(longest, get_length(tokenizer, re.sub(r" +", " ", chunk.strip())))
chunks.append(re.sub(r" +", " ", chunk.strip()))
chunk = re.sub(r" +", " ", chunk)
chunks.append(chunk)

return chunks

# 映射函数
def map_function(tokenizer, samples):
encodings = tokenizer(
samples["line"],
padding = "max_length",
truncation = True,
max_length = LENGTH_THRESHOLD,
return_attention_mask = True,
return_offsets_mapping = True,
return_special_tokens_mask = True,
)

return chunks, longest
# 计算有效的 Token 数量
encodings["input_length"] = [sum(item) for item in encodings.attention_mask]

return encodings

# 加载数据集
def load_dataset(tokenizer):
print(f"")
print(f"正在加载数据集 ...")

print(f"")
longest = 0
count = 0
datas = []
count = 0
for path, num in DATASET_PATH:
datas_by_type = []
dir_path, dir_name = os.path.split(path)

if os.path.exists(f"{path}_{MODEL_NAME}.txt"):
if os.path.exists(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt"):
count = count + 1
with open(f"{path}_{MODEL_NAME}.txt", "r", encoding = "utf-8") as file:
datas_by_type = [line.strip() for line in file]
with open(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt", "r", encoding = "utf-8") as file:
datas_by_type = [line.strip() for line in tqdm(file, desc = path, total = num)]
random.shuffle(datas_by_type)
else:
total = len([entry for entry in os.listdir(path) if os.path.isfile(os.path.join(path, entry))])

Expand All @@ -174,45 +186,42 @@ def load_dataset(tokenizer):
count = count + 1
lines.extend([line.strip() for line in file if line.strip() != ""])

lines = [lines[i:(i + 64 * 1024)] for i in range(0, len(lines), 64 * 1024)]
lines = split(lines, 32 * 1024)
results = Parallel(n_jobs = -1, prefer = "processes", return_as = "generator_unordered")(
delayed(generate_chunks)(tokenizer, v) for v in lines
)

for v in tqdm(results, desc = path, total = len(lines)):
datas_by_type.extend(v[0])
longest = max(longest, v[1])
datas_by_type.extend(v)

datas_by_type = random.sample(datas_by_type, min(int(num), len(datas_by_type)))
with open(f"{path}_{MODEL_NAME}.txt", "w", encoding = "utf-8") as file:
with open(f"{dir_path}/{MODEL_NAME}_{dir_name}.txt", "w", encoding = "utf-8") as file:
file.writelines([f"{line}\n" for line in datas_by_type])

datas.extend(datas_by_type)
print(f"")
print(f"最大长度为 {longest} ...") if longest > 0 else None
print(f"找到数据文件 {count} 个,共 {len(datas)} 条数据 ...")

# 生成数据集
print(f"")
os.makedirs("cache", exist_ok = True)
dataset_train = Dataset.from_dict({"line": datas})
dataset_train_tokenized = dataset_train.map(
lambda samples: tokenizer(
samples["line"],
padding = "max_length",
truncation = True,
max_length = MAX_LENGTH,
return_attention_mask = True,
return_offsets_mapping = True,
return_special_tokens_mask = True,
),
lambda samples: map_function(tokenizer, samples),
num_proc = 1,
batched = True,
batch_size = 1024,
writer_batch_size = 4096,
batch_size = 512,
writer_batch_size = 8 * 1024,
remove_columns = ["line"],
cache_file_name = f"cache/{MODEL_NAME}.cache",
load_from_cache_file = True,
)

# 计算有效的 Token 数量
total_length = sum(dataset_train_tokenized["input_length"])

print(f"")
print(
f"找到数据文件 {count} 个,数据条目 {len(datas)} 个," +
f"有效 Token {(total_length / 1000 / 1000):.2f} M,平均每个条目 {(total_length / len(datas)):.2f} Token ..."
)
print(f"")

return dataset_train_tokenized
Expand Down

0 comments on commit 8ba4665

Please sign in to comment.