diff --git a/.gitignore b/.gitignore
index a5c2282..6d4b97c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -14,3 +14,5 @@ carp/dataset/dataset_dict.json
carp/dataset/train/cache-a5b0849dd9416bdb.arrow
carp/dataset/train/dataset_info.json
carp/dataset/train/state.json
+*.csv
+/carp/experiments/distil_carp/20B_tokenizer.json
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/magiCARP.iml b/.idea/magiCARP.iml
new file mode 100644
index 0000000..5fdd65b
--- /dev/null
+++ b/.idea/magiCARP.iml
@@ -0,0 +1,15 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..25253bd
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,17 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..8c7d79c
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/carp/pytorch/finetune/__init__.py b/carp/experiments/__init__.py
similarity index 100%
rename from carp/pytorch/finetune/__init__.py
rename to carp/experiments/__init__.py
diff --git a/checkpoints/.gitkeep b/carp/experiments/distil_carp/__init__.py
similarity index 100%
rename from checkpoints/.gitkeep
rename to carp/experiments/distil_carp/__init__.py
diff --git a/carp/experiments/distil_carp/examine_data.py b/carp/experiments/distil_carp/examine_data.py
new file mode 100644
index 0000000..d69e6fc
--- /dev/null
+++ b/carp/experiments/distil_carp/examine_data.py
@@ -0,0 +1,46 @@
+import pandas as pd
+import os
+import csv
+
+def read_dataset_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row[1])
+ return data
+
+def read_paraphrase_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row)
+ return data
+
+path = 'distil_data'
+
+crit_data = []
+crit_datapath = path+'/paraphrase_train_crits'
+files = os.listdir(crit_datapath)
+files.sort()
+for file in files:
+ print(file)
+ filepath = os.path.join(crit_datapath, file)
+ crit_data_chunk = read_paraphrase_component(filepath)
+ crit_data += crit_data_chunk
+
+story_file = 'train_stories.csv'
+story_datapath = os.path.join(path, story_file)
+story_data = read_dataset_component(story_datapath)
+
+orig_crit_file = 'train_crits.csv'
+orig_crit_datapath = os.path.join(path, orig_crit_file)
+orig_crit_data = read_dataset_component(orig_crit_datapath)
+
+print("NUM STORIES: ", len(story_data))
+print("NUM CRITIQUE LISTS: ", len(crit_data))
+print("NUM ORIG CRITS: ", len(orig_crit_data))
+print("NUM CRITIQUES PER: ", len(crit_data[1]))
+print(story_data[1])
+print(crit_data[1])
\ No newline at end of file
diff --git a/carp/experiments/distil_carp/generate_fuzzing_data.py b/carp/experiments/distil_carp/generate_fuzzing_data.py
new file mode 100644
index 0000000..1905c7b
--- /dev/null
+++ b/carp/experiments/distil_carp/generate_fuzzing_data.py
@@ -0,0 +1,75 @@
+import csv
+import numpy as np
+import openai
+from neox_tokenizer import *
+
+bad_word_dict = \
+ {"0": -1000, "50256": -1000, "3353": -1000}
+# add all line breaks to the bad word dict
+for idx in line_break_token_ids:
+ bad_word_dict[str(idx)] = -1000
+
+key = ""
+with open("../../pytorch/data/utils/api_key.txt") as f:
+ key = f.read()
+
+# nice try ;)
+openai.api_key = key
+openai.api_base = "https://api.goose.ai/v1"
+
+prompt = "You are an editor of stories. Below is a set of stories and the the criticisms you have written for each " \
+ "manuscript.\n\n "
+
+
+def read_dataset_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row[1])
+ return data
+
+
+val_stories = read_dataset_component("../../pytorch/data/utils/val_stories.csv")
+val_crits = read_dataset_component("../../pytorch/data/utils/val_crits.csv")
+
+# Filter out stories that are below seven words long.
+stories_crits = list(list(filter(lambda x: (len(x[0].split()) > 7), zip(val_stories, val_crits))))
+stories_crits = [[story for story, _ in stories_crits],
+ [crit for _, crit in stories_crits]]
+val_stories = stories_crits[0]
+val_crits = stories_crits[1]
+
+train_stories = read_dataset_component("../../pytorch/data/utils/train_stories.csv")[0:10]
+examples_n = 5
+indices = np.random.choice(len(val_stories), examples_n, replace=False)
+print("[" + ", ".join(list(map(str, indices))) + "]")
+indices = [56, 23, 14, 65, 60]
+
+for i in range(1, examples_n + 1, 1):
+ idx = indices[i - 1]
+ story_example = str(i) + ". Story: " + val_stories[idx] + "\nCriticism: " + val_crits[idx] + "\n\n"
+ prompt += story_example
+
+for input_story in train_stories:
+ story_prompt = prompt + str(examples_n + 1) + ". Story: " + input_story + "\nCriticism:"
+ # Biasing against EOT (0, 50256), and \n
+ print(story_prompt)
+ import sys
+ sys.exit()
+ completion = openai.Completion.create(
+ engine="gpt-neo-20b",
+ prompt=story_prompt,
+ max_tokens=40,
+ typical_p=0.5,
+ logit_bias=bad_word_dict,
+ logprobs=30,
+ stream=True)
+ print(input_story)
+ print("\n")
+
+ # Print each token as it is returned
+ for c in completion:
+ print(c.choices[0].text, end='')
+
+ print("\n=============\n")
diff --git a/carp/experiments/distil_carp/neox_tokenizer.py b/carp/experiments/distil_carp/neox_tokenizer.py
new file mode 100644
index 0000000..c75d0bd
--- /dev/null
+++ b/carp/experiments/distil_carp/neox_tokenizer.py
@@ -0,0 +1,8 @@
+import json
+tokenizer_data = None
+with open("20B_tokenizer.json") as f:
+ tokenizer_data = json.load(f)
+line_break_token_ids = list()
+for idx, k in enumerate(tokenizer_data['model']['vocab'].keys()):
+ if u"\u010a" in k:
+ line_break_token_ids.append(idx)
diff --git a/carp/experiments/distil_carp/paraphrase_critiques.py b/carp/experiments/distil_carp/paraphrase_critiques.py
new file mode 100644
index 0000000..2c187fe
--- /dev/null
+++ b/carp/experiments/distil_carp/paraphrase_critiques.py
@@ -0,0 +1,55 @@
+from transformers import PegasusForConditionalGeneration, PegasusTokenizer
+import csv
+import sys
+from tqdm import tqdm
+
+tokenizer_pegasus = PegasusTokenizer.from_pretrained('tuner007/pegasus_paraphrase')
+model_pegasus = PegasusForConditionalGeneration.from_pretrained('tuner007/pegasus_paraphrase').half().to("cuda")
+
+def read_dataset_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row[1])
+ return data
+
+num_beams = 5
+def get_review_ensemble(input_text):
+ batch = tokenizer_pegasus(input_text,truncation=True,padding='longest', max_length=60, return_tensors="pt").to("cuda")
+ translated = model_pegasus.generate(**batch,max_length=60, num_beams=num_beams, num_return_sequences=num_beams, temperature=1.5)
+ return tokenizer_pegasus.batch_decode(translated, skip_special_tokens=True)
+
+def write_dataset_csv(data, filepath):
+ with open(filepath, mode='w') as csvfile:
+ writer = csv.writer(csvfile, quoting=csv.QUOTE_MINIMAL)
+ writer.writerows(data)
+
+filepath = '/mnt/raid/users/AlexH/magiCARP/carp/pytorch/data/utils/train_crits.csv'
+data = read_dataset_component(filepath)
+batch_size = 100
+num_batches = (len(data) + batch_size - 1) // batch_size
+output_file = 'paraphrase_train_crits.csv'
+write_thresh = 1000
+temp_csv = []
+print(len(data))
+for i in tqdm(range(num_batches)):
+ cur_batch_size = min(batch_size, len(data)-batch_size*i)
+ batch = data[i*batch_size:i*batch_size+cur_batch_size]
+ #print(batch)
+ num_paraphrases = 5
+ paraphrases = get_review_ensemble(batch)
+ reshaped_paraphrases = []
+ for j in range(len(paraphrases)//num_paraphrases):
+ reshaped_paraphrases.append([])
+ for k in range(num_paraphrases):
+ reshaped_paraphrases[-1].append(paraphrases[j*num_paraphrases+k])
+ #print(reshaped_paraphrases)
+ temp_csv += reshaped_paraphrases
+ if (i+1) % write_thresh == 0:
+ print("WRITING TO CSV")
+ cur_output_file = output_file+f"_{i}"
+ write_dataset_csv(temp_csv ,cur_output_file)
+ temp_csv = []
+write_dataset_csv(temp_csv, output_file)
+
diff --git a/carp/pytorch/data/__init__.py b/carp/pytorch/data/__init__.py
index d7bd3ae..216cab5 100644
--- a/carp/pytorch/data/__init__.py
+++ b/carp/pytorch/data/__init__.py
@@ -14,8 +14,9 @@
# specifies a dictionary of architectures
_DATAPIPELINE: Dict[str, any] = {} # registry
+
def register_datapipeline(name):
- """Decorator used register a CARP architecture
+ """Decorator used register a CARP architecture
Args:
name: Name of the architecture
@@ -29,7 +30,7 @@ def register_class(cls, name):
if isinstance(name, str):
name = name.lower()
return lambda c: register_class(c, name)
-
+
cls = name
name = cls.__name__
register_class(cls, name.lower())
@@ -42,9 +43,9 @@ class BaseDataPipeline(Dataset):
"""Dataset wrapper class to ease working with the CARP dataset and Pytorch data utilities."""
def __init__(
- self,
- dupe_protection: bool = True,
- path: str = "dataset",
+ self,
+ dupe_protection: bool = True,
+ path: str = "dataset",
):
dataset = load_from_disk(path)
train = dataset["train"]
@@ -70,8 +71,7 @@ def __len__(self) -> int:
return len(self.passages)
@staticmethod
- def create_tokenizer_factory(call_tokenizer : Callable, tokenizer_factory : Callable, context_len : int) -> Callable:
-
+ def create_tokenizer_factory(call_tokenizer: Callable, tokenizer_factory: Callable, context_len: int) -> Callable:
"""Function creates a callable tokenizer subroutine and uses it to curry the tokenizer factory
Args:
@@ -85,7 +85,7 @@ def create_tokenizer_factory(call_tokenizer : Callable, tokenizer_factory : Call
return partial(tokenizer_factory, tok_func)
@staticmethod
- def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable:
+ def tokenizer_factory(_tok: Callable, encoder: BaseEncoder) -> Callable:
"""Function factory that creates a collate function for use with a torch.util.data.Dataloader
@@ -95,9 +95,10 @@ def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable:
Returns:
Callable: A function that will take a batch of string tuples and tokenize them properly.
"""
+
@typechecked
def collate(
- data: Iterable[Tuple[str, str]]
+ data: Iterable[Tuple[str, str]]
) -> Tuple[BatchElement, BatchElement]:
passages, reviews = zip(*data)
pass_tokens, rev_tokens = _tok(list(passages)), _tok(list(reviews))
@@ -111,12 +112,15 @@ def collate(
)
return collate
-
+
from carp.pytorch.data.mlm_pipeline import MLMDataPipeline
from carp.pytorch.data.scarecrow_pipeline import ScarecrowDataPipeline
+from carp.pytorch.data.distill_pipeline import DistillDataPipeline
+
def get_datapipeline(name):
return _DATAPIPELINE[name.lower()]
+
def get_datapipeline_names():
return _DATAPIPELINE.keys()
diff --git a/carp/pytorch/data/distill_pipeline.py b/carp/pytorch/data/distill_pipeline.py
new file mode 100644
index 0000000..95804dd
--- /dev/null
+++ b/carp/pytorch/data/distill_pipeline.py
@@ -0,0 +1,103 @@
+from torch.functional import Tensor
+from carp.pytorch.data import *
+from carp.pytorch.model.encoders import BaseEncoder
+from transformers.data.data_collator import DataCollatorForLanguageModeling
+from carp.pytorch.data.utils.data_util import read_dataset_component, read_paraphrase_component
+
+from dataclasses import dataclass
+from torchtyping import TensorType
+from typing import List
+import torch
+import os
+
+#TODO:
+'''Custom chunk_batch_element
+'''
+
+
+@dataclass
+class DistillBatchElement(BatchElement):
+ #Reducing over critiques for same stories
+ #reduction_matrix : TensorType["pass_N", -1]
+ reviews_per_passage: int
+
+
+@register_datapipeline
+class DistillDataPipeline(BaseDataPipeline):
+
+ """Dataset wrapper class to ease working with the CARP dataset and Pytorch data utilities."""
+ def __init__(
+ self,
+ #Prevents duplicates of multiple stories
+ dupe_protection: bool = True,
+ path: str = "dataset",
+ ):
+ crit_data = []
+ crit_datapath = path+'/paraphrase_train_crits'
+ files = os.listdir(crit_datapath)
+ files.sort()
+ for file in files:
+ print(file)
+ datapath = os.path.join(crit_datapath,file)
+ crit_data_chunk = read_paraphrase_component(datapath)
+ crit_data+=crit_data_chunk
+
+ orig_crit_file = 'train_crits.csv'
+ orig_crit_datapath = os.path.join(path, orig_crit_file)
+ orig_crit_data = read_dataset_component(orig_crit_datapath)
+ for orig_crit, crits in zip(orig_crit_data, crit_data):
+ crits.append(orig_crit)
+ self.reviews_list = crit_data
+
+ story_file = 'train_stories.csv'
+ story_datapath = os.path.join(path, story_file)
+ story_data = read_dataset_component(story_datapath)
+ self.passages = story_data
+
+
+ # prune to the last 3
+ self.reviews_list = list(map(lambda x: [x[-1]], self.reviews_list))
+
+ print("NUM STORIES: ", len(self.passages))
+ print("NUM CRITIQUE LISTS: ", len(self.reviews_list))
+ print("NUM CRITIQUES PER: ", len(self.reviews_list[0]))
+
+
+ #Overload for data format (passage, [crit_1,...,crit_n])
+ def __getitem__(self, index: int) -> Tuple[str, List[str]]:
+ return self.passages[index], self.reviews_list[index]
+
+ @staticmethod
+ def tokenizer_factory(_tok : Callable, encoder: BaseEncoder) -> Callable:
+ """Function factory that creates a collate function for use with a torch.util.data.Dataloader
+
+ Args:
+ tokenizer (PreTrainedTokenizer): A Huggingface model tokenizer, taking strings to torch Tensors
+ context_len (int): Max length of the passages passed to the tokenizer
+
+ Returns:
+ Callable: A function that will take a batch of string tuples and tokenize them properly.
+ """
+
+ @typechecked
+ def collate(
+ data: Iterable[Tuple[str, List[str]]]
+ ) -> Tuple[BatchElement, DistillBatchElement]:
+ #Expects us to double reviews beforehand: passing in list of critiques for each story
+ passages, review_lists = zip(*data)
+ reviews_per_passage = len(review_lists[0])
+ reviews = [review for review_list in review_lists for review in review_list]
+ pass_tokens, rev_tokens = _tok(list(passages)), _tok(list(reviews))
+ pass_masks = pass_tokens["attention_mask"]
+ rev_masks = rev_tokens["attention_mask"]
+ pass_tokens = pass_tokens["input_ids"]
+ rev_tokens = rev_tokens["input_ids"]
+
+ #eduction_matrix = torch.arange(0, rev_tokens.size()[0], step=1).reshape(-1, reviews_per_passage)
+
+ return (
+ BatchElement(pass_tokens, pass_masks),
+ DistillBatchElement(rev_tokens, rev_masks, reviews_per_passage),
+ )
+
+ return collate
diff --git a/carp/pytorch/data/utils/data_util.py b/carp/pytorch/data/utils/data_util.py
index afe8ebc..1e7e372 100644
--- a/carp/pytorch/data/utils/data_util.py
+++ b/carp/pytorch/data/utils/data_util.py
@@ -5,6 +5,23 @@
from dataclasses import dataclass
import torch
import math
+import csv
+
+def read_dataset_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row[1])
+ return data
+
+def read_paraphrase_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row)
+ return data
def check_char(char):
"""Check if char can be encoded"""
@@ -85,8 +102,8 @@ def _tok(string_batch: Iterable[str]) -> BatchEncoding:
@dataclass
class BatchElement:
- input_ids : TensorType[-1, "pass_N"]
- mask : TensorType[-1, "pass_N"]
+ input_ids : TensorType[-1, "pass_N"]
+ mask : TensorType[-1, "pass_N"]
# Assumes first axis of all tensor attributes in data are the same
# If no tensor attributes, returns original data object
@@ -105,7 +122,7 @@ def chunkBatchElement(data : BatchElement, chunk_size : int) -> List[BatchElemen
is_tensor.append(True)
else:
is_tensor.append(False)
-
+
# If no tensor type just return
has_tensor = False
for t in is_tensor:
@@ -133,7 +150,7 @@ def chunkBatchElement(data : BatchElement, chunk_size : int) -> List[BatchElemen
data_args.append(vars(data)[key][inds])
else:
data_args.append(vars(data)[key])
-
+
new_datas.append(data_class(*data_args))
return new_datas
\ No newline at end of file
diff --git a/carp/pytorch/data/utils/examine_train_stories.py b/carp/pytorch/data/utils/examine_train_stories.py
new file mode 100644
index 0000000..db04182
--- /dev/null
+++ b/carp/pytorch/data/utils/examine_train_stories.py
@@ -0,0 +1,22 @@
+import pandas as pd
+import csv
+
+with open('train_stories.csv', 'r') as f:
+ data = f.readlines()
+
+print(len(data))
+print(data[:5])
+data = pd.read_csv('train_stories.csv')
+print(data.shape[0])
+
+
+def read_dataset_component(filepath):
+ data = list()
+ with open(filepath, newline='') as csvfile:
+ reader = csv.reader(csvfile, delimiter=",", quoting=csv.QUOTE_MINIMAL)
+ for row in reader:
+ data.append(row[1])
+ return data
+
+data = read_dataset_component('train_stories.csv')
+print(len(data))
\ No newline at end of file
diff --git a/carp/pytorch/data/utils/to_csv.py b/carp/pytorch/data/utils/to_csv.py
index ab1422e..6ed0e92 100644
--- a/carp/pytorch/data/utils/to_csv.py
+++ b/carp/pytorch/data/utils/to_csv.py
@@ -58,8 +58,19 @@ def write_dataset_csv(data, filepath):
if __name__ == "__main__":
train_set, val_set = get_dataset(100,dupe_protection=False)
train_set = list(map(lambda x: list(x), train_set))
- val_set = list(map(lambda x: ["", list(x)[1]], val_set))
- write_dataset_csv(train_set, 'train.csv')
- write_dataset_csv(val_set, 'val_crits.csv')
+ train_stories = list(map(lambda x: ["", list(x)[0]], train_set))
+ train_crits = list(map(lambda x: ["", list(x)[1]], train_set))
+
+ val_stories = list(map(lambda x: ["", list(x)[0]], val_set))
+ val_crits = list(map(lambda x: ["", list(x)[1]], val_set))
+
+ print(len(train_stories))
+ print(len(train_crits))
+
+ write_dataset_csv(train_stories, 'train_stories.csv')
+ write_dataset_csv(train_crits, 'train_crits.csv')
+
+ write_dataset_csv(val_stories, 'val_stories.csv')
+ write_dataset_csv(val_crits, 'val_crits.csv')
diff --git a/carp/pytorch/model/architectures/__init__.py b/carp/pytorch/model/architectures/__init__.py
index f35b3bd..67ddbfb 100644
--- a/carp/pytorch/model/architectures/__init__.py
+++ b/carp/pytorch/model/architectures/__init__.py
@@ -18,7 +18,7 @@
_ARCHITECTURES: Dict[str, any] = {} # registry
def register_architecture(name):
- """Decorator used register a CARP architecture
+ """Decorator used register a CARP architecture
Args:
name: Name of the architecture
@@ -32,7 +32,7 @@ def register_class(cls, name):
if isinstance(name, str):
name = name.lower()
return lambda c: register_class(c, name)
-
+
cls = name
name = cls.__name__
register_class(cls, name.lower())
@@ -65,7 +65,7 @@ def __init__(self, config, skip_init=False):
)
self.clamp_min = torch.log(torch.tensor([1 / 100], device=self.config.device))
self.clamp_max = torch.log(torch.tensor([100], device=self.config.device))
- # used to count the number of steps until the next accumulation
+ # used to count the number of steps until the next accumulation
self.accum_step = 0
self.config = config
@abstractmethod
@@ -87,12 +87,12 @@ def attempt_load(cls, path : str, component_name : str):
"""
Attempts to load a component of the model. Throws an exception and continues if the component cannot be loaded
Args:
- path : directory to load from
+ path : directory to load from
component_name : name of component to append onto path
Returns:
component : nn.module
"""
- try:
+ try:
return torch.load(path + component_name)
except:
print("Unable to load " + component_name + ". Continuing.")
@@ -110,7 +110,7 @@ def save(self, path : str):
pass
- # must be run after initialize
+ # must be run after initialize
def load(self, path : str):
self.passage_encoder.model = self.attempt_load(path, "passage_encoder.pt")
self.review_encoder.model = self.attempt_load(path, "review_encoder.pt")
@@ -131,7 +131,7 @@ def compute_accuracy(self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "l
acc_t = (torch.argmax(logits, dim=0) == labels).sum()
return (acc_i + acc_t) / n / 2
def cosine_sim(self,\
- x: TensorType[-1, "latent_dim"],
+ x: TensorType[-1, "latent_dim"],
y: TensorType[-1, "latent_dim"]):
"""
Computes the cosine similarity between two sets of vectors x,y
@@ -145,7 +145,7 @@ def cosine_sim(self,\
x = F.normalize(x)
y = F.normalize(y)
# small term added to avoid nans in low precision softmax
- return (x @ y.T + 1e-6)
+ return (x @ y.T + 1e-6)
def contrastive_loss(
self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"]
@@ -158,7 +158,7 @@ def contrastive_loss(
loss_i = F.cross_entropy(logits, labels)
loss_t = F.cross_entropy(logits.T, labels)
return (loss_i + loss_t) / 2
-
+
def clamp(self):
with torch.no_grad():
self.logit_scale.clamp(self.clamp_min, self.clamp_max)
@@ -183,10 +183,10 @@ def _make_projection_layers(self, config):
self.review_encoder.d_model, self.latent_dim, config.proj_dropout
)
return proj_pass, proj_rev
-
+
def _embed_data(
self,
- x: BatchElement,
+ x: BatchElement,
encoder,
projector,
):
@@ -218,7 +218,7 @@ def calculate_embeddings(
with torch.no_grad(), torch.cuda.amp.autocast():
pass_encs = [self.encode_passages(p) for p in passages]
rev_encs = [self.encode_reviews(r) for r in reviews]
-
+
# if we only need the embeddings, fetch them
if return_only_embeddings:
pass_encs = list(map(lambda x: x.hidden, pass_encs))
@@ -240,7 +240,7 @@ def zero_grad(self,
opt : torch.optim.Optimizer):
if self.accum_step % self.config.grad_accum == 0:
opt.zero_grad()
-
+
def step(self,
scaler : torch.cuda.amp.GradScaler,
opt: torch.optim.Optimizer):
@@ -258,7 +258,7 @@ def eval_step(self, dataset):
for p, r in dataset:
passages.append(p)
reviews.append(r)
-
+
# TODO: Ideally should get microbatch size from trainconfig for the second argument
passages = chunkBatchElement(passages[0], 8)
reviews = chunkBatchElement(reviews[0], 8)
@@ -297,6 +297,7 @@ def forward(
from carp.pytorch.model.architectures.carp_mlm import CARPMLM
from carp.pytorch.model.architectures.carp_coop import CARPCoOp
from carp.pytorch.model.architectures.carp_shared_encoder import CARPSharedEncoder
+from carp.pytorch.model.architectures.distill_carp import DistillCARP
def get_architecture(name):
return _ARCHITECTURES[name.lower()]
diff --git a/carp/pytorch/model/architectures/distill_carp.py b/carp/pytorch/model/architectures/distill_carp.py
new file mode 100644
index 0000000..8ae434c
--- /dev/null
+++ b/carp/pytorch/model/architectures/distill_carp.py
@@ -0,0 +1,126 @@
+import torch
+import torch.nn.functional as F
+from torchtyping import TensorType, patch_typeguard
+from typeguard import typechecked
+from carp.configs import CARPConfig, ModelConfig, TrainConfig
+from carp.pytorch.model.architectures import *
+from carp.pytorch.model.encoders import get_encoder
+from carp.util import mbTokens, generate_indices
+from typing import List
+
+from carp.pytorch.data.utils.data_util import BatchElement
+from carp.pytorch.data.distill_pipeline import DistillBatchElement
+
+patch_typeguard()
+
+@typechecked
+@register_architecture
+class DistillCARP(BaseModel):
+ def __init__(self, config: ModelConfig):
+ super().__init__(config)
+
+ def _reduction(self, logits: TensorType["pass_N", "rev_N"]) -> TensorType["pass_N", "reduced_rev_N"]:
+ n = logits.shape[0]
+ logits = torch.sum(logits.reshape((n,-1,self.reviews_per_passage)), dim=-1) / float(self.reviews_per_passage)
+ return logits
+
+ def compute_accuracy(self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"]):
+ with torch.no_grad():
+ n = x.shape[0]
+ x = F.normalize(x)
+ y = F.normalize(y)
+ logits = x @ y.T * self.logit_scale.exp()
+
+ logits = self._reduction(logits)
+
+ labels = torch.arange(n, device=self.config.device)
+ acc_i = (torch.argmax(logits, dim=1) == labels).sum()
+ acc_t = (torch.argmax(logits, dim=0) == labels).sum()
+ return (acc_i + acc_t) / n / 2
+
+ def contrastive_loss(
+ self, x: TensorType[-1, "latent_dim"], y: TensorType[-1, "latent_dim"]
+ ) -> TensorType[(), float]:
+
+ n = x.shape[0]
+ # small term added to avoid nans in low precision softmax
+ #(num_passages, num_reviews)
+ logits = self.cosine_sim(x,y) * self.logit_scale.exp()
+ logits = self._reduction(logits)
+
+ logits_i = F.softmax(logits, dim=-1)
+ logits_t = F.softmax(logits, dim=0)
+
+ #Reduce logits into diagonal
+ labels = torch.arange(n, device=self.config.device)
+ loss_i = F.nll_loss(torch.log(logits_i), labels)
+ loss_t = F.nll_loss(torch.log(logits_t.T), labels)
+ return (loss_i + loss_t) / 2
+
+
+ def train_step(
+ self,
+ passages: BatchElement,
+ reviews: DistillBatchElement,
+ config: TrainConfig,
+ opt: torch.optim.Optimizer,
+ scaler: torch.cuda.amp.GradScaler,
+ ) -> Dict[str, TensorType[()]]:
+
+ self.reviews_per_passage = reviews.input_ids.size()[0] // passages.input_ids.size()[0]
+
+ microbatch_inds_passages = generate_indices(
+ passages.input_ids.shape[0], config.microbatch_size, shuffle=False
+ )
+ microbatch_inds_reviews = generate_indices(
+ reviews.input_ids.shape[0], config.microbatch_size, shuffle=False
+ )
+ # Split tokens and masks into these microbatches
+ pass_mbs: List[BatchElement] = [
+ BatchElement(passages.input_ids[i], passages.mask[i]) for i in microbatch_inds_passages
+ ]
+ rev_mbs: List[DistillBatchElement] = [
+ DistillBatchElement(reviews.input_ids[i], reviews.mask[i], reviews.reviews_per_passage) for i in microbatch_inds_reviews
+ ]
+
+ reviews_per_passage = reviews.reviews_per_passage
+
+ # Initially get all encodings without grad
+ pass_encs, rev_encs = self.calculate_embeddings(pass_mbs, rev_mbs)
+
+
+ #compute accuracy
+ forward_acc = self.compute_accuracy(torch.cat(pass_encs), torch.cat(rev_encs))
+
+ # does gradient accumulation
+ self.zero_grad(opt)
+
+ # Encode passages in microbatches (with grad)
+ for index, passage in enumerate(pass_mbs):
+ pass_tmp = pass_encs.copy()
+ with torch.cuda.amp.autocast():
+ pass_tmp[index] = self.encode_passages(passage).hidden
+ loss = self.contrastive_loss(
+ torch.cat(pass_tmp), torch.cat(rev_encs)
+ )
+ scaler.scale(loss).backward()
+ # Encode reviews in microbatches (with grad)
+ for index, review in enumerate(rev_mbs):
+ rev_tmp = rev_encs.copy() # no_grad
+ with torch.cuda.amp.autocast():
+ rev_tmp[index] = self.encode_reviews(review).hidden
+ # grad _just_ at positions in `index`
+ loss = self.contrastive_loss(
+ torch.cat(pass_encs), torch.cat(rev_tmp)
+ )
+ scaler.scale(loss).backward()
+ # Clipping
+ if self.config.grad_clip != -1:
+ scaler.unscale_(opt)
+ torch.nn.utils.clip_grad_norm_(self.parameters(), self.config.grad_clip)
+
+ self.step(scaler, opt)
+ return {
+ "Loss/Train": loss,
+ "Acc/Forward": forward_acc,
+ }
diff --git a/carp/pytorch/model/encoders/__init__.py b/carp/pytorch/model/encoders/__init__.py
index 8672348..b662f07 100644
--- a/carp/pytorch/model/encoders/__init__.py
+++ b/carp/pytorch/model/encoders/__init__.py
@@ -104,7 +104,7 @@ def forward(self, inputs_embeds=False, **kwargs):
def last_ones(self, t):
# Multipliying arange by max
# makes last non zero column have largest number in arange
- t = t * torch.arange(t.shape[1])
+ t = t * torch.arange(t.shape[1]).to(t)
# Then argmax gives index of last non zero column
t = t.argmax(1)
return t
diff --git a/configs/carp_shared_deberta.yml b/configs/carp_shared_deberta.yml
new file mode 100644
index 0000000..132a2f2
--- /dev/null
+++ b/configs/carp_shared_deberta.yml
@@ -0,0 +1,34 @@
+model:
+ latent_dim: 2048
+ proj_dropout: 0.1
+ linear_projection: true
+ model_path: "microsoft/deberta-v2-xlarge"
+ model_arch: "roberta"
+ encoder_type: "SharedSumTextEncoder"
+ grad_clip: -1.0
+ grad_accum: 1
+ momentum: 0.0
+ device: "cuda"
+
+train_job:
+ n_ctx: 512
+ epochs: 10
+ batch_size: 2048
+ microbatch_size: 8
+ lr_ramp_steps: 400
+ lr_decay_steps: 3366
+ learning_rate_init: 1.0e-4
+ learning_rate_target: 0.000006
+ log_interval: 2
+ checkpoint_interval: 500
+ validate_interval: 50
+ use_half: false
+ do_log: true
+ validation_size: 1000
+ eval_selection: "final_n"
+ use_bucket: false
+ dupe_protection: true
+ hard_dupe_protection: false
+ data_pipeline: "BaseDataPipeline"
+ orchestrator: "BaseOrchestrator"
+
\ No newline at end of file
diff --git a/configs/distill_carp.yml b/configs/distill_carp.yml
new file mode 100644
index 0000000..927fa25
--- /dev/null
+++ b/configs/distill_carp.yml
@@ -0,0 +1,33 @@
+model:
+ latent_dim: 2048
+ proj_dropout: 0.1
+ linear_projection: true
+ model_path: "johngiorgi/declutr-base"
+ model_arch: "roberta"
+ encoder_type: "SumTextEncoder"
+ grad_clip: -1.0
+ grad_accum: 1
+ momentum: 0.0
+ device: "cuda"
+
+train_job:
+ n_ctx: 512
+ epochs: 10
+ batch_size: 2048
+ microbatch_size: 48
+ lr_ramp_steps: 400
+ lr_decay_steps: 3366
+ learning_rate_init: 1.0e-4
+ learning_rate_target: 0.000006
+ log_interval: 2
+ checkpoint_interval: 500
+ validate_interval: 50
+ use_half: false
+ do_log: true
+ validation_size: 1000
+ eval_selection: "final_n"
+ use_bucket: false
+ dupe_protection: true
+ hard_dupe_protection: false
+ data_pipeline: "DistillDataPipeline"
+ orchestrator: "BaseOrchestrator"
\ No newline at end of file