-
Notifications
You must be signed in to change notification settings - Fork 180
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[distilbert] Preliminary perf benchmarks
cc @VictorSanh see also #5
- Loading branch information
Showing
8 changed files
with
251 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// | ||
// DistilBERTPerfTests.swift | ||
// CoreMLBertTests | ||
// | ||
// Created by Julien Chaumond on 16/09/2019. | ||
// Copyright © 2019 Hugging Face. All rights reserved. | ||
// | ||
|
||
import XCTest | ||
import CoreML | ||
@testable import CoreMLBert | ||
|
||
class DistilBERTPerfTests: XCTestCase { | ||
let context = "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50." | ||
|
||
let question = "Which NFL team represented the AFC at Super Bowl 50?" | ||
let m = BertForQuestionAnswering() | ||
let mDistilbert = distilbert_64_12() | ||
|
||
func testPerformanceNakedBERTModel() { | ||
let input = m.featurizeTokens(question: question, context: context) | ||
|
||
self.measure { | ||
_ = try! m.model.prediction(input: input) | ||
} | ||
} | ||
|
||
func testPerformanceDistilBERTModel() { | ||
let input_ids = MLMultiArray.from(Array(repeating: 0, count: 64)) | ||
|
||
self.measure { | ||
_ = try! mDistilbert.prediction(input_ids: input_ids) | ||
/// print(output.output_logits) | ||
} | ||
} | ||
} | ||
|
Git LFS file not shown
Git LFS file not shown
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# DistilBERT performance benchmarks | ||
|
||
#### Full BERT-Squad with tokenization/featurization. | ||
|
||
``` | ||
~/swift-coreml-transformers/CoreMLBertTests/BertForQATests.swift:75: | ||
Test Case '-[CoreMLBertTests.BertForQATests testPerformanceExample]' measured [Time, seconds] | ||
average: 1.583, relative standard deviation: 5.232%, values: [1.746976, 1.550390, 1.549479, 1.529654, 1.528065, 1.508825, 1.534357, 1.702786, 1.514816, 1.665434], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100 | ||
Test Case '-[CoreMLBertTests.BertForQATests testPerformanceExample]' passed (16.130 seconds). | ||
``` | ||
|
||
--- | ||
|
||
#### Full BERT-Squad, only the inference. | ||
|
||
``` | ||
~/swift-coreml-transformers/CoreMLBertTests/DistilBERTPerfTests.swift:23: | ||
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceNakedModel]' measured [Time, seconds] | ||
average: 1.118, relative standard deviation: 5.919%, values: [1.195310, 1.068182, 1.131890, 1.251984, 1.095551, 1.186633, 1.060465, 1.072363, 1.061609, 1.059508], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100 | ||
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceNakedModel]' passed (11.822 seconds). | ||
``` | ||
|
||
--- | ||
|
||
#### DistilBERT, only the inference. | ||
|
||
``` | ||
~/swift-coreml-transformers/CoreMLBertTests/DistilBERTPerfTests.swift:32: | ||
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceDistilBERTModel]' measured [Time, seconds] | ||
average: 0.319, relative standard deviation: 0.548%, values: [0.321627, 0.321694, 0.317964, 0.316413, 0.318463, 0.319897, 0.319386, 0.317997, 0.318780, 0.321835], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100 | ||
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceDistilBERTModel]' passed (3.466 seconds). | ||
``` | ||
|
||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import coremltools | ||
import coremltools.models.datatypes as datatypes | ||
import numpy as np | ||
import torch | ||
from coremltools.models import neural_network as neural_network | ||
from coremltools.models.utils import save_spec | ||
# get weights | ||
from pytorch_transformers.modeling_distilbert import (DistilBertConfig, | ||
DistilBertModel, | ||
TransformerBlock) | ||
|
||
model = DistilBertModel.from_pretrained('distilbert-base-uncased-distilled-squad') | ||
config: DistilBertConfig = model.config | ||
|
||
sequence_length = config.max_position_embeddings # 512 | ||
steps = config.n_layers # 6 | ||
|
||
|
||
# build model | ||
input_features = [ | ||
('input_ids', datatypes.Array(sequence_length)), | ||
] | ||
output_features = [ | ||
('output_logits', None) | ||
] | ||
|
||
builder = neural_network.NeuralNetworkBuilder( | ||
input_features, | ||
output_features, | ||
mode=None, | ||
disable_rank5_shape_mapping=True, | ||
) | ||
builder.add_expand_dims( | ||
name='input_ids_expanded_to_rank5', | ||
input_name='input_ids', | ||
output_name='input_ids_expanded_to_rank5', | ||
axes=(1, 2, 3, 4) | ||
) | ||
builder.add_embedding( | ||
name='token_embeddings', | ||
input_name='input_ids_expanded_to_rank5', | ||
output_name='token_embeddings', | ||
W=model.embeddings.word_embeddings.weight.data.numpy().transpose(), # shape (768, 30522) | ||
b=None, | ||
input_dim=config.vocab_size, | ||
output_channels=768, | ||
has_bias=False, | ||
) | ||
builder.add_mvn( | ||
name='embeddings_ln', | ||
input_name=f"token_embeddings", | ||
output_name=f"embeddings_ln", | ||
across_channels=True, | ||
normalize_variance=True, | ||
epsilon=model.embeddings.LayerNorm.eps | ||
) | ||
builder.add_scale( | ||
name=f"embeddings_ln_scaled", | ||
input_name=f"embeddings_ln", | ||
output_name=f'{0}_previous_block', | ||
# output_name=f'output_logits', | ||
W=model.embeddings.LayerNorm.weight.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
b=model.embeddings.LayerNorm.bias.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
has_bias=True, | ||
shape_scale=[768], | ||
shape_bias=[768] | ||
) | ||
|
||
for i in range(steps): | ||
print(i) | ||
layer: TransformerBlock = model.transformer.layer[i] | ||
|
||
# MultiHeadSelfAttention | ||
## wip | ||
# sa_layer_norm | ||
builder.add_mvn( | ||
name=f"{i}_block_ln_2", | ||
input_name=f"{i}_block_xa_sum", | ||
output_name=f"{i}_block_ln_2", | ||
across_channels=True, | ||
normalize_variance=True, | ||
epsilon=layer.sa_layer_norm.eps | ||
) | ||
builder.add_scale( | ||
name=f"{i}_block_ln_2_scaled", | ||
input_name=f"{i}_block_ln_2", | ||
output_name=f"{i}_block_ln_2_scaled", | ||
W=layer.sa_layer_norm.weight.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
b=layer.sa_layer_norm.bias.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
has_bias=True, | ||
shape_scale=[768], | ||
shape_bias=[768] | ||
) | ||
|
||
# Feed Forward Network | ||
builder.add_inner_product( | ||
name=f"{i}_block_mlp_conv_fc", | ||
input_name=f"{i}_block_ln_2_scaled", | ||
output_name=f"{i}_block_mlp_conv_fc", | ||
input_channels=768, | ||
output_channels=3072, | ||
W=layer.ffn.lin1.weight.data.numpy().transpose().reshape((1, 768, 3072, 1, 1)), | ||
b=layer.ffn.lin1.bias.data.numpy().reshape((1, 1, 3072, 1, 1)), | ||
has_bias=True | ||
) | ||
builder.add_gelu( | ||
name=f"{i}_block_mlp_gelu", | ||
input_name=f"{i}_block_mlp_conv_fc", | ||
output_name=f"{i}_block_mlp_gelu", | ||
mode='TANH_APPROXIMATION' | ||
) | ||
builder.add_inner_product( | ||
name=f"{i}_block_mlp_conv_proj", | ||
input_name=f"{i}_block_mlp_gelu", | ||
output_name=f"{i}_block_mlp_conv_proj", | ||
input_channels=3072, | ||
output_channels=768, | ||
W=layer.ffn.lin2.weight.data.numpy().transpose().reshape((1, 3072, 768, 1, 1)), | ||
b=layer.ffn.lin2.bias.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
has_bias=True | ||
) | ||
|
||
# output_layer_norm | ||
# Input: (1, seq, 768, 1, 1), Output: | ||
builder.add_mvn( | ||
name=f"{i}_output_ln", | ||
input_name=f"{i}_block_mlp_conv_proj", | ||
output_name=f"{i}_output_ln", | ||
across_channels=True, | ||
normalize_variance=True, | ||
epsilon=layer.output_layer_norm.eps | ||
) | ||
builder.add_scale( | ||
name=f"{i}_output_ln_scaled", | ||
input_name=f"{i}_output_ln", | ||
output_name=f"{i}_output_ln_scaled", | ||
W=layer.output_layer_norm.weight.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
b=layer.output_layer_norm.bias.data.numpy().reshape((1, 1, 768, 1, 1)), | ||
has_bias=True, | ||
shape_scale=[768], | ||
shape_bias=[768] | ||
) | ||
|
||
|
||
|
||
# compile spec to model | ||
# mlmodel = coremltools.models.MLModel(builder.spec) | ||
# input_data = { | ||
# 'input_ids': np.array([ 7592, 1010, 2026, 3899, 2003, 10140 ]) | ||
# } | ||
# predictions = mlmodel.predict(input_data)["output_logits"] | ||
# print(predictions) | ||
|
||
|
||
save_spec(builder.spec, f'../Resources/distilbert-{sequence_length}-{steps}.mlmodel') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
coremltools==3.0b3 | ||
pytorch-transformers==1.0.0 | ||
coremltools==3.0b6 | ||
pytorch-transformers==1.2.0 |