From 250cb702e2f880a05e6e6b778b8f1f6bbd2dd451 Mon Sep 17 00:00:00 2001 From: Yangyang Li Date: Tue, 15 Oct 2024 10:56:48 -0500 Subject: [PATCH] update notebook --- .gitignore | 8 +- notebooks/.gitkeep | 0 notebooks/Introduction.ipynb | 33 - notebooks/model.ipynb | 1213 - notebooks/protype.ipynb | 2151 - notebooks/smooth2.ipynb | 191061 -------------------------------- 6 files changed, 3 insertions(+), 194463 deletions(-) delete mode 100644 notebooks/.gitkeep delete mode 100644 notebooks/Introduction.ipynb delete mode 100644 notebooks/model.ipynb delete mode 100644 notebooks/protype.ipynb delete mode 100644 notebooks/smooth2.ipynb diff --git a/.gitignore b/.gitignore index db6d490..44f67c0 100644 --- a/.gitignore +++ b/.gitignore @@ -580,10 +580,7 @@ hyena_model_use_qual hyena_model_train train.slurm deepchopper_train* -notebooks/dc* -notebooks/cdc* -notebooks/ont* -notebooks/sg* +notebooks/* test_predict* cnn.onnx tmp @@ -594,4 +591,5 @@ scripts/*.json !build.rs .ruff_cache hg_deepchopper -analysis_data \ No newline at end of file +analysis_data + diff --git a/notebooks/.gitkeep b/notebooks/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/notebooks/Introduction.ipynb b/notebooks/Introduction.ipynb deleted file mode 100644 index d2b8454..0000000 --- a/notebooks/Introduction.ipynb +++ /dev/null @@ -1,33 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0468d1f9-45ec-4931-9067-fc1c7dd6dbb4", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.13" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/model.ipynb b/notebooks/model.ipynb deleted file mode 100644 index 58dcb1d..0000000 --- a/notebooks/model.ipynb +++ /dev/null @@ -1,1213 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "0", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from pathlib import Path\n", - "\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "import seaborn as sns\n", - "import torch\n", - "from datasets import Dataset, load_dataset\n", - "from IPython.core.interactiveshell import InteractiveShell\n", - "\n", - "import deepchopper\n", - "\n", - "InteractiveShell.ast_node_interactivity = \"all\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1", - "metadata": {}, - "outputs": [], - "source": [ - "from rich.console import Console\n", - "from rich.text import Text\n", - "\n", - "\n", - "def highlight_target(seq: str, start: int, end: int, style=\"bold magenta\"):\n", - " text = Text(seq)\n", - " console = Console()\n", - " text.stylize(style, start, end)\n", - " console.print(text)\n", - "\n", - "\n", - "def hightlight_predict(\n", - " seq: str, target_start: int, target_end: int, predict_start: int, predict_end: int\n", - "):\n", - " text = Text(seq)\n", - " console = Console()\n", - "\n", - " text.stylize(\"#adb0b1\", target_start, target_end)\n", - " text.stylize(\"bold magenta\", predict_start, predict_end)\n", - "\n", - " console.print(text)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import platform\n", - "\n", - "print(f\"{platform.system()=}\")\n", - "if platform.system() == \"Linux\":\n", - " root_dir = Path(\"/projects/b1171/ylk4626/project/DeepChopper\")\n", - "else:\n", - " root_dir = Path(\"/Users/ylk4626/ClionProjects/DeepChopper\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "train_file = root_dir / \"tests/data/test_input.parquet\"\n", - "data_files = {\"train\": train_file.as_posix()}\n", - "\n", - "num_proc = 8\n", - "train_dataset = load_dataset(\n", - " \"parquet\",\n", - " data_files=data_files,\n", - " num_proc=num_proc,\n", - " split=\"train[:80%]\",\n", - ").with_format(\"torch\")\n", - "val_dataset = load_dataset(\n", - " \"parquet\", data_files=data_files, num_proc=num_proc, split=\"train[80%:90%]\"\n", - ").with_format(\"torch\")\n", - "test_dataset = load_dataset(\n", - " \"parquet\", data_files=data_files, num_proc=num_proc, split=\"train[90%:]\"\n", - ").with_format(\"torch\")\n", - "\n", - "print(f\"train_dataset: {train_dataset}\")\n", - "print(f\"val_dataset: {val_dataset}\")\n", - "print(f\"test_dataset: {test_dataset}\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4", - "metadata": {}, - "outputs": [], - "source": [ - "train_dataset.features" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5", - "metadata": {}, - "outputs": [], - "source": [ - "import pandas as pd\n", - "\n", - "\n", - "def show_example_for_dataset(dataset, split=None, first_examples: int = 10):\n", - " if split is not None:\n", - " id = dataset[split][\"id\"][0:first_examples]\n", - " seq = dataset[split][\"seq\"][0:first_examples]\n", - " qual = dataset[split][\"qual\"][0:first_examples]\n", - " target = dataset[split][\"target\"][0:first_examples]\n", - " else:\n", - " id = dataset[\"id\"][0:first_examples]\n", - " seq = dataset[\"seq\"][0:first_examples]\n", - " qual = dataset[\"qual\"][0:first_examples]\n", - " target = dataset[\"target\"][0:first_examples]\n", - "\n", - " qual = [i.tolist() for i in qual]\n", - " target = [i.tolist() for i in target]\n", - " df = pd.DataFrame({\"id\": id, \"seq\": seq, \"qual\": qual, \"target\": target})\n", - " return df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6", - "metadata": {}, - "outputs": [], - "source": [ - "show_example_for_dataset(train_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8", - "metadata": {}, - "outputs": [], - "source": [ - "highlight_target(seq, *target)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9", - "metadata": {}, - "outputs": [], - "source": [ - "hightlight_predict(seq, *target, 1070, 1120)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "hightlight_predict(seq, *target, 1060, 1120)" - ] - }, - { - "cell_type": "markdown", - "id": "11", - "metadata": {}, - "source": [ - "# 1. Read Len of Direct RNA" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "12", - "metadata": {}, - "outputs": [], - "source": [ - "def vis_bam_record_len():\n", - " direc_rna_samples = [\"22Rv1\", \"DU145\", \"LNCaP\", \"LuCaP\", \"PC3\", \"VCaP\"]\n", - " data = [np.load(root_dir / f\"data/direct_rna/{p}.npy\") for p in direc_rna_samples]\n", - " # plt.rc('font', family='Times New Roman')\n", - "\n", - " fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(10, 6))\n", - "\n", - " flat_axs = axs.flatten()\n", - "\n", - " for i, sample in enumerate(range(len(direc_rna_samples))):\n", - " # Create the density plot\n", - " sns.kdeplot(data[i], fill=True, ax=flat_axs[i])\n", - " flat_axs[i].set_title(f\"Sample {sample}\")\n", - "\n", - " # _ = ax1.set_xlabel('Threshold', fontsize=20)\n", - " # _ = ax1.set_ylabel('Length of itemsets', fontsize=20)\n", - "\n", - " # ax1.legend(['Sliding window average'],fontsize=18,loc='lower left',edgecolor='k',fancybox=True)\n", - "\n", - " # ax1.tick_params(axis='y', labelsize=15)\n", - " # ax1.tick_params(axis='x', labelsize=15\n", - " fig.set_size_inches(20, 20)\n", - "\n", - " # Adding labels and title\n", - " plt.title(\"Read Length of Direc RNA\")\n", - " plt.xticks(rotation=30)\n", - "\n", - " return data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "13", - "metadata": {}, - "outputs": [], - "source": [ - "vis_bam_record_len()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "14", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "15", - "metadata": {}, - "outputs": [], - "source": [ - "data = vis_bam_record_len(root_dir / f\"data/direct_rna/{direc_rna_samples[0]}.npy\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16", - "metadata": {}, - "outputs": [], - "source": [ - "max(data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "17", - "metadata": {}, - "outputs": [], - "source": [ - "d2 = list(data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "18", - "metadata": {}, - "outputs": [], - "source": [ - "d2.remove(103380)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "19", - "metadata": {}, - "outputs": [], - "source": [ - "max(d2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20", - "metadata": {}, - "outputs": [], - "source": [ - "sns.kdeplot(d2, fill=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "21", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "22", - "metadata": {}, - "outputs": [], - "source": [ - "data.sort()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "23", - "metadata": {}, - "outputs": [], - "source": [ - "sns.kdeplot(data[:-800], fill=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "24", - "metadata": {}, - "outputs": [], - "source": [ - "des = pd.Series(data).describe()" - ] - }, - { - "cell_type": "markdown", - "id": "25", - "metadata": {}, - "source": [ - "# 2. Build Model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "26", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "from transformers import (\n", - " AutoConfig,\n", - " AutoModelForSequenceClassification,\n", - " AutoTokenizer,\n", - " Trainer,\n", - " TrainingArguments,\n", - " logging,\n", - ")\n", - "\n", - "\n", - "def load_tokenizer_from_hyena_model(model_name):\n", - " max_lengths = {\n", - " \"hyenadna-tiny-1k-seqlen\": 1024,\n", - " \"hyenadna-small-32k-seqlen\": 32768,\n", - " \"hyenadna-medium-160k-seqlen\": 160000,\n", - " \"hyenadna-medium-450k-seqlen\": 450000, # T4 up to here\n", - " \"hyenadna-large-1m-seqlen\": 1_000_000, # only A100 (paid tier)\n", - " }\n", - "\n", - " if model_name not in max_lengths:\n", - " msg = f\"Model name {model_name} not found in available models.\"\n", - " raise ValueError(msg)\n", - "\n", - " max_length = max_lengths[model_name]\n", - " # bfloat16 for better speed and reduced memory usage\n", - " model_name = f\"LongSafari/{model_name}-hf\"\n", - " return AutoTokenizer.from_pretrained(\n", - " model_name, max_length=max_length, truncation=True, padding=True, trust_remote_code=True\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "27", - "metadata": {}, - "outputs": [], - "source": [ - "import evaluate\n", - "import numpy as np\n", - "\n", - "clf_metrics = evaluate.combine([\"accuracy\", \"f1\", \"precision\", \"recall\"])\n", - "\n", - "\n", - "def compute_metrics(p):\n", - " predictions, labels = p\n", - "\n", - " # print(f\"{predictions.shape=}\")\n", - " # print(f\"{labels.shape=}\")\n", - "\n", - " predictions = np.argmax(predictions, axis=2)\n", - " # Initialize lists to hold the filtered predictions and labels\n", - " true_predictions = []\n", - " true_labels = []\n", - "\n", - " # Filter out '-100' labels and correspondingly filter predictions\n", - " for prediction, label in zip(predictions, labels):\n", - " filtered_prediction = []\n", - " filtered_label = []\n", - "\n", - " for p, l in zip(prediction, label):\n", - " if l != -100:\n", - " filtered_prediction.append(p)\n", - " filtered_label.append(l)\n", - " true_predictions.append(filtered_prediction)\n", - " true_labels.append(filtered_label)\n", - "\n", - " for preds, refs in zip(true_predictions, true_labels):\n", - " clf_metrics.add_batch(predictions=preds, references=refs)\n", - "\n", - " result = clf_metrics.compute()\n", - " return result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "28", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import List\n", - "\n", - "import torch\n", - "from torch import nn\n", - "from transformers import AutoModel, PretrainedConfig, PreTrainedModel\n", - "from transformers.modeling_outputs import TokenClassifierOutput\n", - "from transformers.utils import logging\n", - "\n", - "logging.set_verbosity_info()\n", - "logger = logging.get_logger(\"transformers\")\n", - "\n", - "\n", - "class TokenClassificationHead(nn.Module):\n", - " def __init__(\n", - " self,\n", - " input_size: int,\n", - " lin1_size: int,\n", - " lin2_size: int,\n", - " num_class: int,\n", - " *,\n", - " use_identity_layer_for_qual: bool,\n", - " ):\n", - " super().__init__()\n", - " self.activation = nn.ReLU()\n", - " self.linear1 = nn.Linear(input_size, lin1_size)\n", - " self.linear2 = nn.Linear(lin1_size, lin2_size)\n", - " self.linear3 = nn.Linear(lin2_size, num_class)\n", - " self.qual_linear1 = (\n", - " nn.Identity() if use_identity_layer_for_qual else nn.Linear(1, lin1_size)\n", - " )\n", - "\n", - " def forward(self, x: torch.Tensor, input_quals: torch.Tensor) -> torch.Tensor:\n", - " output = self.activation(self.linear1(x))\n", - " residual = output + self.qual_linear1(input_quals.unsqueeze(-1)) # may add activation\n", - " output = self.activation(self.linear2(residual) + residual)\n", - " return self.linear3(output)\n", - "\n", - "\n", - "class TokenClassificationConfig(PretrainedConfig):\n", - " model_type = \"token-classification\"\n", - "\n", - " def __init__(\n", - " self,\n", - " input_size: int = 256,\n", - " lin1_size: int = 1024,\n", - " lin2_size: int = 1024,\n", - " num_class: int = 2,\n", - " *,\n", - " use_identity_layer_for_qual: bool = True,\n", - " **kwargs,\n", - " ):\n", - " self.input_size = input_size\n", - " self.lin1_size = lin1_size\n", - " self.lin2_size = lin2_size\n", - " self.num_class = num_class\n", - " self.use_identity_layer_for_qual = use_identity_layer_for_qual\n", - " super().__init__(**kwargs)\n", - "\n", - "\n", - "class TokenClassification(PreTrainedModel):\n", - " config_class = TokenClassificationConfig\n", - "\n", - " def __init__(\n", - " self,\n", - " config,\n", - " hyenadna_model: str = \"hyenadna-small-32k-seqlen\",\n", - " **kwargs,\n", - " ):\n", - " super().__init__(config, **kwargs)\n", - " self.num_class = config.num_class\n", - " self.hyenadna_model_name = hyenadna_model\n", - " self.hyenadna = AutoModel.from_pretrained(\n", - " f\"LongSafari/{hyenadna_model}-hf\", trust_remote_code=True\n", - " )\n", - "\n", - " self.head = TokenClassificationHead(\n", - " input_size=config.input_size,\n", - " lin1_size=config.lin1_size,\n", - " lin2_size=config.lin2_size,\n", - " num_class=config.num_class,\n", - " use_identity_layer_for_qual=config.use_identity_layer_for_qual,\n", - " )\n", - "\n", - " # Initialize weights and apply final processing\n", - " self.post_init()\n", - "\n", - " def forward(\n", - " self,\n", - " input_ids: torch.Tensor,\n", - " labels: torch.Tensor,\n", - " input_quals: torch.Tensor,\n", - " inputs_embeds: torch.FloatTensor | None = None,\n", - " output_hidden_states: bool | None = None,\n", - " return_dict: bool | None = None,\n", - " ) -> torch.Tensor:\n", - " # logger.info(f\"{input_ids.shape=}\")\n", - " # logger.info(f\"{labels.shape=}\")\n", - " # logger.info(f\"{input_quals.shape=}\")\n", - "\n", - " transformer_outputs = self.backbone(\n", - " input_ids,\n", - " inputs_embeds=inputs_embeds,\n", - " output_hidden_states=output_hidden_states,\n", - " return_dict=return_dict,\n", - " )\n", - "\n", - " batch_size = input_ids.shape[0]\n", - " hidden_states = transformer_outputs[0]\n", - "\n", - " logits = self.head(hidden_states, input_quals)\n", - " labels = labels.to(logits.device)\n", - " loss_fct = nn.CrossEntropyLoss()\n", - "\n", - " loss = loss_fct(logits.view(-1, self.num_class), labels.view(-1))\n", - "\n", - " return TokenClassifierOutput(\n", - " loss=loss,\n", - " logits=logits,\n", - " hidden_states=transformer_outputs.hidden_states,\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "29", - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import DataCollatorForTokenClassification\n", - "\n", - "\n", - "def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):\n", - " \"\"\"\n", - " Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.\n", - " \"\"\"\n", - "\n", - " # To avoid errors when using Feature extractors\n", - " if not hasattr(tokenizer, \"deprecation_warnings\"):\n", - " return tokenizer.pad(*pad_args, **pad_kwargs)\n", - "\n", - " # Save the state of the warning, then disable it\n", - " warning_state = tokenizer.deprecation_warnings.get(\"Asking-to-pad-a-fast-tokenizer\", False)\n", - " tokenizer.deprecation_warnings[\"Asking-to-pad-a-fast-tokenizer\"] = True\n", - "\n", - " try:\n", - " padded = tokenizer.pad(*pad_args, **pad_kwargs)\n", - " finally:\n", - " # Restore the state of the warning.\n", - " tokenizer.deprecation_warnings[\"Asking-to-pad-a-fast-tokenizer\"] = warning_state\n", - "\n", - " return padded\n", - "\n", - "\n", - "class DataCollatorForTokenClassificationWithQual(DataCollatorForTokenClassification):\n", - "\n", - " def torch_call(self, features):\n", - " import torch\n", - "\n", - " label_name = \"label\" if \"label\" in features[0].keys() else \"labels\"\n", - " labels = (\n", - " [feature[label_name] for feature in features]\n", - " if label_name in features[0].keys()\n", - " else None\n", - " )\n", - "\n", - " qual_name = \"input_quals\"\n", - " qual_pad_token_id = 0\n", - " input_quals = [feature[qual_name] for feature in features]\n", - "\n", - " no_labels_features = [\n", - " {k: v for k, v in feature.items() if k not in [qual_name, label_name]}\n", - " for feature in features\n", - " ]\n", - "\n", - " batch = pad_without_fast_tokenizer_warning(\n", - " self.tokenizer,\n", - " no_labels_features,\n", - " padding=self.padding,\n", - " max_length=self.max_length,\n", - " pad_to_multiple_of=self.pad_to_multiple_of,\n", - " return_tensors=\"pt\",\n", - " )\n", - "\n", - " if labels is None:\n", - " return batch\n", - "\n", - " sequence_length = batch[\"input_ids\"].shape[1]\n", - " padding_side = self.tokenizer.padding_side\n", - "\n", - " def to_list(tensor_or_iterable):\n", - " if isinstance(tensor_or_iterable, torch.Tensor):\n", - " return tensor_or_iterable.tolist()\n", - " return list(tensor_or_iterable)\n", - "\n", - " if padding_side == \"right\":\n", - " batch[label_name] = [\n", - " to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label))\n", - " for label in labels\n", - " ]\n", - " batch[qual_name] = [\n", - " to_list(qual) + [qual_pad_token_id] * (sequence_length - len(qual))\n", - " for qual in input_quals\n", - " ]\n", - " else:\n", - " batch[label_name] = [\n", - " [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label)\n", - " for label in labels\n", - " ]\n", - " batch[qual_name] = [\n", - " [qual_pad_token_id] * (sequence_length - len(qual)) + to_list(qual)\n", - " for qual in input_quals\n", - " ]\n", - "\n", - " batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)\n", - " batch[qual_name] = torch.tensor(batch[qual_name], dtype=torch.int64)\n", - " return batch" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30", - "metadata": {}, - "outputs": [], - "source": [ - "from functools import partial\n", - "\n", - "from transformers import DataCollatorForTokenClassification\n", - "\n", - "\n", - "def tokenize_and_align_labels_and_quals(data, tokenizer, max_length, pad_qual=0, pad_label=-100):\n", - " tokenized_inputs = tokenizer(data[\"seq\"], max_length=max_length, truncation=True, padding=True)\n", - " labels = torch.tensor(\n", - " deepchopper.vertorize_target(*data[\"target\"], len(data[\"seq\"])) + [pad_label]\n", - " )\n", - " quals = torch.cat((data[\"qual\"], torch.tensor([pad_qual]))).float()\n", - " normalized_quals = torch.nn.functional.normalize(quals, dim=0)\n", - " tokenized_inputs.update({\"labels\": labels, \"input_quals\": quals})\n", - " return tokenized_inputs\n", - "\n", - "\n", - "def tokenize_dataset(dataset, tokenizer, max_length):\n", - " return dataset.map(\n", - " partial(tokenize_and_align_labels_and_quals, tokenizer=tokenizer, max_length=max_length)\n", - " ).remove_columns([\"id\", \"seq\", \"qual\", \"target\"])\n", - "\n", - "\n", - "hyenadna_name = \"hyenadna-small-32k-seqlen\"\n", - "tokenizer = load_tokenizer_from_hyena_model(hyenadna_name)\n", - "\n", - "tokenize_train_dataset = tokenize_dataset(\n", - " train_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence\n", - ")\n", - "tokenize_val_dataset = tokenize_dataset(\n", - " val_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence\n", - ")\n", - "tokenize_test_dataset = tokenize_dataset(\n", - " test_dataset, tokenizer, max_length=tokenizer.max_len_single_sentence\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "31", - "metadata": {}, - "outputs": [], - "source": [ - "# data_collator = DataCollatorForTokenClassification(tokenizer)\n", - "data_collator = DataCollatorForTokenClassificationWithQual(tokenizer)\n", - "model_config = TokenClassificationConfig()\n", - "model = TokenClassification(model_config)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "32", - "metadata": {}, - "outputs": [], - "source": [ - "model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "33", - "metadata": {}, - "outputs": [], - "source": [ - "tokenize_train_dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "34", - "metadata": {}, - "outputs": [], - "source": [ - "from accelerate import Accelerator\n", - "\n", - "accelerator = Accelerator()\n", - "\n", - "training_args = TrainingArguments(\n", - " output_dir=\"hyena_model_use_qual_test\",\n", - " learning_rate=2e-5,\n", - " per_device_train_batch_size=12,\n", - " per_device_eval_batch_size=12,\n", - " num_train_epochs=1,\n", - " weight_decay=0.01,\n", - " evaluation_strategy=\"epoch\",\n", - " save_strategy=\"epoch\",\n", - " load_best_model_at_end=True,\n", - " push_to_hub=False,\n", - " torch_compile=False,\n", - " # tf32=True,\n", - " report_to=\"wandb\",\n", - " run_name=\"hyena_model_use_qual\",\n", - " resume_from_checkpoint=False,\n", - ")\n", - "\n", - "trainer = Trainer(\n", - " model=model,\n", - " args=training_args,\n", - " train_dataset=tokenize_train_dataset,\n", - " eval_dataset=tokenize_test_dataset,\n", - " tokenizer=tokenizer,\n", - " data_collator=data_collator,\n", - " compute_metrics=compute_metrics,\n", - ")\n", - "\n", - "trainer = accelerator.prepare(trainer)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "35", - "metadata": {}, - "outputs": [], - "source": [ - "result = trainer.train()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "38", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "39", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40", - "metadata": {}, - "outputs": [], - "source": [ - "# resume_config = TokenClassificationConfig.from_pretrained(\"./hyena_model_test2/checkpoint-1000/\")\n", - "resume_model = TokenClassification.from_pretrained(\"./hyena_model_test2/checkpoint-500/\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "41", - "metadata": {}, - "outputs": [], - "source": [ - "for k in model.state_dict():\n", - " v1 = resume_model.state_dict()[k]\n", - " v2 = model.state_dict()[k]\n", - " result = v1.eq(v2)\n", - " if not torch.all(result):\n", - " print(f\"{k} is not equal\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "42", - "metadata": {}, - "outputs": [], - "source": [ - "from safetensors import safe_open\n", - "\n", - "tensors = {}\n", - "with safe_open(\n", - " \"./hyena_model_test2/checkpoint-500/model.safetensors\", framework=\"pt\", device=0\n", - ") as f:\n", - " for k in f.keys():\n", - " tensors[k] = f.get_tensor(k)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "43", - "metadata": {}, - "outputs": [], - "source": [ - "trainer.evaluate()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "44", - "metadata": {}, - "outputs": [], - "source": [ - "predicts = trainer.predict(tokenize_val_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "45", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "46", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47", - "metadata": {}, - "outputs": [], - "source": [ - "def summary_predict(predictions, labels):\n", - " predictions = np.argmax(predictions, axis=2)\n", - " # Initialize lists to hold the filtered predictions and labels\n", - " true_predictions = []\n", - " true_labels = []\n", - "\n", - " # Filter out '-100' labels and correspondingly filter predictions\n", - " for prediction, label in zip(predictions, labels):\n", - " filtered_prediction = []\n", - " filtered_label = []\n", - "\n", - " for p, l in zip(prediction, label):\n", - " if l != -100:\n", - " filtered_prediction.append(p)\n", - " filtered_label.append(l)\n", - " true_predictions.append(filtered_prediction)\n", - " true_labels.append(filtered_label)\n", - "\n", - " return true_predictions, true_labels\n", - "\n", - "\n", - "from rich.console import Console\n", - "from rich.highlighter import RegexHighlighter\n", - "from rich.theme import Theme\n", - "\n", - "\n", - "class LabelHighlighter(RegexHighlighter):\n", - " \"\"\"Apply style to anything that looks like an email.\"\"\"\n", - "\n", - " base_style = \"label.\"\n", - " highlights = [r\"(?P