diff --git a/med_palm/model.py b/med_palm/model.py
index 0655245..cc30c45 100644
--- a/med_palm/model.py
+++ b/med_palm/model.py
@@ -1,241 +1,3 @@
-
-# import bitsandbytes
-# import torch
-# import torch.nn as nn
-# from flamingo_pytorch import PerceiverResampler
-# from transformers import AutoTokenizer, CLIPModel, CLIPProcessor
-
-# from med_palm.palm import PaLM
-
-
-# class MedPalmTokenizer:
-# def __init__(self):
-# try:
-# self.processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
-# self.tokenizer = AutoTokenizer.from_pretrained(
-# "EleutherAI/gpt-neox-20b",
-# additional_special_tokens=["", ""],
-# eos_token ="",
-# pad_token="",
-# extra_ids=0,
-# model_max_length=8192
-# )
-
-# self.im_idx, self.im_end_idx = self.tokenizer.convert_tokens_to_ids(["", ""])
-# except Exception as e:
-# print(f"Error init tokenizer: {e}")
-
-
-# def tokenize_texts(self, texts):
-# try:
-
-# texts = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).input_ids
-# image_tokens = torch.tensor([[self.im_idx, self.im_end_idx]] * texts.shape[0])
-# return torch.cat([texts[:, 0:1], image_tokens, texts[:, 1:]], dim=1), texts
-# except Exception as e:
-# print(f"Error tokenizing texts: {e}")
-
-# def tokenize_images(self, images):
-# try:
-
-# tokenized_images = self.processor(images=images, return_tensors="pt").pixel_values
-# print(f"Tokenized image: {tokenized_images.shape}")
-# return tokenized_images
-
-# except Exception as e:
-# print(f"Error tokenizing texts: {e}")
-
-# def tokenize(self, sample):
-# try:
-
-# text_tokens, only_text_tokens = self.tokenize_texts(sample["target_text"])
-# attention_mask = text_tokens != self.tokenizer.pad_token_id
-# dummy_image_features = torch.ones((text_tokens.shape[0], 64))
-# attention_mask = torch.cat([dummy_image_features, attention_mask], dim=1)
-# return {
-# "text_tokens": text_tokens,
-# "images": self.tokenize_images(sample["image"]),
-# "labels": only_text_tokens,
-# "attention_mask": attention_mask,
-# }
-
-# except Exception as e:
-# print(f"Error during tokenization {e}")
-
-# class MedPalm(nn.Module):
-# def __init__(self,
-# num_tokens: int = 50528,
-# dim: int = 2048,
-# depth: int = 16,
-# dim_head:int = 128,
-# heads: int = 8,
-# flash_attn: bool = True,
-# qk_rmsnorm: bool = False
-# ):
-# super(MedPalm, self).__init__()
-
-# self.num_tokens = num_tokens
-# self.dim = dim
-# self.depth = depth
-
-# self.heads = heads
-# self.flash_attn = flash_attn
-# self.qk_rmsnorm = qk_rmsnorm
-
-# try:
-
-# self.vit_model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K").vision_model
-
-# self.embed = bitsandbytes.nn.modules.Embedding(
-# num_tokens,
-# dim,
-# padding_idx=1
-# )
-
-# self.output_projection = torch.nn.Linear(
-# dim, num_tokens, bias=False
-# )
-# torch.nn.init.normal_(
-# self.output_projection.weight, mean=0, std=dim**-0.5
-# )
-
-# self.decoder = PaLM(
-# num_tokens=num_tokens,
-# dim=dim,
-# depth=depth,
-# dim_head=dim_head,
-# heads=heads,
-# flash_attn=flash_attn,
-# qk_rmsnorm=qk_rmsnorm,
-# )
-
-# self.perceive = PerceiverResampler(
-# dim= 1024,
-# depth = 2,
-# dim_head = 8,
-# num_latents = 64,
-# num_media_embeds = 257
-# )
-
-# # self.image_resize = torch.nn.Linear(224 * 224, 1024 * 1024)
-
-# self.image_proj = torch.nn.Linear(1024, dim, bias=False)
-# torch.nn.init.normal_(
-# self.image_proj.weight, mean=0, std=dim**-0.5
-# )
-
-# except Exception as e:
-# print(f"Error initlizing palme components: {e}")
-
-# def forward(self, text_tokens, images):
-# # try:
-
-# # # images = images.view(images.size(0), -1) # Flatten the images
-# # # images = self.image_resize(images) # Resize the images using the linear transformation layer
-# # # images = images.view(images.size(0), 3, 1024, 1024) # Reshape the images to the expected size
-
-# # images = self.vit_model(pixel_values=images)["last_hidden_state"]
-# # print(f'Images first" {images.shape}')
-
-# # images = self.perceive(images).squeeze(1)
-# # print(f"Images perceive: {images}")
-
-# # images = self.image_proj(images)
-# # print(f"Images projected: {images}")
-
-# # # images_flattened = images.view(images.size(0), -1)
-# # # print(f"Images flattened: {images_flattened}")
-
-# # model_input = self.decoder(text_tokens)
-
-# # print(model_input[:, 0:2].shape, images.shape, model_input[:, 2:].shape)
-
-# # # images_flattened = images_flattened.view(1, 2, -1)
-# # # print(f"Images flattened: {images_flattened}")
-
-
-# # if model_input.size(1) < 3:
-# # print(f"Error model_input has less than 3 columns: {model_input.shape}")
-# # return None
-
-# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=-1)
-# # print(f"Model input: {model_input}")
-
-# # model_input = self.decoder(model_input, tokens_mask=None)
-# # print(f"Model input: {model_input}")
-
-# # output = self.decoder(model_input, passed_x=model_input)[0]
-# # print(f"output: {output}")
-
-# # return output
-
-# # except Exception as e:
-# # print(f"Error duing forward pass: {e}")
-# # return None
-
-
-# ######################## v2
-# # if not isinstance(text_tokens, torch.Tensor) or not isinstance(images, torch.Tensor):
-# # raise TypeError("text_tokens and images must be instances of torch.Tensor")
-
-# # print(f'RAWWWW IMAGE SHAPE: {images.shape}')
-
-
-# # images = self.vit_model(pixel_values=images)["last_hidden_state"]
-# # print(f'1st images shape in vit: {images}')
-
-# # images = self.perceive(images).squeeze(1)
-# # print(f'self perceive: {images}')
-
-# # images = self.image_proj(images)
-# # print(f'projection layer :{images}')
-
-
-# # model_input = self.decoder(text_tokens)
-# # print(f'1ST MODEL INPUT {model_input}')
-
-# # model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=1)
-# # print(f'MODEL INPUT : {model_input}')
-
-# # model_input = self.decoder(model_input, tokens_mask=None)
-# # print(f'model_input: {model_input}')
-
-# # output = self.decoder(model_input, passed_x=model_input)[0]
-# # print(f'output: {output}')
-
-# # return output
-
-# images = self.vit_model(pixel_values=images)["last_hidden_state"]
-# print(f'1st images shape in vit: {images}')
-
-# images = self.perceive(images).squeeze(1)
-# print(f'self perceive: {images.shape}')
-
-# images = self.image_proj(images)
-# print(f'projection layer :{images.shape}')
-
-# model_input = self.decoder(text_tokens)
-
-# if model_input.size(1) < 3:
-# print(f"Error model_input has less than 3 columns: {model_input.shape}")
-# return None
-
-# model_input = torch.cat([model_input[:, 0:2], images, model_input[:, 2:]], dim=-1)
-# print(f"Model input: {model_input.shape}")
-
-# model_input = self.decoder(model_input, tokens_mask=None)
-# print(f"Model input: {model_input.shape}")
-
-# output = self.decoder(model_input, passed_x=model_input)[0]
-# print(f"output: {output.shape}")
-
-
-
-
-
-
-
-import bitsandbytes
import torch
import torch.nn as nn
from flamingo_pytorch import PerceiverResampler
@@ -243,7 +5,6 @@
from med_palm.palm import PaLM
-
class MedPalmTokenizer:
def __init__(self):
try:
@@ -341,70 +102,6 @@ def __init__(self):
except Exception as e:
print(f"Error initlizing palme components: {e}")
- # def forward(self, text_tokens, images):
- # try:
- # # Average the number of channels dimension
- # images = images.mean(dim=1, keepdim=True)
- # print(f'images1st mean: {images.shape}')
-
- # # Flatten the images
- # images = images.view(images.size(0), -1)
- # print(f"Images shape before resize: {images.shape}")
-
- # # Check if images have the correct shape for image_resize
- # if images.size(-1) != self.image_resize.in_features:
- # print(f"Error: images has incorrect shape for image_resize. Expected last dimension: {self.image_resize.in_features}, got: {images.size(-1)}")
- # return None
-
- # # Resize the images
- # images = self.image_resize(images)
-
- # # Reshape the images to the expected size
- # images = images.view(images.size(0), 3, 1024, 1024)
-
- # # Apply the PerceiverResampler to the images
- # images = self.perceive(images).squeeze(1)
- # print(f"Images perceive: {images}")
-
- # # Check if images have the correct shape for image_proj
- # print(f"Images shape before proj: {images.shape}")
- # if images.size(-1) != self.image_proj.in_features:
- # print(f"Error: images has incorrect shape for image_proj. Expected last dimension: {self.image_proj.in_features}, got: {images.size(-1)}")
- # return None
-
- # # Project the images
- # images = self.image_proj(images)
- # print(f"Images projected: {images}")
-
- # # Flatten the images
- # images_flattened = images.view(images.size(0), -1)
- # print(f"Images flattened: {images_flattened}")
-
- # # Pass the text tokens through the decoder
- # model_input = self.decoder(text_tokens)
- # print(model_input[:, 0:2].shape, images.shape, model_input[:, 2:].shape)
-
- # # Reshape the flattened images
- # images_flattened = images_flattened.view(1, 2, -1)
- # print(f"Images flattened: {images_flattened}")
-
- # # Concatenate the model input and the flattened images
- # model_input = torch.cat([model_input[:, 0:2], images_flattened, model_input[:, 2:]], dim=-1)
- # print(f"Model input: {model_input}")
-
- # # Pass the model input through the decoder
- # model_input = self.decoder(model_input, tokens_mask=None)
- # print(f"Model input: {model_input}")
-
- # # Get the output from the decoder
- # output = self.decoder(model_input, passed_x=model_input)[0]
- # print(f"output: {output}")
-
- # return output
-
- # except Exception as e:
- # print(f"Error during forward pass: {e}")
- # return None
def forward(self, text_tokens, images):
try:
# images = images.view(images.size(0), -1)
diff --git a/requirements.txt b/requirements.txt
index 2e86ddf..a2ed243 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,4 @@ lion-pytorch
accelerate
datasets
deepspeed
+bitsandbytes
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 5f0cc3e..6b190b2 100644
--- a/setup.py
+++ b/setup.py
@@ -3,7 +3,7 @@
setup(
name = 'MedPalm',
packages = find_packages(exclude=['examples']),
- version = '0.0.3',
+ version = '0.0.5',
license='MIT',
description = 'MedPalm - Pytorch',
author = 'Kye Gomez',
@@ -21,6 +21,7 @@
"numpy",
"einops",
"accelerate",
+ "bitsandbytes",
"transformers",
"SentencePiece",
"datasets",