Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph. #9704

Open
alansmithee-cpu opened this issue Oct 17, 2024 · 53 comments
Labels
bug Something isn't working

Comments

@alansmithee-cpu
Copy link

alansmithee-cpu commented Oct 17, 2024

Describe the bug

Hello. I tried the Img2Img Pipeline and encountered the error in the images. Could you please check it for me? Thank you
Screenshot 2024-10-17 at 11 39 30
Screenshot 2024-10-17 at 11 39 46

Reproduction

import torch
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import make_image_grid, load_image

pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stable-diffusion-v1-5/stable-diffusion-v1-5/", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipeline.enable_model_cpu_offload()


url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/img2img-init.png"
init_image = load_image(url)

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"


image = pipeline(prompt, image=init_image).images[0]
make_image_grid([init_image, image], rows=1, cols=2)

Logs

No response

System Info

diffusers 0.30.3
Python 3.9.20

Who can help?

No response

@alansmithee-cpu alansmithee-cpu added the bug Something isn't working label Oct 17, 2024
@a-r-r-o-w
Copy link
Member

What version of pytorch are you using? It seems like this error comes from the latest changes in pytorch. this, this and this.

@alansmithee-cpu
Copy link
Author

What version of pytorch are you using? It seems like this error comes from the latest changes in pytorch. this, this and this.

I'm using torch 2.5.0+cu124

@a-r-r-o-w
Copy link
Member

Could you try the 2.4.0 stable release and see if the problem persists?

@alansmithee-cpu
Copy link
Author

Could you try the 2.4.0 stable release and see if the problem persists?

Now I encountered this error
Screenshot 2024-10-17 at 12 16 22

@a-r-r-o-w
Copy link
Member

If you're running in a notebook, make sure to restart it and please do a clean reinstall of v0.30.3. Auraflow was released in v0.30.0, so this should not lead to any errors. Just to be sure that there are no longer any environment errors, could you paste the output of diffusers-cli env?

@alansmithee-cpu
Copy link
Author

If you're running in a notebook, make sure to restart it and please do a clean reinstall of v0.30.3. Auraflow was released in v0.30.0, so this should not lead to any errors. Just to be sure that there are no longer any environment errors, could you paste the output of diffusers-cli env?

Yes, I've reinstalled v0.30.0 (image 1), but have the error in image 2

Screenshot 2024-10-17 at 12 25 46 Screenshot2024-10-17at12 26 28

@alansmithee-cpu
Copy link
Author

Hello, the problem is now solved, thank you for your time and consideration.

Here are the version that worked for me
Diffusers: v.0.30.3
Torch: 2.4.0+cu121

@readleyj
Copy link

I am facing the same error on torch 2.5.0+cu124. The error is preceded by the following warning:

cuDNN SDPA backward got grad_output.strides() != output.strides()

I'm on an H100, I'm guessing this has to do with the new cuDNN SDPA backend introduced in PyTorch 2.5

@a-r-r-o-w
Copy link
Member

Yes, this seems like a problem with torch 2.5.0, and I've been able to reproduce this now as well. We'll need to take a look into how best to fix this (either on our end or we could talk with the pytorch folks) cc @sayakpaul @DN6 @yiyixuxu. Re-opening the issue for now

@drisspg
Copy link

drisspg commented Oct 18, 2024

As a work around you can disable the cudnn backend via https://pytorch.org/docs/stable/backends.html#torch.backends.cuda.enable_cudnn_sdp

Would you mind opening an issue on PyTorch with a smallish repro, I can then forward to the Nvidia folks

@vladmandic
Copy link
Contributor

vladmandic commented Oct 18, 2024

torch==2.5.0 breaks sdpa functionality used by transformers which is used by diffusers for clip during prompt encoding

transformers/models/clip/modeling_clip.py:491 in forward

490 │   │   # CLIP text model uses both `causal_attention_mask` and `attention_mask` sequentially.
491 │   │   attn_output = torch.nn.functional.scaled_dot_product_attention( ... )

yes, torch.backends.cuda.enable_cudnn_sdp(False) is a workaround, but comes at a massive performance cost.

imo, this should be reported to transformers team as they can implement a workaround much faster than torch releasing a service pack which takes a while.
(from what I gather, issue has been caught first back in May in cudnn-frontend package and it's still not assigned)

@drisspg
Copy link

drisspg commented Oct 18, 2024

but comes at a massive performance cost.
The performance should be the same as in 2.4.1 since this is the first release with cuDNN backend enabled.

Can you link the frontend issue

@readleyj
Copy link

readleyj commented Oct 18, 2024

Seems to be NVIDIA/cudnn-frontend#75 and NVIDIA/cudnn-frontend#78

@JackismyShephard
Copy link

JackismyShephard commented Oct 18, 2024

Having this issue as well but only on linux. no problems with cuda on windows.

@vladmandic
Copy link
Contributor

but comes at a massive performance cost.
The performance should be the same as in 2.4.1 since this is the first release with cuDNN backend enabled.

Can you link the frontend issue

performance deg is from 6its to 2.5its using sdxl and having everything the same expect that one param.

links to issues are already posted below.

@eqy
Copy link

eqy commented Oct 18, 2024

The cuDNN issues linked are generic across any unsupported config and may not correspond to this particular issue. Would it be possible to link a shorter repro as I'm currently trying to clone stable-diffusion-v1-5/stable-diffusion-v1-5/ which seems to be > 10GiB?

@vladmandic
Copy link
Contributor

vladmandic commented Oct 19, 2024

here's the shortest reproduction, like i said its when transformers uses sdp to process clip:

import torch
from transformers import CLIPTextModel, AutoTokenizer

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir='/mnt/models/huggingface').to(device=device, dtype=torch.float16)

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
print(inputs)
outputs = encoder(**inputs)
print(outputs)
  File "/home/vlado/dev/clip/venv/lib/python3.12/site-packages/transformers/models/clip/modeling_clip.py", line 491, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

btw, i just noticed that there is no issue when using torch.float32. but nobody uses torch.float32 anymore.
and yes, this is the same issue as noted here with diffusion models - its when encoding prompt.
you can try to use any other clip model as long as underlying processor is the same.

@eqy
Copy link

eqy commented Oct 19, 2024

Thanks, and it not happening with float32 is expected as PyTorch will not dispatch to cuDNN for float32

@eqy
Copy link

eqy commented Oct 19, 2024

@vladmandic I am not seeing the same error locally with cuDNN 9.3. Which GPU are you on? I will try 9.1.7 in the meantime

@vladmandic
Copy link
Contributor

vladmandic commented Oct 19, 2024

@vladmandic I am not seeing the same error locally with cuDNN 9.3. Which GPU are you on? I will try 9.1.7 in the meantime

print(f'torch={torch.__version__} cuda={torch.version.cuda} cuDNN={torch.backends.cudnn.version()} device={torch.cuda.get_device_name(0)} cap={torch.cuda.get_device_capability(0)}')

torch=2.5.0+cu124 cuda=12.4 cuDNN=90100 device=NVIDIA GeForce RTX 4090 cap=(8, 9)

note that cuda and cudnn are ones that come with torch. if torch 2.5 requires newer cudnn, it should handle its installation.
this is simple pip install torch transformers in a clean venv and without any extra flags.

@eqy
Copy link

eqy commented Oct 19, 2024

Yes, 9.1.0.70 is what comes with cuDNN and I didn't see the failure on L40, L4, or RTX 6000 Ada which are also sm89 (it is able to generate and run a kernel).

I'm thinking that maybe the issue is the CUDA version, will also try that later.

@soumendukrg
Copy link

Even a clean environment didn't help me. I had to install torch=2.4.0 to get rid of the issue.

@sayakpaul
Copy link
Member

Hmm. How much speedup does one get when using CLIP in SDPA? I remember when we incorporated SDPA in CLIP the speedup wasn't that significant.

We could verify this by instantiating the CLIP with:

text_encoder = CLIPTextModel.from_pretrained(..., attn_implementation="eager", ...)
pipeline = DiffusionPipeline.from_prertrained(..., text_encoder=text_encoder)

Cc: @ArthurZucker

@vladmandic
Copy link
Contributor

Hmm. How much speedup does one get when using CLIP in SDPA? I remember when we incorporated SDPA in CLIP the speedup wasn't that significant.

i tried using torch==2.4.1 with default sdp and with eager and for 1,000 iterations i'm getting 4.72s vs 7.59s, so pretty significant impact at 60% slower. good thing is that encoding only happens once so overall performance hit would hardly be seen.
but in how many places would this need to be touched?

@sayakpaul
Copy link
Member

sayakpaul commented Oct 19, 2024

but in how many places would this need to be touched?

You mean changing the CLIP (and potentially other models from transformers we rely on in diffusers) to use "eager" as attn_implementation?

I guess we have a couple of ways but I think we could pass this info to load_method here:

loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)

Something like (pseudo-code):

if is_transformers_model:
    if is_transformers_version(...):
        if is_torch_version(">=", "2.5"):
        loading_kwargs.update({"attn_implementation": "eager"})

@DN6 WDYT? Or maybe @ArthurZucker from transformers has a better idea.

@eqy
Copy link

eqy commented Oct 21, 2024

@vladmandic does your output look similar to this?
I was able to run on the 2.5.0 binary on RTX 6000 (Ada)

import torch
from transformers import CLIPTextModel, AutoTokenizer

print(f"cuda: {torch.version.cuda} cudnn: {torch.backends.cudnn.version()} compute capability: {torch.cuda.get_device_capability()}")

device = torch.device('cuda')
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32", cache_dir='/mnt/models/huggingface').to(device=device, dtype=torch.float16)

inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
print(inputs)
outputs = encoder(**inputs)
print(outputs)
cuda: 12.4 cudnn: 90100 compute capability: (8, 9)
{'input_ids': tensor([[49406,   320,  1125,   539,   320,  2368, 49407],
        [49406,   320,  1125,   539,   320,  1929, 49407]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}
BaseModelOutputWithPooling(last_hidden_state=tensor([[[ 0.3391,  0.1165,  0.1020,  ...,  0.2469,  0.5903,  0.1014],
         [ 1.9775, -0.5840,  0.3699,  ...,  1.1670,  0.8047, -0.9795],
         [ 1.0586, -0.9580,  1.0039,  ..., -0.5151, -0.1436, -1.9443],
         ...,
         [ 0.3076, -1.4961, -0.4001,  ..., -0.0224,  0.9111, -0.3879],
         [ 1.0117, -0.6704,  1.7734,  ..., -0.1541, -0.0244, -1.5059],
         [-0.5151,  0.1665,  0.8887,  ..., -0.0677, -0.4563, -1.7959]],

        [[ 0.3391,  0.1165,  0.1020,  ...,  0.2469,  0.5903,  0.1014],
         [ 1.9775, -0.5840,  0.3699,  ...,  1.1670,  0.8047, -0.9795],
         [ 1.0586, -0.9580,  1.0039,  ..., -0.5151, -0.1436, -1.9443],
         ...,
         [ 0.3076, -1.4961, -0.4001,  ..., -0.0224,  0.9111, -0.3879],
         [-0.1440, -0.5166,  1.7109,  ..., -0.0795,  0.3611, -1.2441],
         [ 0.0415,  0.0185,  1.2754,  ..., -0.4209, -0.4387, -1.3018]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>), pooler_output=tensor([[-0.5151,  0.1665,  0.8887,  ..., -0.0677, -0.4563, -1.7959],
        [ 0.0415,  0.0185,  1.2754,  ..., -0.4209, -0.4387, -1.3018]],
       device='cuda:0', dtype=torch.float16, grad_fn=<IndexBackward0>), hidden_states=None, attentions=None)

@vladmandic
Copy link
Contributor

that would be weird - sm count is really close between those two, after all its just a variation of the same ga102 die.

drisspg added a commit to pytorch/pytorch that referenced this issue Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend    

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager). 


Cc atalman

cc mikaylagawarecki

[ghstack-poisoned]
@sayakpaul
Copy link
Member

sayakpaul commented Oct 22, 2024

Related: pytorch/pytorch#138522.

Also concur with @vladmandic for #9704 (comment). From all the issues, replies, and PRs PT devs have made it clear that they quite care.

@vladmandic
Copy link
Contributor

@egy i see that torch pr that changes the order of sdpa backends was just merged - do you have a tentative target for 2.5.1 release?

pytorchbot pushed a commit to pytorch/pytorch that referenced this issue Oct 22, 2024
# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)
@drisspg
Copy link

drisspg commented Oct 22, 2024

@vladmandic Just landed yesterday, we are planning to release a cherry-pick release sooner than we normally would. No date in stone but it is imminent.

Would you be up for testing out the "nightly" build when the changes get merged into that release (should be tomorrow).

@vladmandic
Copy link
Contributor

Would you be up for testing out the "nightly" build when the changes get merged into that release (should be tomorrow).

sure

kit1980 pushed a commit to pytorch/pytorch that referenced this issue Oct 22, 2024
[SDPA-CUDNN] Make CuDNN Attention Opt in (#138522)

# Summary
Currently we have a `cudnn_order` that says on H100 w/ new enough CuDNN backend (we ship a 9.1 version in OSS) try to run CuDNN attention first. We have already encountered a few bugs with the release of 2.5:

1. #138529
2. huggingface/diffusers#9704
3. #138354

In light of the above we are going to make the CuDNN backend Opt-in by default.

This can be done easily with the context manager for choosing backends I.e.:
``` Python
from torch.nn.attention import sdpa_kernel, SDPBackend

with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
    out = F.scaled_dot_product_attention(q, k, v)

```

This PR puts the CuDNN backend as the lowest precedence in the backend list, meaning that the Math backend will always be chosen unless disabled (which is done via the context manager).

Cc @atalman

Pull Request resolved: #138522
Approved by: https://github.com/ngimel, https://github.com/eqy, https://github.com/malfet

(cherry picked from commit 9a9a0ab)

Co-authored-by: drisspg <[email protected]>
digitalsp added a commit to digitalsp/ai-rakugaki-app that referenced this issue Oct 23, 2024
diffusers v0.30.3とPyTorch 2.4.3の組み合わせだと精度をfp16にした際cuDNNのエラーが出て画像生成できなかった
該当Issue: huggingface/diffusers#9704
このため、diffusersのバージョンはそのままに、PyTorchのバージョンを2.4.0に指定した

update: `pyproject.toml`を編集し、torchとtorchvisionの取得先などを設定
update: なんとなくPythonバージョンを3.11系に
@drisspg
Copy link

drisspg commented Oct 24, 2024

@vladmandic we have an RC available, would you mind trying w/ this version of PyTorch:

pip3 install torch==2.5.1 --index-url https://download.pytorch.org/whl/test/cu124

@vladmandic
Copy link
Contributor

vladmandic commented Oct 24, 2024

tried both 2.5.1-rc and 2.6.0-nightly and both look fine - thanks!

pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/test/cu124
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu124

but...this is basically moving cudnn backend from highest priority to lowest priority so it behaves the same as previous versions of torch and there is still underlying issue with cudnn...

yup - confirmed with:

torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

this makes sdpa pick cudnn backend and issue is back.

@drisspg
Copy link

drisspg commented Oct 24, 2024

@vladmandic That is expected. The goal of this is to fix the default behavior. We are currently making CuDNN always the lowest priority so that math will be picked before cuDNN. You can of course manually call CuDNN and you will hit the existing error. @eqy Is working on the actual fix to the CuDNN backend so that we can in the future increase its priority.

@vladmandic
Copy link
Contributor

vladmandic commented Oct 24, 2024

@vladmandic That is expected. The goal of this is to fix the default behavior. We are currently making CuDNN always the lowest priority so that math will be picked before cuDNN. You can of course manually call CuDNN and you will hit the existing error. @eqy Is working on the actual fix to the CuDNN backend so that we can in the future increase its priority.

yup, makes total sense, just wanted to confirm - thanks.
from my perspective this is ok for 2.5.1 while you work on cudnn improvements.

@eqy
Copy link

eqy commented Oct 24, 2024

@vladmandic in the meantime we tried to repro the issue on 4090 but were unable to see it on our end (both cuDNN team and my own local testing). Could you share some more details about your environment? Is it e.g., Windows?

In the meantime I'm working on cuDNN robustness on sm8x and have found similar issues but it would be good if we could guarantee your specific use-case was covered.

@vladmandic
Copy link
Contributor

@vladmandic in the meantime we tried to repro the issue on 4090 but were unable to see it on our end (both cuDNN team and my own local testing). Could you share some more details about your environment? Is it e.g., Windows?

ahhh, that made me wonder...

running on host:

  • Windows 11 23H2
  • nVidia drivers 566 (just updated from 560, no change)

and...no issues!

but running in a VM

  • WSL2 2.3.24
  • Ubuntu 24.04

RuntimeError: cuDNN Frontend error: [cudnn_frontend] Error: No execution plans support the graph.

so its something about virtualization.
i'm assuming you're using nvml to get gpu info, so i've compared nvidia-smi and CUDA\extras\demo_suite\deviceQuery on both Windows and Linux and they look the same.

@felixniemeyer
Copy link

ran into the same issue,

pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/test/cu124 

as suggested fixed it for me for now.
Love the vibe of your conversation.
Thanks!

@vladmandic
Copy link
Contributor

vladmandic commented Oct 24, 2024

thanks @felixniemeyer :)
btw, can you document your env briefly - which os/gpu, any virtualization, etc...

@sayakpaul
Copy link
Member

@vladmandic thanks for leading the charge here! Also, thanks to @eqy @drisspg for the help!

@felixniemeyer
Copy link

thanks @felixniemeyer :) btw, can you document your env briefly - which os/gpu, any virtualization, etc...

sure!
OS: Linux 6.6.52-1-MANJARO x86_64
GPU: NVIDIA Corporation AD106 [GeForce RTX 4060 Ti 16GB]
no virtualization, just a python venv.

@bghira
Copy link
Contributor

bghira commented Oct 25, 2024

yes it happens on Linux systems w/o any para/hw virt.

@eqy
Copy link

eqy commented Oct 25, 2024

Could you post the repro(s) that you are running if they are different? @bghira @felixniemeyer (The same error message can be triggered with different root causes, e.g., compilation failure due to environment differences vs. compilation failure due to incorrect code generation)

@bghira
Copy link
Contributor

bghira commented Oct 25, 2024

i wasnt able to identify any cause as we had cuda 12.4 images working okay but no clear link between library versions and this error.

@felixniemeyer
Copy link

Could you post the repro(s) that you are running if they are different? @bghira @felixniemeyer (The same error message can be triggered with different root causes, e.g., compilation failure due to environment differences vs. compilation failure due to incorrect code generation)

Repro means way to reproduce the error message, right?

I was sticking to this guide to train a stable diffusion 1.5 LoRA:
https://huggingface.co/docs/diffusers/training/lora

This was the command I was executing after setting up according to the guide:

export MODEL_NAME="stable-diffusion-v1-5/stable-diffusion-v1-5" \
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py /
  --pretrained_model_name_or_path=$MODEL_NAME /
  --train_data_dir=../../../../training_dataset /
  --dataloader_num_workers=8 /
  --resolution=512 /
  --center_crop /
  --random_flip /
  --train_batch_size=1 /
  --gradient_accumulation_steps=4 /
  --max_train_steps=15000 /
  --learning_rate=1e-04 /
  --max_grad_norm=1 /
  --lr_scheduler="cosine" /
  --lr_warmup_steps=0 /
  --output_dir=../../../lora_15_out /
  --checkpointing_steps=1500 /
  --mixed_precision="fp16" /
  --seed=1337 

I have created a venv at diffusers root level and installed the examples/text_to_image's requirements in the same venv.

Can you reproduce it like this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests