Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[application] add lora sft example #6192

Merged
merged 5 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,12 @@ jobs:

- name: Install Colossal-AI
run: |
BUILD_EXT=1 pip install --no-cache-dir -v -e .
pip install --no-cache-dir -v -e .

- name: Install ChatGPT
run: |
cd applications/ColossalChat
pip install --no-cache-dir -v .
export BUILD_EXT=1
pip install --no-cache-dir -r examples/requirements.txt

- name: Install Transformers
Expand Down
34 changes: 33 additions & 1 deletion applications/ColossalChat/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- [Alternative Option For RLHF: KTO](#alternative-option-for-rlhf-kahneman-tversky-optimization-kto)
- [O1 Journey](#o1-journey)
- [Inference with Self-refined MCTS](#inference-with-self-refined-mcts)
- [SFT for DeepSeek V3/R1](#sft-for-deepseek-v3)
- [FAQ](#faq)
- [How to save/load checkpoint](#faq)
- [How to train with limited resources](#faq)
Expand Down Expand Up @@ -389,6 +390,37 @@ You can find more examples in this [repo](https://github.com/XueFuzhao/Instructi
- Cannot abide by OpenAI's policy: When generating prompts from OpenAI API, it always abides by its policy. So no violation case is in the datasets.
</details>

## SFT for DeepSeek V3

We add a script to supervised-fintune the DeepSeek V3/R1 model with LoRA. The script is located in `examples/training_scripts/lora_fintune.py`. The script is similar to the SFT script for Coati7B, but with a few differences. This script is compatible with Peft.

### Dataset preparation

This script receives JSONL format file as input dataset. Each line of dataset should be a list of chat dialogues. E.g.
```json
[{"role": "user", "content": "Hello, how are you?"}, {"role": "assistant", "content": "I'm doing great. How can I help you today?"}]
```
```json
[{"role": "user", "content": "火烧赤壁 曹操为何不拨打119求救?"}, {"role": "assistant", "content": "因为在三国时期,还没有电话和现代的消防系统,所以曹操无法拨打119求救。"}]
```

The dialogues can by multiple turns and it can contain system prompt. For more details, see the [chat_templating](https://huggingface.co/docs/transformers/main/chat_templating).

### Model weights preparation

We use bf16 weights for finetuning. If you downloaded fp8 DeepSeek V3/R1 weights, you can use the [script](https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/fp8_cast_bf16.py) to convert the weights to bf16 via GPU. For Ascend NPU, you can use this [script](https://gitee.com/ascend/ModelZoo-PyTorch/blob/master/MindIE/LLM/DeepSeek/DeepSeek-V2/NPU_inference/fp8_cast_bf16.py).

### Usage

After preparing the dataset and model weights, you can run the script with the following command:
```bash
colossalai run --hostfile path-to-host-file --nproc_per_node 8 lora_finetune.py --pretrained path-to-DeepSeek-R1-bf16 --dataset path-to-dataset.jsonl --plugin moe --lr 2e-5 --max_length 256 -g --ep 8 --pp 3 --batch_size 24 --lora_rank 8 --lora_alpha 16 --num_epochs 2 --warmup_steps 8 --tensorboard_dir logs --save_dir DeepSeek-R1-bf16-lora
```

For more details of each argument, you can run `python lora_finetune.py --help`.

The sample command does not use CPU offload to get better throughput. The minimum hardware requirement for sample command is 32 ascend 910B NPUs (with `ep=8,pp=4`) or 24 H100/H800 GPUs (with `ep=8,pp=3`). If you enable CPU offload by `--zero_cpu_offload`, the hardware requirement can be further reduced.

## FAQ

<details><summary><b>How to save/load checkpoint</b></summary>
Expand Down Expand Up @@ -501,7 +533,7 @@ Thanks so much to all of our amazing contributors!
- Keep in a sufficiently high running speed

| Model Pair | Alpaca-7B ⚔ Coati-7B | Coati-7B ⚔ Alpaca-7B |
| :-----------: | :------------------: | :------------------: |
|:-------------:|:--------------------:|:--------------------:|
| Better Cases | 38 ⚔ **41** | **45** ⚔ 33 |
| Win Rate | 48% ⚔ **52%** | **58%** ⚔ 42% |
| Average Score | 7.06 ⚔ **7.13** | **7.31** ⚔ 6.82 |
Expand Down
75 changes: 75 additions & 0 deletions applications/ColossalChat/coati/dataset/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dataclasses import dataclass
from typing import Dict, Iterator, List, Optional, Sequence, Union

import jsonlines
import torch
import torch.nn.functional as F
from coati.dataset.utils import chuncate_sequence, pad_to_max_len
Expand Down Expand Up @@ -344,3 +345,77 @@ def __len__(self) -> int:

def set_start_index(self, start_index: int) -> None:
self.start_index = start_index


def apply_chat_template_and_mask(
tokenizer: PreTrainedTokenizer,
chat: List[Dict[str, str]],
max_length: Optional[int] = None,
padding: bool = True,
truncation: bool = True,
ignore_idx: int = -100,
) -> Dict[str, torch.Tensor]:
tokens = []
assistant_mask = []
for i, msg in enumerate(chat):
msg_tokens = tokenizer.apply_chat_template([msg], tokenize=True)
# remove unexpected bos token
if i > 0 and msg_tokens[0] == tokenizer.bos_token_id:
msg_tokens = msg_tokens[1:]
tokens.extend(msg_tokens)
if msg["role"] == "assistant":
assistant_mask.extend([True] * len(msg_tokens))
else:
assistant_mask.extend([False] * len(msg_tokens))
attention_mask = [1] * len(tokens)
if max_length is not None:
if padding and len(tokens) < max_length:
to_pad = max_length - len(tokens)
if tokenizer.padding_side == "right":
tokens.extend([tokenizer.pad_token_id] * to_pad)
assistant_mask.extend([False] * to_pad)
attention_mask.extend([0] * to_pad)
else:
tokens = [tokenizer.pad_token_id] * to_pad + tokens
assistant_mask = [False] * to_pad + assistant_mask
attention_mask = [0] * to_pad + attention_mask
if truncation and len(tokens) > max_length:
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
attention_mask = attention_mask[:max_length]
input_ids = torch.tensor(tokens, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = input_ids.clone()
labels[~torch.tensor(assistant_mask, dtype=torch.bool)] = ignore_idx

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}


class RawConversationDataset(Dataset):
"""
Raw conversation dataset.
Each instance is a dictionary with fields `system`, `roles`, `messages`, `offset`, `sep_style`, `seps`.
"""

def __init__(self, tokenizer: PreTrainedTokenizer, input_file: str, max_length: int) -> None:
self.tokenizer = tokenizer
self.raw_texts = []
with jsonlines.open(input_file) as f:
for line in f:
self.raw_texts.append(line)
self.tokenized_texts = [None] * len(self.raw_texts)
self.max_length = max_length

def __len__(self) -> int:
return len(self.raw_texts)

def __getitem__(self, index: int):
if self.tokenized_texts[index] is None:
message = self.raw_texts[index]
tokens = apply_chat_template_and_mask(self.tokenizer, message, self.max_length)
self.tokenized_texts[index] = dict(tokens)
return self.tokenized_texts[index]
Loading