Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bug with n_fused #41

Open
JiayiFeng opened this issue Oct 18, 2022 · 4 comments
Open

A bug with n_fused #41

JiayiFeng opened this issue Oct 18, 2022 · 4 comments
Labels
bug Something isn't working

Comments

@JiayiFeng
Copy link

JiayiFeng commented Oct 18, 2022

When a attn_qkv Layer is set with n_fused>1 and reversed=False, the shape of its sliced weight is incorrect.

Seems that the root cause is here:

dim = dim if not reversed or is_bias else abs(dim - 1)
n_fused = 1 if not n_fused else n_fused
proj_layer = proj_layer.chunk(
n_fused * self.world_size,
dim=dim,
)
if n_fused > 1:
ranks = (len(proj_layer) + self.world_size - 1) // self.world_size
proj_layer = [
proj_layer[i * self.world_size : (i + 1) * self.world_size]
for i in range(ranks)
]
proj_layer = list(
map(lambda x: torch.cat([*x], dim=-1), zip(*proj_layer))
)

For a attn_qkv weight, the arg dim is 0. So when the reversed=False and n_fused>1, the tensor is chunked on the dim 0 and then concatenated on the dim 1. Which make its shape incorrect.

@JiayiFeng JiayiFeng added the bug Something isn't working label Oct 18, 2022
@hyunwoongko
Copy link
Contributor

which model did you use?

@JiayiFeng
Copy link
Author

I used a modified GPT-NeoX model, which is not officially supported. So I written a custom policy, and find this issue.

@JiayiFeng JiayiFeng changed the title A bug in with n_fused A bug with n_fused Oct 18, 2022
@JiayiFeng
Copy link
Author

Maybe the

proj_layer = list( 
         map(lambda x: torch.cat([*x], dim=-1), zip(*proj_layer)) 
     ) 

should be:

proj_layer = list( 
         map(lambda x: torch.cat([*x], dim=dim), zip(*proj_layer)) 
     ) 

I guess.

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Oct 18, 2022

okay so could you test it with other models?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants