Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modifying RNN Quantization for bitwidth lower than 8-bit #1041

Open
JiaMingLin opened this issue Oct 3, 2024 · 0 comments
Open

Modifying RNN Quantization for bitwidth lower than 8-bit #1041

JiaMingLin opened this issue Oct 3, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@JiaMingLin
Copy link

JiaMingLin commented Oct 3, 2024

Hi,

I've tested the QuantLSTM on the EMG dataset with the following bitwidth settings:

weight_quant: 8
io_quant: 2
sigmoid: 8
tanh_quant: 8
cell_state_quant: 8
accumulation_quant: 16

However, the quantized model failed to learn properly (classification accuracy for the 8-category task is only 12.5%).
After reviewing the code and visualizing the quantization process in the figure below:
截圖 2024-10-03 晚上11 46 34
I noticed that with io_quant set to 2-bit, the output from the final LSTM layer to the fully connected layer is also 2-bit. In my experience, extreme low-bitwidth models generally require higher bitwidths in the final classifier layer (i.e., the FC layer). This could explain the poor performance.

I've made a proposed modification, as shown in the figure below:
截圖 2024-10-04 凌晨12 02 35

  • Quantize the input for every LSTM layer using the $q_x$ quantizer (previously, only the first layer's input was quantized).
  • Quantize the input hidden state for each LSTM layer using the $q_h$ quantizer (previously, this was done by output_quant).
  • Remove output_quant from the output hidden state.

With this change and some quick tests, the new scheme achieved 53% accuracy—much higher than the previous 12.5%.

Next, I plan to run more experiments on datasets like PTB and explore different LSTM and RNN variations. Any feedback or suggestions are welcome, and I'll update this thread with results.

P.S. The weight and activation quantizers are defined as follows:

class Int8WeightPerTensorFloatScratch(WeightQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.STATS # scale based on statistics
    scaling_stats_op = StatsOp.MAX # scale statistics is the absmax value
    restrict_scaling_type = RestrictValueType.FP # scale factor is a floating point value
    scaling_per_output_channel = False # scale is per tensor
    bit_width = 8 # bit width is 8
    signed = True # quantization range is signed
    narrow_range = True # quantization range is [-127,127] rather than [-128, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.
    scaling_min_val = 1e-10 # minimum value for the scale factor

class Int8ActPerTensorFloatScratch(ActQuantSolver):
    quant_type = QuantType.INT # integer quantization
    bit_width_impl_type = BitWidthImplType.CONST # constant bit width
    float_to_int_impl_type = FloatToIntImplType.ROUND # round to nearest
    scaling_impl_type = ScalingImplType.STATS # scale is a parameter initialized from statistics
    scaling_stats_op = StatsOp.PERCENTILE # scale statistics is a percentile of the abs value
    high_percentile_q = 99.999 # percentile is 99.999
    collect_stats_steps = 300  # statistics are collected for 300 forward steps before switching to a learned parameter
    restrict_scaling_type = RestrictValueType.FP # scale is a floating-point value
    scaling_per_output_channel = False  # scale is per tensor
    bit_width = 8  # bit width is 8
    signed = True # quantization range is signed
    narrow_range = False # quantization range is [-128, 127] rather than [-127, 127]
    zero_point_impl = ZeroZeroPoint # zero point is 0.
    scaling_min_val = 1e-10 # minimum value for the scale factor

class Int2ActPerTensorFloatScratch(Int8ActPerTensorFloatScratch):
    bit_width=2
@nickfraser nickfraser added the enhancement New feature or request label Oct 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants