Skip to content

Commit

Permalink
Surgery updates.
Browse files Browse the repository at this point in the history
  • Loading branch information
V committed Mar 1, 2025
1 parent a4ca0c3 commit 801d64d
Showing 1 changed file with 94 additions and 30 deletions.
124 changes: 94 additions & 30 deletions examples/llava/qwen2_5_vl_surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from transformers import (
Qwen2_5_VLForConditionalGeneration,
Qwen2_5_VLProcessor,
AutoProcessor,
Qwen2_5_VLConfig,
Qwen2VLImageProcessor
)
Expand All @@ -21,14 +20,29 @@ def k(raw_key: str, arch: str) -> str:

def to_gguf_name(name: str) -> str:
og = name
name = name.replace("text_model", "t").replace("vision_model", "v")
name = name.replace("blocks", "blk").replace("embeddings.", "")
# Handle the base case where vision_model is not in the name
if not name.startswith("vision_model."):
name = "vision_model." + name

name = name.replace("vision_model", "v")
name = name.replace("text_model", "t")
name = name.replace("blocks", "blk")
name = name.replace("embeddings.", "")
name = name.replace("attn.", "attn_")
name = name.replace("mlp.gate_proj", "ffn_gate").replace("mlp.up_proj", "ffn_up").replace("mlp.down_proj", "ffn_down")

# Handle MLP components correctly
name = name.replace("mlp.gate_proj", "ffn_gate")
name = name.replace("mlp.up_proj", "ffn_up")
name = name.replace("mlp.down_proj", "ffn_down")

# Handle projection and norm components
name = name.replace("proj.", "out.")
# Replace norm names so that layernorms become ln1/ln2
name = name.replace("norm1", "ln1").replace("norm2", "ln2")
name = name.replace("norm1", "ln1")
name = name.replace("norm2", "ln2")

# Handle merger components correctly
name = name.replace("merger.mlp", "mm")

print(f"[to_gguf_name] {og} --> {name}")
return name

Expand All @@ -37,6 +51,10 @@ def find_vision_tensors(qwen2vl, np_dtype) -> Dict[str, np.ndarray]:
vision_model = qwen2vl.visual
tensor_map = {}

# Debug info
print(f"Vision model type: {type(vision_model)}")
print(f"Number of blocks: {len(vision_model.blocks)}")

for name, ten in vision_model.state_dict().items():
ten = ten.numpy()

Expand All @@ -51,14 +69,14 @@ def find_vision_tensors(qwen2vl, np_dtype) -> Dict[str, np.ndarray]:
wq = ten[:c]
wk = ten[c: c * 2]
wv = ten[c * 2:]
base_name = to_gguf_name(f"vision_model.{name}")
base_name = to_gguf_name(name)
tensor_map[base_name.replace("qkv", "q")] = wq
tensor_map[base_name.replace("qkv", "k")] = wk
tensor_map[base_name.replace("qkv", "v")] = wv

elif 'gate_proj' in name or 'up_proj' in name or 'down_proj' in name:
# Handle the MLP structure with gate/up/down projections
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
tensor_map[to_gguf_name(name)] = ten

elif 'merger' in name:
# Map merger layernorm parameters to post_ln keys
Expand All @@ -85,26 +103,38 @@ def find_vision_tensors(qwen2vl, np_dtype) -> Dict[str, np.ndarray]:
# For the Conv3d, split the temporal kernel dimension (which is 2)
c1, c2, kt, kh, kw = ten.shape
assert kt == 2, "Current implementation only supports temporal_patch_size of 2"
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, ...]
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, ...]

# Properly handle the Conv3d weights for GGUF
# Reshape from [output_channels, input_channels, temporal, height, width]
# to the format expected by GGUF
# For temporal slice 0
tensor_map["v.patch_embd.weight"] = ten[:, :, 0, :, :].reshape(c1, c2 * kh * kw)
# For temporal slice 1
tensor_map["v.patch_embd.weight.1"] = ten[:, :, 1, :, :].reshape(c1, c2 * kh * kw)

elif 'norm1' in name or 'norm2' in name:
# Handle the RMSNorm correctly
tensor_map[to_gguf_name(name)] = ten

else:
tensor_map[to_gguf_name(f"vision_model.{name}")] = ten
tensor_map[to_gguf_name(name)] = ten

# Ensure biases and layer norm weights remain in fp32
for new_name, ten in tensor_map.items():
if (ten.ndim <= 1 or
new_name.endswith("ln1.weight") or
new_name.endswith("ln1.bias") or
new_name.endswith("ln2.weight") or
new_name.endswith("ln2.bias")):
new_name.endswith("ln2.bias") or
new_name.endswith("post_ln.weight") or
new_name.endswith("post_ln.bias")):
tensor_map[new_name] = ten.astype(np.float32)
else:
tensor_map[new_name] = ten.astype(np_dtype)

# Dummy tensor as a placeholder for position embeddings
# Required even when using rotary embeddings
tensor_map["v.position_embd.weight"] = np.zeros([10, 10], dtype=np.float32)
# Add rotary embeddings info - dummy tensor as a placeholder
# This is needed because the model uses rotary position embeddings
tensor_map["v.position_embd.weight"] = np.zeros([1, 1], dtype=np.float32)

return tensor_map

Expand Down Expand Up @@ -160,36 +190,70 @@ def main(args):
for name, data in tensor_map.items():
fout.add_tensor(name, data)

# Add key vision model parameters
fout.add_uint32("clip.vision.patch_size", vcfg.patch_size)
fout.add_uint32("clip.vision.image_size", 560)
fout.add_uint32("clip.vision.projection_dim", 1536)
fout.add_uint32("clip.vision.projection_dim", 1536) # Output of the merger
fout.add_uint32("clip.vision.embedding_length", vcfg.hidden_size)
fout.add_uint32(k(KEY_ATTENTION_HEAD_COUNT, VISION), vcfg.num_heads)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6)
fout.add_float32(k(KEY_ATTENTION_LAYERNORM_EPS, VISION), 1e-6) # From the RMSNorm epsilon
fout.add_uint32(k(KEY_BLOCK_COUNT, VISION), vcfg.depth)
# For Qwen2.5VL the feed forward dim is 0 since we handle the MLP differently
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 0)

# For Qwen2.5VL, specify the feed forward dimension from mlp
fout.add_uint32(k(KEY_FEED_FORWARD_LENGTH, VISION), 3420) # From gate_proj/up_proj dimensions

# Add additional flags for Qwen2.5 specific features
fout.add_bool("clip.vision.use_rms_norm", True) # Qwen2 uses RMSNorm
fout.add_bool("clip.vision.use_rotary_embeddings", True) # Uses rotary embeddings

fout.add_name(model_name)

fout.add_string("clip.vision.mm_patch_merge_type", "qwen2vl_merger")
# Set the appropriate crop resolution based on image_size
fout.add_uint32("clip.vision.image_crop_resolution", 560)

# Add image grid pinpoints to avoid buffer overflow
# This array defines normalized coordinates for grid sampling in the vision model
# Using standard grid points for 560x560 image with patch size 14
grid_size = 560 // 14 # Number of patches in each dimension
pinpoints = []
for y in range(grid_size):
for x in range(grid_size):
# Normalized coordinates from 0.0 to 1.0
# Convert to Python float instead of numpy.float32
pinpoints.append(float(x / (grid_size - 1)))
pinpoints.append(float(y / (grid_size - 1)))

# Add pinpoints as a float array
fout.add_array("clip.vision.image_grid_pinpoints", pinpoints)

# Load processor for image normalization values
if MODEL_INPUT_DIR is not None:
processor: Qwen2_5_VLProcessor = Qwen2VLImageProcessor.from_pretrained(model_path)
processor = Qwen2VLImageProcessor.from_pretrained(model_path)
else:
processor: Qwen2_5_VLProcessor = Qwen2_5_VLProcessor.from_pretrained(model_name)

fout.add_array("clip.vision.image_mean", processor.image_mean)
fout.add_array("clip.vision.image_std", processor.image_std)
processor = Qwen2_5_VLProcessor.from_pretrained(model_name)

# Get the image mean and std values and ensure they're in the right format
try:
# Try accessing through image_processor first (newer versions)
image_mean = processor.image_mean
image_std = processor.image_std
except AttributeError:
# Fallback to direct access (older versions)
image_mean = processor.image_mean
image_std = processor.image_std

# Convert numpy values to Python floats
image_mean = [float(x) for x in image_mean]
image_std = [float(x) for x in image_std]

# Add arrays with Python float values
fout.add_array("clip.vision.image_mean", image_mean)
fout.add_array("clip.vision.image_std", image_std)

# Set the activation function flags based on the model config
if hasattr(vcfg, 'hidden_act') and 'silu' in vcfg.hidden_act.lower():
fout.add_bool("clip.use_silu", True)
fout.add_bool("clip.use_gelu", False)
else:
fout.add_bool("clip.use_silu", False)
fout.add_bool("clip.use_gelu", False) # Use defaults from dump
fout.add_bool("clip.use_silu", True) # Qwen2.5VL uses SiLU activation in MLP
fout.add_bool("clip.use_gelu", False)

fout.write_header_to_file()
fout.write_kv_data_to_file()
Expand Down

0 comments on commit 801d64d

Please sign in to comment.