-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
run sdpa with dtensor #180
Conversation
[ghstack-poisoned]
ghstack-source-id: 33d3d0b6a19c747269aab1a95589bb61bf9c1f51 Pull Request resolved: #180
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of `wq`, `wk`, `wv`, so that the SDPA is aware of the distributedness. [ghstack-poisoned]
ghstack-source-id: 43941c1ca0dfc7a04589a7513a110b877c217917 Pull Request resolved: #180
"attention.wq": col_parallel_strategy(), | ||
"attention.wk": col_parallel_strategy(), | ||
"attention.wv": col_parallel_strategy(), | ||
"attention.wq": col_parallel_strategy(use_local_output=False), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🤔 I thought we need to replicate the freq_cis
but here it seems we don't need to?
just curious, is this gonna land soon or does it have some risk or unfinished business? also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm) |
It hasn't been landed because there is a very strange bug (#267) associated with (but seemingly not caused by) multiplication using DTensor. It would be triggered in the rotary embedding computation if this PR is landed. I will work on the bug soon since it will also benefit PP (iiuc). @wconstab |
oh, is this related to dispatching for complex numbers by any chance? |
@wconstab Possibly, we don't know. The |
ghstack-source-id: 58ba72163a4b03d77f4b2ba7c97cef7e7e8b3096 Pull Request resolved: #180
ghstack-source-id: a18a3cb1ba48fb751f437a5ee44f186ff9a26e9a Pull Request resolved: #180
ghstack-source-id: b8b2b58ffc72fcb8bfc88f4ba2a3455e3cc92c0a Pull Request resolved: #180
ghstack-source-id: 55bb9e1ba289c212f4af58e19d9bede2ad0246a8 Pull Request resolved: #180
9d45a6c
to
e773b75
Compare
fe1f241
to
a28e74e
Compare
ghstack-source-id: 55bb9e1ba289c212f4af58e19d9bede2ad0246a8 Pull Request resolved: #180
Stack from ghstack (oldest at bottom):
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of
wq
,wk
,wv
, so that the SDPA is aware of the distributedness.