Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

多卡并行报错:Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! #419

Open
Darknessrky opened this issue Nov 13, 2024 · 4 comments
Labels
question Further information is requested

Comments

@Darknessrky
Copy link

Darknessrky commented Nov 13, 2024

11.14 修改了错别字和表述问题

由于7b和3b模型用双卡3090在跑的时候都会出现第一张卡爆显存的情况,在4卡3090上跑基于llama-7b的MEMIT后出现如题报错,查看类似issue后未能解决问题。
代码如下:

import sys
import os
import json
from easyeditor.editors.editor import BaseEditor
from easyeditor import MEMITHyperParams

sys.path.append('/data/renky/EasyEdit')
os.chdir("/data/renky/EasyEdit")
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

# load zsre_data
edit_data = json.load(open('./data/ZsRE/ZsRE-test-all.json', 'r', encoding='utf-8'))[:100]
prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
ground_truth = [edit_data_['ground_truth'][0] for edit_data_ in edit_data]  
subject = [edit_data_['subject'] for edit_data_ in edit_data]
target_new = [edit_data_['target_new'] for edit_data_ in edit_data]

# MEMIT
hparams=MEMITHyperParams.from_hparams('./hparams/MEMIT/llama-7b.yaml')
editor = BaseEditor.from_hparams(hparams)
metrics, edited_model_false, _ = editor.edit(
    prompts=prompts,
    ground_truth=ground_truth,
    target_new=target_new,
    subject=subject,
    keep_original_weight=False
)
print(metrics)

超参如下:

alg_name: "MEMIT"
model_name: "./hugging_cache/llama-2-7b"
stats_dir: "./data/stats"
device: 0
layers: [4, 5, 6, 7, 8]
clamp_norm_factor: 4
layer_selection: "all"
fact_token: "subject_last"
v_num_grad_steps: 25
v_lr: 5e-1
v_loss_layer: 31
v_weight_decay: 1e-3
kl_factor: 0.0625
mom2_adjustment: true
mom2_update_weight: 15000
rewrite_module_tmp: "model.layers.{}.mlp.down_proj"
layer_module_tmp: "model.layers.{}"
mlp_module_tmp: "model.layers.{}.mlp"
attn_module_tmp: "model.layers.{}.self_attn"
ln_f_module: "model.norm"
lm_head_module: "lm_head"
mom2_dataset: "wikipedia"
mom2_n_samples: 100000
mom2_dtype: "float32"
model_parallel: true

报错如下:

Traceback (most recent call last):
  File "/data/renky/EasyEdit/test.py", line 21, in <module>
    metrics, edited_model_false, _ = editor.edit(
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 183, in edit
    return self.edit_requests(requests, sequential_edit, verbose, test_generation=test_generation, **kwargs)
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 371, in edit_requests
    edited_model, weights_copy, icl_examples = edit_func(request)
  File "/data/renky/EasyEdit/easyeditor/editors/editor.py", line 319, in edit_func
    edited_model, weights_copy = self.apply_algo(
  File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 46, in apply_memit_to_model
    deltas = execute_memit(model, tok, requests, hparams, cache_template=cache_template)
  File "/data/renky/EasyEdit/easyeditor/models/memit/memit_main.py", line 137, in execute_memit
    cur_z = compute_z(
  File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 129, in compute_z
    logits = model(**input_tok).logits
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/accelerate/hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/transformers/models/llama/modeling_llama.py", line 1001, in forward
    layer_outputs = decoder_layer(
  File "/home/renky/anaconda3/envs/EasyEdit/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1547, in _call_impl
    hook_result = hook(self, args, result)
  File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 80, in retain_hook
    output = invoke_with_optional_args(
  File "/data/renky/EasyEdit/easyeditor/util/nethook.py", line 454, in invoke_with_optional_args
    return fn(*pass_args, **pass_kw)
  File "/data/renky/EasyEdit/easyeditor/models/memit/compute_z.py", line 106, in edit_output_fn
    cur_out[0][i, idx, :] += delta
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

烦请告知是哪里出了问题或者我需要做哪些更改,谢谢!

@zxlzr zxlzr added the question Further information is requested label Nov 14, 2024
@XeeKee
Copy link
Collaborator

XeeKee commented Nov 15, 2024

我在本地测试是可以正常run的,但是也会爆显存。
可以试试给模型开量化,或者换到更大显存的机器

@Darknessrky
Copy link
Author

小于等于2卡就能正常run但是爆显存,大于等于3卡就报上述错误hhh,我后面再试试看看能不能解决

@zxlzr
Copy link
Contributor

zxlzr commented Nov 16, 2024

hi, do you have any further issues?

@Darknessrky
Copy link
Author

hi, do you have any further issues?

Right now, no.
I tried setting device_map manaully on 3090s failed, still attempting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants