Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten layer followed by linear layer causes HardFault on Cortex M4F #7651

Open
ChristophKarlHeck opened this issue Jan 14, 2025 · 6 comments
Assignees
Labels
partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ChristophKarlHeck
Copy link

ChristophKarlHeck commented Jan 14, 2025

🐛 Describe the bug

Hi,
I want to run the following model on NUCLEO_WB55RG. At the moment, I am getting a HardFault, and I am using this executor_runner: https://github.com/ChristophKarlHeck/mbed-torch-fusion-os/tree/main/src/model_executor

import torch
from torch import nn
import torch.nn.functional as F
from torch.export import export, export_for_training, ExportedProgram
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager
import executorch.exir as exir
import pytorch_lightning as pl
from torch.utils.data import DataLoader, TensorDataset

import numpy as np

# Define the Conv1D Model with LightningModule
class Conv1DModel(pl.LightningModule):
    def __init__(self, input_channels, output_channels, kernel_size):
        super().__init__()
        self.conv1d = nn.Conv1d(input_channels, output_channels, kernel_size)
        self.pool = nn.MaxPool1d(3, stride=3) # window size 3, how far windows slided 3
        self.flatten = nn.Flatten() # flatten
        self.linear = nn.Linear(64, 16) # fully connected layer
        self.output = nn.Linear(16, 2)
        self.loss_fn = nn.CrossEntropyLoss()
        # pooling

    def forward(self, x):
        # Ensure input is [batch_size, input_channels, seq_length]
        x = self.conv1d(x)
        x = F.relu(x)
        x = self.pool(x) # compress to one convolution block 
        x = self.flatten(x)
        x = self.linear(x)
        x = F.relu(x) # ReLu is not linear. At least one non-linear to recognize non-linear pattern
        x = F.softmax(self.output(x),dim=1) # Sum = 1
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        output = self(x)
        print(output.shape)
        y = torch.argmax(y, dim=1)
        loss = self.loss_fn(output, y)  # Example loss
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)
    
    example_input = torch.randn(1, 1, 100)
    can_delegate = False

# Prepare Dummy Data
x_data = torch.randn(10, 1, 100)  # [batch_size, input_channels, seq_length[50,100]] double check if input channels and seq_length are switched (lstm opposite)
y_data = torch.randn(10, 1)  # [batch_size, output_channels]
dataset = TensorDataset(x_data, y_data)
train_loader = DataLoader(dataset, batch_size=3)

# Train the Model
model = Conv1DModel(input_channels=1, output_channels=2, kernel_size=5)
trainer = pl.Trainer(max_epochs=5, logger=False)
trainer.fit(model, train_loader)

# Export the Model
model.eval()

# Example Input
final_example_input = (torch.randn(1, 1, 100),)  # [batch_size, input_channels, seq_length]

# Export the Model
pre_autograd_aten_dialect = export_for_training(
        model,
        final_example_input
    ).module()

aten_dialect: ExportedProgram = export(pre_autograd_aten_dialect, final_example_input)

# The graph returned by torch.export only contains functional ATen operators (~2000 ops), which we will call the ATen Dialect.
print(aten_dialect)
edge_program: exir.EdgeProgramManager = exir.to_edge(aten_dialect)

executorch_program: exir.ExecutorchProgramManager = edge_program.to_executorch(
    ExecutorchBackendConfig(
        passes=[],  # User-defined passes
    )
)

with open("model.pte", "wb") as file:
    file.write(executorch_program.buffer)

print("CNN model saved as model.pte")

HardFault:

I executorch:ModelExecutor.cpp:83] Model in 0x200008A0 <                                                                                
I executorch:ModelExecutor.cpp:104] Model PTE file loaded. Size: 2600 bytes.                                                            
I executorch:ModelExecutor.cpp:124] Model buffer loaded, has 1 methods                                                                  
I executorch:ModelExecutor.cpp:132] Running method forward                                                                              
I executorch:ModelExecutor.cpp:224] Setting up planned buffer 0, size 1216.                                                             
I executorch:ModelExecutor.cpp:265] Method loaded.                                                                                      
I executorch:ModelExecutor.cpp:267] Preparing inputs...                                                                                 
I executorch:ModelExecutor.cpp:280] Input prepared.                                                                                     
I executorch:ModelExecutor.cpp:288] Number of input values required by model:100                                                        
Input[0][0]: 1.000000                                                                                                                   
Input[0][1]: 1.000000                                                                                                                   
Input[0][2]: 1.000000                                                                                                                   
Input[0][3]: 1.000000                                                                                                                   
Input[0][4]: 1.000000                                                                                                                   
Input[0][5]: 1.000000                                                                                                                   
Input[0][6]: 1.000000                                                                                                                   
Input[0][7]: 1.000000                                                                                                                   
Input[0][8]: 1.000000                                                                                                                   
Input[0][9]: 1.000000                                                                                                                   
Input[0][10]: 1.000000                                                                                                                  
Input[0][11]: 1.000000                                                                                                                  
Input[0][12]: 1.000000                                                                                                                  
Input[0][13]: 1.000000                                                                                                                  
Input[0][14]: 1.000000                                                                                                                  
Input[0][15]: 1.000000                                                                                                                  
Input[0][16]: 1.000000                                                                                                                  
Input[0][17]: 1.000000                                                                                                                  
Input[0][18]: 1.000000                                                                                                                  
Input[0][19]: 1.000000                                                                                                                  
Input[0][20]: 1.000000                                                                                                                  
Input[0][21]: 1.000000                                                                                                                  
Input[0][22]: 1.000000                                                                                                                  
Input[0][23]: 1.000000                                                                                                                  
Input[0][24]: 1.000000                                                                                                                  
Input[0][25]: 1.000000                                                                                                                  
Input[0][26]: 1.000000                                                                                                                  
Input[0][27]: 1.000000                                                                                                                  
Input[0][28]: 1.000000                                                                                                                  
Input[0][29]: 1.000000                                                                                                                  
Input[0][30]: 1.000000                                                                                                                  
Input[0][31]: 1.000000                                                                                                                  
Input[0][32]: 1.000000                                                                                                                  
Input[0][33]: 1.000000                                                                                                                  
Input[0][34]: 1.000000                                                                                                                  
Input[0][35]: 1.000000                                                                                                                  
Input[0][36]: 1.000000                                                                                                                  
Input[0][37]: 1.000000                                                                                                                  
Input[0][38]: 1.000000                                                                                                                  
Input[0][39]: 1.000000                                                                                                                  
Input[0][40]: 1.000000                                                                                                                  
Input[0][41]: 1.000000                                                                                                                  
Input[0][42]: 1.000000                                                                                                                  
Input[0][43]: 1.000000                                                                                                                  
Input[0][44]: 1.000000                                                                                                                  
Input[0][45]: 1.000000                                                                                                                  
Input[0][46]: 1.000000                                                                                                                  
Input[0][47]: 1.000000                                                                                                                  
Input[0][48]: 1.000000                                                                                                                  
Input[0][49]: 1.000000                                                                                                                  
Input[0][50]: 1.000000                                                                                                                  
Input[0][51]: 1.000000                                                                                                                  
Input[0][52]: 1.000000                                                                                                                  
Input[0][53]: 1.000000                                                                                                                  
Input[0][54]: 1.000000                                                                                                                  
Input[0][55]: 1.000000                                                                                                                  
Input[0][56]: 1.000000                                                                                                                  
Input[0][57]: 1.000000                                                                                                                  
Input[0][58]: 1.000000                                                                                                                  
Input[0][59]: 1.000000                                                                                                                  
Input[0][60]: 1.000000                                                                                                                  
Input[0][61]: 1.000000                                                                                                                  
Input[0][62]: 1.000000                                                                                                                  
Input[0][63]: 1.000000                                                                                                                  
Input[0][64]: 1.000000                                                                                                                  
Input[0][65]: 1.000000                                                                                                                  
Input[0][66]: 1.000000                                                                                                                  
Input[0][67]: 1.000000                                                                                                                  
Input[0][68]: 1.000000                                                                                                                  
Input[0][69]: 1.000000                                                                                                                  
Input[0][70]: 1.000000                                                                                                                  
Input[0][71]: 1.000000                                                                                                                  
Input[0][72]: 1.000000                                                                                                                  
Input[0][73]: 1.000000                                                                                                                  
Input[0][74]: 1.000000                                                                                                                  
Input[0][75]: 1.000000                                                                                                                  
Input[0][76]: 1.000000                                                                                                                  
Input[0][77]: 1.000000                                                                                                                  
Input[0][78]: 1.000000                                                                                                                  
Input[0][79]: 1.000000                                                                                                                  
Input[0][80]: 1.000000                                                                                                                  
Input[0][81]: 1.000000                                                                                                                  
Input[0][82]: 1.000000                                                                                                                  
Input[0][83]: 1.000000                                                                                                                  
Input[0][84]: 1.000000                                                                                                                  
Input[0][85]: 1.000000                                                                                                                  
Input[0][86]: 1.000000                                                                                                                  
Input[0][87]: 1.000000                                                                                                                  
Input[0][88]: 1.000000                                                                                                                  
Input[0][89]: 1.000000                                                                                                                  
Input[0][90]: 1.000000                                                                                                                  
Input[0][91]: 1.000000                                                                                                                  
Input[0][92]: 1.000000                                                                                                                  
Input[0][93]: 1.000000                                                                                                                  
Input[0][94]: 1.000000                                                                                                                  
Input[0][95]: 1.000000                                                                                                                  
Input[0][96]: 1.000000                                                                                                                  
Input[0][97]: 1.000000                                                                                                                  
Input[0][98]: 1.000000                                                                                                                  
Input[0][99]: 1.000000                                                                                                                  
I executorch:ModelExecutor.cpp:324] Starting the model execution...                                                                     
Thread: 0x20002088, Stack size: 1920 / 4096                                                                                             
Thread: 0x20001BFC, Stack size: 64 / 896                                                                                                
Thread: 0x20001BB8, Stack size: 88 / 768                                                                                                
Heap size: 20160 / 149760 bytes                                                                                                         
                                                                                                                                        
++ MbedOS Fault Handler ++                                                                                                              
                                                                                                                                        
FaultType: HardFault                                                                                                                    
                                                                                                                                        
Context:                                                                                                                                
R   0: 20002E08                                                                                                                         
R   1: 2000BA4C                                                                                                                         
R   2: 00000000                                                                                                                         
R   3: 00000004                                                                                                                         
R   4: 20002EE0                                                                                                                         
R   5: 00000001                                                                                                                         
R   6: 00000000                                                                                                                         
R   7: 200012C8                                                                                                                         
R   8: 20000AD8                                                                                                                         
R   9: 2000BA4C                                                                                                                         
R  10: 00000070                                                                                                                         
R  11: 20000A68                                                                                                                         
R  12: 00000004                                                                                                                         
SP   : 20002DC8                                                                                                                         
LR   : 0800DEC7                                                                                                                         
PC   : 00000004                                                                                                                         
xPSR : 40070000                                                                                                                         
PSP  : 20002D60                                                                                                                         
MSP  : 2002FFC0                                                                                                                         
CPUID: 410FC241                                                                                                                         
HFSR : 40000000                                                                                                                         
MMFSR: 00000000                                                                                                                         
BFSR : 00000000                                                                                                                         
UFSR : 00000002                                                                                                                         
DFSR : 00000000                                                                                                                         
AFSR : 00000000                                                                                                                         
Mode : Thread                                                                                                                           
Priv : Privileged                                                                                                                       
Stack: PSP                                                                                                                              
                                                                                                                                        
-- MbedOS Fault Handler --                                                                                                              
                                                                                                                                        
                                                                                                                                        
                                                                                                                                        
++ MbedOS Error Info ++                                                                                                                 
Error Status: 0x80FF013D Code: 317 Module: 255                                                                                          
Error Message: Fault exception                                                                                                          
Location: 0x4                                                                                                                           
Error Value: 0x20001610                                                                                                                 
Current Thread: main Id: 0x20002088 Entry: 0x8008265 StackSize: 0x1000 StackMem: 0x200020D0 SP: 0x20002DC8                              
For more info, visit: https://mbed.com/s/error?error=0x80FF013D&osver=61700&core=0x410FC241&comp=2&ver=130300&tgt=NUCLEO_WB55RG         
-- MbedOS Error Info --

But the following models don't cause a HardFault:

class Conv1DModel(pl.LightningModule):
    def __init__(self, input_channels, output_channels, kernel_size):
        super().__init__()
        self.conv1d = nn.Conv1d(input_channels, output_channels, kernel_size)
        self.pool = nn.MaxPool1d(3, stride=3)  # Window size 3, stride 3
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        # Apply convolution, ReLU activation, and max pooling
        x = self.pool(F.relu(self.conv1d(x)))
        return x
class Conv1DModel(pl.LightningModule):
    def __init__(self, input_channels, output_channels, kernel_size):
        super().__init__()
        self.conv1d = nn.Conv1d(input_channels, output_channels, kernel_size)
        self.pool = nn.MaxPool1d(3, stride=3)
        self.flatten = nn.Flatten()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1d(x)))
        x = self.flatten(x)
        return x
class LinearModel(pl.LightningModule):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.linear(x)
        return x
class Conv1DModel(pl.LightningModule):
    def __init__(self, input_channels, output_channels, kernel_size):
        super().__init__()
        self.conv1d = nn.Conv1d(input_channels, output_channels, kernel_size)
        self.pool = nn.MaxPool1d(3, stride=3)  # Window size 3, stride 3
        self.linear = nn.Linear(output_channels, 16)  # Adjusted input features
        self.output = nn.Linear(16, 2)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        # Ensure input is [batch_size, input_channels, seq_length]
        x = self.conv1d(x)
        x = F.relu(x)
        x = self.pool(x)  # Output shape: [batch_size, output_channels, reduced_seq_length]
        x = x.mean(dim=2)  # Global average pooling across seq_length
        x = self.linear(x)
        x = F.relu(x)  # ReLU for non-linearity
        x = F.softmax(self.output(x), dim=1)  # Softmax for class probabilities
        return x

Therefore, it has something to do with the flatten layer followed by the linear layer.

I also tried the torch.reshape function, but the export function maps reshape and flatten to view: "f32[1, 64]" = torch.ops.aten.view.default(max_pool1d, [1, 64]); max_pool1d = None.

Does anybody have a related issue or an idea of how to fix it?
Please let me know if you need some further information.

Versions

PyTorch version: 2.5.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.31.0
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov  6 2024, 20:22:13) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.15.0-122-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 11.5.119
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce GTX 950M
Nvidia driver version: 535.183.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        39 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               8
On-line CPU(s) list:                  0-7
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Core(TM) i7-6700HQ CPU @ 2.60GHz
CPU family:                           6
Model:                                94
Thread(s) per core:                   2
Core(s) per socket:                   4
Socket(s):                            1
Stepping:                             3
CPU max MHz:                          3500,0000
CPU min MHz:                          800,0000
BogoMIPS:                             5199.98
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single pti ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid mpx rdseed adx smap clflushopt intel_pt xsaveopt xsavec xgetbv1 xsaves dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp md_clear flush_l1d arch_capabilities
Virtualization:                       VT-x
L1d cache:                            128 KiB (4 instances)
L1i cache:                            128 KiB (4 instances)
L2 cache:                             1 MiB (4 instances)
L3 cache:                             6 MiB (1 instance)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-7
Vulnerability Gather data sampling:   Vulnerable: No microcode
Vulnerability Itlb multihit:          KVM: Mitigation: VMX disabled
Vulnerability L1tf:                   Mitigation; PTE Inversion; VMX conditional cache flushes, SMT vulnerable
Vulnerability Mds:                    Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Meltdown:               Mitigation; PTI
Vulnerability Mmio stale data:        Mitigation; Clear CPU buffers; SMT vulnerable
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Mitigation; IBRS
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; IBRS; IBPB conditional; STIBP conditional; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Mitigation; Microcode
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] executorch==0.4.0
[pip3] numpy==1.21.3
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-lightning==2.5.0.post0
[pip3] torch==2.5.0
[pip3] torchaudio==2.5.0
[pip3] torchmetrics==1.6.1
[pip3] torchvision==0.20.0
[pip3] triton==3.1.0
[conda] Could not collect
@digantdesai
Copy link
Contributor

Thanks for the detailed bug report @ChristophKarlHeck.

Skimmed through your runner code and it seems OK. Looking at the error, PC suggests we are in the weeds, dereferenced a nullptr? Can you figure out which branch following LR? Also looking at FSRs - seems like UFSR:INVSTATE and HFSR:FORCED. Not sure what is going on TBH.

Couple of suggestions,
(1) Can you reproduce it on the M55 FVP we have in the CI? That way we can start poking at it too. If you find a closer matching FVP that should be OK too.
(2) Can you figure out where is C code (perhaps in ET) or generated asm code this is happening? Call stack would be ideal.
(3) Also share your compiler and flag info?
(4) I saw you already tried other model combinations but perhaps if you learn something from figuring out (2), may be a small repro would go a long way towards quickly iterating on this.

Thanks again.

@mcr229 mcr229 added partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm bug triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 15, 2025
@zingo
Copy link
Collaborator

zingo commented Jan 23, 2025

Hi @ChristophKarlHeck How is it going are you still stuck?

@ChristophKarlHeck
Copy link
Author

ChristophKarlHeck commented Jan 23, 2025

Hi @zingo,
At the moment, we are using a workaround for the flatten operation:

import torch

# generate a random torch tensor with dimensions 32,2,20 with integers between 1 and 100
x = torch.randint(1, 100, (32, 2, 20))

# flatten operation
flatten_layer = torch.nn.Flatten()
x_flatten = flatten_layer(x)

# equivalent operation
x_test = torch.transpose(x, 1, 2)
x_test = torch.cat((x_test[:, :, 0], x_test[:, :, 1]), dim=1)

# get all shapes
print(f"shape of x: {x.shape}, shape of x_flatten: {x_flatten.shape}, shape of x_test: {x_test.shape}")

#check if all values are the same
print(f"all values are the same: {torch.all(torch.eq(x_flatten, x_test))}")

Since I am in the middle of my master's thesis, I must prepare for the MVP. I will do further research when the MVP is ready :)

@zingo
Copy link
Collaborator

zingo commented Jan 25, 2025

Since I am in the middle of my master's thesis,

Good luck I hope it goes well!

@digantdesai
Copy link
Contributor

@zingo can we try to repro it on our end using our M55 FVPs?

@zingo
Copy link
Collaborator

zingo commented Jan 28, 2025

With a bit of luck running the arm aot compiler and not using the delegate flag should probably just run in on Cortex-M only.

@jackzhxng jackzhxng removed the bug label Feb 21, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
partner: arm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants