Skip to content

Commit

Permalink
[distilbert] Preliminary perf benchmarks
Browse files Browse the repository at this point in the history
cc @VictorSanh

see also #5
  • Loading branch information
julien-c committed Sep 17, 2019
1 parent 02228fb commit f9ce549
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 3 deletions.
10 changes: 10 additions & 0 deletions CoreMLBert.xcodeproj/project.pbxproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
7934AF8822EA4CEC00396AD5 /* gpt2-512.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 7934AF8722EA4CEC00396AD5 /* gpt2-512.mlmodel */; };
7944CDBF22C501D6000C0B1E /* dev-v1.1.json in Resources */ = {isa = PBXBuildFile; fileRef = 7944CDBE22C501D6000C0B1E /* dev-v1.1.json */; };
7944CDC022C501D6000C0B1E /* dev-v1.1.json in Resources */ = {isa = PBXBuildFile; fileRef = 7944CDBE22C501D6000C0B1E /* dev-v1.1.json */; };
79458104233086EB00024429 /* distilbert-64-12.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 79458103233086EB00024429 /* distilbert-64-12.mlmodel */; };
79458105233086EB00024429 /* distilbert-64-12.mlmodel in Sources */ = {isa = PBXBuildFile; fileRef = 79458103233086EB00024429 /* distilbert-64-12.mlmodel */; };
796DF51022E0EB1D00140C02 /* GPT2Tokenizer.swift in Sources */ = {isa = PBXBuildFile; fileRef = 796DF50F22E0EB1D00140C02 /* GPT2Tokenizer.swift */; };
796DF51522E0EE7800140C02 /* gpt2-merges.txt in Resources */ = {isa = PBXBuildFile; fileRef = 796DF51322E0EE7800140C02 /* gpt2-merges.txt */; };
796DF51622E0EE7800140C02 /* gpt2-vocab.json in Resources */ = {isa = PBXBuildFile; fileRef = 796DF51422E0EE7800140C02 /* gpt2-vocab.json */; };
Expand All @@ -34,6 +36,7 @@
796DF57822E1047B00140C02 /* gpt2-vocab.json in Resources */ = {isa = PBXBuildFile; fileRef = 796DF51422E0EE7800140C02 /* gpt2-vocab.json */; };
796DF57F22E1209B00140C02 /* encoded_tokens.json in Resources */ = {isa = PBXBuildFile; fileRef = 796DF57D22E1209B00140C02 /* encoded_tokens.json */; };
796DF58722E2727000140C02 /* GPT2.swift in Sources */ = {isa = PBXBuildFile; fileRef = 796DF58622E2727000140C02 /* GPT2.swift */; };
79E5A6A223303A6500EC42C5 /* DistilBERTPerfTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79E5A6A123303A6500EC42C5 /* DistilBERTPerfTests.swift */; };
79F2CC5C22C50078009F8551 /* AppDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CC5B22C50078009F8551 /* AppDelegate.swift */; };
79F2CC5E22C50078009F8551 /* SceneDelegate.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CC5D22C50078009F8551 /* SceneDelegate.swift */; };
79F2CC6022C50078009F8551 /* ViewController.swift in Sources */ = {isa = PBXBuildFile; fileRef = 79F2CC5F22C50078009F8551 /* ViewController.swift */; };
Expand Down Expand Up @@ -87,6 +90,7 @@
791D169222EA4918004D7A79 /* MultiArrayUtilsTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MultiArrayUtilsTests.swift; sourceTree = "<group>"; };
7934AF8722EA4CEC00396AD5 /* gpt2-512.mlmodel */ = {isa = PBXFileReference; lastKnownFileType = file.mlmodel; path = "gpt2-512.mlmodel"; sourceTree = "<group>"; };
7944CDBE22C501D6000C0B1E /* dev-v1.1.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = "dev-v1.1.json"; sourceTree = "<group>"; };
79458103233086EB00024429 /* distilbert-64-12.mlmodel */ = {isa = PBXFileReference; lastKnownFileType = file.mlmodel; path = "distilbert-64-12.mlmodel"; sourceTree = "<group>"; };
796DF50F22E0EB1D00140C02 /* GPT2Tokenizer.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GPT2Tokenizer.swift; sourceTree = "<group>"; };
796DF51322E0EE7800140C02 /* gpt2-merges.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = "gpt2-merges.txt"; sourceTree = "<group>"; };
796DF51422E0EE7800140C02 /* gpt2-vocab.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = "gpt2-vocab.json"; sourceTree = "<group>"; };
Expand All @@ -104,6 +108,7 @@
796DF56822E1026800140C02 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
796DF57D22E1209B00140C02 /* encoded_tokens.json */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.json; path = encoded_tokens.json; sourceTree = "<group>"; };
796DF58622E2727000140C02 /* GPT2.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = GPT2.swift; sourceTree = "<group>"; };
79E5A6A123303A6500EC42C5 /* DistilBERTPerfTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = DistilBERTPerfTests.swift; sourceTree = "<group>"; };
79F2CC5822C50078009F8551 /* CoreMLBert.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = CoreMLBert.app; sourceTree = BUILT_PRODUCTS_DIR; };
79F2CC5B22C50078009F8551 /* AppDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AppDelegate.swift; sourceTree = "<group>"; };
79F2CC5D22C50078009F8551 /* SceneDelegate.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = SceneDelegate.swift; sourceTree = "<group>"; };
Expand Down Expand Up @@ -172,6 +177,7 @@
79F2CC8D22C55413009F8551 /* tokenized_questions.json */,
79F2CC9F22C666C7009F8551 /* question_tokens.json */,
79F2CC9022C5590C009F8551 /* BERTSQUADFP16.mlmodel */,
79458103233086EB00024429 /* distilbert-64-12.mlmodel */,
796DF51322E0EE7800140C02 /* gpt2-merges.txt */,
796DF51422E0EE7800140C02 /* gpt2-vocab.json */,
796DF57D22E1209B00140C02 /* encoded_tokens.json */,
Expand Down Expand Up @@ -248,6 +254,7 @@
children = (
79F2CC7222C5007B009F8551 /* BertTokenizerTests.swift */,
79F2CC9D22C57825009F8551 /* BertForQATests.swift */,
79E5A6A123303A6500EC42C5 /* DistilBERTPerfTests.swift */,
79F2CC7422C5007B009F8551 /* Info.plist */,
);
path = CoreMLBertTests;
Expand Down Expand Up @@ -481,6 +488,7 @@
796DF51A22E0FF7A00140C02 /* GPT2ByteEncoder.swift in Sources */,
79F2CC9722C56891009F8551 /* Math.swift in Sources */,
79F2CC5C22C50078009F8551 /* AppDelegate.swift in Sources */,
79458104233086EB00024429 /* distilbert-64-12.mlmodel in Sources */,
796DF51022E0EB1D00140C02 /* GPT2Tokenizer.swift in Sources */,
79F2CC9422C56693009F8551 /* BertForQuestionAnswering.swift in Sources */,
79F2CC9122C5590C009F8551 /* BERTSQUADFP16.mlmodel in Sources */,
Expand All @@ -497,8 +505,10 @@
79F2CC9E22C57825009F8551 /* BertForQATests.swift in Sources */,
79F2CC7322C5007B009F8551 /* BertTokenizerTests.swift in Sources */,
79F2CC9C22C57731009F8551 /* BertTokenizer.swift in Sources */,
79458105233086EB00024429 /* distilbert-64-12.mlmodel in Sources */,
79F2CC8C22C549C1009F8551 /* Utils.swift in Sources */,
79F2CC9522C56693009F8551 /* BertForQuestionAnswering.swift in Sources */,
79E5A6A223303A6500EC42C5 /* DistilBERTPerfTests.swift in Sources */,
79F7060E22EA0CA900C4432C /* BERTSQUADFP16.mlmodel in Sources */,
79F2CC9B22C57132009F8551 /* MLMultiArray+Utils.swift in Sources */,
79F2CC9822C56891009F8551 /* Math.swift in Sources */,
Expand Down
37 changes: 37 additions & 0 deletions CoreMLBertTests/DistilBERTPerfTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//
// DistilBERTPerfTests.swift
// CoreMLBertTests
//
// Created by Julien Chaumond on 16/09/2019.
// Copyright © 2019 Hugging Face. All rights reserved.
//

import XCTest
import CoreML
@testable import CoreMLBert

class DistilBERTPerfTests: XCTestCase {
let context = "Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50."

let question = "Which NFL team represented the AFC at Super Bowl 50?"
let m = BertForQuestionAnswering()
let mDistilbert = distilbert_64_12()

func testPerformanceNakedBERTModel() {
let input = m.featurizeTokens(question: question, context: context)

self.measure {
_ = try! m.model.prediction(input: input)
}
}

func testPerformanceDistilBERTModel() {
let input_ids = MLMultiArray.from(Array(repeating: 0, count: 64))

self.measure {
_ = try! mDistilbert.prediction(input_ids: input_ids)
/// print(output.output_logits)
}
}
}

3 changes: 3 additions & 0 deletions Resources/distilbert-512-6.mlmodel
Git LFS file not shown
3 changes: 3 additions & 0 deletions Resources/distilbert-64-12.mlmodel
Git LFS file not shown
2 changes: 1 addition & 1 deletion Sources/BertForQuestionAnswering.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Foundation
import CoreML

class BertForQuestionAnswering {
private let model = BERTSQUADFP16()
internal let model = BERTSQUADFP16()
private let tokenizer = BertTokenizer()
public let seqLen = 384

Expand Down
40 changes: 40 additions & 0 deletions model_generation/distilbert-performance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# DistilBERT performance benchmarks

#### Full BERT-Squad with tokenization/featurization.

```
~/swift-coreml-transformers/CoreMLBertTests/BertForQATests.swift:75:
Test Case '-[CoreMLBertTests.BertForQATests testPerformanceExample]' measured [Time, seconds]
average: 1.583, relative standard deviation: 5.232%, values: [1.746976, 1.550390, 1.549479, 1.529654, 1.528065, 1.508825, 1.534357, 1.702786, 1.514816, 1.665434], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100
Test Case '-[CoreMLBertTests.BertForQATests testPerformanceExample]' passed (16.130 seconds).
```

---

#### Full BERT-Squad, only the inference.

```
~/swift-coreml-transformers/CoreMLBertTests/DistilBERTPerfTests.swift:23:
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceNakedModel]' measured [Time, seconds]
average: 1.118, relative standard deviation: 5.919%, values: [1.195310, 1.068182, 1.131890, 1.251984, 1.095551, 1.186633, 1.060465, 1.072363, 1.061609, 1.059508], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceNakedModel]' passed (11.822 seconds).
```

---

#### DistilBERT, only the inference.

```
~/swift-coreml-transformers/CoreMLBertTests/DistilBERTPerfTests.swift:32:
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceDistilBERTModel]' measured [Time, seconds]
average: 0.319, relative standard deviation: 0.548%, values: [0.321627, 0.321694, 0.317964, 0.316413, 0.318463, 0.319897, 0.319386, 0.317997, 0.318780, 0.321835], performanceMetricID:com.apple.XCTPerformanceMetric_WallClockTime, baselineName: "", baselineAverage: , maxPercentRegression: 10.000%, maxPercentRelativeStandardDeviation: 10.000%, maxRegression: 0.100, maxStandardDeviation: 0.100
Test Case '-[CoreMLBertTests.DistilBERTPerfTests testPerformanceDistilBERTModel]' passed (3.466 seconds).
```

---
155 changes: 155 additions & 0 deletions model_generation/distilbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import coremltools
import coremltools.models.datatypes as datatypes
import numpy as np
import torch
from coremltools.models import neural_network as neural_network
from coremltools.models.utils import save_spec
# get weights
from pytorch_transformers.modeling_distilbert import (DistilBertConfig,
DistilBertModel,
TransformerBlock)

model = DistilBertModel.from_pretrained('distilbert-base-uncased-distilled-squad')
config: DistilBertConfig = model.config

sequence_length = config.max_position_embeddings # 512
steps = config.n_layers # 6


# build model
input_features = [
('input_ids', datatypes.Array(sequence_length)),
]
output_features = [
('output_logits', None)
]

builder = neural_network.NeuralNetworkBuilder(
input_features,
output_features,
mode=None,
disable_rank5_shape_mapping=True,
)
builder.add_expand_dims(
name='input_ids_expanded_to_rank5',
input_name='input_ids',
output_name='input_ids_expanded_to_rank5',
axes=(1, 2, 3, 4)
)
builder.add_embedding(
name='token_embeddings',
input_name='input_ids_expanded_to_rank5',
output_name='token_embeddings',
W=model.embeddings.word_embeddings.weight.data.numpy().transpose(), # shape (768, 30522)
b=None,
input_dim=config.vocab_size,
output_channels=768,
has_bias=False,
)
builder.add_mvn(
name='embeddings_ln',
input_name=f"token_embeddings",
output_name=f"embeddings_ln",
across_channels=True,
normalize_variance=True,
epsilon=model.embeddings.LayerNorm.eps
)
builder.add_scale(
name=f"embeddings_ln_scaled",
input_name=f"embeddings_ln",
output_name=f'{0}_previous_block',
# output_name=f'output_logits',
W=model.embeddings.LayerNorm.weight.data.numpy().reshape((1, 1, 768, 1, 1)),
b=model.embeddings.LayerNorm.bias.data.numpy().reshape((1, 1, 768, 1, 1)),
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)

for i in range(steps):
print(i)
layer: TransformerBlock = model.transformer.layer[i]

# MultiHeadSelfAttention
## wip
# sa_layer_norm
builder.add_mvn(
name=f"{i}_block_ln_2",
input_name=f"{i}_block_xa_sum",
output_name=f"{i}_block_ln_2",
across_channels=True,
normalize_variance=True,
epsilon=layer.sa_layer_norm.eps
)
builder.add_scale(
name=f"{i}_block_ln_2_scaled",
input_name=f"{i}_block_ln_2",
output_name=f"{i}_block_ln_2_scaled",
W=layer.sa_layer_norm.weight.data.numpy().reshape((1, 1, 768, 1, 1)),
b=layer.sa_layer_norm.bias.data.numpy().reshape((1, 1, 768, 1, 1)),
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)

# Feed Forward Network
builder.add_inner_product(
name=f"{i}_block_mlp_conv_fc",
input_name=f"{i}_block_ln_2_scaled",
output_name=f"{i}_block_mlp_conv_fc",
input_channels=768,
output_channels=3072,
W=layer.ffn.lin1.weight.data.numpy().transpose().reshape((1, 768, 3072, 1, 1)),
b=layer.ffn.lin1.bias.data.numpy().reshape((1, 1, 3072, 1, 1)),
has_bias=True
)
builder.add_gelu(
name=f"{i}_block_mlp_gelu",
input_name=f"{i}_block_mlp_conv_fc",
output_name=f"{i}_block_mlp_gelu",
mode='TANH_APPROXIMATION'
)
builder.add_inner_product(
name=f"{i}_block_mlp_conv_proj",
input_name=f"{i}_block_mlp_gelu",
output_name=f"{i}_block_mlp_conv_proj",
input_channels=3072,
output_channels=768,
W=layer.ffn.lin2.weight.data.numpy().transpose().reshape((1, 3072, 768, 1, 1)),
b=layer.ffn.lin2.bias.data.numpy().reshape((1, 1, 768, 1, 1)),
has_bias=True
)

# output_layer_norm
# Input: (1, seq, 768, 1, 1), Output:
builder.add_mvn(
name=f"{i}_output_ln",
input_name=f"{i}_block_mlp_conv_proj",
output_name=f"{i}_output_ln",
across_channels=True,
normalize_variance=True,
epsilon=layer.output_layer_norm.eps
)
builder.add_scale(
name=f"{i}_output_ln_scaled",
input_name=f"{i}_output_ln",
output_name=f"{i}_output_ln_scaled",
W=layer.output_layer_norm.weight.data.numpy().reshape((1, 1, 768, 1, 1)),
b=layer.output_layer_norm.bias.data.numpy().reshape((1, 1, 768, 1, 1)),
has_bias=True,
shape_scale=[768],
shape_bias=[768]
)



# compile spec to model
# mlmodel = coremltools.models.MLModel(builder.spec)
# input_data = {
# 'input_ids': np.array([ 7592, 1010, 2026, 3899, 2003, 10140 ])
# }
# predictions = mlmodel.predict(input_data)["output_logits"]
# print(predictions)


save_spec(builder.spec, f'../Resources/distilbert-{sequence_length}-{steps}.mlmodel')
4 changes: 2 additions & 2 deletions model_generation/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
coremltools==3.0b3
pytorch-transformers==1.0.0
coremltools==3.0b6
pytorch-transformers==1.2.0

0 comments on commit f9ce549

Please sign in to comment.