Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixed precision conversion getting Assertion Error #89

Open
caffeinetoomuch opened this issue Jun 2, 2022 · 15 comments
Open

Mixed precision conversion getting Assertion Error #89

caffeinetoomuch opened this issue Jun 2, 2022 · 15 comments

Comments

@caffeinetoomuch
Copy link

Firstly, thanks for this deployment package! @pommedeterresautee
I was trying to optimize our finetuned(also pre, post forward code modified) T5 model from your recent T5 optimization notebook.
When I was converting encoder to mixed precision, I was getting an assertion error of Graph is not a DAG

Here is the snippet for conversion:

keep_fp32_encoder = get_keep_fp32_nodes(onnx_model_path=encoder_model_path, get_input=get_input)
assert len(keep_fp32_encoder) > 0
enc_model_onnx = convert_fp16(onnx_model=encoder_model_path, nodes_to_exclude=keep_fp32_encoder)

Stacktrace:

06-02 01:05:27 [onnxruntime.transformers.fusion_utils] INFO     - Removed 531 cascaded Cast nodes
06-02 01:05:34 [onnxruntime.transformers.onnx_model] INFO     - Graph pruned: 0 inputs, 0 outputs and 391 nodes are removed
06-02 01:05:43 [onnxruntime.transformers.fusion_utils] INFO     - Removed 431 Cast nodes with output type same as input
Traceback (most recent call last):
  File "convert_checkpoint_to_onnx.py", line 331, in <module>
    enc_model_onnx = convert_fp16(onnx_model=encoder_model_path, nodes_to_exclude=keep_fp32_encoder)
  File "/workspace/transformer-deploy/src/transformer_deploy/backends/ort_utils.py", line 418, in convert_fp16
    wrapped_fp16_model.topological_sort()
  File "/home/jinkoo/.local/lib/python3.8/site-packages/onnxruntime/transformers/onnx_model.py", line 898, in topological_sort
    OnnxModel.graph_topological_sort(self.model.graph)
  File "/home/jinkoo/.local/lib/python3.8/site-packages/onnxruntime/transformers/onnx_model.py", line 890, in graph_topological_sort    assert end == len(graph.node), "Graph is not a DAG"
AssertionError: Graph is not a DAG

I installed the transformer-deploy from the most recent repo as well. Could it be happening due to our modified forward call, which just involves reshaping and re-padding after original forward?

@caffeinetoomuch caffeinetoomuch changed the title Mixed precision conversion not Mixed precision conversion getting Assertion Error Jun 2, 2022
@pommedeterresautee
Copy link
Member

Hi @ice-americano,
1/ does it work with non modified version of forward method?
2/ can you share your modified forward method?

@caffeinetoomuch
Copy link
Author

caffeinetoomuch commented Jun 2, 2022

does it work with non modified version of forward method?

Yeah, when I was using T5EncoderModel from my checkpoint(3b), script was able to convert(I was reshaping the encoder input ids before forward call). However, the output tensor of the Onnx FP16 model had nan for all.

can you share your modified forward method?

So basically, before calling original forward, we reshape both input_ids and attention_mask from (batch_size, inner_batch_size, len) to (batch * inner_batch_size, len). Then after forward, we reshape the last_hidden_state from (inner_batch_size, len, d_model) to (1, inner_batch_size * len, d_model).

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jun 3, 2022

you are using T5 3b with external data ? if yes, retry with ORT compiled from master, it fixes a bug with ext data

@caffeinetoomuch
Copy link
Author

I was using T5-3b, so I just assumed that package will default to using external data, isn't it?

retry with ORT compiled from master, it fixes a bug with ext data

Actually I saw your issue on ORT repo, so I built the ORT package from master two days ago. So ORT should be pretty up-to-date, just missing 2~3 commits.

Also I just decided to use default encoder and everything looks fine(including merging graphs with if node) without converting them to fp16. Are you saying fp16 model generating nan has to do with the bug with ext data?

@pommedeterresautee
Copy link
Member

haha may be not, the commit you want is 2 days old
image

So it depends of the exact hour you refreshed stuff :-)

Fp16 works only if you do mixed precision, keeping some nodes in FP32. trtexec can't do such thing.

@caffeinetoomuch
Copy link
Author

😎 Yeah I pulled from master after seeing that Revert ... commit, since I was facing that issue(CUDA failure 700) before.
I actually meant mixed precision(although I did not fully get what trtexec is). I just followed the script as the following:

keep_fp32_encoder = get_keep_fp32_nodes(onnx_model_path=encoder_model_path, get_input=get_input)
assert len(keep_fp32_encoder) > 0
enc_model_onnx = convert_fp16(onnx_model=encoder_model_path, nodes_to_exclude=keep_fp32_encoder)
save_onnx(proto=enc_model_onnx, model_path=encoder_fp16_model_path)
...
are_equal(a=enc_onnx_out, b=out_enc.last_hidden_state)  # Failed!

Any clue on what I could be doing something wrong?

@pommedeterresautee
Copy link
Member

You may want to increase early_stop to a bigger value, it will make more nodes in FP32 (at least in theory)
What kind of GPU are you using?

@caffeinetoomuch
Copy link
Author

caffeinetoomuch commented Jun 3, 2022

Oh okay. Default value I was using was 100, so I might try 150, 200, etc.
I am using A6000(Driver Version: 470.103.01, CUDA Version: 11.4)!

@pommedeterresautee
Copy link
Member

Just something which comes to mind, have you checked that your transformations are correctly exported in Onnx? (by using onnx lib or netron)

@caffeinetoomuch
Copy link
Author

I think fp32 checkpoint is well exported, since generated tensor was equal to that of the torch model. I did it with convert_to_onnx function from the script. I guess you meant the mixed precision checkpoint?

convert_to_onnx(
    model_pytorch=encoder,
    output_path=encoder_model_path,
    inputs_pytorch={
        "input_ids": encoder_input_ids,
        "attention_mask": encoder_attention_mask,
    },
    var_output_seq=True,
    quantization=False,
)

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Jun 7, 2022

Just wondering, did you update the fp16_default_tolerance value used in are_equal?

It has to be increased with the size of the model to take into account the rounding errors accumulating with the number of layers (no quality issue, these models are usually trained in FP16, just that it's required to compare with FP32).

@caffeinetoomuch
Copy link
Author

Just wondering, did you update the fp16_default_tolerance value used in are_equal?

As the comments, I was using fp16_default_tolerance = 3 since I am dealing with t5-3b.
Also, I tried higher early_stop value for mixed precision quantization, and my last_hidde_state still gives nan tensor as the following:

tensor([[[nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         ...,                                                                                                                                                                                                      
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan]],                                                                                                                                                                    
                                                                                                                                                                                                                   
        [[nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         ...,                                                                                                                                                                                                      
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan]],                                                                                                                                                                    
                                                                                                                                                                                                                   
        [[nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         ...,                                                                                                                                                                                                      
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan]],                                                                                                                                                                    
                                                                                                                                                                                                                   
        [[nan, nan, nan,  ..., nan, nan, nan],                                                                                                                                                                     
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       dtype=torch.float16)

@pommedeterresautee
Copy link
Member

We have completely rewrote this stuff. Can you recheck?

@caffeinetoomuch
Copy link
Author

Sorry, I was busy 😅 Yeah, I was going to try again, but just found the new notebook t5_bf16.ipynb!
Actually, I have been waiting for bf16 quantization, so I wanna try that out instead.

@pommedeterresautee
Copy link
Member

pommedeterresautee commented Aug 4, 2022

TBH, I think you will be disappointed (at least we were), the precision is very low, not tried but I would expect it shows in the accuracy measures whatever your task. Our understanding is that it's supposed to be used with mixed precision in mind, so you still have some casting here and there, just none for out of range values.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants