Skip to content

Commit

Permalink
.
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro committed Oct 5, 2024
1 parent 21f8cb9 commit c5e813b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 22 deletions.
6 changes: 3 additions & 3 deletions include/flexflow/ops/kernels/lora_linear_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ class LoraLinearMeta : public OpMeta {
LoraLinearMeta(FFHandler handle, LoraLinear const *li);
~LoraLinearMeta(void);
// PEFT related fields
void *low_rank_activation;
void *input_activation;
std::unordeded_map<PEFTModelID, LoraLinearWeight> model_state;
// void *low_rank_activation;
// void *input_activation;
// std::unordeded_map<PEFTModelID, LoraLinearWeight> model_state;
// std::unordered_map<PEFTModelID, LoraLinearModelState> model_state;
// size_t allocated_peft_buffer_size1 = 0, allocated_peft_buffer_size2 = 0;
PEFTMemoryManager *peft_memory_manager;
Expand Down
20 changes: 12 additions & 8 deletions include/flexflow/utils/peft_weight_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,27 @@ struct LoraLinearWeight {
void *w0_ptr, *w1_ptr;
// gradients
void *w0_grad_ptr, *w1_grad_ptr;
// activations
void *input_activation;
void *low_rank_activation;
// v values for SGD optimizer (when using momentum)
void *w0_v_values_ptr, *w1_v_values_ptr;
// int in_dim, out_dim, rank, num_shards;
LoraLinearWeight(void *w0=nullptr, void *w1=nullptr, void *w0_grad=nullptr, void *w1_grad=nullptr,
void *w0_v_values=nullptr, void *w1_v_values=nullptr)
void *w0_v_values=nullptr, void *w1_v_values=nullptr, void *low_rank_activation_=nullptr, void *input_activation_=nullptr)
: w0_ptr(w0), w1_ptr(w1),
w0_grad_ptr(w0_grad), w1_grad_ptr(w1_grad),
w0_v_values_ptr(w0_v_values), w1_v_values_ptr(w1_v_values) {}
w0_v_values_ptr(w0_v_values), w1_v_values_ptr(w1_v_values),
low_rank_activation(low_rank_activation_), input_activation(input_activation_) {}
};

class PEFTMemoryManager {
public:
PEFTMemoryManager(Memory gpu_mem_, size_t max_lora_size_, int max_concurrent_adapters_, int in_dim_, int out_dim_, int num_shards_, int shard_id_, std::string const &lora_layername_substr_, DataType dt_)
PEFTMemoryManager(Memory gpu_mem_, size_t max_lora_size_, int max_concurrent_adapters_, int max_peft_tokens_, int in_dim_, int out_dim_, int num_shards_, int shard_id_, std::string const &lora_layername_substr_, DataType dt_)
: gpu_mem(gpu_mem_),
max_concurrent_adapters(max_concurrent_adapters_),
max_lora_size(max_lora_size_),
in_dim(in_dim_), out_dim(out_dim_), num_shards(num_shards_), shard_id(shard_id_),
max_peft_tokens(max_peft_tokens_),
lora_layername_substr(lora_layername_substr_), dt(dt_),
base_ptr(nullptr),
finetuning_ptr(nullptr),
Expand All @@ -128,17 +132,16 @@ class PEFTMemoryManager {
// allocate memory for the PEFT adapter for a finetuning request for a given layer and shard
void allocate_finetuning_memory();

LoraLinearWeight get_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);

private:
// Check if the PEFT adapter for the given model is in memory. If not, sets the cache_miss flag to true. If this is the first finetuning request, allocate memory for the finetuning adapter.
void get_finetuning_slot(PEFTModelID const &model_id, bool *cache_miss);

// Returns the slot in memory where the peft model weights are/will be stored.
// If the model is not in memory (cache miss), set the cache_miss flag to true.
int get_inference_peft_slot(PEFTModelID const &model_id, bool *cache_miss);

void load_peft_model(LoraLinearWeight &weight, LoraLinearConfig const &lora_config);

LoraLinearWeight get_inference_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);

LoraLinearWeight get_finetuning_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config);

// Legion memory management apparatus
Expand All @@ -149,6 +152,7 @@ class PEFTMemoryManager {
int max_concurrent_adapters;
size_t max_lora_size;
int in_dim, out_dim, num_shards, shard_id;
int max_peft_tokens;
// LRU cache apparatus
std::unordered_map<PEFTModelID, int> lru_hashtable;
std::vector<PEFTModelID> lru_list; // head = least recently used, tail=most recently used
Expand Down
80 changes: 69 additions & 11 deletions src/ops/kernels/lora_linear_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ void inference_kernel(LoraLinearMeta *m,
ffStream_t stream) {
checkCUDA(cublasSetStream(m->handle.blas, stream));
checkCUDNN(cudnnSetStream(m->handle.dnn, stream));
cudaDataType_t input_type = ff_to_cuda_datatype(m->input_type[0]);
cudaDataType_t output_type = ff_to_cuda_datatype(m->input_type[1]);
cudaDataType_t lr_actv_type = output_type;
assert(input_type == output_type);
cudaDataType_t weight_type = output_type;
cudaDataType_t compute_type = output_type;

int num_peft_requests = 0;
for (int i=0; i< bc->max_requests_per_batch(); i++) {
Expand All @@ -320,22 +326,74 @@ void inference_kernel(LoraLinearMeta *m,
if (bc->requestsInfo[i].peft_bwd) {
num_peft_requests++;
}
LoraLinearConfig deserialized_config = LoraLinearConfig::deserialize_from_json_string(bc->requestsInfo[i].peft_adapters[bc->requestsInfo[i].peft_model_id]);
if (!lora_applies_to_this_layer(m, deserialized_config)) {
LoraLinearConfig lora_config = LoraLinearConfig::deserialize_from_json_string(bc->requestsInfo[i].peft_adapters[bc->requestsInfo[i].peft_model_id]);
if (!lora_applies_to_this_layer(m, lora_config)) {
continue;
}
assert(lora_config.trainable == bc->requestsInfo[i].peft_bwd && "Trainable flag mismatch");
bool cache_miss;
void *peft_slot;
if (!lora_config.trainable) {
peft_slot = m->peft_memory_manager->get_peft_model_handle(bc->requestsInfo[i].peft_model_id, &cache_miss);
int num_peft_tokens = bc->requestsInfo[i].num_tokens_in_batch;
// int max_peft_tokens = bc->requestsInfo[i].max_length;
int first_token_offset = bc->requestsInfo[i].first_token_offset_in_batch;
LoraLinearWeight weight = m->peft_memory_manager->get_peft(bc->requestsInfo[i].peft_model_id, lora_config);
void *intermediate_result_ptr = (bc->requestsInfo[i].peft_bwd) ? weight.low_rank_activation : m->handle.workSpace;
if (bc->requestsInfo[i].peft_bwd) {
checkCUDA(cudaMemcpyAsync(weight.input_activation,
input_ptr + first_token_offset * in_dim,
data_type_size(m->input_type[0]) *
num_peft_tokens * in_dim,
cudaMemcpyDeviceToDevice,
stream));
} else {
peft_slot = m->peft_memory_manager->get_finetuning_handle(bc->requestsInfo[i].peft_model_id, &cache_miss);
}
if (cache_miss) {
// load model into memory
load_peft_model(m, peft_slot, deserialized_config, in_dim, out_dim, num_shards);
// use workspace to save intermediate result
assert(m->handle.workSpaceSize >=
data_type_size(m->input_type[1]) * num_peft_tokens * lora_config.rank);
}
DT alpha = 1.0f, beta = 0.0f;
// buffer = weight_first * input
// [rank, num_peft_tokens] = [in_dim, rank].T * [in_dim, num_peft_tokens]
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
lora_config.rank,
num_peft_tokens,
in_dim,
&alpha,
weight.w0_ptr,
weight_type,
in_dim,
input_ptr + first_token_offset * in_dim,
input_type,
in_dim,
&beta,
intermediate_result_ptr,
lr_actv_type,
lora_config.rank,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// output = weight_second * buffer
// [out_dim, num_peft_tokens] = [rank, out_dim].T * [rank, num_peft_tokens]
// Note that we use alpha in both places since we do
// an in-place update for LoraLinear
DT scaling_constant = (DT)(lora_config.lora_alpha / lora_config.rank);
checkCUDA(cublasGemmEx(m->handle.blas,
CUBLAS_OP_T,
CUBLAS_OP_N,
out_dim,
num_peft_tokens,
lora_config.rank,
&scaling_constant,
weight.w1_ptr,
weight_type,
lora_config.rank,
intermediate_result_ptr,
lr_actv_type,
lora_config.rank,
&alpha,
output_ptr + first_token_offset * out_dim,
output_type,
out_dim,
compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/runtime/peft_weight_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ void PEFTMemoryManager::allocate_inference_memory() {

void PEFTMemoryManager::allocate_finetuning_memory() {
size_t ft_size = max_lora_size*3; // weights, gradients, momentum values
ft_size += max_peft_tokens*(in_dim+rank); // input, low-rank activations
// allocate chunk of memory for PEFT adapter
Realm::Rect<1, coord_t> bounds(
Realm::Point<1, coord_t>(0),
Expand Down Expand Up @@ -254,10 +255,20 @@ LoraLinearWeight PEFTMemoryManager::get_finetuning_peft(PEFTModelID const &model
result.w1_grad_ptr = result.w0_grad_ptr + w0_num_elements*data_size;
result.w0_v_values_ptr = result.w1_grad_ptr + w1_num_elements*data_size;
result.w1_v_values_ptr = result.w0_v_values_ptr + w0_num_elements*data_size;
result.input_activation = result.w1_v_values_ptr + w1_num_elements*data_size; // max_peft_tokens*in_dim
result.low_rank_activation = result.input_activation + max_peft_tokens*in_dim*data_size; // max_peft_tokens*rank
if (cache_miss) {
load_peft_model(result, lora_config);
}
return result;
}

LoraLinearWeight PEFTMemoryManager::get_peft(PEFTModelID const &model_id, LoraLinearConfig const &lora_config) {
if (lora_config.trainable) {
return get_finetuning_peft(model_id, lora_config);
} else {
return get_inference_peft(model_id, lora_config);
}
}

}; // namespace FlexFlow

0 comments on commit c5e813b

Please sign in to comment.