Skip to content

Commit

Permalink
[Bugfix] Fix fully sharded LoRA bug (#10352)
Browse files Browse the repository at this point in the history
Signed-off-by: Jee Jee Li <[email protected]>
  • Loading branch information
jeejeelee authored Nov 15, 2024
1 parent 2690855 commit 1d65ec7
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 19 deletions.
23 changes: 12 additions & 11 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,14 @@ class MergedColumnParallelLinearWithShardedLoRA(
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None:
return lora_a
#NOTE: lora_a contains 2 subloras, and each sublora could be None.
output_shard_size = self.lora_a_stacked[0].shape[2]
output_start_idx = self.tp_rank * output_shard_size
lora_a = [
lora_a[0][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[1][:,
output_start_idx:output_start_idx + output_shard_size],
lora_a[0][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[0] is not None else None,
lora_a[1][:, output_start_idx:output_start_idx +
output_shard_size] if lora_a[1] is not None else None,
]
return lora_a

Expand Down Expand Up @@ -261,14 +260,16 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
def slice_lora_a(
self, lora_a: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_a[0] is None or lora_a[1] is None or lora_a[2] is None:
return lora_a
# NOTE: lora_a contains 3 subloras, and each sublora could be None.
shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)]
start_idx = [self.tp_rank * shard_size[i] for i in range(3)]
lora_a = [
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
lora_a[0][:, start_idx[0]:start_idx[0] +
shard_size[0]] if lora_a[0] is not None else None,
lora_a[1][:, start_idx[1]:start_idx[1] +
shard_size[1]] if lora_a[1] is not None else None,
lora_a[2][:, start_idx[2]:start_idx[2] +
shard_size[2]] if lora_a[2] is not None else None,
]
return lora_a

Expand Down
15 changes: 8 additions & 7 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,26 +685,27 @@ def slice_lora_a(
def slice_lora_b(
self, lora_b: List[Union[torch.Tensor, None]]
) -> List[Union[torch.Tensor, None]]:
if lora_b[0] is None or lora_b[1] is None:
return lora_b
#NOTE: lora_b contains 2 subloras, and each sublora could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
lora_b = [
lora_b[0][:, start_idx:end_idx],
lora_b[1][:, start_idx:end_idx],
lora_b[0][:, start_idx:end_idx] if lora_b[0] is not None else None,
lora_b[1][:, start_idx:end_idx] if lora_b[1] is not None else None,
]
return lora_b

def slice_bias(
self, bias: List[Union[torch.Tensor,
None]]) -> List[Union[torch.Tensor, None]]:
if bias[0] is None or bias[1] is None:
return bias
# NOTE : each bias could be None.
shard_size = self.output_dim
start_idx = self.tp_rank * shard_size
end_idx = (self.tp_rank + 1) * shard_size
bias = [bias[0][start_idx:end_idx], bias[1][start_idx:end_idx]]
bias = [
bias[0][start_idx:end_idx] if bias[0] is not None else None,
bias[1][start_idx:end_idx] if bias[1] is not None else None
]
return bias

def set_lora(
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
logger.info(
"Memory profiling results: total_gpu_memory=%.2fGiB"
" initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB"
" memory_usage_post_profile=%.2fGib"
" memory_usage_post_profile=%.2fGiB"
" non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB"
" gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
Expand Down

0 comments on commit 1d65ec7

Please sign in to comment.