diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama/README.rst similarity index 100% rename from tutorials/llm/llama-3/README.rst rename to tutorials/llm/llama/README.rst diff --git a/tutorials/llm/llama-3/biomedical-qa/README.rst b/tutorials/llm/llama/biomedical-qa/README.rst similarity index 100% rename from tutorials/llm/llama-3/biomedical-qa/README.rst rename to tutorials/llm/llama/biomedical-qa/README.rst diff --git a/tutorials/llm/llama-3/biomedical-qa/img/e2e-lora-train-and-deploy.png b/tutorials/llm/llama/biomedical-qa/img/e2e-lora-train-and-deploy.png similarity index 100% rename from tutorials/llm/llama-3/biomedical-qa/img/e2e-lora-train-and-deploy.png rename to tutorials/llm/llama/biomedical-qa/img/e2e-lora-train-and-deploy.png diff --git a/tutorials/llm/llama-3/biomedical-qa/llama3-lora-deploy-nim.ipynb b/tutorials/llm/llama/biomedical-qa/llama3-lora-deploy-nim.ipynb similarity index 100% rename from tutorials/llm/llama-3/biomedical-qa/llama3-lora-deploy-nim.ipynb rename to tutorials/llm/llama/biomedical-qa/llama3-lora-deploy-nim.ipynb diff --git a/tutorials/llm/llama-3/biomedical-qa/llama3-lora-nemofw.ipynb b/tutorials/llm/llama/biomedical-qa/llama3-lora-nemofw.ipynb similarity index 100% rename from tutorials/llm/llama-3/biomedical-qa/llama3-lora-nemofw.ipynb rename to tutorials/llm/llama/biomedical-qa/llama3-lora-nemofw.ipynb diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/.gitignore b/tutorials/llm/llama/domain-adaptive-pretraining/.gitignore new file mode 100644 index 000000000000..ca5fecf18b75 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/.gitignore @@ -0,0 +1,8 @@ +/code/general_data/* +/code/data/* +/code/models/* +/code/nemo_experiments/ +./preprocessed_data_text_document.bin +./preprocessed_data_text_document.idx +./llama2_7b.py +./test_convert.py \ No newline at end of file diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/README.md b/tutorials/llm/llama/domain-adaptive-pretraining/README.md new file mode 100755 index 000000000000..2cd24d5ab712 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/README.md @@ -0,0 +1,35 @@ +# ChipNeMo - Custom tokenization + Domain Adaptive Pre-training on Llama 2 7b with NeMo Framework + +[ChipNeMo](https://arxiv.org/pdf/2311.00176) is a chip design domain-adapted Large Language Model (LLM). Instead of directly deploying off-the-shelf commercial or open-source LLMs, the paper adopts the following domain adaptation techniques: domain-adaptive tokenization, domain-adaptive continued pre-training, model alignment with domain-specific instructions, and domain-adapted retrieval models. Specifically, Llama 2 foundation models are continually pre-trained with more than 20 billion tokens on domain-specific chip design data, including code and documents. They are then fine-tuned with instruction datasets from design data as well as external sources. Evaluations on the resultant domain-adapted ChipNeMo model demonstrate that domain-adaptive pre-training of language models can lead to superior performance in domain-related downstream tasks compared to their base Llama 2 counterparts, without degradations in generic capabilities. + +Here, we share a tutorial with best practices on custom tokenization and DAPT (Domain-Adaptive Pre-Training) for a ChipNeMo-like code generation use case. + +## Requirements + +### Software Requirements +* Access to latest NeMo Framework NGC Containers +* This playbook has been tested on: nvcr.io/nvidia/nemo:24.07. It is expected to work similarly on other environments. + +### Hardware Requirements +* This playbook can run on CPUs or GPUs. For GPUs, this playbook has been tested on minimum 2xA100 80G + +### Data Curation + +* In this tutorial, we will leverage chip domain/hardware datasets from open-source GitHub repositories, wiki URLs, and academic papers. Therefore, as a pre-requisite the user should curate the domain specific and general purpose data using the NeMo Curator and place them in the directories mentioned below. + +* `./code/data` should contain curated data from chip domain after processing with NeMo Curator. Playbook for DAPT data curation can be found [here](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/dapt-curation) + +* `./code/general_data` should contain open-source general purpose data that llama-2 was trained on. This data will help idenitfy token/vocabulary differences between general purpose and domain-specific datasets. Data can be downloaded from [Wikepedia](https://huggingface.co/datasets/legacy-datasets/wikipedia), [CommonCrawl](https://data.commoncrawl.org/) etc. and curated with [NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main/tutorials/single_node_tutorial) + + +## Custom Tokenization for DAPT + +After placing the curated data in the directories mentioned above, we can proceed with custom tokenization and DAPT. + +* `./code/custom_tokenization.ipynb` walks through the custom tokenization workflow required for DAPT + +## Pretraining for DAPT + +Once we have the domain adapted custom tokenizer from above, we can proceed with pretraining using the customer tokenizer. + +* `./code/domain_adaptive_pretraining.ipynb` walks through the pretraining workflow required for DAPT diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/custom_tokenization.ipynb b/tutorials/llm/llama/domain-adaptive-pretraining/code/custom_tokenization.ipynb new file mode 100644 index 000000000000..f4f0547c59a2 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/custom_tokenization.ipynb @@ -0,0 +1,1920 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6196c6f2-f088-4c28-b9a8-e921f9a7465d", + "metadata": {}, + "source": [ + "# Custom Tokenization for DAPT (Domain Adaptive Pre-Training)" + ] + }, + { + "cell_type": "markdown", + "id": "dbd33f30-2a18-480f-a7f1-210ac99b937c", + "metadata": {}, + "source": [ + "This notebook walks through the custom tokenization workflow required for DAPT (Domain Adaptive Pre-training) as shown in the schematic diagram below. \n", + "\n", + "![pipeline](imgs/tokenization_diagram.png)" + ] + }, + { + "cell_type": "markdown", + "id": "56f509da", + "metadata": {}, + "source": [ + "### Custom Tokenization Workflow" + ] + }, + { + "cell_type": "markdown", + "id": "12fe579a", + "metadata": {}, + "source": [ + "#### Goal\n", + "Given a pre-trained tokenizer trained on general purpose datasets (Original Tokenizer), our goal is to adapt it to a given domain that we want to apply it to (in this notebook, the example domain we are looking at is ChipDesign).\n", + "\n", + "When adapting a pre-trained tokenizer to a given domain, the main goals are to improve tokenization efficiency on domain-specific data, maintain efficiency and language model performance on general purpose datasets, and minimize the effort for retraining/fine-tuning. Since we don't have access to the entire general purpose data used for pretraining the original tokenizer, we want to preserve the existing token mappings, and any new tokens that are added should be strictly an \"extension\". \n", + "\n", + "Generally, when adapting tokenizer to domain-specific data, the goal is to create a tokenizer that is better suited to the vocabulary and structure of that specific domain. This can improve the efficiency and performance of the model on tasks within that domain through efficient representation of domain specific information.\n", + "\n", + "#### Approach \n", + "The general approach we adopt is to train a Domain Specific Tokenizer from scratch on domain data and use it to identify domain specific tokens that are missing from the original tokenizer. This is done by simply comparing the vocabs of the Original Tokenizer and the newly trained Domain Specific Tokenizer. The missing domain specific tokens are then added to the original tokenizer for extending it to get the final Domain Adapted Tokenizer. \n", + "\n", + "#### Tradeoff \n", + "However, there is a tradeoff to adding missing domain specific tokens to the Original Tokenizer. The challenge is to balance this tradeoff between tokenization efficiency on domain data vs disturbance to the performance on general-purpose data as a result of adding domain specific tokens to the Original Tokenizer.\n", + "\n", + "For instance, addition of a large no. of domain specific tokens can lead to higher efficiency on domain specific data, but DAPT process might take longer since it would take longer for the loss to converge​ due to disturbance of efficiency/performance on the general purpose data.\n", + "\n", + "On the other hand, addition of only a small no. of domain specific tokens can lead to maintained efficiency on general purpose data, but may lack coverage on the domain specific dataset​.\n", + "\n", + "#### Balancing The Tradeoff\n", + "To balance this tradeoff, instead of adding all identified missing domain specific tokens to the original tokenizer, we identify the most frequently occuring tokens using a threshold and only add the ones with usage frequencies above the given threshold to get the final Domain Adapted Tokenizer. \n", + "\n", + "For identifying the most frequently used tokens, we first extend the Original Tokenizer by adding all identified missing domain specific tokens to get an Extended Tokenizer. The Extended Tokenizer is then applyied to the domain specific data in order to identify high frequency tokens. Thus the Extended Tokenizer is just an intermediate step towards building a Domain Adapted Tokenizer.\n", + "\n", + "Finally, the Original Tokenizer is extended using only high frequency tokens to get the final Domain Adapted Tokenizer. " + ] + }, + { + "cell_type": "markdown", + "id": "d0b69b3b-aa66-42c9-b76f-2fde7f29a4b0", + "metadata": {}, + "source": [ + "## Notebook Outline\n", + "\n", + "To achieve the process described above, we’ve developed a step-by-step approach that this notebook will walk you through:\n", + "\n", + "- Step 0: Install pre-requisites and import the required modules\n", + "- Step 1: Download llama-2-70b embedding model and tokenizer (Original Tokenizer). Convert the orginal weights to trainable format and save. \n", + "- Step 2: Train an opt-350m tokenizer from scratch using domain-specific data to get a Domain Specific Tokenizer.\n", + "- Step 3: From the vocabulary of the newly trained tokenizer, identifying tokens that are absent in the general-purpose tokenizer and are rarely found in general-purpose datasets. Next, expand the general-purpose tokenizer with the newly identified tokens to get an Extended Tokenizer.\n", + "- Step 4: Apply the Extended Tokenizer to the domain-specific dataset, analyze the usage frequencies of the newly-added tokens, and select the top-K tokens in a way that their cumulative frequency accounts for approximately 98% (a hyper-parameter) of the total frequency of the new tokens.\n", + "- Step 5: Initialize the embeddings of the new tokens by utilizing the general-purpose tokenizer i.e., Original Tokenizer. When a new token is encountered, it is tokenized using the pretrained general-purpose tokenizer. The embedding and output layer weights corresponding to the new token are determined by averaging the embeddings / weights corresponding to the tokens generated by the general-purpose tokenizer.\n", + "- Step 6: Merge the new embeddings with the original embedding table (in llama2-2-70b) to get the final Domain Adapted Tokenizer.\n", + "## Data\n", + "\n", + "In this playbook, we will leverage chip domain/hardware datasets from open-source GitHub repositories, wiki URLs, and academic papers. Data has been processed and curated using [NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main) as shown in this [playbook](https://github.com/jvamaraju/ndc_dapt_playbook/tree/dapt_jv)" + ] + }, + { + "cell_type": "markdown", + "id": "fbee82dd", + "metadata": {}, + "source": [ + "## NeMo Tools and Resources\n", + "\n", + "* [Nvidia Nemo Framework](https://github.com/NVIDIA/NeMo)" + ] + }, + { + "cell_type": "markdown", + "id": "74be8ece", + "metadata": {}, + "source": [ + "## Software Requirements\n", + "* Access to latest NeMo Framework NGC Containers\n", + "* This playbook has been tested on: nvcr.io/nvidia/nemo:24.07. It is expected to work similarly on other environments. " + ] + }, + { + "cell_type": "markdown", + "id": "d2b5ad09", + "metadata": {}, + "source": [ + "## Hardware Requirements\n", + "* This playbook can run on CPUs or GPUs. For GPUs, this playbook has been tested on minimum 1xA100 80G" + ] + }, + { + "cell_type": "markdown", + "id": "80bae538-308f-4d1b-8186-69de6226f3cd", + "metadata": {}, + "source": [ + "## Step 0: install the prerequisites and import the required modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc83794a-daf9-44fb-9b89-f8cde05101a4", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "! pip install datasets sentencepiece jsonlines tokenizers transformers torch ftfy matplotlib\n", + "! pip install protobuf==3.20.1\n", + "! pip install --upgrade jupyter ipywidgets widgetsnbextension pandas-profiling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a91cd358", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import torch\n", + "from datasets import Dataset\n", + "from datasets import IterableDataset\n", + "from datasets import load_dataset\n", + "import jsonlines\n", + "import glob\n", + "from tokenizers import (\n", + " decoders,\n", + " models,\n", + " normalizers,\n", + " pre_tokenizers,\n", + " processors,\n", + " trainers,\n", + " Tokenizer,\n", + ")\n", + "from transformers import AutoTokenizer\n", + "from tokenization_helper import *\n", + "from extend_tokenizer_utils import extend_tokenizer, extend_tokenizer_high_freq_tokens\n", + "from get_high_freq_tokens import get_high_freq_tokens\n", + "from util import load_weights, merge_embed" + ] + }, + { + "cell_type": "markdown", + "id": "24d00d02", + "metadata": {}, + "source": [ + "## Step 1: Download llama-2-70b embedding model and tokenizer (Original Tokenizer). Convert the orginal weights to trainable format and save. \n", + "\n", + "The Original Tokenizer model used here is the llama2 tokenizer which is a Byte Pair Encoding (BPE) model based on sentencepiece.\n", + "\n", + "Here we first log into the Hugging Face before downloading the model since the model is in a restricted repo." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "38cf0264", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Install the hugging face CLI\n", + "! pip install -U \"huggingface_hub[cli]\"\n", + "# Generate a user access token at https://huggingface.co/settings/tokens\n", + "\n", + "# To download the model, please login via huggingface-cli login since it is a restricted repo\n", + "! huggingface-cli login\n", + "# You will be prompted to enter your User Access Token. Copy and paste the token, then press Enter. The CLI will verify the token and save it locally." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d54f21e5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# create directory for storing the downloaded hugging face model \n", + "os.makedirs(\"models/weight/llama2-hf\", exist_ok=True)\n", + "\n", + "# create directories for storing the model weights \n", + "os.makedirs(\"models/weight/llama2/ori_llama2-hf_weight\", exist_ok=True)\n", + "os.makedirs(\"models/weight/llama2/new_llama2-hf_weight\", exist_ok=True)\n", + "\n", + "# create directories for storing the tokenizers\n", + "os.makedirs(\"models/tokenizer/llama2/original_tokenizer\", exist_ok=True)\n", + "os.makedirs(\"models/tokenizer/llama2/new_tokenizer\", exist_ok=True)" + ] + }, + { + "cell_type": "markdown", + "id": "558753ac", + "metadata": {}, + "source": [ + "Before running the next step, make sure you have access granted for Meta's Llama2 models gated group. You can fill the form available on https://huggingface.co/meta-llama/Llama-2-7b in order to get the access. (Takes ~20 minutes)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7f0c3988", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# download llama2-70b model weights and tokenizer \n", + "! huggingface-cli download meta-llama/Llama-2-70b --local-dir ./models/weight/llama2-hf/\n", + "\n", + "# #Copy original tokenizer to a different folder\n", + "! cp ./models/weight/llama2-hf/tokenizer.model ./models/tokenizer/llama2/original_tokenizer\n", + "\n", + "# Load embedding and output layer weights (size = (vocab size,embedding dim)) from each snapshot and create a dict\n", + "load_path = \"./models/weight/llama2-hf\"\n", + "save_path = './models/weight/llama2/ori_llama2-hf_weight'\n", + "\n", + "if not os.path.exists(save_path):\n", + " os.makedirs(save_path)\n", + " \n", + "#load weight and store in a dictionary suitable for NeMo\n", + "load_weights(load_path, save_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6ade9a02-d38f-436c-82d0-bd16b54dbbf8", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Index: 0, layer: tok_embeddings.weight, Layer size: torch.Size([32000, 1024])\n", + "Index: 1, layer: norm.weight, Layer size: torch.Size([8192])\n", + "Index: 2, layer: output.weight, Layer size: torch.Size([4000, 8192])\n", + "Index: 3, layer: layers.0.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 4, layer: layers.0.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 5, layer: layers.0.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 6, layer: layers.0.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 7, layer: layers.0.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 8, layer: layers.0.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 9, layer: layers.0.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 10, layer: layers.0.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 11, layer: layers.0.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 12, layer: layers.1.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 13, layer: layers.1.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 14, layer: layers.1.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 15, layer: layers.1.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 16, layer: layers.1.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 17, layer: layers.1.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 18, layer: layers.1.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 19, layer: layers.1.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 20, layer: layers.1.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 21, layer: layers.2.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 22, layer: layers.2.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 23, layer: layers.2.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 24, layer: layers.2.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 25, layer: layers.2.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 26, layer: layers.2.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 27, layer: layers.2.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 28, layer: layers.2.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 29, layer: layers.2.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 30, layer: layers.3.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 31, layer: layers.3.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 32, layer: layers.3.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 33, layer: layers.3.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 34, layer: layers.3.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 35, layer: layers.3.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 36, layer: layers.3.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 37, layer: layers.3.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 38, layer: layers.3.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 39, layer: layers.4.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 40, layer: layers.4.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 41, layer: layers.4.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 42, layer: layers.4.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 43, layer: layers.4.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 44, layer: layers.4.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 45, layer: layers.4.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 46, layer: layers.4.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 47, layer: layers.4.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 48, layer: layers.5.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 49, layer: layers.5.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 50, layer: layers.5.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 51, layer: layers.5.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 52, layer: layers.5.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 53, layer: layers.5.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 54, layer: layers.5.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 55, layer: layers.5.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 56, layer: layers.5.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 57, layer: layers.6.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 58, layer: layers.6.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 59, layer: layers.6.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 60, layer: layers.6.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 61, layer: layers.6.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 62, layer: layers.6.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 63, layer: layers.6.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 64, layer: layers.6.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 65, layer: layers.6.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 66, layer: layers.7.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 67, layer: layers.7.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 68, layer: layers.7.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 69, layer: layers.7.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 70, layer: layers.7.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 71, layer: layers.7.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 72, layer: layers.7.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 73, layer: layers.7.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 74, layer: layers.7.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 75, layer: layers.8.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 76, layer: layers.8.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 77, layer: layers.8.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 78, layer: layers.8.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 79, layer: layers.8.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 80, layer: layers.8.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 81, layer: layers.8.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 82, layer: layers.8.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 83, layer: layers.8.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 84, layer: layers.9.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 85, layer: layers.9.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 86, layer: layers.9.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 87, layer: layers.9.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 88, layer: layers.9.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 89, layer: layers.9.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 90, layer: layers.9.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 91, layer: layers.9.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 92, layer: layers.9.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 93, layer: layers.10.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 94, layer: layers.10.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 95, layer: layers.10.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 96, layer: layers.10.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 97, layer: layers.10.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 98, layer: layers.10.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 99, layer: layers.10.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 100, layer: layers.10.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 101, layer: layers.10.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 102, layer: layers.11.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 103, layer: layers.11.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 104, layer: layers.11.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 105, layer: layers.11.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 106, layer: layers.11.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 107, layer: layers.11.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 108, layer: layers.11.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 109, layer: layers.11.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 110, layer: layers.11.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 111, layer: layers.12.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 112, layer: layers.12.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 113, layer: layers.12.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 114, layer: layers.12.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 115, layer: layers.12.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 116, layer: layers.12.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 117, layer: layers.12.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 118, layer: layers.12.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 119, layer: layers.12.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 120, layer: layers.13.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 121, layer: layers.13.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 122, layer: layers.13.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 123, layer: layers.13.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 124, layer: layers.13.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 125, layer: layers.13.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 126, layer: layers.13.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 127, layer: layers.13.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 128, layer: layers.13.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 129, layer: layers.14.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 130, layer: layers.14.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 131, layer: layers.14.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 132, layer: layers.14.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 133, layer: layers.14.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 134, layer: layers.14.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 135, layer: layers.14.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 136, layer: layers.14.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 137, layer: layers.14.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 138, layer: layers.15.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 139, layer: layers.15.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 140, layer: layers.15.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 141, layer: layers.15.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 142, layer: layers.15.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 143, layer: layers.15.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 144, layer: layers.15.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 145, layer: layers.15.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 146, layer: layers.15.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 147, layer: layers.16.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 148, layer: layers.16.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 149, layer: layers.16.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 150, layer: layers.16.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 151, layer: layers.16.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 152, layer: layers.16.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 153, layer: layers.16.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 154, layer: layers.16.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 155, layer: layers.16.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 156, layer: layers.17.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 157, layer: layers.17.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 158, layer: layers.17.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 159, layer: layers.17.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 160, layer: layers.17.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 161, layer: layers.17.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 162, layer: layers.17.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 163, layer: layers.17.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 164, layer: layers.17.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 165, layer: layers.18.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 166, layer: layers.18.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 167, layer: layers.18.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 168, layer: layers.18.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 169, layer: layers.18.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 170, layer: layers.18.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 171, layer: layers.18.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 172, layer: layers.18.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 173, layer: layers.18.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 174, layer: layers.19.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 175, layer: layers.19.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 176, layer: layers.19.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 177, layer: layers.19.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 178, layer: layers.19.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 179, layer: layers.19.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 180, layer: layers.19.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 181, layer: layers.19.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 182, layer: layers.19.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 183, layer: layers.20.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 184, layer: layers.20.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 185, layer: layers.20.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 186, layer: layers.20.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 187, layer: layers.20.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 188, layer: layers.20.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 189, layer: layers.20.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 190, layer: layers.20.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 191, layer: layers.20.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 192, layer: layers.21.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 193, layer: layers.21.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 194, layer: layers.21.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 195, layer: layers.21.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 196, layer: layers.21.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 197, layer: layers.21.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 198, layer: layers.21.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 199, layer: layers.21.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 200, layer: layers.21.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 201, layer: layers.22.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 202, layer: layers.22.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 203, layer: layers.22.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 204, layer: layers.22.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 205, layer: layers.22.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 206, layer: layers.22.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 207, layer: layers.22.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 208, layer: layers.22.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 209, layer: layers.22.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 210, layer: layers.23.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 211, layer: layers.23.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 212, layer: layers.23.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 213, layer: layers.23.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 214, layer: layers.23.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 215, layer: layers.23.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 216, layer: layers.23.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 217, layer: layers.23.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 218, layer: layers.23.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 219, layer: layers.24.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 220, layer: layers.24.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 221, layer: layers.24.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 222, layer: layers.24.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 223, layer: layers.24.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 224, layer: layers.24.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 225, layer: layers.24.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 226, layer: layers.24.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 227, layer: layers.24.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 228, layer: layers.25.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 229, layer: layers.25.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 230, layer: layers.25.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 231, layer: layers.25.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 232, layer: layers.25.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 233, layer: layers.25.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 234, layer: layers.25.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 235, layer: layers.25.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 236, layer: layers.25.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 237, layer: layers.26.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 238, layer: layers.26.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 239, layer: layers.26.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 240, layer: layers.26.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 241, layer: layers.26.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 242, layer: layers.26.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 243, layer: layers.26.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 244, layer: layers.26.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 245, layer: layers.26.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 246, layer: layers.27.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 247, layer: layers.27.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 248, layer: layers.27.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 249, layer: layers.27.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 250, layer: layers.27.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 251, layer: layers.27.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 252, layer: layers.27.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 253, layer: layers.27.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 254, layer: layers.27.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 255, layer: layers.28.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 256, layer: layers.28.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 257, layer: layers.28.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 258, layer: layers.28.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 259, layer: layers.28.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 260, layer: layers.28.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 261, layer: layers.28.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 262, layer: layers.28.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 263, layer: layers.28.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 264, layer: layers.29.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 265, layer: layers.29.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 266, layer: layers.29.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 267, layer: layers.29.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 268, layer: layers.29.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 269, layer: layers.29.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 270, layer: layers.29.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 271, layer: layers.29.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 272, layer: layers.29.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 273, layer: layers.30.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 274, layer: layers.30.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 275, layer: layers.30.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 276, layer: layers.30.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 277, layer: layers.30.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 278, layer: layers.30.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 279, layer: layers.30.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 280, layer: layers.30.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 281, layer: layers.30.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 282, layer: layers.31.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 283, layer: layers.31.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 284, layer: layers.31.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 285, layer: layers.31.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 286, layer: layers.31.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 287, layer: layers.31.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 288, layer: layers.31.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 289, layer: layers.31.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 290, layer: layers.31.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 291, layer: layers.32.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 292, layer: layers.32.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 293, layer: layers.32.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 294, layer: layers.32.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 295, layer: layers.32.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 296, layer: layers.32.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 297, layer: layers.32.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 298, layer: layers.32.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 299, layer: layers.32.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 300, layer: layers.33.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 301, layer: layers.33.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 302, layer: layers.33.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 303, layer: layers.33.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 304, layer: layers.33.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 305, layer: layers.33.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 306, layer: layers.33.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 307, layer: layers.33.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 308, layer: layers.33.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 309, layer: layers.34.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 310, layer: layers.34.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 311, layer: layers.34.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 312, layer: layers.34.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 313, layer: layers.34.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 314, layer: layers.34.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 315, layer: layers.34.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 316, layer: layers.34.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 317, layer: layers.34.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 318, layer: layers.35.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 319, layer: layers.35.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 320, layer: layers.35.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 321, layer: layers.35.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 322, layer: layers.35.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 323, layer: layers.35.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 324, layer: layers.35.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 325, layer: layers.35.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 326, layer: layers.35.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 327, layer: layers.36.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 328, layer: layers.36.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 329, layer: layers.36.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 330, layer: layers.36.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 331, layer: layers.36.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 332, layer: layers.36.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 333, layer: layers.36.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 334, layer: layers.36.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 335, layer: layers.36.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 336, layer: layers.37.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 337, layer: layers.37.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 338, layer: layers.37.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 339, layer: layers.37.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 340, layer: layers.37.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 341, layer: layers.37.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 342, layer: layers.37.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 343, layer: layers.37.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 344, layer: layers.37.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 345, layer: layers.38.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 346, layer: layers.38.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 347, layer: layers.38.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 348, layer: layers.38.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 349, layer: layers.38.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 350, layer: layers.38.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 351, layer: layers.38.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 352, layer: layers.38.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 353, layer: layers.38.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 354, layer: layers.39.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 355, layer: layers.39.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 356, layer: layers.39.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 357, layer: layers.39.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 358, layer: layers.39.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 359, layer: layers.39.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 360, layer: layers.39.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 361, layer: layers.39.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 362, layer: layers.39.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 363, layer: layers.40.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 364, layer: layers.40.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 365, layer: layers.40.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 366, layer: layers.40.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 367, layer: layers.40.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 368, layer: layers.40.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 369, layer: layers.40.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 370, layer: layers.40.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 371, layer: layers.40.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 372, layer: layers.41.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 373, layer: layers.41.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 374, layer: layers.41.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 375, layer: layers.41.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 376, layer: layers.41.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 377, layer: layers.41.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 378, layer: layers.41.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 379, layer: layers.41.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 380, layer: layers.41.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 381, layer: layers.42.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 382, layer: layers.42.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 383, layer: layers.42.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 384, layer: layers.42.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 385, layer: layers.42.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 386, layer: layers.42.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 387, layer: layers.42.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 388, layer: layers.42.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 389, layer: layers.42.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 390, layer: layers.43.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 391, layer: layers.43.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 392, layer: layers.43.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 393, layer: layers.43.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 394, layer: layers.43.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 395, layer: layers.43.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 396, layer: layers.43.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 397, layer: layers.43.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 398, layer: layers.43.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 399, layer: layers.44.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 400, layer: layers.44.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 401, layer: layers.44.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 402, layer: layers.44.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 403, layer: layers.44.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 404, layer: layers.44.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 405, layer: layers.44.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 406, layer: layers.44.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 407, layer: layers.44.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 408, layer: layers.45.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 409, layer: layers.45.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 410, layer: layers.45.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 411, layer: layers.45.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 412, layer: layers.45.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 413, layer: layers.45.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 414, layer: layers.45.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 415, layer: layers.45.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 416, layer: layers.45.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 417, layer: layers.46.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 418, layer: layers.46.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 419, layer: layers.46.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 420, layer: layers.46.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 421, layer: layers.46.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 422, layer: layers.46.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 423, layer: layers.46.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 424, layer: layers.46.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 425, layer: layers.46.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 426, layer: layers.47.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 427, layer: layers.47.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 428, layer: layers.47.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 429, layer: layers.47.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 430, layer: layers.47.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 431, layer: layers.47.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 432, layer: layers.47.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 433, layer: layers.47.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 434, layer: layers.47.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 435, layer: layers.48.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 436, layer: layers.48.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 437, layer: layers.48.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 438, layer: layers.48.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 439, layer: layers.48.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 440, layer: layers.48.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 441, layer: layers.48.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 442, layer: layers.48.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 443, layer: layers.48.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 444, layer: layers.49.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 445, layer: layers.49.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 446, layer: layers.49.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 447, layer: layers.49.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 448, layer: layers.49.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 449, layer: layers.49.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 450, layer: layers.49.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 451, layer: layers.49.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 452, layer: layers.49.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 453, layer: layers.50.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 454, layer: layers.50.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 455, layer: layers.50.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 456, layer: layers.50.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 457, layer: layers.50.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 458, layer: layers.50.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 459, layer: layers.50.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 460, layer: layers.50.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 461, layer: layers.50.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 462, layer: layers.51.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 463, layer: layers.51.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 464, layer: layers.51.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 465, layer: layers.51.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 466, layer: layers.51.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 467, layer: layers.51.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 468, layer: layers.51.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 469, layer: layers.51.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 470, layer: layers.51.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 471, layer: layers.52.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 472, layer: layers.52.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 473, layer: layers.52.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 474, layer: layers.52.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 475, layer: layers.52.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 476, layer: layers.52.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 477, layer: layers.52.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 478, layer: layers.52.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 479, layer: layers.52.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 480, layer: layers.53.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 481, layer: layers.53.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 482, layer: layers.53.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 483, layer: layers.53.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 484, layer: layers.53.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 485, layer: layers.53.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 486, layer: layers.53.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 487, layer: layers.53.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 488, layer: layers.53.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 489, layer: layers.54.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 490, layer: layers.54.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 491, layer: layers.54.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 492, layer: layers.54.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 493, layer: layers.54.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 494, layer: layers.54.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 495, layer: layers.54.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 496, layer: layers.54.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 497, layer: layers.54.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 498, layer: layers.55.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 499, layer: layers.55.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 500, layer: layers.55.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 501, layer: layers.55.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 502, layer: layers.55.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 503, layer: layers.55.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 504, layer: layers.55.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 505, layer: layers.55.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 506, layer: layers.55.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 507, layer: layers.56.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 508, layer: layers.56.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 509, layer: layers.56.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 510, layer: layers.56.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 511, layer: layers.56.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 512, layer: layers.56.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 513, layer: layers.56.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 514, layer: layers.56.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 515, layer: layers.56.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 516, layer: layers.57.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 517, layer: layers.57.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 518, layer: layers.57.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 519, layer: layers.57.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 520, layer: layers.57.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 521, layer: layers.57.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 522, layer: layers.57.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 523, layer: layers.57.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 524, layer: layers.57.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 525, layer: layers.58.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 526, layer: layers.58.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 527, layer: layers.58.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 528, layer: layers.58.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 529, layer: layers.58.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 530, layer: layers.58.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 531, layer: layers.58.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 532, layer: layers.58.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 533, layer: layers.58.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 534, layer: layers.59.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 535, layer: layers.59.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 536, layer: layers.59.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 537, layer: layers.59.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 538, layer: layers.59.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 539, layer: layers.59.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 540, layer: layers.59.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 541, layer: layers.59.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 542, layer: layers.59.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 543, layer: layers.60.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 544, layer: layers.60.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 545, layer: layers.60.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 546, layer: layers.60.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 547, layer: layers.60.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 548, layer: layers.60.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 549, layer: layers.60.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 550, layer: layers.60.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 551, layer: layers.60.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 552, layer: layers.61.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 553, layer: layers.61.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 554, layer: layers.61.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 555, layer: layers.61.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 556, layer: layers.61.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 557, layer: layers.61.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 558, layer: layers.61.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 559, layer: layers.61.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 560, layer: layers.61.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 561, layer: layers.62.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 562, layer: layers.62.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 563, layer: layers.62.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 564, layer: layers.62.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 565, layer: layers.62.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 566, layer: layers.62.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 567, layer: layers.62.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 568, layer: layers.62.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 569, layer: layers.62.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 570, layer: layers.63.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 571, layer: layers.63.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 572, layer: layers.63.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 573, layer: layers.63.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 574, layer: layers.63.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 575, layer: layers.63.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 576, layer: layers.63.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 577, layer: layers.63.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 578, layer: layers.63.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 579, layer: layers.64.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 580, layer: layers.64.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 581, layer: layers.64.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 582, layer: layers.64.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 583, layer: layers.64.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 584, layer: layers.64.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 585, layer: layers.64.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 586, layer: layers.64.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 587, layer: layers.64.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 588, layer: layers.65.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 589, layer: layers.65.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 590, layer: layers.65.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 591, layer: layers.65.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 592, layer: layers.65.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 593, layer: layers.65.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 594, layer: layers.65.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 595, layer: layers.65.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 596, layer: layers.65.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 597, layer: layers.66.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 598, layer: layers.66.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 599, layer: layers.66.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 600, layer: layers.66.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 601, layer: layers.66.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 602, layer: layers.66.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 603, layer: layers.66.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 604, layer: layers.66.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 605, layer: layers.66.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 606, layer: layers.67.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 607, layer: layers.67.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 608, layer: layers.67.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 609, layer: layers.67.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 610, layer: layers.67.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 611, layer: layers.67.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 612, layer: layers.67.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 613, layer: layers.67.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 614, layer: layers.67.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 615, layer: layers.68.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 616, layer: layers.68.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 617, layer: layers.68.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 618, layer: layers.68.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 619, layer: layers.68.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 620, layer: layers.68.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 621, layer: layers.68.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 622, layer: layers.68.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 623, layer: layers.68.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 624, layer: layers.69.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 625, layer: layers.69.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 626, layer: layers.69.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 627, layer: layers.69.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 628, layer: layers.69.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 629, layer: layers.69.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 630, layer: layers.69.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 631, layer: layers.69.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 632, layer: layers.69.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 633, layer: layers.70.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 634, layer: layers.70.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 635, layer: layers.70.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 636, layer: layers.70.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 637, layer: layers.70.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 638, layer: layers.70.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 639, layer: layers.70.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 640, layer: layers.70.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 641, layer: layers.70.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 642, layer: layers.71.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 643, layer: layers.71.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 644, layer: layers.71.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 645, layer: layers.71.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 646, layer: layers.71.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 647, layer: layers.71.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 648, layer: layers.71.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 649, layer: layers.71.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 650, layer: layers.71.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 651, layer: layers.72.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 652, layer: layers.72.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 653, layer: layers.72.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 654, layer: layers.72.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 655, layer: layers.72.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 656, layer: layers.72.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 657, layer: layers.72.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 658, layer: layers.72.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 659, layer: layers.72.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 660, layer: layers.73.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 661, layer: layers.73.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 662, layer: layers.73.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 663, layer: layers.73.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 664, layer: layers.73.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 665, layer: layers.73.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 666, layer: layers.73.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 667, layer: layers.73.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 668, layer: layers.73.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 669, layer: layers.74.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 670, layer: layers.74.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 671, layer: layers.74.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 672, layer: layers.74.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 673, layer: layers.74.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 674, layer: layers.74.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 675, layer: layers.74.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 676, layer: layers.74.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 677, layer: layers.74.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 678, layer: layers.75.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 679, layer: layers.75.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 680, layer: layers.75.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 681, layer: layers.75.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 682, layer: layers.75.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 683, layer: layers.75.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 684, layer: layers.75.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 685, layer: layers.75.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 686, layer: layers.75.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 687, layer: layers.76.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 688, layer: layers.76.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 689, layer: layers.76.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 690, layer: layers.76.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 691, layer: layers.76.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 692, layer: layers.76.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 693, layer: layers.76.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 694, layer: layers.76.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 695, layer: layers.76.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 696, layer: layers.77.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 697, layer: layers.77.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 698, layer: layers.77.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 699, layer: layers.77.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 700, layer: layers.77.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 701, layer: layers.77.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 702, layer: layers.77.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 703, layer: layers.77.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 704, layer: layers.77.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 705, layer: layers.78.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 706, layer: layers.78.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 707, layer: layers.78.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 708, layer: layers.78.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 709, layer: layers.78.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 710, layer: layers.78.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 711, layer: layers.78.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 712, layer: layers.78.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 713, layer: layers.78.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 714, layer: layers.79.attention.wq.weight, Layer size: torch.Size([1024, 8192])\n", + "Index: 715, layer: layers.79.attention.wk.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 716, layer: layers.79.attention.wv.weight, Layer size: torch.Size([128, 8192])\n", + "Index: 717, layer: layers.79.attention.wo.weight, Layer size: torch.Size([8192, 1024])\n", + "Index: 718, layer: layers.79.feed_forward.w1.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 719, layer: layers.79.feed_forward.w2.weight, Layer size: torch.Size([8192, 3584])\n", + "Index: 720, layer: layers.79.feed_forward.w3.weight, Layer size: torch.Size([3584, 8192])\n", + "Index: 721, layer: layers.79.attention_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 722, layer: layers.79.ffn_norm.weight, Layer size: torch.Size([8192])\n", + "Index: 723, layer: rope.freqs, Layer size: torch.Size([64])\n" + ] + } + ], + "source": [ + "# check layers and dimensions (optional)\n", + "state_dict = torch.load(f\"{load_path}/consolidated.0{1}.pth\")\n", + "for index, (key, value) in enumerate(state_dict.items()):\n", + " print(f\"Index: {index}, layer: {key}, Layer size: {value.size()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a5a65e1f", + "metadata": {}, + "source": [ + "## Step 2: Train a tokenizer from scratch using domain-specific data to get a Domain Specific Tokenizer." + ] + }, + { + "cell_type": "markdown", + "id": "fa45b02d-2e2f-4b81-9cba-ed07cbda5b9d", + "metadata": {}, + "source": [ + "First, we train a tokenizer from scratch using domain-specific data.\n", + "\n", + "The tokenizer that we use is the facebook/opt-350m model tokenizer available here on hugging face. Similar to the llama-2 tokenizer, opt-350m tokenizer is also a Byte Pair Encoding (BPE) model and since we are training from scratch we could use any of them. Infact, we can use any model's tokenizer that is implemented based on BPE since the training algorithm inside the tokenizer is what matters. However, we chose opt-350m since it has a more general purpose design and can be used flexibly across different tasks/domains and with various models beyond the OPT series. On the other hand llama-2 tokenizer is designed specifically for llama-2 architecture, optimizing performance for tasks that llama-2 model is intended to handle. \n", + "\n", + "The two hyperparameters that need to be set here are ```batch_size``` and ```vocab_size```.
\n", + "\n", + "```vocab_size``` : is the target vocab size in finetuning the tokenizer. This depends on the original tokenizer and should be slightly higher than half of the original vocab size. Note that this doesn't have to equal the number of new tokens that will be added. \n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "518ca72b-4ab8-4538-8eb4-42d648346347", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Is a directory: True\n" + ] + } + ], + "source": [ + "data_root = \"./data/all_jsonl_data_sample/\" # path where the domain specific data is stored\n", + "save_root = \"./models/tokenizer/llama2/\" # path to save the finetuned opt tokenizer\n", + "batch_size = 1000 # batch size used in the tokenization process\n", + "vocab_size = 20000 # target vocab size for training opt tokenizer\n", + "\n", + "# ensure that the directory exists before changing permissions\n", + "directory = \"../code/\"\n", + "is_directory = os.path.isdir(directory)\n", + "print(f\"Is a directory: {is_directory}\")\n", + "\n", + "# change permissions to ensure we have read, write and execute permissions\n", + "! chmod ugo+rwx ../code/" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "519e6c13-45c3-4f52-9ed0-770d4ec62766", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:797: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Before Training: \n", + "total token cnt 66025\n", + "\n", + "\n", + "\n", + "After Training: \n", + "total token cnt 47712\n" + ] + }, + { + "data": { + "text/plain": [ + "('./models/tokenizer/llama2/custom_tokenizer_init_20000_json/tokenizer_config.json',\n", + " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/special_tokens_map.json',\n", + " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/vocab.json',\n", + " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/merges.txt',\n", + " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/added_tokens.json',\n", + " './models/tokenizer/llama2/custom_tokenizer_init_20000_json/tokenizer.json')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Train a tokenizer from scratch and save output files\n", + "keys = [\"text\"] # keys to extract from json files\n", + "tokenizer = AutoTokenizer.from_pretrained(\"facebook/opt-350m\") # load pre-trained tokenizer (https://huggingface.co/facebook/opt-350m)\n", + "# Train the tokenizer from scratch on a new corpus with the same defaults (in terms of special tokens or tokenization pipeline) as the current one.\n", + "tokenizer = train_tokenizer(data_root, batch_size, vocab_size, tokenizer, keys)\n", + "\n", + "#Save and print paths\n", + "tokenizer.save_pretrained(save_root + \"custom_tokenizer_init_\" + str(vocab_size) + \"_json\")" + ] + }, + { + "cell_type": "markdown", + "id": "127e0591-fbaa-41bc-87a5-f594587ea12d", + "metadata": { + "tags": [] + }, + "source": [ + "## Step 3: From the vocabulary of the newly trained tokenizer, identify tokens that are absent in the general-purpose tokenizer and are rarely found in general-purpose datasets. Next, expand the general-purpose tokenizer with the newly identified tokens to get an extended Tokenizer.\n", + "\n", + "Here we expand/resize the model embeddings of the original general-purpose tokenizer with the newly identified tokens in Step 3 to get an extended tokenizer.\n", + "\n", + "The two hyperparemeters that need to be set here are ```split``` and ```model_type```. \n", + "\n", + "```split```: is the number of partitions to split the embeddings in (.pt files) for the purpose of model parallelism.\n", + "\n", + "```model_type``` : this is the original tokenizer model (llama2 in our case)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ec202bf4-3508-4a64-90c6-debcf116e81e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Domain vocab size: 5965\n", + "token pattern: [a-zA-Z]\n", + "Num of added tokens and dropped tokens 4931 1034\n", + "Original model pieces: 32000\n", + "input: \"/large_experiments/theorem/datasets/MERGED/all.test1.merged\"\n", + "model_prefix: \"spm_model_32k_200M_charcov099995_allowWSO__v2\"\n", + "model_type: BPE\n", + "vocab_size: 32000\n", + "self_test_sample_size: 0\n", + "input_format: \"text\"\n", + "character_coverage: 0.99995\n", + "input_sentence_size: 200000000\n", + "seed_sentencepiece_size: 1000000\n", + "shrinking_factor: 0.75\n", + "num_threads: 80\n", + "num_sub_iterations: 2\n", + "max_sentence_length: 4192\n", + "shuffle_input_sentence: true\n", + "max_sentencepiece_length: 16\n", + "split_by_unicode_script: true\n", + "split_by_whitespace: true\n", + "split_by_number: true\n", + "treat_whitespace_as_suffix: false\n", + "split_digits: true\n", + "allow_whitespace_only_pieces: true\n", + "vocabulary_output_piece_score: true\n", + "hard_vocab_limit: true\n", + "use_all_vocab: false\n", + "byte_fallback: true\n", + "required_chars: \"\"\n", + "unk_id: 0\n", + "bos_id: 1\n", + "eos_id: 2\n", + "pad_id: -1\n", + "unk_surface: \" \\342\\201\\207 \"\n", + "unk_piece: \"\"\n", + "bos_piece: \"\"\n", + "eos_piece: \"\"\n", + "pad_piece: \"\"\n", + "train_extremely_large_corpus: false\n", + "enable_differential_privacy: false\n", + "differential_privacy_noise_level: 0.0\n", + "differential_privacy_clipping_threshold: 0\n", + "\n", + "original vocab size: 32000\n", + "new token cnt: 1400\n", + "add token cnt: 2048\n", + "add normal token cnt: 1400\n", + "add dummy token cnt: 648\n", + "New model pieces: 34048\n", + "input: \"/large_experiments/theorem/datasets/MERGED/all.test1.merged\"\n", + "model_prefix: \"spm_model_32k_200M_charcov099995_allowWSO__v2\"\n", + "model_type: BPE\n", + "vocab_size: 32000\n", + "self_test_sample_size: 0\n", + "input_format: \"text\"\n", + "character_coverage: 0.99995\n", + "input_sentence_size: 200000000\n", + "seed_sentencepiece_size: 1000000\n", + "shrinking_factor: 0.75\n", + "num_threads: 80\n", + "num_sub_iterations: 2\n", + "max_sentence_length: 4192\n", + "shuffle_input_sentence: true\n", + "max_sentencepiece_length: 16\n", + "split_by_unicode_script: true\n", + "split_by_whitespace: true\n", + "split_by_number: true\n", + "treat_whitespace_as_suffix: false\n", + "split_digits: true\n", + "allow_whitespace_only_pieces: true\n", + "vocabulary_output_piece_score: true\n", + "hard_vocab_limit: true\n", + "use_all_vocab: false\n", + "byte_fallback: true\n", + "required_chars: \"\"\n", + "unk_id: 0\n", + "bos_id: 1\n", + "eos_id: 2\n", + "pad_id: -1\n", + "unk_surface: \" \\342\\201\\207 \"\n", + "unk_piece: \"\"\n", + "bos_piece: \"\"\n", + "eos_piece: \"\"\n", + "pad_piece: \"\"\n", + "train_extremely_large_corpus: false\n", + "enable_differential_privacy: false\n", + "differential_privacy_noise_level: 0.0\n", + "differential_privacy_clipping_threshold: 0\n", + "\n", + "Parent directory './models/tokenizer/llama2/new_tokenizer' exists.\n", + "Parent directory './models/tokenizer/llama2/new_tokenizer' exists.\n", + "word_embedding shape: torch.Size([32000, 8192])\n", + "output_layer shape: torch.Size([32000, 8192])\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Parent directory './models/weight/llama2/new_llama2-hf_weight' exists.\n", + "Completed saving new embeddings\n", + "Vocabulary path for extended tokenizer: ./models/tokenizer/llama2/new_tokenizer/code_gen_vocab.json\n", + "Tokenizer model path for extended tokenizer: ./models/tokenizer/llama2/new_tokenizer/tokenizer_code_gen.model\n", + "Modified embedding weights path for extended tokenizer: ./models/weight/llama2/new_llama2-hf_weight/\n" + ] + } + ], + "source": [ + "split = 8 # number of partitions to split the embeddings of domain-adapted tokenizer\n", + "model_type = \"llama2\" # Add more model_types if you want the codebase to support alternate ones\n", + "extend_tokenizer(vocab_size, split, model_type)" + ] + }, + { + "cell_type": "markdown", + "id": "82ddf5a4-663f-40b0-9477-bd3d3e803c12", + "metadata": {}, + "source": [ + "## Step 4: Use the extended Tokenizer to anylze the frequency of newly added tokens" + ] + }, + { + "cell_type": "markdown", + "id": "9c78a7b5-8122-4d95-afd1-997f047c37f1", + "metadata": {}, + "source": [ + "Here we apply the extended tokenizer to the domain-specific dataset, analyzing the usage frequencies of the newly-added tokens, and selecting the top-K tokens in a way that their cumulative frequency accounts for approximately 98% (a hyper-parameter: ```freq_threshold```) of the total frequency of the new tokens.\n", + "\n", + "The idea is that only high-frequency tokens will be added to the vocabulary of the original tokenizer to get the final domain adapted tokenizer. \n", + "\n", + "The benefits of high-frequency token analysis have been explored in several studies: ([Liu, Mingjie, et al](https://research.nvidia.com/publication/2023-10_chipnemo-domain-adapted-llms-chip-design); [Lian, Haoran, et al](https://arxiv.org/abs/2404.17808)).This is because previous studies have shown that disparities in token frequencies can result in imbalanced learning difficulties across different tokens. For instance, low frequency tokens are harder to learn for models ([Su, Zhenpeng, et al](https://arxiv.org/abs/2310.19531); [Lin, Tsung-Yi, et al](https://openaccess.thecvf.com/content_iccv_2017/html/Lin_Focal_Loss_for_ICCV_2017_paper.html)).\n", + "\n", + "We use two functions for frequency analysis. Helper function `analyze_token_usage` applies the extended tokenizer to domain specific data, and stores the usage/occurence frequencies of the newly-added tokens at `token_usage_path`.
\n", + "\n", + "Helper function `get_high_freq_tokens` looks at the token usage frequencies from above and performs a binary search to search for domain specific tokens with usage frequency above the specified threshold (`freq_threshold` parameter). It stores the tokens it finds at `high_freq_tokens_path`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b36a5465-b409-4d32-9a2f-1dd7c91ef917", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "split = 8 # number of partitions to split the embeddings of domain-adapted tokenizer\n", + "model_type = \"llama2\"\n", + "tag = \"code_gen\"\n", + "keys = [\"text\"]\n", + "# path to the saved extended tokenizer (from previous tep)\n", + "extended_tokenizer_path = f\"./models/tokenizer/{model_type}/new_tokenizer/tokenizer_{tag}.model\"\n", + "# path to save token usage frequency analysis results\n", + "token_usage_path = f\"./models/tokenizer/{model_type}/new_tokenizer/{model_type}_token_usage.json\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7e2b40db-82d6-48ca-a900-145647b4dff1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "vocab_size: 34048\n", + "ori cnt and new cnt: 2209.0 22.0\n", + "ori cnt and new cnt: 1764.0 20.0\n", + "ori cnt and new cnt: 4062.0 259.0\n", + "ori cnt and new cnt: 406.0 7.0\n", + "ori cnt and new cnt: 1872.0 39.0\n", + "ori cnt and new cnt: 645.0 32.0\n", + "ori cnt and new cnt: 2655.0 20.0\n", + "ori cnt and new cnt: 154.0 6.0\n", + "ori cnt and new cnt: 997.0 30.0\n", + "ori cnt and new cnt: 523.0 29.0\n", + "ori cnt and new cnt: 523.0 29.0\n", + "ori cnt and new cnt: 2317.0 95.0\n", + "ori cnt and new cnt: 419.0 10.0\n", + "ori cnt and new cnt: 813.0 13.0\n", + "ori cnt and new cnt: 18796.0 1238.0\n", + "ori cnt and new cnt: 3327.0 113.0\n", + "ori cnt and new cnt: 963.0 29.0\n", + "ori cnt and new cnt: 500.0 21.0\n", + "ori cnt and new cnt: 610.0 22.0\n", + "ori cnt and new cnt: 879.0 18.0\n", + "ori cnt and new cnt: 1681.0 88.0\n", + "ori cnt and new cnt: 654.0 16.0\n", + "ori cnt and new cnt: 62.0 2.0\n", + "ori cnt and new cnt: 1230.0 151.0\n", + "ori cnt and new cnt: 786.0 40.0\n", + "ori cnt and new cnt: 1454.0 22.0\n", + "ori cnt and new cnt: 1237.0 29.0\n", + "ori cnt and new cnt: 1610.0 60.0\n", + "ori cnt and new cnt: 383.0 20.0\n", + "ori cnt and new cnt: 766.0 22.0\n", + "ori cnt and new cnt: 2361.0 20.0\n", + "ori cnt and new cnt: 120.0 3.0\n", + "ori cnt and new cnt: 714.0 31.0\n", + "ori cnt and new cnt: 2185.0 137.0\n", + "ori cnt and new cnt: 1270.0 75.0\n", + "ori cnt and new cnt: 506.0 24.0\n" + ] + } + ], + "source": [ + "# analyze tokens using frequency analysis\n", + "analyze_token_usage(data_root, extended_tokenizer_path, batch_size, keys, token_usage_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ea495ba4-8b65-4560-94a6-269e8af8a83a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# path to save selected high-frequency tokens (new tokens to be added)\n", + "high_freq_tokens_path = f\"./models/tokenizer/{model_type}/new_tokenizer/{model_type}_freq_analy_new_token.json\"\n", + "\n", + "# hyperparameter \n", + "freq_threshold = 0.98" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b816793c-a77c-4cfa-8317-f278cfdbe247", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "./data/all_jsonl_data_sample/7db92aa7a05ae3eb86ec8bd0ab6e6768.lef.gz-0.jsonl\n", + "[4 4 2 2 2 1 1 1 1 1 1 1 1] 21.56\n", + "[4 4 2 2 2 1 1 1 1 1 1 1 1] 21.56\n", + "3\n", + "./data/all_jsonl_data_sample/7d3eb10b8384155f4f262b9d4a9d95b2.lef.gz-0.jsonl\n", + "[4 2 2 2 2 2 1 1 1 1 1 1] 19.6\n", + "[4 2 2 2 2 2 1 1 1 1 1 1] 19.6\n", + "3\n", + "./data/all_jsonl_data_sample/7d1e17d8e8367778544c7664a0dcca34.scala.gz-0.jsonl\n", + "[31 30 30 7 4 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1] 253.82\n", + "[31 30 30 7 4 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1] 253.82\n", + "[31 30 30 7 4 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1] 253.82\n", + "[31 30 30 7 4 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1] 253.82\n", + "[31 30 30 7 4 3 3 3 3 3 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1] 253.82\n", + "31\n", + "./data/all_jsonl_data_sample/7d9023bf5e97a4417b3e3d15bd0155e5.v.gz-0.jsonl\n", + "[2 1 1 1 1 1] 6.859999999999999\n", + "1\n", + "./data/all_jsonl_data_sample/7d4746f7028947dbbf6f6be1e705a343.h.gz-0.jsonl\n", + "[12 6 5 4 3 2 1 1 1 1 1 1 1] 38.22\n", + "[12 6 5 4 3 2 1 1 1 1 1 1 1] 38.22\n", + "[12 6 5 4 3 2 1 1 1 1 1 1 1] 38.22\n", + "[12 6 5 4 3 2 1 1 1 1 1 1 1] 38.22\n", + "12\n", + "./data/all_jsonl_data_sample/7de4bd2089ed29650c6813a692c0b7fd.cdl.gz-0.jsonl\n", + "[5 5 3 3 3 3 3 2 2 1 1 1] 31.36\n", + "[5 5 3 3 3 3 3 2 2 1 1 1] 31.36\n", + "4\n", + "./data/all_jsonl_data_sample/7da75b519311e22a70fe54061b51b67c.sv.gz-0.jsonl\n", + "[2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 19.6\n", + "1\n", + "./data/all_jsonl_data_sample/7d2caac63ccb0ee43f143dfad745a878.v.gz-0.jsonl\n", + "[2 1 1 1 1] 5.88\n", + "1\n", + "./data/all_jsonl_data_sample/7d3ac231744dee023fffc01079a56367.v.gz-0.jsonl\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1 1] 29.4\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1 1] 29.4\n", + "3\n", + "./data/all_jsonl_data_sample/7d223aaa5ad782ba0e026a4fcd6a5e0d.v.gz-0.jsonl\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n", + "3\n", + "./data/all_jsonl_data_sample/7d233e7cb17ecddca9baf0704309e739.v.gz-0.jsonl\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n", + "[4 3 3 3 3 3 2 2 1 1 1 1 1 1] 28.419999999999998\n", + "3\n", + "./data/all_jsonl_data_sample/7de185d29809f5259616436204ae6c07.spice.gz-0.jsonl\n", + "[21 20 20 16 16 1 1] 93.1\n", + "[21 20 20 16 16 1 1] 93.1\n", + "6\n", + "./data/all_jsonl_data_sample/7d398c165432cac33c33442b2b2b9915.v.gz-0.jsonl\n", + "[3 2 1 1 1 1 1] 9.8\n", + "[3 2 1 1 1 1 1] 9.8\n", + "3\n", + "./data/all_jsonl_data_sample/7db39e3425d097664d5b3aa4800501ad.h.gz-0.jsonl\n", + "[8 1 1 1 1 1] 12.74\n", + "[8 1 1 1 1 1] 12.74\n", + "6\n", + "./data/all_jsonl_data_sample/7dbb099365a3ef31bbc60c3fc37be762.qip.gz-0.jsonl\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "[143 27 26 25 20 19 19 17 17 17 17 17 17 17 17 17 16 16\n", + " 13 12 11 11 11 10 10 10 10 9 8 8 8 7 7 7 7 7\n", + " 6 6 6 6 6 6 6 6 6 6 6 5 5 5 5 5 5 5\n", + " 5 5 5 5 5 4 4 4 4 4 4 4 4 4 4 4 4 4\n", + " 4 4 4 4 4 4 4 4 4 4 4 3 3 3 3 3 3 3\n", + " 3 3 3 3 3 3 3 3 3 3 3 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2\n", + " 2 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1] 1213.24\n", + "142\n", + "./data/all_jsonl_data_sample/7dc1209d13f0e65aab95b30e28fdc7b0.spice.gz-0.jsonl\n", + "[29 24 24 15 11 7 1 1 1] 110.74\n", + "[29 24 24 15 11 7 1 1 1] 110.74\n", + "7\n", + "./data/all_jsonl_data_sample/7dddcc0a609031cc37396af611abd521.v.gz-0.jsonl\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "7\n", + "./data/all_jsonl_data_sample/7dbf14d1da77bcf50408a3548fba5443.v.gz-0.jsonl\n", + "[4 3 3 3 2 2 1 1 1 1] 20.58\n", + "[4 3 3 3 2 2 1 1 1 1] 20.58\n", + "3\n", + "./data/all_jsonl_data_sample/7d1585b1aef10fb448a5b5b8fbd0b624.v.gz-0.jsonl\n", + "[3 3 3 2 2 2 2 2 2 1] 21.56\n", + "[3 3 3 2 2 2 2 2 2 1] 21.56\n", + "3\n", + "./data/all_jsonl_data_sample/7d0e56309d51283206ed83074b1ccf76.sv.gz-0.jsonl\n", + "[4 4 1 1 1 1 1 1 1 1 1 1] 17.64\n", + "[4 4 1 1 1 1 1 1 1 1 1 1] 17.64\n", + "3\n", + "./data/all_jsonl_data_sample/7dab8d5649d18c7a9c155b51392a6588.v.gz-0.jsonl\n", + "[29 18 4 4 4 4 4 3 3 3 2 2 2 1 1 1 1 1 1] 86.24\n", + "[29 18 4 4 4 4 4 3 3 3 2 2 2 1 1 1 1 1 1] 86.24\n", + "[29 18 4 4 4 4 4 3 3 3 2 2 2 1 1 1 1 1 1] 86.24\n", + "18\n", + "./data/all_jsonl_data_sample/7dbb8e6199e137dcf3d7085bb0ff9975.v.gz-0.jsonl\n", + "[11 3 1 1] 15.68\n", + "[11 3 1 1] 15.68\n", + "4\n", + "./data/all_jsonl_data_sample/7d62d6bd44676d9f2ea5cdbbd594c4ef.c.gz-0.jsonl\n", + "[2] 1.96\n", + "2\n", + "./data/all_jsonl_data_sample/7dbfcd8e236b3b802c78e9ab57b3a1d0.scala.gz-0.jsonl\n", + "[63 11 7 7 5 4 4 3 3 3 3 3 2 2 2 2 2 2 2 2 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 147.98\n", + "[63 11 7 7 5 4 4 3 3 3 3 3 2 2 2 2 2 2 2 2 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 147.98\n", + "[63 11 7 7 5 4 4 3 3 3 3 3 2 2 2 2 2 2 2 2 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 147.98\n", + "[63 11 7 7 5 4 4 3 3 3 3 3 2 2 2 2 2 2 2 2 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 147.98\n", + "36\n", + "./data/all_jsonl_data_sample/7d7b96f51a734da259a6a2ecf379cded.cdl.gz-0.jsonl\n", + "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n", + "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n", + "[6 6 4 4 4 4 3 2 2 2 1 1 1] 39.2\n", + "6\n", + "./data/all_jsonl_data_sample/7d23810b472d58f3487c52b1f773189a.lef.gz-0.jsonl\n", + "[4 3 2 2 2 1 1 1 1 1 1 1 1 1] 21.56\n", + "[4 3 2 2 2 1 1 1 1 1 1 1 1 1] 21.56\n", + "3\n", + "./data/all_jsonl_data_sample/7dc70412013409c8b16c0c9e5f14fcfa.v.gz-0.jsonl\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "[7 7 7 2 2 2 1 1] 28.419999999999998\n", + "7\n", + "./data/all_jsonl_data_sample/7d5dd4296bf9ada66d63916f69a46faf.emf.gz-0.jsonl\n", + "[20 14 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1] 58.8\n", + "[20 14 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1] 58.8\n", + "[20 14 2 2 2 2 2 2 2 2 2 1 1 1 1 1 1 1 1] 58.8\n", + "18\n", + "./data/all_jsonl_data_sample/7dd8f2f49230dee32da3142ac3984412.v.gz-0.jsonl\n", + "[3 2 2 2 2 2 1 1 1 1 1 1 1] 19.6\n", + "[3 2 2 2 2 2 1 1 1 1 1 1 1] 19.6\n", + "3\n", + "./data/all_jsonl_data_sample/7d43a09a7438300752244fb0d9fb05e8.v.gz-0.jsonl\n", + "[3 3 3 2 2 2 2 2 2 1] 21.56\n", + "[3 3 3 2 2 2 2 2 2 1] 21.56\n", + "3\n", + "./data/all_jsonl_data_sample/7de4d2d50bb4424b0fe35bae7a83be7b.lef.gz-0.jsonl\n", + "[4 2 2 2 2 1 1 1 1 1 1 1 1] 19.6\n", + "[4 2 2 2 2 1 1 1 1 1 1 1 1] 19.6\n", + "3\n", + "./data/all_jsonl_data_sample/7d5db7ecf8c09e4d6872e9998d3ffc4c.v.gz-0.jsonl\n", + "[2 1] 2.94\n", + "1\n", + "./data/all_jsonl_data_sample/7d89dd77dcd8779cd013cdad4527558c.h.gz-0.jsonl\n", + "[3 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 30.38\n", + "[3 2 2 2 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1] 30.38\n", + "3\n", + "./data/all_jsonl_data_sample/7d5a331f93e5a37c5e367230ca1c1a14.cdl.gz-0.jsonl\n", + "[16 14 14 14 11 11 8 7 7 5 5 4 3 3 2 2 2 2 2 2 1 1 1] 134.26\n", + "[16 14 14 14 11 11 8 7 7 5 5 4 3 3 2 2 2 2 2 2 1 1 1] 134.26\n", + "[16 14 14 14 11 11 8 7 7 5 5 4 3 3 2 2 2 2 2 2 1 1 1] 134.26\n", + "[16 14 14 14 11 11 8 7 7 5 5 4 3 3 2 2 2 2 2 2 1 1 1] 134.26\n", + "15\n", + "./data/all_jsonl_data_sample/7deb467f7b1a328162b5b8ae171ca139.scala.gz-0.jsonl\n", + "[30 7 6 6 6 4 3 3 2 2 2 1 1 1 1] 73.5\n", + "[30 7 6 6 6 4 3 3 2 2 2 1 1 1 1] 73.5\n", + "[30 7 6 6 6 4 3 3 2 2 2 1 1 1 1] 73.5\n", + "[30 7 6 6 6 4 3 3 2 2 2 1 1 1 1] 73.5\n", + "14\n", + "./data/all_jsonl_data_sample/7d146ea4e04027f987bc4c8d1bf2326e.cdl.gz-0.jsonl\n", + "[4 4 3 2 2 2 2 2 1 1 1] 23.52\n", + "[4 4 3 2 2 2 2 2 1 1 1] 23.52\n", + "3\n" + ] + } + ], + "source": [ + "# selecting the top-K tokens in a way that their cumulative frequency accounts for approximately 98%\n", + "get_high_freq_tokens(token_usage_path, high_freq_tokens_path, float(freq_threshold))" + ] + }, + { + "cell_type": "markdown", + "id": "b040dc93-6667-4c8c-91a7-dd715151bbc6", + "metadata": {}, + "source": [ + "## Step 5: Initialize the embeddings of the new tokens by utilizing the extended general-purpose tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "b2e64e1d-fb92-4cd5-b19a-ff411fb231d7", + "metadata": {}, + "source": [ + "Here we use the `extend_tokenizer` helper fucntion to first add high freq. tokens identified in Step 4 to original tokenizer vocab.​\n", + "\n", + "Both the embedding table and the output layer weights of the original tokenizer depend on the vocab size. Since the vocab size is now changed due to addition of high freq. domain specific tokens, both of these need to be updated.\n", + "\n", + "`extend_sentencepiece` initializes the embeddings of the new tokens by utilizing the general-purpose tokenizer. When a new token (a word or subword unit) is encountered, it is first broken down (tokenized) using the pretrained general-purpose tokenizer. \n", + "\n", + "The new token doesn’t have a predefined embedding (a numerical representation). The embedding of the new token is determined by averaging the embeddings of the tokens generated by the general-purpose tokenizer. For example, if the new token is split into three sub-tokens, the embeddings of these three sub-tokens are averaged to form the embedding of the new token.\n", + "\n", + "Similarly, the weights in the output layer corresponding to the new token are also initialized to the average of the tokens generated by the general-purpose tokenizer. For example, if the new token is split into three sub-tokens, the weights corresponding to these three sub-tokens are averaged to form the weights corresponding to the new token.\n", + "\n", + "Once done, in Step 6 we will merge the new embeddings with the original embedding table (in llama2) to get the final Domain Adapted Tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2644f313-66d8-47d0-9fc0-f1b2edc72d79", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "ori_tokenizer_path = f\"./models/tokenizer/{model_type}/original_tokenizer/tokenizer.model\" # original sentencepiece tokenizer model\n", + "new_vocab_path = f\"./models/tokenizer/{model_type}/new_tokenizer/freq_vocab.json\" # path to record added new tokens\n", + "old_ebd_path = f\"./models/weight/{model_type}/ori_{model_type}-hf_weight/\" # original embeddings\n", + "new_ebd_path = f\"./models/weight/{model_type}/new_{model_type}-hf_weight/\" # path to store augmented embeddings\n", + "domain_adapter_tokenizer_path = f\"./models/tokenizer/{model_type}/new_tokenizer/tokenizer_freq.model\" # augmented sentencepiece model\n", + "split = 8 # num of partitions to split the augmented embeddings" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3946599e-43b6-4391-b6ab-0068f9f93113", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "f = open(high_freq_tokens_path, \"r\")\n", + "new_tokens = json.load(f)\n", + "print(\"new_tokens: \", new_tokens)\n", + "extend_tokenizer_high_freq_tokens(data_root, ori_tokenizer_path, new_tokens, new_vocab_path, domain_adapter_tokenizer_path, old_ebd_path, new_ebd_path, split)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fda5c5ad", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(new_ebd_path) #New weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "27f8da73", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(domain_adapter_tokenizer_path) # domained adapted tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "227c5b66-bb02-4fc6-96c3-c4284d2f6e99", + "metadata": {}, + "source": [ + "# Step 6: Merge the new embeddings with the original embedding table (in llama2) to get the final Domain Adapted Tokenizer and Embeddings." + ] + }, + { + "cell_type": "markdown", + "id": "05999ff2", + "metadata": {}, + "source": [ + "Helper function `merge_embed` takes the original embeddings downloaded from hugging face, and the augmented embeddings generated in Step 5 above, merges them and then saves the result at `save_path`.\n", + "\n", + "For instance, figure below shows an illustration of embedding table modification. Here each row corresponds to a unique token and each column represents a dimension of the embedding vector. The size of the vocabulary determines the number of rows in the embedding table. The embedding layer in the LLM which is responsible for converting the data into numerical vectors uses the embedding table to perform this conversion. The dimensionality of the embedding layer is given by the number of columns in the embedding table.
\n", + "\n", + "![pipeline](imgs/embedding_table.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21658fb2-a54a-41cf-94fc-81735767cdab", + "metadata": {}, + "outputs": [], + "source": [ + "os.makedirs(f\"/models/weight/new_merged_{model_type}-hf\", exist_ok=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4c616e7f", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "old_ebd_path = f\"./models/weight/{model_type}-hf\" # original embeddings downloaded from hf\n", + "new_ebd_path = f\"./models/weight/{model_type}/new_{model_type}-hf_weight\" # augmented embeddings\n", + "save_path = f\"./models/weight/new_merged_{model_type}-hf\" # Path to adapted llama2 weights\n", + "merge_embed(old_ebd_path, new_ebd_path, save_path)" + ] + }, + { + "cell_type": "markdown", + "id": "9060868f", + "metadata": {}, + "source": [ + "### New weights and tokenizer are stored at:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ca159e0c-28b2-4a64-bdf4-314e4191c2a0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(new_ebd_path) #New weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d657312-814a-4ad8-a46f-c9c7bc4ee978", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "print(domain_adapter_tokenizer_path) # domained adapted tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d0782e9c", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [], + "source": [ + "# check layers and dimensions (optional)\n", + "state_dict = torch.load(f'{save_path}/consolidated.01.pth')\n", + "for index, (key, value) in enumerate(state_dict.items()):\n", + " print(f\"Index: {index}, layer: {key}, Layer size: {value.size()}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15a172e2", + "metadata": {}, + "source": [ + "# Next Step\n", + "\n", + "The final Domain adapted Tokenizer obtained using this notebook can be used in a continual pre-training pipeline for domain adaptive pretraining." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo1.0.ipynb b/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo1.0.ipynb new file mode 100644 index 000000000000..3d48fd6e8b24 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo1.0.ipynb @@ -0,0 +1,489 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cd13460c", + "metadata": {}, + "source": [ + "# Domain Adaptive Pre-Training (DAPT)\n", + "\n", + "## Goal\n", + "\n", + "Given a foundational language model (in this case llama-2-7B) that was pre-trained on a broad, general-purpose corpus, our goal is to further pretrain the model on a specific domain (in this example, ChipDesign) to enhance its understanding of domain-specific language and context. This process is called Domain-Adaptive Pretraining (DAPT). DAPT adapts a general-purpose model to specialized tasks within a particular field. Instead of training from scratch, we aim to “specialize” the model by focusing on a target domain corpus, allowing it to adapt to the unique vocabulary, semantics, and syntax of that field.\n", + "\n", + "Our primary goals with respect to DAPT are as follows:\n", + "* Improve the model’s performance and accuracy on domain-specific tasks\n", + "* Ensure the model retains general language capabilities\n", + "* Minimize pretraining time by leveraging existing knowledge in the model\n", + "\n", + "DAPT typically enhances a model’s efficacy in downstream tasks for the domain by exposing it to domain-relevant texts. This pretraining phase can result in more accurate and context-aware predictions on domain-specific data, as the model gains an understanding of field-specific terminology, abbreviations, and common phrases." + ] + }, + { + "cell_type": "markdown", + "id": "c43ef563", + "metadata": {}, + "source": [ + "# NeMo Tools and Resources\n", + "\n", + "* [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html)" + ] + }, + { + "cell_type": "markdown", + "id": "bea0b51f", + "metadata": {}, + "source": [ + "# Software Requirements\n", + "* Access to latest NeMo Framework NGC Containers\n", + "* This playbook has been tested on: nvcr.io/nvidia/nemo:dev. It is expected to work similarly on other environments.\n", + "\n", + "\n", + "#### Launch the NeMo Framework container as follows: \n", + "\n", + "```\n", + "docker run -it -p 8080:8080 -p 8088:8088 --rm --gpus '\"device=0,1\"' --ipc=host --network host -v $(pwd):/workspace nvcr.io/nvidia/nemo:dev\n", + "```\n", + "\n", + "#### Launch Jupyter Notebook as follows: \n", + "```\n", + "jupyter notebook --allow-root --ip 0.0.0.0 --port 8088 --no-browser --NotebookApp.token=''\n", + "\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "id": "7137e1db", + "metadata": {}, + "source": [ + "# Hardware Requirements\n", + "\n", + "* This playbook has been tested on 2xA100 80G but can be scaled to multiple GPUs as well as multiple nodes by modifying the appropriate parameters" + ] + }, + { + "cell_type": "markdown", + "id": "91ecb0d3", + "metadata": {}, + "source": [ + "# Data\n", + "\n", + "* In this playbook, we will leverage chip domain/hardware datasets from open-source GitHub repositories, wiki URLs, and academic papers. Data has been processed and curated using [NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main) as shown in this [playbook](https://github.com/jvamaraju/ndc_dapt_playbook/tree/dapt_jv)" + ] + }, + { + "cell_type": "markdown", + "id": "ba16a72b", + "metadata": {}, + "source": [ + "# Notebook Outline\n", + "\n", + "* Step 1: Prepare the data for pretraining. This is a multi-step process discussed in detail later in the specific section (later in the notebook).\n", + "\n", + "* Step 2: Download the llama-2-7B hugging face checkpoint and convert to .nemo format.\n", + "\n", + "* Step 3: Continued pretraining the llama-2-7b model using the prepared data and the custom trained tokenizer (from the previous notebook)." + ] + }, + { + "cell_type": "markdown", + "id": "ec372453", + "metadata": {}, + "source": [ + "# Step 1: Data Preparation for pretraining\n", + "\n", + "Identify the different file types (example: code, text, etc) in the pretraining data, in this case we only have 'code' type files. This is typically dataset dependent. \n", + "\n", + "If you used the Data Curation tutorial as instructed in the Readme, you can point ```data_path ``` variable to the path containing the curated data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2c935b99", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Number of Files containing 'file_type':'text': 0\n", + "Number of Files containing 'file_type':'code': 8835\n" + ] + } + ], + "source": [ + "import os\n", + "import json\n", + "\n", + "# Function to count the number of files in each of the different file types- code, text\n", + "def identify_jsonl_files(data_path):\n", + " code_files = []\n", + " text_files = []\n", + " cnt_text = 0\n", + " cnt_code = 0\n", + " for root, _, files in os.walk(data_path):\n", + " for file in files:\n", + " if file.endswith('.jsonl'):\n", + " file_path = os.path.join(root, file)\n", + " with open(file_path, 'r') as f:\n", + " has_code = False\n", + " has_text = False\n", + " for line in f:\n", + " try:\n", + " json_obj = json.loads(line.strip())\n", + " file_type = json_obj.get('file_type', '').lower()\n", + " if file_type == 'code':\n", + " has_code = True\n", + " elif file_type == 'text':\n", + " has_text = True\n", + " if has_code and has_text:\n", + " break\n", + " except json.JSONDecodeError:\n", + " continue\n", + " if has_code:\n", + " code_files.append(file_path)\n", + " cnt_code = cnt_code + 1\n", + " if has_text:\n", + " text_files.append(file_path)\n", + " cnt_text = cnt_text + 1\n", + " return code_files, text_files, cnt_code, cnt_text\n", + "\n", + "# Modify data path to point to jsonl data source, in this case data_path='code/data/all_jsonl_data'\n", + "data_path = 'code/data/all_jsonl_data'\n", + "\n", + "code_files, text_files, cnt_code, cnt_text = identify_jsonl_files(data_path)\n", + "\n", + "print(\"\\nNumber of Files containing 'file_type':'text':\", cnt_text)\n", + "print(\"Number of Files containing 'file_type':'code':\", cnt_code)" + ] + }, + { + "cell_type": "markdown", + "id": "60987ff2", + "metadata": {}, + "source": [ + "### Merging code JSONL files into a single JSONL file for further preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "c02f2e6f", + "metadata": {}, + "source": [ + "This is an optional step, it is possible to use multiple jsonl files in this workflow as well. This example uses a single merged. jsonl file" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "892f4493", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "def list_jsonl_files(directory):\n", + " jsonl_files = []\n", + " for root, _, files in os.walk(directory):\n", + " for file in files:\n", + " if file.endswith('.jsonl'):\n", + " jsonl_files.append(os.path.join(root, file))\n", + " return jsonl_files\n", + "\n", + "# Function to merge multiple jsonl files into a single file \n", + "def merge_jsonl_files(directory, output_file):\n", + " jsonl_files = list_jsonl_files(directory)\n", + " \n", + " with open(output_file, 'w') as outfile:\n", + " for input_file in jsonl_files:\n", + " with open(input_file, 'r') as infile:\n", + " for line in infile:\n", + " try:\n", + " json_object = json.loads(line.strip())\n", + " json.dump(json_object, outfile)\n", + " outfile.write('\\n')\n", + " except json.JSONDecodeError:\n", + " print(f\"Skipping invalid JSON in {input_file}: {line.strip()}\")\n", + "\n", + " print(f\"Merged {len(jsonl_files)} JSONL files into {output_file}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9bb0c80a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Merged 8835 JSONL files into code_merged_output.jsonl\n" + ] + } + ], + "source": [ + "directory = 'code/data/all_jsonl_data'\n", + "output_file = 'code_merged_output.jsonl'\n", + "merge_jsonl_files(directory, output_file)" + ] + }, + { + "cell_type": "markdown", + "id": "6d00ad63", + "metadata": {}, + "source": [ + "### Data Format Conversion for pretraining: JSONL to bin/idx files \n", + "\n", + "For efficient pretraining, we convert data from JSONL to bin/idx format. \n", + "\n", + "JSONL files, while convenient for storing structured text data, are not optimized for high-speed data loading during large language model training. In pretraining workflows, particularly those with large datasets and complex model architectures, the need for fast data access and efficient memory management is essential.\n", + "\n", + "The bin/idx format is a binary format specifically designed to facilitate high-throughput data loading. This format allows direct, randomized access to data samples, which speeds up I/O operations and reduces the memory footprint compared to loading JSONL files. By converting data to bin/idx format, hardware utilization can be maximized and bottlenecks in data processing can be avoided, leading to a more efficient pretraining process.\n", + "\n", + "#### Benefits of bin/idx format for Pretraining:\n", + "\n", + "* **Optimized I/O Performance:** The binary format enables quicker data reads and reduces latency, allowing the model to continuously access data at high speeds.\n", + "* **Efficient Memory Usage:** Data in bin/idx format consumes less memory during loading, making it suitable for large datasets and enabling better use of available system resources.\n", + "* **Enhanced Scalability:** With bin/idx, it’s easier to handle shuffling and batching of large datasets, which is essential for pretraining on diverse domain-specific data." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "709f2c08", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "merges.txt\t\t tokenizer.json\t\tvocab.json\r\n", + "special_tokens_map.json tokenizer_config.json\r\n" + ] + } + ], + "source": [ + "# After the running through the custom_tokenization.ipynb, you would have \n", + "# the new domain adpated tokenizer model in the following directory\n", + "!ls models/tokenizer/llama2/custom_tokenizer_init_20000_json" + ] + }, + { + "cell_type": "markdown", + "id": "de696d7b", + "metadata": {}, + "source": [ + "Modify the `input` to point to the merged `jsonl` file. Similarly modify paths to `vocab`, `tokenizer-model`, `merge-file` to point to relevant file paths. \n", + "\n", + "In the following code block, ```tokenizer-model``` is set to using the original tokenizer that comes as a part of llama2-7b-hf, but `tokenizer-model` should point to the custom tokenizer (trained in the custom tokenizer training notebook) if your data has domain specific terminology" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcbf66a2", + "metadata": {}, + "outputs": [], + "source": [ + "!python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input='code_merged_output.jsonl' \\\n", + "--json-keys=text \\\n", + "--tokenizer-library=sentencepiece \\\n", + "--vocab 'models/tokenizer/llama2/custom_tokenizer_init_20000_json/vocab.json' \\\n", + "--dataset-impl mmap \\\n", + "--tokenizer-model '/workspace/Llama-2-7b-hf/tokenizer.model' \\\n", + "--tokenizer-type llama \\\n", + "--merge-file 'models/tokenizer/llama2/custom_tokenizer_init_20000_json/merges.txt' \\\n", + "--append-eod \\\n", + "--output-prefix='preprocessed_data'" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "0f05efa5", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "README.md\t\t\t nemo_experiments\r\n", + "cdeng\t\t\t\t preprocessed_data_text_document\r\n", + "code\t\t\t\t preprocessed_data_text_document.bin\r\n", + "code_merged_output.jsonl\t preprocessed_data_text_document.idx\r\n", + "domain_adaptive_pretraining.ipynb venv\r\n" + ] + } + ], + "source": [ + "# If the above step runs successfully, two files with the extensions .bin and .idx will be generated\n", + "!ls " + ] + }, + { + "cell_type": "markdown", + "id": "82f95149", + "metadata": {}, + "source": [ + "# Step 2: Download Llama-2-7b Hugging Face checkpoint and convert to .nemo checkpoint\n", + "\n", + "The code below assumes you already have the llama-2-7b checkpoint downloaded in ```/workspace/Llama-2-7b-hf/```\n", + "\n", + "Llama-2-7b-hf checkpoint can be downloaded from https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46c7f997", + "metadata": {}, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py --input_name_or_path=/workspace/Llama-2-7b-hf/ --output_path=/workspace/llama2-7b.nemo" + ] + }, + { + "cell_type": "markdown", + "id": "b94e774b", + "metadata": {}, + "source": [ + "The conversion will generate a ```llama2-7b.nemo``` file which can be used for the continued pretraining using NeMo Toolkit as shown in Step 3. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "c689e584", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Llama-2-7b-hf\t\t dapt-custom-tokenization megatron_llama\r\n", + "bin-idx-conversion.ipynb dapt-data-curation\t megatron_llama_config.yaml\r\n", + "convert.py\t\t llama2-7b.nemo\t sentencepiece\r\n", + "custom-tokenizer\t loader_llama2.py\t venv\r\n" + ] + } + ], + "source": [ + "!ls /workspace" + ] + }, + { + "cell_type": "markdown", + "id": "fe1bdfe0", + "metadata": {}, + "source": [ + "# Step 3: Continued Pretraining using Llama2-7b with NeMo\n", + "\n", + "For this step `megatron_gpt_pretraining.py` from NeMo Toolkit is used for continued pretraining, this step allows to configure different parameters for the pretraining depending on the set up. For example `trainer.devices` `model.tensor_model_parallel_size` depend on the number of GPUs available for this job. \n", + "\n", + "Additionally, specify the path to the custom trained tokenizer for `model.tokenizer.model`, the `.nemo` checkpoint for `model.restore_from_path`. \n", + "\n", + "The `model.data.data_prefix` is specified in the form [weightage to data, datafile] Example `[1,preprocessed_data_text_document]` assigns the whole weightage [=1] to `preprocessed_data_text_document`. If there are multiple files, different weightage (should sum to 1) can be assigned to each file to control the data blend for pretraining. \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a40f547", + "metadata": {}, + "outputs": [], + "source": [ + "# Test out the pretraining set up with mock data: model.data.data_impl=mock\n", + "\n", + "!python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", + " --config-path=/opt/NeMo/examples/nlp/language_modeling/conf \\\n", + " --config-name=megatron_llama_config \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=1 \\\n", + " trainer.num_nodes=1 \\\n", + " trainer.max_steps=2 \\\n", + " trainer.val_check_interval=8 \\\n", + " model.data.data_impl=mock \\\n", + " model.micro_batch_size=1 \\\n", + " model.global_batch_size=4 \\\n", + " model.tensor_model_parallel_size=1 \\\n", + " model.pipeline_model_parallel_size=1 \\\n", + " model.tokenizer.library=sentencepiece \\\n", + " model.tokenizer.model=/workspace/Llama-2-7b-hf/tokenizer.model \\\n", + " +model.restore_from_path=/workspace/llama2-7b.nemo \\\n", + " exp_manager.name=megatron_llama_continual \\\n", + " exp_manager.resume_ignore_no_checkpoint=false \\\n", + " exp_manager.resume_if_exists=false " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71672ff4", + "metadata": {}, + "outputs": [], + "source": [ + "# Pretraining using preprocessed data (+model.data.data_prefix)\n", + "\n", + "!python /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \\\n", + " --config-path=/opt/NeMo/examples/nlp/language_modeling/conf \\\n", + " --config-name=megatron_llama_config \\\n", + " trainer.precision=bf16 \\\n", + " trainer.devices=2 \\\n", + " trainer.num_nodes=1 \\\n", + " trainer.max_steps=5 \\\n", + " trainer.val_check_interval=8 \\\n", + " model.micro_batch_size=1 \\\n", + " model.global_batch_size=4 \\\n", + " model.tensor_model_parallel_size=2 \\\n", + " model.pipeline_model_parallel_size=1 \\\n", + " model.tokenizer.library=sentencepiece \\\n", + " model.tokenizer.model=/workspace/Llama-2-7b-hf/tokenizer.model \\\n", + " model.megatron_amp_O2=True \\\n", + " +model.restore_from_path=/workspace/llama2-7b.nemo \\\n", + " +model.data.data_prefix=[1,preprocessed_data_text_document] \\\n", + " exp_manager.name=megatron_llama_continual \\\n", + " exp_manager.resume_ignore_no_checkpoint=true \\\n", + " exp_manager.resume_if_exists=false " + ] + }, + { + "cell_type": "markdown", + "id": "cf30d8c8", + "metadata": {}, + "source": [ + "### To monitor the training, launch Tensorboard from another terminal\n", + "\n", + "`tensorboard --logdir nemo_experiments --bind_all`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo2.0.ipynb b/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo2.0.ipynb new file mode 100644 index 000000000000..84d3ce6b619d --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/domain_adaptive_pretraining_nemo2.0.ipynb @@ -0,0 +1,577 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cd13460c", + "metadata": {}, + "source": [ + "# Domain Adaptive Pre-Training (DAPT)\n", + "\n", + "## Goal\n", + "\n", + "Given a foundational language model (in this case llama-2-7B) that was pre-trained on a broad, general-purpose corpus, our goal is to further pretrain the model on a specific domain (in this example, ChipDesign) to enhance its understanding of domain-specific language and context. This process is called Domain-Adaptive Pretraining (DAPT). DAPT adapts a general-purpose model to specialized tasks within a particular field. Instead of training from scratch, we aim to “specialize” the model by focusing on a target domain corpus, allowing it to adapt to the unique vocabulary, semantics, and syntax of that field.\n", + "\n", + "Our primary goals with respect to DAPT are as follows:\n", + "* Improve the model’s performance and accuracy on domain-specific tasks\n", + "* Ensure the model retains general language capabilities\n", + "* Minimize pretraining time by leveraging existing knowledge in the model\n", + "\n", + "DAPT typically enhances a model’s efficacy in downstream tasks for the domain by exposing it to domain-relevant texts. This pretraining phase can result in more accurate and context-aware predictions on domain-specific data, as the model gains an understanding of field-specific terminology, abbreviations, and common phrases." + ] + }, + { + "cell_type": "markdown", + "id": "c43ef563", + "metadata": {}, + "source": [ + "# NeMo Tools and Resources\n", + "\n", + "* [NeMo Framework](https://docs.nvidia.com/nemo-framework/user-guide/latest/overview.html)" + ] + }, + { + "cell_type": "markdown", + "id": "bea0b51f", + "metadata": {}, + "source": [ + "# Software Requirements\n", + "* Access to latest NeMo Framework NGC Containers\n", + "* This playbook has been tested on: nvcr.io/nvidia/nemo:dev. It is expected to work similarly on other environments.\n", + "\n", + "\n", + "#### Launch the NeMo Framework container as follows: \n", + "\n", + "```\n", + "docker run -it -p 8080:8080 -p 8088:8088 --rm --gpus '\"device=0,1\"' --ipc=host --network host -v $(pwd):/workspace nvcr.io/nvidia/nemo:dev\n", + "```\n", + "\n", + "#### Launch Jupyter Notebook as follows: \n", + "```\n", + "jupyter notebook --allow-root --ip 0.0.0.0 --port 8088 --no-browser --NotebookApp.token=''\n", + "\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "id": "7137e1db", + "metadata": {}, + "source": [ + "# Hardware Requirements\n", + "\n", + "* This playbook has been tested on 2xA100 80G but can be scaled to multiple GPUs as well as multiple nodes by modifying the appropriate parameters" + ] + }, + { + "cell_type": "markdown", + "id": "91ecb0d3", + "metadata": {}, + "source": [ + "# Data\n", + "\n", + "* In this playbook, we will leverage chip domain/hardware datasets from open-source GitHub repositories, wiki URLs, and academic papers. Data has been processed and curated using [NeMo Curator](https://github.com/NVIDIA/NeMo-Curator/tree/main) as shown in this [playbook](https://github.com/jvamaraju/ndc_dapt_playbook/tree/dapt_jv)" + ] + }, + { + "cell_type": "markdown", + "id": "ba16a72b", + "metadata": {}, + "source": [ + "# Notebook Outline\n", + "\n", + "* Step 1: Prepare the data for pretraining. This is a multi-step process discussed in detail later in the specific section (later in the notebook).\n", + "\n", + "* Step 2: Download the llama-2-7B hugging face checkpoint and convert to .nemo format.\n", + "\n", + "* Step 3: Continued pretraining the llama-2-7b model using the prepared data and the custom trained tokenizer (from the previous notebook)." + ] + }, + { + "cell_type": "markdown", + "id": "115e8b1f", + "metadata": {}, + "source": [ + "# Step 0: Clone the Model Checkpoint\n", + "\n", + "This notebook assumed the model has been cloned from [hugging face](https://huggingface.co/meta-llama/Llama-2-7b-hf) in the mounted directory ```/workspace```" + ] + }, + { + "cell_type": "markdown", + "id": "9bc658bd", + "metadata": {}, + "source": [ + "Clone the model: \n", + "```\n", + "git lfs install\n", + "git clone https://huggingface.co/meta-llama/Llama-2-7b-hf\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "ec372453", + "metadata": {}, + "source": [ + "# Step 1: Data Preparation for pretraining\n", + "\n", + "Identify the different file types (example: code, text, etc) in the pretraining data, in this case we only have 'code' type files. This is typically dataset dependent. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2c935b99", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "from nemo.collections.llm import Llama2Config7B\n", + "\n", + "\n", + "# Function to count the number of files in each of the different file types- code, text\n", + "def identify_jsonl_files(data_path):\n", + " code_files = []\n", + " text_files = []\n", + " cnt_text = 0\n", + " cnt_code = 0\n", + " for root, _, files in os.walk(data_path):\n", + " for file in files:\n", + " if file.endswith('.jsonl'):\n", + " file_path = os.path.join(root, file)\n", + " with open(file_path, 'r') as f:\n", + " has_code = False\n", + " has_text = False\n", + " for line in f:\n", + " try:\n", + " json_obj = json.loads(line.strip())\n", + " file_type = json_obj.get('file_type', '').lower()\n", + " if file_type == 'code':\n", + " has_code = True\n", + " elif file_type == 'text':\n", + " has_text = True\n", + " if has_code and has_text:\n", + " break\n", + " except json.JSONDecodeError:\n", + " continue\n", + " if has_code:\n", + " code_files.append(file_path)\n", + " cnt_code = cnt_code + 1\n", + " if has_text:\n", + " text_files.append(file_path)\n", + " cnt_text = cnt_text + 1\n", + " return code_files, text_files, cnt_code, cnt_text\n", + "\n", + "# Modify data path to point to jsonl data source, in this case data_path='code/data/all_jsonl_data'\n", + "data_path = '/workspace/dapt-custom-tokenization/code/data/all_jsonl_data'\n", + "\n", + "code_files, text_files, cnt_code, cnt_text = identify_jsonl_files(data_path)\n", + "\n", + "print(\"\\nNumber of Files containing 'file_type':'text':\", cnt_text)\n", + "print(\"Number of Files containing 'file_type':'code':\", cnt_code)" + ] + }, + { + "cell_type": "markdown", + "id": "60987ff2", + "metadata": {}, + "source": [ + "### Merging code JSONL files into a single JSONL file for further preprocessing" + ] + }, + { + "cell_type": "markdown", + "id": "c02f2e6f", + "metadata": {}, + "source": [ + "This is an optional step, it is possible to use multiple jsonl files in this workflow as well. This example uses a single merged. jsonl file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "892f4493", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import json\n", + "\n", + "def list_jsonl_files(directory):\n", + " jsonl_files = []\n", + " for root, _, files in os.walk(directory):\n", + " for file in files:\n", + " if file.endswith('.jsonl'):\n", + " jsonl_files.append(os.path.join(root, file))\n", + " return jsonl_files\n", + "\n", + "# Function to merge multiple jsonl files into a single file \n", + "def merge_jsonl_files(directory, output_file):\n", + " jsonl_files = list_jsonl_files(directory)\n", + " \n", + " with open(output_file, 'w') as outfile:\n", + " for input_file in jsonl_files:\n", + " with open(input_file, 'r') as infile:\n", + " for line in infile:\n", + " try:\n", + " json_object = json.loads(line.strip())\n", + " json.dump(json_object, outfile)\n", + " outfile.write('\\n')\n", + " except json.JSONDecodeError:\n", + " print(f\"Skipping invalid JSON in {input_file}: {line.strip()}\")\n", + "\n", + " print(f\"Merged {len(jsonl_files)} JSONL files into {output_file}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bb0c80a", + "metadata": {}, + "outputs": [], + "source": [ + "directory = '/workspace/dapt-custom-tokenization/code/data/all_jsonl_data'\n", + "output_file = '/workspace/dapt-custom-tokenization/code_merged_output.jsonl'\n", + "merge_jsonl_files(directory, output_file)" + ] + }, + { + "cell_type": "markdown", + "id": "6d00ad63", + "metadata": {}, + "source": [ + "### Data Format Conversion for pretraining: JSONL to bin/idx files \n", + "\n", + "For efficient pretraining, we convert data from JSONL to bin/idx format. \n", + "\n", + "JSONL files, while convenient for storing structured text data, are not optimized for high-speed data loading during large language model training. In pretraining workflows, particularly those with large datasets and complex model architectures, the need for fast data access and efficient memory management is essential.\n", + "\n", + "The bin/idx format is a binary format specifically designed to facilitate high-throughput data loading. This format allows direct, randomized access to data samples, which speeds up I/O operations and reduces the memory footprint compared to loading JSONL files. By converting data to bin/idx format, hardware utilization can be maximized and bottlenecks in data processing can be avoided, leading to a more efficient pretraining process.\n", + "\n", + "#### Benefits of bin/idx format for Pretraining:\n", + "\n", + "* **Optimized I/O Performance:** The binary format enables quicker data reads and reduces latency, allowing the model to continuously access data at high speeds.\n", + "* **Efficient Memory Usage:** Data in bin/idx format consumes less memory during loading, making it suitable for large datasets and enabling better use of available system resources.\n", + "* **Enhanced Scalability:** With bin/idx, it’s easier to handle shuffling and batching of large datasets, which is essential for pretraining on diverse domain-specific data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "709f2c08", + "metadata": {}, + "outputs": [], + "source": [ + "!ls /workspace/dapt-custom-tokenization/code/code/models/tokenizer/llama2/custom_tokenizer_init_20000.json" + ] + }, + { + "cell_type": "markdown", + "id": "de696d7b", + "metadata": {}, + "source": [ + "Modify the `input` to point to the merged `jsonl` file. Similarly modify paths to `vocab`, `tokenizer-model`, `merge-file` to point to relevant file paths. `tokenizer-model` should point to the custom tokenizer (trained in the custom tokenizer training notebook) if your data has domain specific terminology" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dcbf66a2", + "metadata": {}, + "outputs": [], + "source": [ + "#### Uncomment to use custom trained tokenizer ####\n", + "# !python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "# --input='/workspace/dapt-custom-tokenization/code_merged_output.jsonl' \\\n", + "# --json-keys=text \\\n", + "# --tokenizer-library=sentencepiece \\\n", + "# --vocab '/workspace/dapt-custom-tokenization/code/code/models/tokenizer/llama2/custom_tokenizer_init_20000.json/vocab.json' \\\n", + "# --dataset-impl mmap \\\n", + "# --tokenizer-model '/workspace/Llama-2-7b-hf/tokenizer.model' \\\n", + "# --tokenizer-type llama \\\n", + "# --merge-file '/workspace/dapt-custom-tokenization/code/code/models/tokenizer/llama2/custom_tokenizer_init_20000.json/merges.txt' \\\n", + "# --append-eod \\\n", + "# --output-prefix='preprocessed_data'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "89b9583d-1dac-4717-b028-c78d0d703f45", + "metadata": {}, + "outputs": [], + "source": [ + "# Using default Llama-2 tokenizer for testing purpose\n", + "!python3 /opt/NeMo/scripts/nlp_language_modeling/preprocess_data_for_megatron.py \\\n", + "--input='/workspace/dapt-custom-tokenization/code_merged_output.jsonl' \\\n", + "--json-keys=text \\\n", + "--tokenizer-library=sentencepiece \\\n", + "--dataset-impl mmap \\\n", + "--tokenizer-model '/workspace/Llama-2-7b-hf/tokenizer.model' \\\n", + "--tokenizer-type llama \\\n", + "--append-eod \\\n", + "--output-prefix='preprocessed_data'" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0f05efa5", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "# If the above step runs successfully, two files with the extensions .bin and .idx will be generated\n", + "!ls " + ] + }, + { + "cell_type": "markdown", + "id": "82f95149", + "metadata": {}, + "source": [ + "# Step 2: Download and Import Llama-2-7b Hugging Face checkpoint\n", + "\n", + "Llama2-7B model can be automatically downloaded and converted to NeMo2 format with the following script:\n", + "\n", + "* Save the following code snippet as ```converttonemo2.py```\n", + "* Run ```python3 converttonemo2.py```" + ] + }, + { + "cell_type": "markdown", + "id": "b3260b62-c179-4bc6-b256-729ff6403fa4", + "metadata": {}, + "source": [ + "```\n", + "from nemo.collections import llm\n", + "from nemo.collections.llm import Llama2Config7B\n", + "\n", + "if __name__ == \"__main__\":\n", + " output = llm.import_ckpt(\n", + " model=llm.LlamaModel(config=Llama2Config7B()),\n", + " source=\"hf:///workspace/Llama-2-7b-hf\",\n", + " )\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "46c7f997", + "metadata": {}, + "outputs": [], + "source": [ + "from nemo.collections import llm\n", + "from nemo.collections.llm import Llama2Config7B\n", + "\n", + "if __name__ == \"__main__\":\n", + " output = llm.import_ckpt(\n", + " model=llm.LlamaModel(config=Llama2Config7B()),\n", + " source=\"hf:///workspace/Llama-2-7b-hf\",\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "b94e774b", + "metadata": {}, + "source": [ + "The conversion will generate a ```llama-2``` NeMo2 checkpoint directory which can be used for the continued pretraining using NeMo Toolkit as shown in Step 3 in default ```$NEMO_HOME``` folder, unless otherwise specified ```NEMO_HOME``` is set as ```/root/.cache/nemo```\n", + "\n", + "Alternatively, you can directly use ```source=\"meta-llama/Llama2-2-7b-hf\"``` to use the model directly from Hugging Face instead of using the locally downloaded version in ```\\workspace```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c689e584", + "metadata": {}, + "outputs": [], + "source": [ + "!ls /workspace" + ] + }, + { + "cell_type": "markdown", + "id": "fe1bdfe0", + "metadata": {}, + "source": [ + "# Step 3: Continued Pretraining using Llama2-7b with NeMo2\n", + "\n", + "For this step we use a predefined recipe `llama2_7b.pretrain_recipe` from NeMo Toolkit for continued pretraining. We will modify the `pretrain_recipe` and use it for continued pretraining workflow. Typically this involves changing dataset files and data blends, changing learning rate scheduler, changing default parallelism based on number of devices available, adding connector to resume training, etc.\n", + "\n", + "First, we define the recipe and executor for using NeMo2 as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4a40f547", + "metadata": {}, + "outputs": [], + "source": [ + "import nemo_run as run\n", + "from nemo.collections import llm\n", + "\n", + "# Configure recipe to pre-train based on the default llama-2-7b recipe\n", + "def configure_recipe(nodes: int = 1, gpus_per_node: int = 1):\n", + " recipe = llm.llama2_7b.pretrain_recipe(\n", + " name=\"llama2_7b_dapt\",\n", + " # Modify based on number of nodes available\n", + " num_nodes=nodes,\n", + " num_gpus_per_node=gpus_per_node,\n", + " )\n", + " # Modify\n", + " recipe.trainer.strategy.context_parallel_size = 1\n", + " recipe.trainer.strategy.tensor_model_parallel_size=1\n", + " recipe.trainer.val_check_interval = 100\n", + " return recipe\n", + "\n", + "# Executor for running pretraining \n", + "def local_executor_torchrun(devices: int = 1) -> run.LocalExecutor:\n", + " executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\")\n", + " return executor" + ] + }, + { + "cell_type": "markdown", + "id": "464d303fc973333d", + "metadata": {}, + "source": [ + "Let's instantiate the `recipe` and modify it so that it uses the desired number of GPUs, resuming from the pretrained Llama2-7b checkpoint instead of training from scratch.\n", + "\n", + "The default `recipe` initializes all the essential components required for Llama2 7B pretraining, including model, dataloader, trainer, logger, optimizer etc. `recipe` is not executed during instantiation, so it is very simple to modify it to fit your custom training workflow. In our case, we want to do the DAPT (instead of pretraining from scratch), and all we need to do is to add a `resume` config which points to the Llama2 7B checkpoint.\n", + "\n", + "You can easily change the optimizer, parallelism, data as per your use case. Look at the following example for guidance on how to tweak these parameters. Note: you are only configuring your task at this stage; the underlying code is not executed unless you launch the job using the executor." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b70481ad7579de7a", + "metadata": {}, + "outputs": [], + "source": [ + "import nemo.lightning as nl\n", + "from nemo.collections.common.tokenizers import AutoTokenizer\n", + "\n", + "# Instantiate data\n", + "data = run.Config(\n", + " llm.PreTrainingDataModule,\n", + " # Pass the path to your data here\n", + " paths=['preprocessed_data_text_document'],\n", + " seq_length=4096,\n", + " tokenizer=run.Config(\n", + " AutoTokenizer,\n", + " pretrained_model_name=\"/workspace/Llama-2-7b-hf\",\n", + " ),\n", + " micro_batch_size=1,\n", + " global_batch_size=8,\n", + " )\n", + "\n", + "\n", + "# Instantiate the recipe\n", + "recipe = configure_recipe(nodes=1, gpus_per_node=2)\n", + "\n", + "# Modify resume connector\n", + "resume = run.Config(\n", + " nl.AutoResume,\n", + " restore_config=run.Config(nl.RestoreConfig, path=\"/root/.cache/nemo/models/Llama-2-7b-hf\"),\n", + " )\n", + "recipe.resume = resume\n", + "recipe.data.tokenizer = run.Config(\n", + " AutoTokenizer,\n", + " pretrained_model_name=\"/workspace/Llama-2-7b-hf\"\n", + " )\n", + "\n", + "# (Optional) Modify the TP/PP/CP settings\n", + "recipe.trainer.strategy.tensor_model_parallel_size = 2\n", + "recipe.trainer.strategy.pipeline_model_parallel_size = 1\n", + "recipe.trainer.strategy.context_parallel_size = 1\n", + "\n", + "# (Optional) Modify the batch size settings\n", + "recipe.data.global_batch_size = 8\n", + "recipe.data.micro_batch_size = 1\n", + "\n", + "# (Optional) Modify the checkpoint and log location\n", + "recipe.log.log_dir= \"/workspace/logs_01_31\"\n", + "\n", + "# (Optional) Modify the learning rate scheudler\n", + "recipe.optim.config.lr = 1e-5\n", + "recipe.optim.lr_scheduler.min_lr = 1e-6\n", + "\n", + "# If not configured, the recipe uses mock data for pretraining\n", + "recipe.data = data\n", + "\n", + "# (Optional) Modify the data blends\n", + "# recipe.data.paths = [0.2, 'path/to/data1', 0.1, 'path/to/data2']\n", + "# recipe.data.paths = [1, 'preprocessed_data_text_document']" + ] + }, + { + "cell_type": "markdown", + "id": "303b9f780763d641", + "metadata": {}, + "source": [ + "After configure the training procedure properly, we can run the training by instantiate the `executor` and use `nemorun` to start the training:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c1f8b3071d8ff80", + "metadata": {}, + "outputs": [], + "source": [ + "# Launch the pretraining job \n", + "executor = local_executor_torchrun(devices=recipe.trainer.devices)\n", + "run.run(recipe, executor=executor)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77a82ff2-b15d-48bd-8cea-c3a2503190a8", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "cf30d8c8", + "metadata": {}, + "source": [ + "### To monitor the training, launch Tensorboard from another terminal\n", + "\n", + "`tensorboard --logdir nemo_experiments --bind_all`" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/extend_tokenizer_utils.py b/tutorials/llm/llama/domain-adaptive-pretraining/code/extend_tokenizer_utils.py new file mode 100755 index 000000000000..1e6de449e9a4 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/extend_tokenizer_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import io +import json +import math +import os +import pprint +import random +import re +import sys +from collections import Counter + +import jsonlines +import numpy as np +import sentencepiece as spm +import sentencepiece.sentencepiece_model_pb2 as model +import torch +from datasets import Dataset, IterableDataset, load_dataset +from tokenization_helper import * +from tokenizers import ( + SentencePieceBPETokenizer, + Tokenizer, + decoders, + models, + normalizers, + pre_tokenizers, + processors, + trainers, +) +from transformers import PreTrainedTokenizerFast + + +def get_token_cnt_spm(data_root, tokenizer, batchsize, keys): + """ + Function to get number of tokens generated from a given dataset + + Args: + data_root (str): Path to folder containing data files in jsonl format. + tokenizer (AutoTokenizer): Tokenizer to create tokens from data + batchsize (int): batch size used for the text_iterator that generates of batches of text. + keys (list): Keys/metadata to extract from jsonl files + + Returns: + A new tokenizer of the same type as the original one, trained on data_root + """ + readers = [] + for f in glob.glob(data_root + "**/*.jsonl", recursive=True): + f = open(f, mode="r") + readers.append(jsonlines.Reader(f)) + + def gen(): + data = [] + cnt = 0 + for reader in readers: + for obj in reader: + for key in keys: + data.append(obj[key]) + cnt += 1 + if cnt >= batchsize: + yield data + cnt = 0 + data = [] + if len(data) > 0: + yield data + + ds = IterableDataset.from_generator(gen) + total_cnt = 0 + for d in ds: + ids = tokenizer.encode(d) # for spm model + total_cnt += sum([len(i) for i in ids]) + print("total token cnt", total_cnt) + + +def extend_tokenizer(vocab_size, split, model_type): + """ + Expand the general-purpose tokenizer with the newly identified tokens to get an extended Tokenizer + Args: + vocab_size (int): The target size of the vocabulary you want for domain specific tokenizer. + split (int): Number of splits used for original model weights (model parallelism) + model_type (str): Model type/family + + Returns: + Extended tokenizer is created and saved in the paths specified below + + """ + digit_flag = False + rm_subword_flag = False + unseen_flag = True + init_out_flag = True + newinit_flag = False + + tag = "code_gen" # Tag to identify custom_tokenization per use case + data_root = "./general_data" # path to general datasets collected from open-source domain + original_tokenizer_path = ( + f"./models/tokenizer/{model_type}/original_tokenizer/tokenizer.model" # path to original tokenizer + ) + domain_tok_vocab_path = f"./models/tokenizer/{model_type}/custom_tokenizer_init_{vocab_size}.json/vocab.json" # path to domain specific vocab file (created previously) + + # New model file paths that will be created + new_vocab_path = f"./models/tokenizer/{model_type}/new_tokenizer/" + tag + "_vocab.json" + new_model_path = f"./models/tokenizer/{model_type}/new_tokenizer/tokenizer_" + tag + ".model" + old_ebd_path = f"./models/weight/{model_type}/ori_{model_type}-hf_weight/" + new_ebd_path = f"./models/weight/{model_type}/new_{model_type}-hf_weight/" + + extend_tokenizer_llama( + data_root, + original_tokenizer_path, + domain_tok_vocab_path, + new_vocab_path, + new_model_path, + old_ebd_path, + new_ebd_path, + split, + ) + + print("Vocabulary path for extended tokenizer: ", new_vocab_path) + print("Tokenizer model path for extended tokenizer: ", new_model_path) + print("Modified embedding weights path for extended tokenizer: ", new_ebd_path) + + +def extend_tokenizer_high_freq_tokens( + data_root, + original_tokenizer_path, + new_tokens, + new_vocab_path, + new_model_path, + old_ebd_path=None, + new_ebd_path=None, + split=8, +): + """ + Expand the original llama tokenizer with the newly identified high frequency tokens to get a customized tokenizer + Args: + data_root (str): Path to general/domain specific data to identify tokens and extend tokenizer + original_tokenizer_path (str): Path to original tokenizer (llama 2 tokenizer downlaoded from hf) + new_tokens (List(str)): List of idenitfied high frequency tokens + new_vocab_path (str): Path to new vocabulary file + new_model_path (str): Path to new/customized tokenizer + old_ebd_path (str): Path to original llama2 embedding weights downlaoded from hf + new_ebd_path (str): Path to new embedding weights (modified due to tokenizer changes) + split (int): Number of splits used for original model weights (model parallelism) + + Returns: + New model files created and saved in the paths specified below + + """ + m = model.ModelProto() + m.ParseFromString(open(original_tokenizer_path, 'rb').read()) + ori_vocab_size = len(m.pieces) + + print("token_cnt with original tokenizer: ") + sp = spm.SentencePieceProcessor() + sp.load(original_tokenizer_path) + get_token_cnt_spm(data_root, sp, batchsize=1000, keys=["text"]) + + add_normal_cnt = len(new_tokens) + add_dummy_cnt = (len(new_tokens) // 1024 + 1) * 1024 - len(new_tokens) + total_add_cnt = add_normal_cnt + add_dummy_cnt + new_vocab_size = total_add_cnt + ori_vocab_size + total_cnt = new_vocab_size + 768 ## consider 768 padding vocab in llama/mixtral tokenizer + print("original vocab_size: ", ori_vocab_size) + print("added normal vocab: ", add_normal_cnt) + print("added dummy vocab: ", add_dummy_cnt) + print("new vocab_size: ", new_vocab_size) + print("padded vocab: ", 768) + print("total cnt (with padding vocab): ", total_cnt) + assert add_dummy_cnt >= 3, "there should be at least 3 extra tokens for finetuning" + + record = [] + N = len(m.pieces) + for i, sym in enumerate(new_tokens): + new_sym = m.SentencePiece() + new_sym.piece = sym + new_sym.score = 0.0 # default score for USER_DEFINED + new_sym.type = 4 # type value for USER_DEFINED + m.pieces.insert(N + i, new_sym) # position after default control symbols ("", "", "") + record.append([sym, N + i]) + + N = len(m.pieces) + for i in range(add_dummy_cnt): + new_sym = m.SentencePiece() + new_sym.piece = f"" + new_sym.score = 0.0 # default score for USER_DEFINED + new_sym.type = 4 # type value for USER_DEFINED + m.pieces.insert(N + i, new_sym) # position after default control symbols ("", "", "") + record.append([new_sym.piece, N + i]) + + with open(new_vocab_path, "w", encoding="utf8") as fp: + json.dump(record, fp) + + with open(new_model_path, 'wb') as f: + f.write(m.SerializeToString()) + + print("token_cnt with customized tokenizer: ") + sp = spm.SentencePieceProcessor() + sp.load(new_model_path) + get_token_cnt_spm(data_root, sp, batchsize=1000, keys=["text"]) + + old_ebd_paths = [] + for f in glob.glob(old_ebd_path + "/*.pt"): + old_ebd_paths.append(f) + + def myFunc(s): + return int(s.split("embedding_")[-1].split(".")[0]) + + old_ebd_paths.sort(key=myFunc) + word_embeddings = [] + output_layers = [] + for f in old_ebd_paths: + temp = torch.load(f) + word_embeddings.append(temp['word_embeddings']) + output_layers.append(temp['output_layer']) + word_embedding = torch.cat(word_embeddings, dim=1) + output_layer = torch.cat(output_layers, dim=0) + print("word_embedding shape: ", word_embedding.shape) + print("output_layer shape: ", output_layer.shape) + + N_ori_emb, N = word_embedding.shape + add_weight = torch.zeros(total_add_cnt, N) + word_embedding = torch.cat((word_embedding[:ori_vocab_size], add_weight, word_embedding[ori_vocab_size:]), 0) + + _, M = output_layer.shape + add_out = torch.zeros(total_add_cnt, M) + output_layer = torch.cat((output_layer[:ori_vocab_size], add_out, output_layer[ori_vocab_size:]), 0) + + sp = spm.SentencePieceProcessor() + sp.load(original_tokenizer_path) + + for r in record: + token = r[0] + idx = r[1] + ids = sp.encode_as_ids(token) + word_embedding[idx] = torch.mean(word_embedding[ids], dim=0) + output_layer[idx] = torch.mean(output_layer[ids], dim=0) + + word_embedding = word_embedding.bfloat16() + output_layer = output_layer.bfloat16() + + vocab_size, dimension = word_embedding.shape + split_dimension = dimension // (split) + split_vocab_size = vocab_size // split + prefix = new_ebd_path + "/embedding_" + for i in range(split): + start = i * split_dimension + end = (i + 1) * split_dimension + st = i * split_vocab_size + ed = (i + 1) * split_vocab_size + save_name = prefix + f"{i}" + ".pt" + temp = {} + temp['word_embeddings'] = word_embedding[:, start:end] # split word_embedding + temp['output_layer'] = output_layer[st:ed, :] # split output_layer + torch.save(temp, save_name) + + print("Completed saving new embeddings") + + +if __name__ == "__main__": + original_tokenizer_path = sys.argv[1] # original sentencepiece model + new_tokens = sys.argv[2] # new tokens to be added + new_model_path = sys.argv[3] # augmented sentencepiece model + old_ebd_path = sys.argv[4] # original embeddings + new_ebd_path = sys.argv[5] # augmented embeddings + new_vocab_path = sys.argv[6] # path to record added new tokens + split = int(sys.argv[7]) # num of partitions to split the augmented embeddings + data_root = sys.argv[8] + + extend_tokenizer_high_freq_tokens( + data_root, + original_tokenizer_path, + new_tokens, + new_vocab_path, + new_model_path, + old_ebd_path, + new_ebd_path, + split, + ) diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/get_high_freq_tokens.py b/tutorials/llm/llama/domain-adaptive-pretraining/code/get_high_freq_tokens.py new file mode 100755 index 000000000000..613f72d7eba1 --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/get_high_freq_tokens.py @@ -0,0 +1,130 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys +from bisect import bisect_left + +import matplotlib.pyplot as plt +import numpy as np + + +# + +def binary_search(arr, low, high, bar=0.98): + total = arr.sum() + target = total * bar + print(arr, target) + if (high - low) >= 2: + mid = (high + low) // 2 + s = arr[0:mid].sum() + if s == target: + return mid + elif s < target: + if arr[0 : mid + 1].sum() >= target: + return mid + 1 + else: + return binary_search(arr, mid + 1, high, bar) + else: + if arr[0 : mid - 1].sum() <= target: + return mid + else: + return binary_search(arr, low, mid - 1, bar) + else: + return low + + +def binary_search2(arr, low, high, bar=0.98): + arr_csum = np.cumsum(arr) + total = arr.sum() + target = int(total * bar) + print(arr, arr_csum, target) + i = bisect_left(arr_csum, target) + if i != len(arr_csum) and arr_csum[i] == target: + return arr[i] + else: + return low + + +# - + + +def get_high_freq_tokens(token_usage_path, high_freq_tokens_path, p_th=0.98): + """ + Function to identify high frequency tokens from previous frequency analysis based on cutoff threshold. Selects the top-K tokens in a way that their cumulative frequency accounts for approximately 98%. + Args: + token_usage_path (str): Path to saved token usage frequency analysis results + high_freq_tokens_path (str): path to save selected high-frequency tokens (new tokens to be added) + p_th (float): Frequency Threshold + Returns: + Saves a file with high frequency tokens + """ + f = open(token_usage_path) + freq_dict = json.load(f) + + topics = [] + p_ths = [] + for key in freq_dict: + topics.append(key) + p_ths.append(p_th) + + tokens = {} + i = 0 + for topic in topics: + print(topic) + freq = freq_dict[topic] + freq_list = freq["new_freq"] + freqs = [] + ids = [] + for term in freq_list: + freqs.append(term[-1]) + ids.append(term[0]) + freqs_np = np.array(freqs) + th = binary_search(freqs_np, freqs_np.min(), freqs_np.max(), bar=p_ths[i]) + print(th) + i += 1 + if th > 0: + tokens[topic] = ids[0:th] + else: + raise ValueError("Threshold value is not greater than 0") + + L = [] + for key in tokens: + L = L + tokens[key] + L = set(L) + + token_category_dict = {} + for key in freq_dict: + temp = freq_dict[key]["new_freq"] + for tok in temp: + ids = tok[0] + name = tok[1] + cate = tok[2] + if ids in token_category_dict: + assert name == token_category_dict[ids][1] + else: + token_category_dict[ids] = [cate, name] + + add_tokens = [] + for i in L: + add_tokens.append(token_category_dict[i][1]) + + with open(high_freq_tokens_path, "w") as outfile: + json.dump(add_tokens, outfile) + + +if __name__ == "__main__": + token_usage_path = sys.argv[1] # token usage frequency + high_freq_tokens_path = sys.argv[2] + freq_threshold = float(sys.argv[3]) + get_high_freq_tokens(token_usage_path, high_freq_tokens_path, freq_threshold) diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/embedding_table.png b/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/embedding_table.png new file mode 100644 index 000000000000..2ce90f866c73 Binary files /dev/null and b/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/embedding_table.png differ diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/tokenization_diagram.png b/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/tokenization_diagram.png new file mode 100755 index 000000000000..983afbf1cc6b Binary files /dev/null and b/tutorials/llm/llama/domain-adaptive-pretraining/code/imgs/tokenization_diagram.png differ diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/tokenization_helper.py b/tutorials/llm/llama/domain-adaptive-pretraining/code/tokenization_helper.py new file mode 100755 index 000000000000..3c1d221ba07a --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/tokenization_helper.py @@ -0,0 +1,414 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob +import io +import json +import os +import random +import re +import sys +from collections import Counter + +import jsonlines +import numpy as np +import sentencepiece as spm +import sentencepiece.sentencepiece_model_pb2 as model +import torch +from datasets import Dataset, IterableDataset, load_dataset +from tokenizers import ( + SentencePieceBPETokenizer, + Tokenizer, + decoders, + models, + normalizers, + pre_tokenizers, + processors, + trainers, +) +from transformers import PreTrainedTokenizerFast + + +def check_parent_directory_exists(directory_path): + parent_directory = os.path.dirname(directory_path) + if not os.path.exists(parent_directory): + raise FileNotFoundError(f"Parent directory '{parent_directory}' does not exist. Please create it.") + else: + print(f"Parent directory '{parent_directory}' exists.") + + +def flatten(l): + return [item for sublist in l for item in sublist] + + +def get_token_cnt(data_root, tokenizer, batchsize, keys): + """ + Function to get number of tokens generated from a given dataset + + Args: + data_root (str): Path to folder containing data files in jsonl format. + tokenizer (AutoTokenizer): Tokenizer to create tokens from data + batchsize (int): batch size used for the text_iterator that generates of batches of text. + keys (list): Keys/metadata to extract from jsonl files + + Returns: + A new tokenizer of the same type as the original one, trained on data_root + """ + readers = [] + for f in glob.glob(data_root + "**/*.jsonl", recursive=True): + f = open(f, mode="r") + readers.append(jsonlines.Reader(f)) + + def gen(): + data = [] + cnt = 0 + for reader in readers: + for obj in reader: + for key in keys: + data.append(obj[key]) + cnt += 1 + if cnt >= batchsize: + yield data + cnt = 0 + data = [] + if len(data) > 0: + yield data + + ds = IterableDataset.from_generator(gen) + total_cnt = 0 + for d in ds: + ids = tokenizer(d).input_ids # tokenizer.encode(d) + total_cnt += sum([len(i) for i in ids]) + print("total token cnt", total_cnt) + + +def train_tokenizer(data_root, batchsize, vocab_size, tokenizer, keys): + """ + Train tokenizer from scratch and evaluate number of tokens both before and after + Args: + data_root (str): Path to folder containing data files in jsonl format. + batchsize (int): batch size used for the text_iterator that generates of batches of text. + vocab_size (int): The target size of the vocabulary you want for your tokenizer. + tokenizer (AutoTokenizer): Tokenizer to create tokens from data + keys (list): Keys/metadata to extract from jsonl files + + Returns: + A new tokenizer of the same type as the original one, trained on data_root + + """ + print("Before Training: ") + get_token_cnt(data_root, tokenizer, batchsize, keys) + + def gen(): + data = [] + cnt = 0 + for f in glob.glob(data_root + "*.jsonl", recursive=True): + f = open(f, mode="r") + reader = jsonlines.Reader(f) + for obj in reader: + for key in keys: + data.append(obj[key]) + cnt += 1 + if cnt >= batchsize: + yield data + cnt = 0 + data = [] + f.close() + if len(data) > 0: + yield data + + ds = IterableDataset.from_generator(gen) + tokenizer = tokenizer.train_new_from_iterator(ds, vocab_size) + print("After Training: ") + get_token_cnt(data_root, tokenizer, batchsize, keys) + return tokenizer + + +def extend_tokenizer_llama( + data_root, + original_tokenizer_path, + domain_tok_vocab_path, + new_vocab_path, + new_model_path, + old_ebd_path, + new_ebd_path, + split=1, +): + """ + Expand the general-purpose llama tokenizer with the newly identified tokens to get an extended Tokenizer + Args: + data_root (str): Path to general/domain specific data to identify tokens and extend tokenizer + original_tokenizer_path (str): Path to original tokenizer (llama 2 tokenizer downlaoded from hf) + domain_tok_vocab_path (str): Path to domain specific vocab file (created from training a tokenizer from scratch) + new_vocab_path (str): Path to new vocabulary file + new_model_path (str): Path to new/extended tokenizer + old_ebd_path (str): Path to original llama2 embedding weights downlaoded from hf + new_ebd_path (str): Path to new embedding weights (modified due to tokenizer changes) + split (int): Number of splits used for original model weights (model parallelism) + + Returns: + Extended/new model files created and saved in the paths specified below + + """ + keys = ["text"] + occur_limit = 3 + + token_pattern = '[a-zA-Z]' # or [a-zA-Z0-9] + + # Read data from data path and store + readers = [] + for f in glob.glob(data_root + "**/*.jsonl", recursive=True): + f = open(f, mode="r") + readers.append(jsonlines.Reader(f)) + data = [] + for reader in readers: + for obj in reader: + for key in keys: + if key in obj: + data.append(" " + obj[key]) + + # Read domain specific voacb file and analyze added tokens + f = open(domain_tok_vocab_path) + vocab = json.load(f) + print("Domain vocab size:", len(vocab)) + + tokens = [] + drop_tokens = [] + print("token pattern: ", token_pattern) + for v in vocab: + if re.search(token_pattern, v): + tokens.append(v.replace("Ġ", "▁")) + else: + drop_tokens.append(v) + print("Num of added tokens and dropped tokens", len(tokens), len(drop_tokens)) + + m = model.ModelProto() + m.ParseFromString(open(original_tokenizer_path, 'rb').read()) + print(f'Original model pieces: {len(m.pieces)}') + print(m.trainer_spec) + ori_vol = [] + for piece in m.pieces: + ori_vol.append(piece.piece) + print("original vocab size: ", len(ori_vol)) + ori_vol = set(ori_vol) + data = set(data) + + new_tokens = [] + for token in tokens: + if token not in ori_vol: + token1 = token.replace("▁", " ") + occur_cnt = 0 + flag = True + for s in data: + if token1 in s: + occur_cnt += 1 + if occur_cnt > occur_limit: + flag = False + break + if flag: + new_tokens.append(token) + print("new token cnt: ", len(new_tokens)) + + normal_cnt = len(new_tokens) + dummy_cnt = (len(new_tokens) // 1024 + 1) * 1024 - len(new_tokens) + add_cnt = normal_cnt + dummy_cnt + print("add token cnt: ", add_cnt) + print("add normal token cnt: ", normal_cnt) + print("add dummy token cnt: ", dummy_cnt) + assert dummy_cnt >= 3, "should be at least 3 extra tokens for finetuning" + + dummy_tokens = [] + for i in range(dummy_cnt): + dummy_tokens.append(f"") + + record = [] + N = len(m.pieces) + for i, sym in enumerate(new_tokens): + new_sym = m.SentencePiece() + new_sym.piece = sym + new_sym.score = 0.0 # default score for USER_DEFINED + new_sym.type = 4 # type value for USER_DEFINED + m.pieces.insert(N + i, new_sym) # position after default control symbols ("", "", "") + record.append([sym, N + i]) + + N = len(m.pieces) + for i, sym in enumerate(dummy_tokens): + new_sym = m.SentencePiece() + new_sym.piece = sym + new_sym.score = 0.0 # default score for USER_DEFINED + new_sym.type = 4 # type value for USER_DEFINED + m.pieces.insert(N + i, new_sym) # position after default control symbols ("", "", "") + record.append([sym, N + i]) + + print(f'New model pieces: {len(m.pieces)}') + print(m.trainer_spec) + + check_parent_directory_exists(new_vocab_path) + with open(new_vocab_path, "w", encoding="utf8") as fp: + json.dump(record, fp) + + check_parent_directory_exists(new_model_path) + with open(new_model_path, 'wb') as f: + f.write(m.SerializeToString()) + + if split > 1: + old_ebd_paths = [] + for f in glob.glob(old_ebd_path + "/*.pt"): + old_ebd_paths.append(f) + + def myFunc(s): + return int(s.split("embedding_")[-1].split(".")[0]) ### embedding_0.pt + + old_ebd_paths.sort(key=myFunc) + word_embeddings = [] + output_layers = [] + for f in old_ebd_paths: + temp = torch.load(f) + word_embeddings.append(temp['word_embeddings']) + output_layers.append(temp['output_layer']) + word_embedding = torch.cat(word_embeddings, dim=1) + output_layer = torch.cat(output_layers, dim=0) + print("word_embedding shape: ", word_embedding.shape) + print("output_layer shape: ", output_layer.shape) + + _, N = word_embedding.shape + add_weight = torch.zeros(add_cnt, N) + word_embedding = torch.cat((word_embedding, add_weight), 0) + else: + old_ebd = torch.load(old_ebd_path) + _, N = old_ebd['word_embeddings'].shape + add_weight = torch.zeros(add_cnt, N) + old_ebd['word_embeddings'] = torch.cat((old_ebd['word_embeddings'], add_weight), 0) + + if split > 1: + _, M = output_layer.shape + add_out = torch.zeros(add_cnt, M) + output_layer = torch.cat((output_layer, add_out), 0) + else: + _, M = old_ebd['output_layer'].shape + add_out = torch.zeros(add_cnt, M) + old_ebd['output_layer'] = torch.cat((old_ebd['output_layer'], add_out), 0) + + sp = spm.SentencePieceProcessor() + sp.load(original_tokenizer_path) + + for r in record: + token = r[0] + idx = r[1] + ids = sp.encode_as_ids(token) + if split > 1: + word_embedding[idx] = torch.mean(word_embedding[ids], dim=0) + output_layer[idx] = torch.mean(output_layer[ids], dim=0) + else: + old_ebd['word_embeddings'][idx] = torch.mean(old_ebd['word_embeddings'][ids], dim=0) + old_ebd['output_layer'][idx] = torch.mean(old_ebd['output_layer'][ids], dim=0) + + if split > 1: + vocab_size, dimension = word_embedding.shape + split_dimension = dimension // (split) + split_vocab_size = vocab_size // split + prefix = new_ebd_path + "/embedding_" + for i in range(split): + start = i * split_dimension + end = (i + 1) * split_dimension + st = i * split_vocab_size + ed = (i + 1) * split_vocab_size + save_name = prefix + f"{i}" + ".pt" + temp = {} + temp['word_embeddings'] = word_embedding[:, start:end] # split word_embedding + temp['output_layer'] = output_layer[st:ed, :] # split output_layer + check_parent_directory_exists(save_name) + torch.save(temp, save_name) + else: + torch.save(old_ebd, new_ebd_path + str(len(m.pieces)) + ".pt") + + print("Completed saving new embeddings") + + +def analyze_token_usage(data_root, tokenizer_path, batchsize, keys, save_path): + """ + Function to analyze domain tokens using frequency analysis + Args: + data_root (str): Path to general/domain specific data to identify tokens + tokenizer_path (str): Path to original tokenizer (llama 2 tokenizer downlaoded from hf) + batchsize (int): batch size used for the text_iterator that generates of batches of text. + keys (list): Keys/metadata to extract from jsonl files + save_path (str): path to save token usage frequency analysis results + + Returns: + None, saves frequency analysis results to the provided path + + """ + extra_id = 32000 + sp = spm.SentencePieceProcessor() + sp.load(tokenizer_path) + + vocab_size = sp.get_piece_size() + print("vocab_size: ", vocab_size) + results = {} + + for name in glob.glob(data_root + "**/*.jsonl", recursive=True): + readers = [] + f = open(name, mode="r") + readers.append(jsonlines.Reader(f)) + + def gen(): + data = [] + cnt = 0 + for reader in readers: + for obj in reader: + for key in keys: + data.append(obj[key]) + cnt += 1 + if cnt >= batchsize: + yield data + cnt = 0 + data = [] + if len(data) > 0: + yield data + + ds = IterableDataset.from_generator(gen) + cnt_np = np.zeros(vocab_size) + for d in ds: + ids = sp.encode(d) + ids = flatten(ids) + counts = Counter(ids) + for key in counts: + cnt_np[key] += counts[key] + ori_cnt = cnt_np[0:extra_id].sum() + new_cnt = cnt_np[extra_id:].sum() + total_cnt = ori_cnt + new_cnt + print("ori cnt and new cnt: ", ori_cnt, new_cnt) + indices = np.flip(cnt_np.ravel().argsort()[-vocab_size:]) + flag = indices >= extra_id + cnts = cnt_np[indices] + old_freq = [] + new_freq = [] + for i in range(len(indices)): + if cnts[i] < 1: + break + id = indices[i] + if flag[i]: + new_freq.append([int(id), str(sp.id_to_piece(int(id))), int(flag[i]), int(cnts[i])]) + else: + old_freq.append([int(id), str(sp.id_to_piece(int(id))), int(flag[i]), int(cnts[i])]) + results[name] = {} + results[name]["ori_cnt"] = [int(ori_cnt), float(ori_cnt / total_cnt)] + results[name]["new_cnt"] = [int(new_cnt), float(new_cnt / total_cnt)] + results[name]["old_freq"] = old_freq + results[name]["new_freq"] = new_freq + f.close() + + with open(save_path, "w") as outfile: + json.dump(results, outfile) diff --git a/tutorials/llm/llama/domain-adaptive-pretraining/code/util.py b/tutorials/llm/llama/domain-adaptive-pretraining/code/util.py new file mode 100755 index 000000000000..a0fece98cfce --- /dev/null +++ b/tutorials/llm/llama/domain-adaptive-pretraining/code/util.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + + +def check_directory_exists(directory): + if os.path.isdir(directory): + print(f"Directory '{directory}' exists") + else: + raise FileNotFoundError(f"The directory '{directory}' does not exist. Please create it.") + + +def load_weights(load_path, save_path): + """ + This function loads llama2 weights (hugging face) and converts it to a Dict format suitable for NeMo + + Args: + load_path (str): Path to llama2 weights downlaoded from hugging face + save_path (str): Path to save modified dictionary containing the weights. + + Returns: + None + + """ + model_type = "llama2" + for i in range(8): + state_dict = torch.load(f"{load_path}/consolidated.0{i}.pth") + batch_dict = {} + if model_type == "llama2": + batch_dict['word_embeddings'] = state_dict['tok_embeddings.weight'] + batch_dict['output_layer'] = state_dict['output.weight'] + else: + batch_dict['word_embeddings'] = state_dict['model']['embedding.word_embeddings.weight'] # embedding layer + batch_dict['output_layer'] = state_dict['model']['output_layer.weight'] # output layer + torch.save(batch_dict, f'{save_path}/embedding_{i}.pt') + + +def merge_embed(old_embd_path, new_embd_path, save_path): + "Function to merge embeddings and convert back to hugging face format" + model_type = "llama2" + for i in range(8): + state_dict = torch.load(f"{old_embd_path}/consolidated.0{i}.pth") + batch_dict = torch.load(f'{new_embd_path}/embedding_{i}.pt') + if model_type == "llama2": + state_dict['output.weight'] = batch_dict['output_layer'] + state_dict['tok_embeddings.weight'] = batch_dict['word_embeddings'] + else: + state_dict['tok_embeddings.weight'] = batch_dict['word_embeddings'] + state_dict['output.weight'] = batch_dict['output_layer'] + check_directory_exists(save_path) + torch.save(state_dict, f"{save_path}/consolidated.0{i}.pth") diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/README.rst b/tutorials/llm/llama/nemo2-sft-peft/README.rst similarity index 100% rename from tutorials/llm/llama-3/nemo2-sft-peft/README.rst rename to tutorials/llm/llama/nemo2-sft-peft/README.rst diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb b/tutorials/llm/llama/nemo2-sft-peft/nemo2-peft.ipynb similarity index 100% rename from tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb rename to tutorials/llm/llama/nemo2-sft-peft/nemo2-peft.ipynb diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb b/tutorials/llm/llama/nemo2-sft-peft/nemo2-sft.ipynb similarity index 100% rename from tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb rename to tutorials/llm/llama/nemo2-sft-peft/nemo2-sft.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama/pruning-distillation/01_data_preparation.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb rename to tutorials/llm/llama/pruning-distillation/01_data_preparation.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama/pruning-distillation/02_teacher_finetuning.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb rename to tutorials/llm/llama/pruning-distillation/02_teacher_finetuning.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama/pruning-distillation/03_a_depth_pruning.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb rename to tutorials/llm/llama/pruning-distillation/03_a_depth_pruning.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama/pruning-distillation/03_b_width_pruning.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb rename to tutorials/llm/llama/pruning-distillation/03_b_width_pruning.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb rename to tutorials/llm/llama/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama/pruning-distillation/04_b_distilling_width_pruned_student.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb rename to tutorials/llm/llama/pruning-distillation/04_b_distilling_width_pruned_student.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama/pruning-distillation/05_display_results.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb rename to tutorials/llm/llama/pruning-distillation/05_display_results.ipynb diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama/pruning-distillation/README.rst similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/README.rst rename to tutorials/llm/llama/pruning-distillation/README.rst diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama/pruning-distillation/introduction.ipynb similarity index 100% rename from tutorials/llm/llama-3/pruning-distillation/introduction.ipynb rename to tutorials/llm/llama/pruning-distillation/introduction.ipynb diff --git a/tutorials/llm/llama-3/sdg-law-title-generation/README.rst b/tutorials/llm/llama/sdg-law-title-generation/README.rst similarity index 100% rename from tutorials/llm/llama-3/sdg-law-title-generation/README.rst rename to tutorials/llm/llama/sdg-law-title-generation/README.rst diff --git a/tutorials/llm/llama-3/sdg-law-title-generation/img/e2e-lora-train-and-deploy.png b/tutorials/llm/llama/sdg-law-title-generation/img/e2e-lora-train-and-deploy.png similarity index 100% rename from tutorials/llm/llama-3/sdg-law-title-generation/img/e2e-lora-train-and-deploy.png rename to tutorials/llm/llama/sdg-law-title-generation/img/e2e-lora-train-and-deploy.png diff --git a/tutorials/llm/llama-3/sdg-law-title-generation/llama3-sdg-lora-deploy-nim.ipynb b/tutorials/llm/llama/sdg-law-title-generation/llama3-sdg-lora-deploy-nim.ipynb similarity index 100% rename from tutorials/llm/llama-3/sdg-law-title-generation/llama3-sdg-lora-deploy-nim.ipynb rename to tutorials/llm/llama/sdg-law-title-generation/llama3-sdg-lora-deploy-nim.ipynb diff --git a/tutorials/llm/llama-3/sdg-law-title-generation/llama3-sdg-lora-nemofw.ipynb b/tutorials/llm/llama/sdg-law-title-generation/llama3-sdg-lora-nemofw.ipynb similarity index 100% rename from tutorials/llm/llama-3/sdg-law-title-generation/llama3-sdg-lora-nemofw.ipynb rename to tutorials/llm/llama/sdg-law-title-generation/llama3-sdg-lora-nemofw.ipynb diff --git a/tutorials/llm/llama-3/slimpajama/README.md b/tutorials/llm/llama/slimpajama/README.md similarity index 100% rename from tutorials/llm/llama-3/slimpajama/README.md rename to tutorials/llm/llama/slimpajama/README.md diff --git a/tutorials/llm/llama-3/slimpajama/data/concat.sh b/tutorials/llm/llama/slimpajama/data/concat.sh similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data/concat.sh rename to tutorials/llm/llama/slimpajama/data/concat.sh diff --git a/tutorials/llm/llama-3/slimpajama/data/download.py b/tutorials/llm/llama/slimpajama/data/download.py similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data/download.py rename to tutorials/llm/llama/slimpajama/data/download.py diff --git a/tutorials/llm/llama-3/slimpajama/data/extract.py b/tutorials/llm/llama/slimpajama/data/extract.py similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data/extract.py rename to tutorials/llm/llama/slimpajama/data/extract.py diff --git a/tutorials/llm/llama-3/slimpajama/data/preprocess.py b/tutorials/llm/llama/slimpajama/data/preprocess.py similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data/preprocess.py rename to tutorials/llm/llama/slimpajama/data/preprocess.py diff --git a/tutorials/llm/llama-3/slimpajama/data_pipeline.ipynb b/tutorials/llm/llama/slimpajama/data_pipeline.ipynb similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data_pipeline.ipynb rename to tutorials/llm/llama/slimpajama/data_pipeline.ipynb diff --git a/tutorials/llm/llama-3/slimpajama/data_pipeline.py b/tutorials/llm/llama/slimpajama/data_pipeline.py similarity index 100% rename from tutorials/llm/llama-3/slimpajama/data_pipeline.py rename to tutorials/llm/llama/slimpajama/data_pipeline.py diff --git a/tutorials/llm/llama-3/slimpajama/pretraining.ipynb b/tutorials/llm/llama/slimpajama/pretraining.ipynb similarity index 100% rename from tutorials/llm/llama-3/slimpajama/pretraining.ipynb rename to tutorials/llm/llama/slimpajama/pretraining.ipynb