Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep FFN output layer in float32 for T5 models #1239

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ namespace ctranslate2 {
// Returns true if the variable can be converted to another type.
virtual bool is_convertible(const StorageView& variable, const std::string& name) const;

// Returns true if the variable should be kept in float32 precision.
virtual bool keep_in_float32(const std::string& variable_name) const;

// Models can override these methods to execute some transformations if needed
// (e.g. a variable name changed in a newer spec revision).
virtual void register_variable(std::string name, StorageView variable);
Expand Down
2 changes: 2 additions & 0 deletions python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,8 @@ def set_ffn(self, spec, module):
self.set_linear(spec.linear_0, module.DenseReluDense.wi)

self.set_linear(spec.linear_1, module.DenseReluDense.wo)
spec.linear_1.keep_in_float32 = True

self.set_layer_norm(spec.layer_norm, module.layer_norm)

def set_self_attention(self, spec, module):
Expand Down
1 change: 1 addition & 0 deletions python/ctranslate2/specs/common_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self):
self.weight = None
self.weight_scale = model_spec.OPTIONAL
self.bias = model_spec.OPTIONAL
self.keep_in_float32 = False

def has_bias(self):
return not isinstance(self.bias, str)
Expand Down
6 changes: 5 additions & 1 deletion python/ctranslate2/specs/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@ def _quantize(spec, name, value):
is_quantizable = hasattr(spec, "%s_scale" % key)
is_convertible = value.dtype in ("float32", "float16", "bfloat16")

if is_quantizable:
if hasattr(spec, "keep_in_float32") and spec.keep_in_float32.numpy():
if is_convertible:
value = value.to("float32")

elif is_quantizable:
if quantization == "int16":
value = value.to("float32").numpy()
# Represent the value with 10 bits so the multiplication is 20 bits
Expand Down
4 changes: 4 additions & 0 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,10 @@ namespace ctranslate2 {
/*trans_b=*/true,
output,
bias);
} else if (input.dtype() != weight->dtype()) {
StorageView tmp_output(weight->dtype(), weight->device());
_gemm_op(input.to(weight->dtype()), *weight, tmp_output, nullptr, bias);
output = tmp_output.to(output.dtype());
} else {
_gemm_op(input, *weight, output, nullptr, bias);
}
Expand Down
14 changes: 14 additions & 0 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ namespace ctranslate2 {
const auto& name = variable_pair.first;
auto& variable = *variable_pair.second;

if (keep_in_float32(name))
continue;

// Convert "weight" variables to the expected compute type.
// Other float variables (e.g. biases) may be converted to another float type.
if (is_quantizable(name))
Expand Down Expand Up @@ -253,6 +256,15 @@ namespace ctranslate2 {
return !variable.is_scalar() && name.find("_scale") == std::string::npos;
}

bool Model::keep_in_float32(const std::string& variable_name) const {
const size_t pos = variable_name.rfind('/');
if (pos == std::string::npos)
return false;

const std::string scope = variable_name.substr(0, pos);
return get_flag_with_default(scope + "/keep_in_float32", false);
}

void Model::ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype) {
Expand Down Expand Up @@ -327,6 +339,8 @@ namespace ctranslate2 {
for (const auto& variable_pair : _variable_index) {
const std::string& name = variable_pair.first;
const StorageView& variable = *variable_pair.second;
if (keep_in_float32(name))
continue;
if (is_quantizable(name)) {
weight_type = variable.dtype();
} else if (is_convertible(variable, name)) {
Expand Down