Skip to content

Commit

Permalink
Feat (examples/generative): block-based optimization for GPTQ (#1046)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: i-colbert <[email protected]>
  • Loading branch information
Giuseppe5 and i-colbert authored Oct 17, 2024
1 parent 4617f7b commit 59f8df7
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 23 deletions.
145 changes: 125 additions & 20 deletions src/brevitas_examples/llm/llm_quant/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,144 @@
# SPDX-License-Identifier: BSD-3-Clause
"""

from copy import deepcopy

from accelerate.utils.operations import send_to_device
import torch
from tqdm import tqdm

from brevitas.graph.calibrate import disable_return_quant_tensor
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.gptq import gptq_mode
from brevitas.graph.gpxq import StopFwdException
from brevitas.utils.python_utils import recurse_getattr


@torch.no_grad()
def block_optimization(model, dataloader, block_name, context_manager_func, context_manager_kwargs):
disable_quant_inference = DisableEnableQuantization()
cache_state = model.config.use_cache
model.config.use_cache = False
blocks = recurse_getattr(model, block_name)
first_block = blocks[0]
cached_args, cached_kwargs = [], []

# Intercept input to first block
def intercept_input(module, args, kwargs):
args = send_to_device(args, 'cpu')
kwargs = send_to_device(kwargs, 'cpu')
cached_args.append(args)
cached_kwargs.append(kwargs)
raise StopFwdException

# Intercept output from block N-1 to set it as input to block N
def intercept_output(module, args, kwargs, output):
if isinstance(output, tuple):
output = output[0]
output = send_to_device(output, 'cpu')
cached_args.append((output,))
raise StopFwdException

# Collect input to first block
if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

hook = first_block.register_forward_pre_hook(intercept_input, with_kwargs=True)
for inps in dataloader:
try:
model(**inps)
except StopFwdException:
pass
hook.remove()

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

# Iterate through all the blocks
for index, block in tqdm(enumerate(blocks), desc="Blocks", total=len(blocks)):
with context_manager_func(block, **context_manager_kwargs) as gpxq:
for _ in tqdm(range(gpxq.num_layers), desc="Layers", leave=False):
for args, kwargs in zip(cached_args, cached_kwargs):
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
gpxq.update()

if index < len(blocks) - 1:
# Once the block is done, we need to update the input to the next block
past_cached_args, past_cached_kwargs = deepcopy(cached_args), deepcopy(cached_kwargs)
cached_args = []
hook = block.register_forward_hook(intercept_output, with_kwargs=True)

if not context_manager_kwargs.get('use_quant_activations', True):
return_quant_tensor_state = disable_return_quant_tensor(model)
disable_quant_inference.disable_act_quantization(model, is_training=model.training)
disable_quant_inference.disable_bias_quantization(model, is_training=model.training)

for args, kwargs in zip(past_cached_args, past_cached_kwargs):
try:
args = send_to_device(args, 'cuda')
kwargs = send_to_device(kwargs, 'cuda')
block(*args, **kwargs)
except StopFwdException:
pass

if not context_manager_kwargs.get('use_quant_activations', True):
disable_quant_inference.enable_act_quantization(model, is_training=model.training)
disable_quant_inference.enable_bias_quantization(model, is_training=model.training)
restore_return_quant_tensor(model, return_quant_tensor_state)

hook.remove()
# Restore cache state
model.config.use_cache = cache_state


@torch.no_grad()
def apply_gptq(
model,
dataloader,
act_order=True,
use_quant_activations=False,
create_weight_orig=False,
group_of_parallel_layers=None,
use_quant_activations=True,
create_weight_orig=False):
with gptq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers,
use_quant_activations=use_quant_activations,
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()
block_name=None):
if block_name is not None:
context_manager_kwargs = {
'act_order': act_order,
'group_of_parallel_layers': group_of_parallel_layers,
'create_weight_orig': create_weight_orig,
'use_quant_activations': use_quant_activations}
block_optimization(model, dataloader, block_name, gptq_mode, context_manager_kwargs)
else:
with gptq_mode(model,
use_quant_activations=use_quant_activations,
group_of_parallel_layers=group_of_parallel_layers,
act_order=act_order,
create_weight_orig=create_weight_orig) as gptq:
gptq_model = gptq.model
for _ in tqdm(range(gptq.num_layers)):
for inps in dataloader:
gptq_model(**inps)
gptq.update()


@torch.no_grad()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None):
with gpfq_mode(model, act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
def apply_gpfq(model, dataloader, act_order=True, group_of_parallel_layers=None, block_name=None):
if block_name is not None:
raise RuntimeError("Block optimization not support for GPFQ at the moment")
else:
with gpfq_mode(model,
act_order=act_order,
group_of_parallel_layers=group_of_parallel_layers,
create_weight_orig=True) as gpfq:
gpfq_model = gpfq.model
for _ in tqdm(range(gpfq.num_layers)):
for inps in dataloader:
gpfq_model(**inps)
gpfq.update()
19 changes: 16 additions & 3 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ def validate(args):
def main(args):
validate(args)
set_seed(args.seed)

if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

Expand Down Expand Up @@ -189,6 +188,8 @@ def main(args):

if require_fx:
model = get_fx(model)
# Blockwise optimization does not work with FX at the moment
args.gpxq_block_name = None

# Apply LN affine merging before inserting MHA layers
# since currently there is support only for merging into Linear
Expand Down Expand Up @@ -285,12 +286,17 @@ def main(args):
calibration_loader,
act_order=args.gpxq_act_order,
use_quant_activations=args.gpxq_use_quant_activations,
create_weight_orig=args.gpxq_create_weight_orig)
create_weight_orig=args.gpxq_create_weight_orig,
block_name=args.gpxq_block_name)
print("GPTQ applied.")

if args.gpfq:
print("Applying GPFQ...")
apply_gpfq(model, calibration_loader, act_order=args.gpxq_act_order)
apply_gpfq(
model,
calibration_loader,
act_order=args.gpxq_act_order,
block_name=args.gpxq_block_name)
print("GPFQ applied.")

if args.bias_corr:
Expand Down Expand Up @@ -340,6 +346,13 @@ def parse_args(args):
choices=['wikitext2', 'c4'],
default='wikitext2',
help='Dataset to use for quantization (default: %(default)s)')
parser.add_argument(
'--gpxq-block-name',
type=str,
default=None,
help=
'Block name for faster GPxQ optimization. It works only if FX is not needed (default: %(default)s)'
)
parser.add_argument(
'--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.')
parser.add_argument(
Expand Down

0 comments on commit 59f8df7

Please sign in to comment.