-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdequantize_and_merge.py
51 lines (40 loc) · 1.89 KB
/
dequantize_and_merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import torch
from accelerate import Accelerator
from os.path import join as pjoin
from palme_model_openx import Palme
if __name__ == "__main__":
checkpoint_dir_path = None # insert path to weights
acces_token = None # insert token if neede
llama_checkpoint = "meta-llama/Llama-2-7b-hf"
#checkpoint_image_model = "google/vit-base-patch16-224-in21k"
checkpoint_image_model = "google/vit-large-patch16-224"
checkpoint_image_model = "openai/clip-vit-large-patch14"
# checkpoint_image_model = "openai/clip-vit-base-patch32"
device_index = Accelerator().process_index
device_map = {"": device_index}
model = Palme(llama_checkpoint=llama_checkpoint, acces_token=acces_token, image_model_name=checkpoint_image_model,
# config=None,
load_in_8bit=True,
lora_lm=True,
lora_vision=False, freeze_vision=True,
device_map=device_map,
torch_dtype = torch.bfloat16,
)
print("Load trained model")
model.load(pjoin(checkpoint_dir_path, "pytorch_model.bin"))
#
print("Dequantize model")
model.lm = model.lm._unload_and_optionally_merge(dtype=torch.bfloat16) # does not work on titan x
# model.merge_and_unload() # does not work on titan x
#
print("Save dequantized model")
os.makedirs(pjoin(checkpoint_dir_path, 'dequant'), exist_ok=True)
torch.save(model.lm.state_dict(),
pjoin(checkpoint_dir_path, 'dequant', "lm_model.bin"))
if 'openai/clip' in checkpoint_image_model:
torch.save(model.proj_layer.state_dict(),
pjoin(checkpoint_dir_path, 'dequant', "img_proj_layer_model.bin"))
elif 'google/vit' in checkpoint_image_model:
torch.save(model.img_embed_model.classifier.state_dict(),
pjoin(checkpoint_dir_path, 'dequant', "img_embed_model_classifier_model.bin"))