diff --git a/.gitignore b/.gitignore index a5c2282..6d4b97c 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ carp/dataset/dataset_dict.json carp/dataset/train/cache-a5b0849dd9416bdb.arrow carp/dataset/train/dataset_info.json carp/dataset/train/state.json +*.csv +/carp/experiments/distil_carp/20B_tokenizer.json diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/magiCARP.iml b/.idea/magiCARP.iml new file mode 100644 index 0000000..5fdd65b --- /dev/null +++ b/.idea/magiCARP.iml @@ -0,0 +1,15 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..25253bd --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,17 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..8c7d79c --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/carp/pytorch/finetune/__init__.py b/carp/experiments/__init__.py similarity index 100% rename from carp/pytorch/finetune/__init__.py rename to carp/experiments/__init__.py diff --git a/checkpoints/.gitkeep b/carp/experiments/distil_carp/__init__.py similarity index 100% rename from checkpoints/.gitkeep rename to carp/experiments/distil_carp/__init__.py diff --git a/carp/experiments/distil_carp/examine_data.py b/carp/experiments/distil_carp/examine_data.py new file mode 100644 index 0000000..d69e6fc --- /dev/null +++ b/carp/experiments/distil_carp/examine_data.py @@ -0,0 +1,46 @@ +import pandas as pd +import os +import csv + +def read_dataset_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row[1]) + return data + +def read_paraphrase_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row) + return data + +path = 'distil_data' + +crit_data = [] +crit_datapath = path+'/paraphrase_train_crits' +files = os.listdir(crit_datapath) +files.sort() +for file in files: + print(file) + filepath = os.path.join(crit_datapath, file) + crit_data_chunk = read_paraphrase_component(filepath) + crit_data += crit_data_chunk + +story_file = 'train_stories.csv' +story_datapath = os.path.join(path, story_file) +story_data = read_dataset_component(story_datapath) + +orig_crit_file = 'train_crits.csv' +orig_crit_datapath = os.path.join(path, orig_crit_file) +orig_crit_data = read_dataset_component(orig_crit_datapath) + +print("NUM STORIES: ", len(story_data)) +print("NUM CRITIQUE LISTS: ", len(crit_data)) +print("NUM ORIG CRITS: ", len(orig_crit_data)) +print("NUM CRITIQUES PER: ", len(crit_data[1])) +print(story_data[1]) +print(crit_data[1]) \ No newline at end of file diff --git a/carp/experiments/distil_carp/generate_fuzzing_data.py b/carp/experiments/distil_carp/generate_fuzzing_data.py new file mode 100644 index 0000000..1905c7b --- /dev/null +++ b/carp/experiments/distil_carp/generate_fuzzing_data.py @@ -0,0 +1,75 @@ +import csv +import numpy as np +import openai +from neox_tokenizer import * + +bad_word_dict = \ + {"0": -1000, "50256": -1000, "3353": -1000} +# add all line breaks to the bad word dict +for idx in line_break_token_ids: + bad_word_dict[str(idx)] = -1000 + +key = "" +with open("../../pytorch/data/utils/api_key.txt") as f: + key = f.read() + +# nice try ;) +openai.api_key = key +openai.api_base = "https://api.goose.ai/v1" + +prompt = "You are an editor of stories. Below is a set of stories and the the criticisms you have written for each " \ + "manuscript.\n\n " + + +def read_dataset_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row[1]) + return data + + +val_stories = read_dataset_component("../../pytorch/data/utils/val_stories.csv") +val_crits = read_dataset_component("../../pytorch/data/utils/val_crits.csv") + +# Filter out stories that are below seven words long. +stories_crits = list(list(filter(lambda x: (len(x[0].split()) > 7), zip(val_stories, val_crits)))) +stories_crits = [[story for story, _ in stories_crits], + [crit for _, crit in stories_crits]] +val_stories = stories_crits[0] +val_crits = stories_crits[1] + +train_stories = read_dataset_component("../../pytorch/data/utils/train_stories.csv")[0:10] +examples_n = 5 +indices = np.random.choice(len(val_stories), examples_n, replace=False) +print("[" + ", ".join(list(map(str, indices))) + "]") +indices = [56, 23, 14, 65, 60] + +for i in range(1, examples_n + 1, 1): + idx = indices[i - 1] + story_example = str(i) + ". Story: " + val_stories[idx] + "\nCriticism: " + val_crits[idx] + "\n\n" + prompt += story_example + +for input_story in train_stories: + story_prompt = prompt + str(examples_n + 1) + ". Story: " + input_story + "\nCriticism:" + # Biasing against EOT (0, 50256), and \n + print(story_prompt) + import sys + sys.exit() + completion = openai.Completion.create( + engine="gpt-neo-20b", + prompt=story_prompt, + max_tokens=40, + typical_p=0.5, + logit_bias=bad_word_dict, + logprobs=30, + stream=True) + print(input_story) + print("\n") + + # Print each token as it is returned + for c in completion: + print(c.choices[0].text, end='') + + print("\n=============\n") diff --git a/carp/experiments/distil_carp/neox_tokenizer.py b/carp/experiments/distil_carp/neox_tokenizer.py new file mode 100644 index 0000000..c75d0bd --- /dev/null +++ b/carp/experiments/distil_carp/neox_tokenizer.py @@ -0,0 +1,8 @@ +import json +tokenizer_data = None +with open("20B_tokenizer.json") as f: + tokenizer_data = json.load(f) +line_break_token_ids = list() +for idx, k in enumerate(tokenizer_data['model']['vocab'].keys()): + if u"\u010a" in k: + line_break_token_ids.append(idx) diff --git a/carp/experiments/distil_carp/paraphrase_critiques.py b/carp/experiments/distil_carp/paraphrase_critiques.py new file mode 100644 index 0000000..2c187fe --- /dev/null +++ b/carp/experiments/distil_carp/paraphrase_critiques.py @@ -0,0 +1,55 @@ +from transformers import PegasusForConditionalGeneration, PegasusTokenizer +import csv +import sys +from tqdm import tqdm + +tokenizer_pegasus = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase') +model_pegasus = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase').half().to("cuda") + +def read_dataset_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row[1]) + return data + +num_beams = 5 +def get_review_ensemble(input_text): + batch = tokenizer_pegasus(input_text,truncation=True,padding='longest', max_length=60, return_tensors="pt").to("cuda") + translated = model_pegasus.generate(**batch,max_length=60, num_beams=num_beams, num_return_sequences=num_beams, temperature=1.5) + return tokenizer_pegasus.batch_decode(translated, skip_special_tokens=True) + +def write_dataset_csv(data, filepath): + with open(filepath, mode='w') as csvfile: + writer = csv.writer(csvfile, quoting=csv.QUOTE_MINIMAL) + writer.writerows(data) + +filepath = '/mnt/raid/users/AlexH/magiCARP/carp/pytorch/data/utils/train_crits.csv' +data = read_dataset_component(filepath) +batch_size = 100 +num_batches = (len(data) + batch_size - 1) // batch_size +output_file = 'paraphrase_train_crits.csv' +write_thresh = 1000 +temp_csv = [] +print(len(data)) +for i in tqdm(range(num_batches)): + cur_batch_size = min(batch_size, len(data)-batch_size*i) + batch = data[i*batch_size:i*batch_size+cur_batch_size] + #print(batch) + num_paraphrases = 5 + paraphrases = get_review_ensemble(batch) + reshaped_paraphrases = [] + for j in range(len(paraphrases)//num_paraphrases): + reshaped_paraphrases.append([]) + for k in range(num_paraphrases): + reshaped_paraphrases[-1].append(paraphrases[j*num_paraphrases+k]) + #print(reshaped_paraphrases) + temp_csv += reshaped_paraphrases + if (i+1) % write_thresh == 0: + print("WRITING TO CSV") + cur_output_file = output_file+f"_{i}" + write_dataset_csv(temp_csv ,cur_output_file) + temp_csv = [] +write_dataset_csv(temp_csv, output_file) + diff --git a/carp/pytorch/data/__init__.py b/carp/pytorch/data/__init__.py index d7bd3ae..216cab5 100644 --- a/carp/pytorch/data/__init__.py +++ b/carp/pytorch/data/__init__.py @@ -14,8 +14,9 @@ # specifies a dictionary of architectures _DATAPIPELINE: Dict[str, any] = {} # registry + def register_datapipeline(name): - """Decorator used register a CARP architecture + """Decorator used register a CARP architecture Args: name: Name of the architecture @@ -29,7 +30,7 @@ def register_class(cls, name): if isinstance(name, str): name = name.lower() return lambda c: register_class(c, name) - + cls = name name = cls.__name__ register_class(cls, name.lower()) @@ -42,9 +43,9 @@ class BaseDataPipeline(Dataset): """Dataset wrapper class to ease working with the CARP dataset and Pytorch data utilities.""" def __init__( - self, - dupe_protection: bool = True, - path: str = "dataset", + self, + dupe_protection: bool = True, + path: str = "dataset", ): dataset = load_from_disk(path) train = dataset["train"] @@ -70,8 +71,7 @@ def __len__(self) -> int: return len(self.passages) @staticmethod - def create_tokenizer_factory(call_tokenizer : Callable, tokenizer_factory : Callable, context_len : int) -> Callable: - + def create_tokenizer_factory(call_tokenizer: Callable, tokenizer_factory: Callable, context_len: int) -> Callable: """Function creates a callable tokenizer subroutine and uses it to curry the tokenizer factory Args: @@ -85,7 +85,7 @@ def create_tokenizer_factory(call_tokenizer : Callable, tokenizer_factory : Call return partial(tokenizer_factory, tok_func) @staticmethod - def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable: + def tokenizer_factory(_tok: Callable, encoder: BaseEncoder) -> Callable: """Function factory that creates a collate function for use with a torch.util.data.Dataloader @@ -95,9 +95,10 @@ def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable: Returns: Callable: A function that will take a batch of string tuples and tokenize them properly. """ + @typechecked def collate( - data: Iterable[Tuple[str, str]] + data: Iterable[Tuple[str, str]] ) -> Tuple[BatchElement, BatchElement]: passages, reviews = zip(*data) pass_tokens, rev_tokens = _tok(list(passages)), _tok(list(reviews)) @@ -111,12 +112,15 @@ def collate( ) return collate - + from carp.pytorch.data.mlm_pipeline import MLMDataPipeline from carp.pytorch.data.scarecrow_pipeline import ScarecrowDataPipeline +from carp.pytorch.data.distill_pipeline import DistillDataPipeline + def get_datapipeline(name): return _DATAPIPELINE[name.lower()] + def get_datapipeline_names(): return _DATAPIPELINE.keys() diff --git a/carp/pytorch/data/distill_pipeline.py b/carp/pytorch/data/distill_pipeline.py new file mode 100644 index 0000000..95804dd --- /dev/null +++ b/carp/pytorch/data/distill_pipeline.py @@ -0,0 +1,103 @@ +from torch.functional import Tensor +from carp.pytorch.data import * +from carp.pytorch.model.encoders import BaseEncoder +from transformers.data.data_collator import DataCollatorForLanguageModeling +from carp.pytorch.data.utils.data_util import read_dataset_component, read_paraphrase_component + +from dataclasses import dataclass +from torchtyping import TensorType +from typing import List +import torch +import os + +#TODO: +'''Custom chunk_batch_element +''' + + +@dataclass +class DistillBatchElement(BatchElement): + #Reducing over critiques for same stories + #reduction_matrix : TensorType["pass_N", -1] + reviews_per_passage: int + + +@register_datapipeline +class DistillDataPipeline(BaseDataPipeline): + + """Dataset wrapper class to ease working with the CARP dataset and Pytorch data utilities.""" + def __init__( + self, + #Prevents duplicates of multiple stories + dupe_protection: bool = True, + path: str = "dataset", + ): + crit_data = [] + crit_datapath = path+'/paraphrase_train_crits' + files = os.listdir(crit_datapath) + files.sort() + for file in files: + print(file) + datapath = os.path.join(crit_datapath,file) + crit_data_chunk = read_paraphrase_component(datapath) + crit_data+=crit_data_chunk + + orig_crit_file = 'train_crits.csv' + orig_crit_datapath = os.path.join(path, orig_crit_file) + orig_crit_data = read_dataset_component(orig_crit_datapath) + for orig_crit, crits in zip(orig_crit_data, crit_data): + crits.append(orig_crit) + self.reviews_list = crit_data + + story_file = 'train_stories.csv' + story_datapath = os.path.join(path, story_file) + story_data = read_dataset_component(story_datapath) + self.passages = story_data + + + # prune to the last 3 + self.reviews_list = list(map(lambda x: [x[-1]], self.reviews_list)) + + print("NUM STORIES: ", len(self.passages)) + print("NUM CRITIQUE LISTS: ", len(self.reviews_list)) + print("NUM CRITIQUES PER: ", len(self.reviews_list[0])) + + + #Overload for data format (passage, [crit_1,...,crit_n]) + def __getitem__(self, index: int) -> Tuple[str, List[str]]: + return self.passages[index], self.reviews_list[index] + + @staticmethod + def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable: + """Function factory that creates a collate function for use with a torch.util.data.Dataloader + + Args: + tokenizer (PreTrainedTokenizer): A Huggingface model tokenizer, taking strings to torch Tensors + context_len (int): Max length of the passages passed to the tokenizer + + Returns: + Callable: A function that will take a batch of string tuples and tokenize them properly. + """ + + @typechecked + def collate( + data: Iterable[Tuple[str, List[str]]] + ) -> Tuple[BatchElement, DistillBatchElement]: + #Expects us to double reviews beforehand: passing in list of critiques for each story + passages, review_lists = zip(*data) + reviews_per_passage = len(review_lists[0]) + reviews = [review for review_list in review_lists for review in review_list] + pass_tokens, rev_tokens = _tok(list(passages)), _tok(list(reviews)) + pass_masks = pass_tokens["attention_mask"] + rev_masks = rev_tokens["attention_mask"] + pass_tokens = pass_tokens["input_ids"] + rev_tokens = rev_tokens["input_ids"] + + #eduction_matrix = torch.arange(0, rev_tokens.size()[0], step=1).reshape(-1, reviews_per_passage) + + return ( + BatchElement(pass_tokens, pass_masks), + DistillBatchElement(rev_tokens, rev_masks, reviews_per_passage), + ) + + return collate diff --git a/carp/pytorch/data/utils/data_util.py b/carp/pytorch/data/utils/data_util.py index afe8ebc..1e7e372 100644 --- a/carp/pytorch/data/utils/data_util.py +++ b/carp/pytorch/data/utils/data_util.py @@ -5,6 +5,23 @@ from dataclasses import dataclass import torch import math +import csv + +def read_dataset_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row[1]) + return data + +def read_paraphrase_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row) + return data def check_char(char): """Check if char can be encoded""" @@ -85,8 +102,8 @@ def _tok(string_batch: Iterable[str]) -> BatchEncoding: @dataclass class BatchElement: - input_ids : TensorType[-1, "pass_N"] - mask : TensorType[-1, "pass_N"] + input_ids : TensorType[-1, "pass_N"] + mask : TensorType[-1, "pass_N"] # Assumes first axis of all tensor attributes in data are the same # If no tensor attributes, returns original data object @@ -105,7 +122,7 @@ def chunkBatchElement(data : BatchElement, chunk_size : int) -> List[BatchElemen is_tensor.append(True) else: is_tensor.append(False) - + # If no tensor type just return has_tensor = False for t in is_tensor: @@ -133,7 +150,7 @@ def chunkBatchElement(data : BatchElement, chunk_size : int) -> List[BatchElemen data_args.append(vars(data)[key][inds]) else: data_args.append(vars(data)[key]) - + new_datas.append(data_class(*data_args)) return new_datas \ No newline at end of file diff --git a/carp/pytorch/data/utils/examine_train_stories.py b/carp/pytorch/data/utils/examine_train_stories.py new file mode 100644 index 0000000..db04182 --- /dev/null +++ b/carp/pytorch/data/utils/examine_train_stories.py @@ -0,0 +1,22 @@ +import pandas as pd +import csv + +with open('train_stories.csv', 'r') as f: + data = f.readlines() + +print(len(data)) +print(data[:5]) +data = pd.read_csv('train_stories.csv') +print(data.shape[0]) + + +def read_dataset_component(filepath): + data = list() + with open(filepath, newline='') as csvfile: + reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL) + for row in reader: + data.append(row[1]) + return data + +data = read_dataset_component('train_stories.csv') +print(len(data)) \ No newline at end of file diff --git a/carp/pytorch/data/utils/to_csv.py b/carp/pytorch/data/utils/to_csv.py index ab1422e..6ed0e92 100644 --- a/carp/pytorch/data/utils/to_csv.py +++ b/carp/pytorch/data/utils/to_csv.py @@ -58,8 +58,19 @@ def write_dataset_csv(data, filepath): if __name__ == "__main__": train_set, val_set = get_dataset(100,dupe_protection=False) train_set = list(map(lambda x: list(x), train_set)) - val_set = list(map(lambda x: ["", list(x)[1]], val_set)) - write_dataset_csv(train_set, 'train.csv') - write_dataset_csv(val_set, 'val_crits.csv') + train_stories = list(map(lambda x: ["", list(x)[0]], train_set)) + train_crits = list(map(lambda x: ["", list(x)[1]], train_set)) + + val_stories = list(map(lambda x: ["", list(x)[0]], val_set)) + val_crits = list(map(lambda x: ["", list(x)[1]], val_set)) + + print(len(train_stories)) + print(len(train_crits)) + + write_dataset_csv(train_stories, 'train_stories.csv') + write_dataset_csv(train_crits, 'train_crits.csv') + + write_dataset_csv(val_stories, 'val_stories.csv') + write_dataset_csv(val_crits, 'val_crits.csv') diff --git a/carp/pytorch/model/architectures/__init__.py b/carp/pytorch/model/architectures/__init__.py index f35b3bd..67ddbfb 100644 --- a/carp/pytorch/model/architectures/__init__.py +++ b/carp/pytorch/model/architectures/__init__.py @@ -18,7 +18,7 @@ _ARCHITECTURES: Dict[str, any] = {} # registry def register_architecture(name): - """Decorator used register a CARP architecture + """Decorator used register a CARP architecture Args: name: Name of the architecture @@ -32,7 +32,7 @@ def register_class(cls, name): if isinstance(name, str): name = name.lower() return lambda c: register_class(c, name) - + cls = name name = cls.__name__ register_class(cls, name.lower()) @@ -65,7 +65,7 @@ def __init__(self, config, skip_init=False): ) self.clamp_min = torch.log(torch.tensor([1 / 100], device=self.config.device)) self.clamp_max = torch.log(torch.tensor([100], device=self.config.device)) - # used to count the number of steps until the next accumulation + # used to count the number of steps until the next accumulation self.accum_step = 0 self.config = config @abstractmethod @@ -87,12 +87,12 @@ def attempt_load(cls, path : str, component_name : str): """ Attempts to load a component of the model. Throws an exception and continues if the component cannot be loaded Args: - path : directory to load from + path : directory to load from component_name : name of component to append onto path Returns: component : nn.module """ - try: + try: return torch.load(path + component_name) except: print("Unable to load " + component_name + ". Continuing.") @@ -110,7 +110,7 @@ def save(self, path : str): pass - # must be run after initialize + # must be run after initialize def load(self, path : str): self.passage_encoder.model = self.attempt_load(path, "passage_encoder.pt") self.review_encoder.model = self.attempt_load(path, "review_encoder.pt") @@ -131,7 +131,7 @@ def compute_accuracy(self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "l acc_t = (torch.argmax(logits, dim=0) == labels).sum() return (acc_i + acc_t) / n / 2 def cosine_sim(self,\ - x: TensorType[-1, "latent_dim"], + x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"]): """ Computes the cosine similarity between two sets of vectors x,y @@ -145,7 +145,7 @@ def cosine_sim(self,\ x = F.normalize(x) y = F.normalize(y) # small term added to avoid nans in low precision softmax - return (x @ y.T + 1e-6) + return (x @ y.T + 1e-6) def contrastive_loss( self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"] @@ -158,7 +158,7 @@ def contrastive_loss( loss_i = F.cross_entropy(logits, labels) loss_t = F.cross_entropy(logits.T, labels) return (loss_i + loss_t) / 2 - + def clamp(self): with torch.no_grad(): self.logit_scale.clamp(self.clamp_min, self.clamp_max) @@ -183,10 +183,10 @@ def _make_projection_layers(self, config): self.review_encoder.d_model, self.latent_dim, config.proj_dropout ) return proj_pass, proj_rev - + def _embed_data( self, - x: BatchElement, + x: BatchElement, encoder, projector, ): @@ -218,7 +218,7 @@ def calculate_embeddings( with torch.no_grad(), torch.cuda.amp.autocast(): pass_encs = [self.encode_passages(p) for p in passages] rev_encs = [self.encode_reviews(r) for r in reviews] - + # if we only need the embeddings, fetch them if return_only_embeddings: pass_encs = list(map(lambda x: x.hidden, pass_encs)) @@ -240,7 +240,7 @@ def zero_grad(self, opt : torch.optim.Optimizer): if self.accum_step % self.config.grad_accum == 0: opt.zero_grad() - + def step(self, scaler : torch.cuda.amp.GradScaler, opt: torch.optim.Optimizer): @@ -258,7 +258,7 @@ def eval_step(self, dataset): for p, r in dataset: passages.append(p) reviews.append(r) - + # TODO: Ideally should get microbatch size from trainconfig for the second argument passages = chunkBatchElement(passages[0], 8) reviews = chunkBatchElement(reviews[0], 8) @@ -297,6 +297,7 @@ def forward( from carp.pytorch.model.architectures.carp_mlm import CARPMLM from carp.pytorch.model.architectures.carp_coop import CARPCoOp from carp.pytorch.model.architectures.carp_shared_encoder import CARPSharedEncoder +from carp.pytorch.model.architectures.distill_carp import DistillCARP def get_architecture(name): return _ARCHITECTURES[name.lower()] diff --git a/carp/pytorch/model/architectures/distill_carp.py b/carp/pytorch/model/architectures/distill_carp.py new file mode 100644 index 0000000..8ae434c --- /dev/null +++ b/carp/pytorch/model/architectures/distill_carp.py @@ -0,0 +1,126 @@ +import torch +import torch.nn.functional as F +from torchtyping import TensorType, patch_typeguard +from typeguard import typechecked +from carp.configs import CARPConfig, ModelConfig, TrainConfig +from carp.pytorch.model.architectures import * +from carp.pytorch.model.encoders import get_encoder +from carp.util import mbTokens, generate_indices +from typing import List + +from carp.pytorch.data.utils.data_util import BatchElement +from carp.pytorch.data.distill_pipeline import DistillBatchElement + +patch_typeguard() + +@typechecked +@register_architecture +class DistillCARP(BaseModel): + def __init__(self, config: ModelConfig): + super().__init__(config) + + def _reduction(self, logits: TensorType["pass_N", "rev_N"]) -> TensorType["pass_N", "reduced_rev_N"]: + n = logits.shape[0] + logits = torch.sum(logits.reshape((n,-1,self.reviews_per_passage)), dim=-1) / float(self.reviews_per_passage) + return logits + + def compute_accuracy(self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"]): + with torch.no_grad(): + n = x.shape[0] + x = F.normalize(x) + y = F.normalize(y) + logits = x @ y.T * self.logit_scale.exp() + + logits = self._reduction(logits) + + labels = torch.arange(n, device=self.config.device) + acc_i = (torch.argmax(logits, dim=1) == labels).sum() + acc_t = (torch.argmax(logits, dim=0) == labels).sum() + return (acc_i + acc_t) / n / 2 + + def contrastive_loss( + self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"] + ) -> TensorType[(), float]: + + n = x.shape[0] + # small term added to avoid nans in low precision softmax + #(num_passages, num_reviews) + logits = self.cosine_sim(x,y) * self.logit_scale.exp() + logits = self._reduction(logits) + + logits_i = F.softmax(logits, dim=-1) + logits_t = F.softmax(logits, dim=0) + + #Reduce logits into diagonal + labels = torch.arange(n, device=self.config.device) + loss_i = F.nll_loss(torch.log(logits_i), labels) + loss_t = F.nll_loss(torch.log(logits_t.T), labels) + return (loss_i + loss_t) / 2 + + + def train_step( + self, + passages: BatchElement, + reviews: DistillBatchElement, + config: TrainConfig, + opt: torch.optim.Optimizer, + scaler: torch.cuda.amp.GradScaler, + ) -> Dict[str, TensorType[()]]: + + self.reviews_per_passage = reviews.input_ids.size()[0] // passages.input_ids.size()[0] + + microbatch_inds_passages = generate_indices( + passages.input_ids.shape[0], config.microbatch_size, shuffle=False + ) + microbatch_inds_reviews = generate_indices( + reviews.input_ids.shape[0], config.microbatch_size, shuffle=False + ) + # Split tokens and masks into these microbatches + pass_mbs: List[BatchElement] = [ + BatchElement(passages.input_ids[i], passages.mask[i]) for i in microbatch_inds_passages + ] + rev_mbs: List[DistillBatchElement] = [ + DistillBatchElement(reviews.input_ids[i], reviews.mask[i], reviews.reviews_per_passage) for i in microbatch_inds_reviews + ] + + reviews_per_passage = reviews.reviews_per_passage + + # Initially get all encodings without grad + pass_encs, rev_encs = self.calculate_embeddings(pass_mbs, rev_mbs) + + + #compute accuracy + forward_acc = self.compute_accuracy(torch.cat(pass_encs), torch.cat(rev_encs)) + + # does gradient accumulation + self.zero_grad(opt) + + # Encode passages in microbatches (with grad) + for index, passage in enumerate(pass_mbs): + pass_tmp = pass_encs.copy() + with torch.cuda.amp.autocast(): + pass_tmp[index] = self.encode_passages(passage).hidden + loss = self.contrastive_loss( + torch.cat(pass_tmp), torch.cat(rev_encs) + ) + scaler.scale(loss).backward() + # Encode reviews in microbatches (with grad) + for index, review in enumerate(rev_mbs): + rev_tmp = rev_encs.copy() # no_grad + with torch.cuda.amp.autocast(): + rev_tmp[index] = self.encode_reviews(review).hidden + # grad _just_ at positions in `index` + loss = self.contrastive_loss( + torch.cat(pass_encs), torch.cat(rev_tmp) + ) + scaler.scale(loss).backward() + # Clipping + if self.config.grad_clip != -1: + scaler.unscale_(opt) + torch.nn.utils.clip_grad_norm_(self.parameters(), self.config.grad_clip) + + self.step(scaler, opt) + return { + "Loss/Train": loss, + "Acc/Forward": forward_acc, + } diff --git a/carp/pytorch/model/encoders/__init__.py b/carp/pytorch/model/encoders/__init__.py index 8672348..b662f07 100644 --- a/carp/pytorch/model/encoders/__init__.py +++ b/carp/pytorch/model/encoders/__init__.py @@ -104,7 +104,7 @@ def forward(self, inputs_embeds=False, **kwargs): def last_ones(self, t): # Multipliying arange by max # makes last non zero column have largest number in arange - t = t * torch.arange(t.shape[1]) + t = t * torch.arange(t.shape[1]).to(t) # Then argmax gives index of last non zero column t = t.argmax(1) return t diff --git a/configs/carp_shared_deberta.yml b/configs/carp_shared_deberta.yml new file mode 100644 index 0000000..132a2f2 --- /dev/null +++ b/configs/carp_shared_deberta.yml @@ -0,0 +1,34 @@ +model: + latent_dim: 2048 + proj_dropout: 0.1 + linear_projection: true + model_path: "microsoft/deberta-v2-xlarge" + model_arch: "roberta" + encoder_type: "SharedSumTextEncoder" + grad_clip: -1.0 + grad_accum: 1 + momentum: 0.0 + device: "cuda" + +train_job: + n_ctx: 512 + epochs: 10 + batch_size: 2048 + microbatch_size: 8 + lr_ramp_steps: 400 + lr_decay_steps: 3366 + learning_rate_init: 1.0e-4 + learning_rate_target: 0.000006 + log_interval: 2 + checkpoint_interval: 500 + validate_interval: 50 + use_half: false + do_log: true + validation_size: 1000 + eval_selection: "final_n" + use_bucket: false + dupe_protection: true + hard_dupe_protection: false + data_pipeline: "BaseDataPipeline" + orchestrator: "BaseOrchestrator" + \ No newline at end of file diff --git a/configs/distill_carp.yml b/configs/distill_carp.yml new file mode 100644 index 0000000..927fa25 --- /dev/null +++ b/configs/distill_carp.yml @@ -0,0 +1,33 @@ +model: + latent_dim: 2048 + proj_dropout: 0.1 + linear_projection: true + model_path: "johngiorgi/declutr-base" + model_arch: "roberta" + encoder_type: "SumTextEncoder" + grad_clip: -1.0 + grad_accum: 1 + momentum: 0.0 + device: "cuda" + +train_job: + n_ctx: 512 + epochs: 10 + batch_size: 2048 + microbatch_size: 48 + lr_ramp_steps: 400 + lr_decay_steps: 3366 + learning_rate_init: 1.0e-4 + learning_rate_target: 0.000006 + log_interval: 2 + checkpoint_interval: 500 + validate_interval: 50 + use_half: false + do_log: true + validation_size: 1000 + eval_selection: "final_n" + use_bucket: false + dupe_protection: true + hard_dupe_protection: false + data_pipeline: "DistillDataPipeline" + orchestrator: "BaseOrchestrator" \ No newline at end of file