Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add automatic batching #22

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 25 additions & 5 deletions auto_fp8/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,10 +246,27 @@ def quantize_weights(
cleanup_memory()


def find_max_batch_size(model: AutoModelForCausalLM, tokens):
batch_size = tokens.shape[0]
while batch_size > 1:
try:
with torch.inference_mode():
model(tokens[:batch_size].reshape(batch_size, -1))
return batch_size
except RuntimeError as e:
print(e)
if 'out of memory' in str(e):
cleanup_memory()
batch_size //= 2
else:
raise e
return batch_size


def quantize_activations(
model: AutoModelForCausalLM,
quantize_config: BaseQuantizeConfig,
calibration_tokens,
calibration_tokens: torch.Tensor,
):
# Replace weight quantizer with a dynamic activation quantizer observer
for name, dynamic_quant_linear in model.named_modules():
Expand All @@ -271,13 +288,16 @@ def quantize_activations(
del dynamic_quant_linear
cleanup_memory()

# Find the maximum batch size that can be used without going OOM
max_batch_size = find_max_batch_size(model, calibration_tokens)

# Pass through calibration data to measure activation scales
with torch.inference_mode():
with tqdm.tqdm(total=calibration_tokens.shape[0], desc="Calibrating activation scales") as pbar:
for row_idx in range(calibration_tokens.shape[0]):
model(calibration_tokens[row_idx].reshape(1, -1))
cleanup_memory()
pbar.update(1)
for i in range(0, calibration_tokens.shape[0], max_batch_size):
batch = calibration_tokens[i:i + max_batch_size]
model(batch.reshape(batch.shape[0], -1))
pbar.update(batch.shape[0])

# Replace dynamic quantizer observer with StaticLinear for export
for name, quantizer in model.named_modules():
Expand Down
Loading