Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix compile issue for Marin qqq on sm<8.0 #1651

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

gau-nernst
Copy link
Collaborator

@gau-nernst gau-nernst commented Feb 2, 2025

Closes #1648

Test on Google Colab T4 (sm75): https://colab.research.google.com/drive/107-fKXymnK-QNCvnfXggTK8SSqy9oz8o?usp=sharing

The main issue is that TORCH_CHECK_NOT_IMPLEMENTED() should not be called inside __global__ function. However, I also believe the aggressive use of __CUDA_ARCH__ guard will lead to undefined behavior, as outlined in here (https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch). Note that __CUDA_ARCH__ is not defined in host code.

Following the practice from @alexsamardzic, I simply let the kernel empty when __CUDA_ARCH__ < 800, and add a runtime check in the torch function.

  • check if there is perf hit due to the at::cuda::getCurrentDeviceProperties() call.
    • No perf hit

Run python benchmarks/benchmark_marlin_qqq.py on 4070Ti SUPER

main
m k n group_size fp16_latency (ms) marlin_qqq_w4a8_latency (ms) speedup (d/s)
1 8192 8192 -1 219.744 15.9633 13.7656
1 8192 10240 -1 307.475 19.7781 15.5462
1 8192 57344 -1 1476.61 380.448 3.88124
1 28672 8192 -1 743.548 187.865 3.95788
2 8192 8192 -1 221.491 16.0433 13.8059
2 8192 10240 -1 271.668 19.7823 13.7329
2 8192 57344 -1 1477.62 382.154 3.86657
2 28672 8192 -1 745.256 188.454 3.95457
4 8192 8192 -1 222.612 16.6062 13.4053
4 8192 10240 -1 271.816 20.5471 13.2289
4 8192 57344 -1 1477.93 384.702 3.84176
4 28672 8192 -1 745.605 189.578 3.93298
8 8192 8192 -1 222.179 17.225 12.8987
8 8192 10240 -1 272.088 21.2379 12.8114
8 8192 57344 -1 1478.33 388.459 3.80564
8 28672 8192 -1 746.224 190.582 3.91549
16 8192 8192 -1 217.72 18.9155 11.5101
16 8192 10240 -1 270.616 22.9631 11.7848
16 8192 57344 -1 1479.04 393.542 3.75828
16 28672 8192 -1 750.581 192.803 3.89299
32 8192 8192 -1 223.47 26.1391 8.54924
32 8192 10240 -1 283.867 29.92 9.48754
32 8192 57344 -1 1606.37 405.499 3.96146
32 28672 8192 -1 752.581 198.579 3.78982
64 8192 8192 -1 234.093 44.7372 5.23263
64 8192 10240 -1 281.155 51.272 5.4836
64 8192 57344 -1 1719.75 437.085 3.93459
64 28672 8192 -1 769.055 219.65 3.50127
128 8192 8192 -1 309.433 72.3182 4.27877
128 8192 10240 -1 445.81 104.66 4.2596
128 8192 57344 -1 2041.16 729.174 2.79927
128 28672 8192 -1 961.523 357.706 2.68802
256 8192 8192 -1 464.695 141.607 3.28158
256 8192 10240 -1 634.023 204.13 3.10598
256 8192 57344 -1 3484.63 1422.08 2.45038
256 28672 8192 -1 1599.34 716.818 2.23117
512 8192 8192 -1 865.744 296.573 2.91916
512 8192 10240 -1 1096.16 432.602 2.53387
512 8192 57344 -1 6242.79 2926.94 2.13288
512 28672 8192 -1 3057.16 1441.21 2.12124
1 8192 8192 128 220.04 21.4464 10.26
1 8192 10240 128 309.417 27.4338 11.2787
1 8192 57344 128 1476.37 393.084 3.75588
1 28672 8192 128 743.653 194.238 3.82857
2 8192 8192 128 221.026 21.507 10.2769
2 8192 10240 128 272.281 27.6903 9.83308
2 8192 57344 128 1477.77 395.257 3.73877
2 28672 8192 128 745.279 194.722 3.8274
4 8192 8192 128 222.304 22.1281 10.0463
4 8192 10240 128 271.928 27.4937 9.89056
4 8192 57344 128 1478.09 396.968 3.72345
4 28672 8192 128 745.738 195.608 3.81241
8 8192 8192 128 222.97 22.5491 9.88819
8 8192 10240 128 272.134 29.1232 9.34424
8 8192 57344 128 1478.1 401.382 3.68252
8 28672 8192 128 746.104 197.306 3.78145
16 8192 8192 128 217.876 24.6423 8.84155
16 8192 10240 128 270.595 30.1804 8.96591
16 8192 57344 128 1479.11 406.62 3.63758
16 28672 8192 128 749.944 198.833 3.77172
32 8192 8192 128 223.591 29.636 7.54456
32 8192 10240 128 283.874 36.8304 7.70759
32 8192 57344 128 1605.43 416.498 3.85461
32 28672 8192 128 752.817 205.591 3.66173
64 8192 8192 128 237.745 46.2037 5.14558
64 8192 10240 128 284.857 64.3198 4.42876
64 8192 57344 128 1750.44 456.794 3.83202
64 28672 8192 128 769.348 229.535 3.35177
128 8192 8192 128 310.719 83.257 3.73205
128 8192 10240 128 442.015 127.657 3.46252
128 8192 57344 128 2040.24 788.657 2.58698
128 28672 8192 128 977.859 395.06 2.47522
256 8192 8192 128 471.349 175.002 2.69338
256 8192 10240 128 618.293 244.242 2.53148
256 8192 57344 128 3482.03 1625.35 2.14232
256 28672 8192 128 1610.62 802.132 2.00792
512 8192 8192 128 864.987 358.192 2.41487
512 8192 10240 128 1089.34 485.58 2.24339
512 8192 57344 128 6349.8 3174.48 2.00026
512 28672 8192 128 3097.25 1606.2 1.92831
this PR
m k n group_size fp16_latency (ms) marlin_qqq_w4a8_latency (ms) speedup (d/s)
1 8192 8192 -1 219.552 15.8444 13.8567
1 8192 10240 -1 307.078 19.4016 15.8275
1 8192 57344 -1 1475.47 380.388 3.87885
1 28672 8192 -1 743.336 187.826 3.95759
2 8192 8192 -1 221.373 15.9487 13.8803
2 8192 10240 -1 271.578 19.5928 13.8611
2 8192 57344 -1 1477.31 382.099 3.86629
2 28672 8192 -1 745.019 188.368 3.95512
4 8192 8192 -1 222.364 16.4044 13.5552
4 8192 10240 -1 271.628 20.2398 13.4205
4 8192 57344 -1 1477.62 384.606 3.84191
4 28672 8192 -1 745.275 189.516 3.93251
8 8192 8192 -1 221.886 17.1354 12.949
8 8192 10240 -1 271.832 21.1215 12.8699
8 8192 57344 -1 1477.95 388.432 3.80492
8 28672 8192 -1 746.099 190.516 3.91621
16 8192 8192 -1 217.785 18.7514 11.6143
16 8192 10240 -1 270.379 22.7819 11.8682
16 8192 57344 -1 1478.57 393.448 3.75797
16 28672 8192 -1 750.212 192.752 3.89211
32 8192 8192 -1 223.232 26.0726 8.56193
32 8192 10240 -1 283.416 29.5884 9.57862
32 8192 57344 -1 1609.67 405.298 3.97157
32 28672 8192 -1 752.512 198.473 3.79151
64 8192 8192 -1 238.469 44.4038 5.37047
64 8192 10240 -1 279.958 50.8142 5.50945
64 8192 57344 -1 1674.71 433.643 3.86196
64 28672 8192 -1 769.556 216.045 3.56202
128 8192 8192 -1 305.164 70.8571 4.30675
128 8192 10240 -1 433.473 102.401 4.23309
128 8192 57344 -1 2003.92 696.747 2.87612
128 28672 8192 -1 922.662 348.429 2.64806
256 8192 8192 -1 475.092 137.335 3.45936
256 8192 10240 -1 583.481 209.415 2.78625
256 8192 57344 -1 3421.62 1417.89 2.41318
256 28672 8192 -1 1561.5 696.361 2.24238
512 8192 8192 -1 862.27 285.891 3.01608
512 8192 10240 -1 1092.17 424.349 2.57375
512 8192 57344 -1 6110.66 2867.04 2.13135
512 28672 8192 -1 3008.58 1406.22 2.13948
1 8192 8192 128 219.967 21.0414 10.454
1 8192 10240 128 308.991 26.8087 11.5258
1 8192 57344 128 1476.37 393.098 3.75573
1 28672 8192 128 743.623 194.217 3.82882
2 8192 8192 128 220.814 21.2613 10.3857
2 8192 10240 128 271.731 27.2328 9.97809
2 8192 57344 128 1478.05 395.233 3.73969
2 28672 8192 128 745.453 194.702 3.82868
4 8192 8192 128 222.336 21.6202 10.2837
4 8192 10240 128 271.877 28.0634 9.68795
4 8192 57344 128 1478.29 396.969 3.72394
4 28672 8192 128 745.647 195.594 3.81222
8 8192 8192 128 222.622 22.7187 9.79907
8 8192 10240 128 272.076 28.4982 9.54714
8 8192 57344 128 1478.16 401.358 3.68291
8 28672 8192 128 746.178 197.292 3.7821
16 8192 8192 128 217.873 23.7224 9.18424
16 8192 10240 128 270.628 29.6718 9.12071
16 8192 57344 128 1479.02 406.611 3.63743
16 28672 8192 128 751.735 198.914 3.7792
32 8192 8192 128 223.5 29.4465 7.59006
32 8192 10240 128 283.74 37.6204 7.54217
32 8192 57344 128 1610.55 416.324 3.86849
32 28672 8192 128 752.454 205.499 3.66159
64 8192 8192 128 233.82 46.1815 5.06306
64 8192 10240 128 279.882 66.1781 4.22922
64 8192 57344 128 1684.68 449.205 3.75036
64 28672 8192 128 769.308 227.58 3.38038
128 8192 8192 128 307.883 82.9074 3.71357
128 8192 10240 128 437.021 125.096 3.4935
128 8192 57344 128 2027.19 776.887 2.60938
128 28672 8192 128 965.605 388.431 2.48591
256 8192 8192 128 464.812 171.261 2.71405
256 8192 10240 128 618.714 238.072 2.59886
256 8192 57344 128 3441.47 1586.64 2.16903
256 28672 8192 128 1594.15 780.242 2.04315
512 8192 8192 128 866.84 355.386 2.43915
512 8192 10240 128 1081.31 481.026 2.24792
512 8192 57344 128 6328.79 3164.21 2.00011
512 28672 8192 128 3119.8 1579.16 1.97561

Copy link

pytorch-bot bot commented Feb 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1651

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit bb389e5 with merge base 6ffe236 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 2, 2025
@gau-nernst gau-nernst added the topic: bug fix Use this tag for PRs that fix bugs label Feb 2, 2025
@gau-nernst
Copy link
Collaborator Author

cc @psinger Can you also check if this PR fixes your issue? Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

CUDA compile guard problem for marlin_qqq
3 participants