diff --git a/med_palm/model.py b/med_palm/model.py index 0655245..cc30c45 100644 --- a/med_palm/model.py +++ b/med_palm/model.py @@ -1,241 +1,3 @@ - -# import bitsandbytes -# import torch -# import torch.nn as nn -# from flamingo_pytorch import PerceiverResampler -# from transformers import AutoTokenizer, CLIPModel, CLIPProcessor - -# from med_palm.palm import PaLM - - -# class MedPalmTokenizer: -# def __init__(self): -# try: -# self.processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K") -# self.tokenizer = AutoTokenizer.from_pretrained( -# "EleutherAI/gpt-neox-20b", -# additional_special_tokens=["", ""], -# eos_token ="", -# pad_token="", -# extra_ids=0, -# model_max_length=8192 -# ) - -# self.im_idx, self.im_end_idx = self.tokenizer.convert_tokens_to_ids(["", ""]) -# except Exception as e: -# print(f"Error init tokenizer: {e}") - - -# def tokenize_texts(self, texts): -# try: - -# texts = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).input_ids -# image_tokens = torch.tensor([[self.im_idx, self.im_end_idx]] * texts.shape[0]) -# return torch.cat([texts[:, 0:1], image_tokens, texts[:, 1:]], dim=1), texts -# except Exception as e: -# print(f"Error tokenizing texts: {e}") - -# def tokenize_images(self, images): -# try: - -# tokenized_images = self.processor(images=images, return_tensors="pt").pixel_values -# print(f"Tokenized image: {tokenized_images.shape}") -# return tokenized_images - -# except Exception as e: -# print(f"Error tokenizing texts: {e}") - -# def tokenize(self, sample): -# try: - -# text_tokens, only_text_tokens = self.tokenize_texts(sample["target_text"]) -# attention_mask = text_tokens != self.tokenizer.pad_token_id -# dummy_image_features = torch.ones((text_tokens.shape[0], 64)) -# attention_mask = torch.cat([dummy_image_features, attention_mask], dim=1) -# return { -# "text_tokens": text_tokens, -# "images": self.tokenize_images(sample["image"]), -# "labels": only_text_tokens, -# "attention_mask": attention_mask, -# } - -# except Exception as e: -# print(f"Error during tokenization {e}") - -# class MedPalm(nn.Module): -# def __init__(self, -# num_tokens: int = 50528, -# dim: int = 2048, -# depth: int = 16, -# dim_head:int = 128, -# heads: int = 8, -# flash_attn: bool = True, -# qk_rmsnorm: bool = False -# ): -# super(MedPalm, self).__init__() - -# self.num_tokens = num_tokens -# self.dim = dim -# self.depth = depth - -# self.heads = heads -# self.flash_attn = flash_attn -# self.qk_rmsnorm = qk_rmsnorm - -# try: - -# self.vit_model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").vision_model - -# self.embed = bitsandbytes.nn.modules.Embedding( -# num_tokens, -# dim, -# padding_idx=1 -# ) - -# self.output_projection = torch.nn.Linear( -# dim, num_tokens, bias=False -# ) -# torch.nn.init.normal_( -# self.output_projection.weight, mean=0, std=dim**-0.5 -# ) - -# self.decoder = PaLM( -# num_tokens=num_tokens, -# dim=dim, -# depth=depth, -# dim_head=dim_head, -# heads=heads, -# flash_attn=flash_attn, -# qk_rmsnorm=qk_rmsnorm, -# ) - -# self.perceive = PerceiverResampler( -# dim= 1024, -# depth = 2, -# dim_head = 8, -# num_latents = 64, -# num_media_embeds = 257 -# ) - -# # self.image_resize = torch.nn.Linear(224 * 224, 1024 * 1024) - -# self.image_proj = torch.nn.Linear(1024, dim, bias=False) -# torch.nn.init.normal_( -# self.image_proj.weight, mean=0, std=dim**-0.5 -# ) - -# except Exception as e: -# print(f"Error initlizing palme components: {e}") - -# def forward(self, text_tokens, images): -# # try: - -# # # images = images.view(images.size(0), -1) # Flatten the images -# # # images = self.image_resize(images) # Resize the images using the linear transformation layer -# # # images = images.view(images.size(0), 3, 1024, 1024) # Reshape the images to the expected size - -# # images = self.vit_model(pixel_values=images)["last_hidden_state"] -# # print(f'Images first" {images.shape}') - -# # images = self.perceive(images).squeeze(1) -# # print(f"Images perceive: {images}") - -# # images = self.image_proj(images) -# # print(f"Images projected: {images}") - -# # # images_flattened = images.view(images.size(0), -1) -# # # print(f"Images flattened: {images_flattened}") - -# # model_input = self.decoder(text_tokens) - -# # print(model_input[:, 0:2].shape, images.shape, model_input[:, 2:].shape) - -# # # images_flattened = images_flattened.view(1, 2, -1) -# # # print(f"Images flattened: {images_flattened}") - - -# # if model_input.size(1) < 3: -# # print(f"Error model_input has less than 3 columns: {model_input.shape}") -# # return None - -# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=-1) -# # print(f"Model input: {model_input}") - -# # model_input = self.decoder(model_input, tokens_mask=None) -# # print(f"Model input: {model_input}") - -# # output = self.decoder(model_input, passed_x=model_input)[0] -# # print(f"output: {output}") - -# # return output - -# # except Exception as e: -# # print(f"Error duing forward pass: {e}") -# # return None - - -# ######################## v2 -# # if not isinstance(text_tokens, torch.Tensor) or not isinstance(images, torch.Tensor): -# # raise TypeError("text_tokens and images must be instances of torch.Tensor") - -# # print(f'RAWWWW IMAGE SHAPE: {images.shape}') - - -# # images = self.vit_model(pixel_values=images)["last_hidden_state"] -# # print(f'1st images shape in vit: {images}') - -# # images = self.perceive(images).squeeze(1) -# # print(f'self perceive: {images}') - -# # images = self.image_proj(images) -# # print(f'projection layer :{images}') - - -# # model_input = self.decoder(text_tokens) -# # print(f'1ST MODEL INPUT {model_input}') - -# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1) -# # print(f'MODEL INPUT : {model_input}') - -# # model_input = self.decoder(model_input, tokens_mask=None) -# # print(f'model_input: {model_input}') - -# # output = self.decoder(model_input, passed_x=model_input)[0] -# # print(f'output: {output}') - -# # return output - -# images = self.vit_model(pixel_values=images)["last_hidden_state"] -# print(f'1st images shape in vit: {images}') - -# images = self.perceive(images).squeeze(1) -# print(f'self perceive: {images.shape}') - -# images = self.image_proj(images) -# print(f'projection layer :{images.shape}') - -# model_input = self.decoder(text_tokens) - -# if model_input.size(1) < 3: -# print(f"Error model_input has less than 3 columns: {model_input.shape}") -# return None - -# model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=-1) -# print(f"Model input: {model_input.shape}") - -# model_input = self.decoder(model_input, tokens_mask=None) -# print(f"Model input: {model_input.shape}") - -# output = self.decoder(model_input, passed_x=model_input)[0] -# print(f"output: {output.shape}") - - - - - - - -import bitsandbytes import torch import torch.nn as nn from flamingo_pytorch import PerceiverResampler @@ -243,7 +5,6 @@ from med_palm.palm import PaLM - class MedPalmTokenizer: def __init__(self): try: @@ -341,70 +102,6 @@ def __init__(self): except Exception as e: print(f"Error initlizing palme components: {e}") - # def forward(self, text_tokens, images): - # try: - # # Average the number of channels dimension - # images = images.mean(dim=1, keepdim=True) - # print(f'images1st mean: {images.shape}') - - # # Flatten the images - # images = images.view(images.size(0), -1) - # print(f"Images shape before resize: {images.shape}") - - # # Check if images have the correct shape for image_resize - # if images.size(-1) != self.image_resize.in_features: - # print(f"Error: images has incorrect shape for image_resize. Expected last dimension: {self.image_resize.in_features}, got: {images.size(-1)}") - # return None - - # # Resize the images - # images = self.image_resize(images) - - # # Reshape the images to the expected size - # images = images.view(images.size(0), 3, 1024, 1024) - - # # Apply the PerceiverResampler to the images - # images = self.perceive(images).squeeze(1) - # print(f"Images perceive: {images}") - - # # Check if images have the correct shape for image_proj - # print(f"Images shape before proj: {images.shape}") - # if images.size(-1) != self.image_proj.in_features: - # print(f"Error: images has incorrect shape for image_proj. Expected last dimension: {self.image_proj.in_features}, got: {images.size(-1)}") - # return None - - # # Project the images - # images = self.image_proj(images) - # print(f"Images projected: {images}") - - # # Flatten the images - # images_flattened = images.view(images.size(0), -1) - # print(f"Images flattened: {images_flattened}") - - # # Pass the text tokens through the decoder - # model_input = self.decoder(text_tokens) - # print(model_input[:, 0:2].shape, images.shape, model_input[:, 2:].shape) - - # # Reshape the flattened images - # images_flattened = images_flattened.view(1, 2, -1) - # print(f"Images flattened: {images_flattened}") - - # # Concatenate the model input and the flattened images - # model_input = torch.cat([model_input[:, 0:2], images_flattened, model_input[:, 2:]], dim=-1) - # print(f"Model input: {model_input}") - - # # Pass the model input through the decoder - # model_input = self.decoder(model_input, tokens_mask=None) - # print(f"Model input: {model_input}") - - # # Get the output from the decoder - # output = self.decoder(model_input, passed_x=model_input)[0] - # print(f"output: {output}") - - # return output - - # except Exception as e: - # print(f"Error during forward pass: {e}") - # return None def forward(self, text_tokens, images): try: # images = images.view(images.size(0), -1) diff --git a/requirements.txt b/requirements.txt index 2e86ddf..a2ed243 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,4 @@ lion-pytorch accelerate datasets deepspeed +bitsandbytes \ No newline at end of file diff --git a/setup.py b/setup.py index 5f0cc3e..6b190b2 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'MedPalm', packages = find_packages(exclude=['examples']), - version = '0.0.3', + version = '0.0.5', license='MIT', description = 'MedPalm - Pytorch', author = 'Kye Gomez', @@ -21,6 +21,7 @@ "numpy", "einops", "accelerate", + "bitsandbytes", "transformers", "SentencePiece", "datasets",