Skip to content

Commit

Permalink
FIX: Failing BOFT tests due to device (#2242)
Browse files Browse the repository at this point in the history
This pull request resolves above issue regarding BOFT forward/merging with CUDA
by ensuring that all relevant tensors and models are moved to the correct
device. This change is necessary to prevent issues such as zero matrices and
test failures when using CUDA.

Also fixed the fbd_cuda deprecation warning.
  • Loading branch information
d-kleine authored Dec 9, 2024
1 parent de88c70 commit ec92cdc
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/peft/tuners/boft/fbd/fbd_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ std::vector<at::Tensor> forward_fast_block_diag_cuda(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "forward_fast_block_diag1", ([&] {
forward_fast_block_diag_cuda_kernel<scalar_t><<<blocks_1, threads>>>(
input.data<scalar_t>(),
output.data<scalar_t>(),
input.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(),
z, N, b);
}));

Expand Down Expand Up @@ -96,8 +96,8 @@ std::vector<at::Tensor> backward_fast_block_diag_cuda(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_output.type(), "backward_fast_block_diag", ([&] {
backward_fast_block_diag_cuda_kernel<scalar_t><<<blocks_1, threads>>>(
grad_output.data<scalar_t>(),
grad_input.data<scalar_t>(),
grad_output.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(),
z, N, b);
}));

Expand Down
15 changes: 7 additions & 8 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,8 +1935,10 @@ class MultipleActiveAdaptersTester(unittest.TestCase):
would be overkill.
"""

torch_device = infer_device()

def prepare_inputs_for_testing(self):
X = torch.arange(90).view(9, 10)
X = torch.arange(90).view(9, 10).to(self.torch_device)
return {"X": X}

def set_multiple_active_adapters(self, model, adapter_names):
Expand All @@ -1949,8 +1951,7 @@ def test_multiple_active_adapters_forward(
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
):
torch.manual_seed(0)
model = MLP(bias=tuner_method != "ia3")
model.eval()
model = MLP(bias=tuner_method != "ia3").to(self.torch_device).eval()
X = self.prepare_inputs_for_testing()

config_1 = config_cls(**config_kwargs_1)
Expand Down Expand Up @@ -1990,8 +1991,7 @@ def test_multiple_active_adapters_merge_and_unmerge(
self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2
):
torch.manual_seed(0)
model = MLP(bias=tuner_method != "ia3")
model.eval()
model = MLP(bias=tuner_method != "ia3").to(self.torch_device).eval()
X = self.prepare_inputs_for_testing()
base_output = model(**X)

Expand All @@ -2007,7 +2007,7 @@ def test_multiple_active_adapters_merge_and_unmerge(

peft_model.merge_adapter()
merged_combined_output = peft_model(**X)
assert torch.allclose(merged_combined_output, combined_output, atol=1e-5)
assert torch.allclose(merged_combined_output, combined_output, atol=1e-4)

peft_model.unmerge_adapter()

Expand All @@ -2019,8 +2019,7 @@ def test_multiple_active_adapters_merge_and_unmerge(
@parameterized.expand(MULTIPLE_ACTIVE_ADAPTERS_TEST_CASES)
def test_merge_layers_multi(self, test_name, tuner_method, config_cls, config_kwargs_1, config_kwargs_2):
torch.manual_seed(0)
model = MLP(bias=tuner_method != "ia3")
model.eval()
model = MLP(bias=tuner_method != "ia3").to(self.torch_device).eval()

config_1 = config_cls(**config_kwargs_1)
config_2 = config_cls(**config_kwargs_2)
Expand Down

0 comments on commit ec92cdc

Please sign in to comment.