-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add python script to convert quality json model to binary and vice versa
- Loading branch information
1 parent
962d419
commit cf4a3a2
Showing
2 changed files
with
118 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Quality Model Tool | ||
|
||
- The python script ```qualityestimator_json_to_bin.py``` converts a logistic regressor quality estimator model from json to binary file and vice versa. | ||
|
||
- To converts a json to binary: | ||
|
||
```console | ||
python qualityestimator_json_to_bin.py --to_json qe_model.json --out qe_model.bin | ||
``` | ||
|
||
- To converts a binary to json: | ||
|
||
```console | ||
python qualityestimator_json_to_bin.py --from_json qe_model.bin --out qe_model.json | ||
``` | ||
|
||
- The json must follow this structure: | ||
```json | ||
{ | ||
"mean_": [ 0.0, 0.0, 0.0, 0.0, ], | ||
"scale_": [ 0.0, 0.0, 0.0, 0.0, ], | ||
"coef_": [ 0.0, 0.0, 0.0, 0.0, ], | ||
"intercept_": 0.0 | ||
} | ||
``` | ||
|
||
- The binary file will have the following structure defined on [LogisticRegressorQualityEstimator](https://github.com/browsermt/bergamot-translator/blob/main/src/translator/quality_estimator.h#L100-L108). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import argparse | ||
import json | ||
import struct | ||
from collections import namedtuple | ||
|
||
# magic(uint64_t), lrParametersDims(uint64_t) | ||
Header_fmt = "<1Q1Q" | ||
Header_len = struct.calcsize(Header_fmt) | ||
|
||
QE_MAGIC_NUMBER = 8704388732126802304 | ||
|
||
|
||
def from_qe_file(file): | ||
magic, paramDim = struct.unpack(Header_fmt, file.read(Header_len)) | ||
|
||
if magic != QE_MAGIC_NUMBER: | ||
print("Invalid quality estimator file.") | ||
exit(1) | ||
|
||
# stds[N] + means[N] + coefficients[N] + intercept | ||
lrParams_fmt = f"<{3*paramDim+1}f" | ||
|
||
lrParams_size = struct.calcsize(lrParams_fmt) | ||
|
||
params = list(struct.unpack(lrParams_fmt, file.read(lrParams_size))) | ||
|
||
lrParams = {} | ||
lrParams["stds"] = params[:paramDim] | ||
lrParams["means"] = params[paramDim : 2 * paramDim] | ||
lrParams["coefficients"] = params[2 * paramDim : 3 * paramDim] | ||
lrParams["intercept"] = params[3 * paramDim] | ||
|
||
return lrParams | ||
|
||
|
||
def to_binary(lrParams): | ||
|
||
paramDims = len(lrParams["stds"]) | ||
|
||
if paramDims != len(lrParams["means"]) and paramDims != len( | ||
lrParams["coefficients"] | ||
): | ||
print("Invalid LR parameters.") | ||
exit(1) | ||
|
||
lrParams_fmt = f"<{3*paramDims+1}f" | ||
|
||
params = ( | ||
lrParams["stds"] | ||
+ lrParams["means"] | ||
+ lrParams["coefficients"] | ||
+ [lrParams["intercept"]] | ||
) | ||
|
||
return struct.pack(Header_fmt, QE_MAGIC_NUMBER, paramDims) + struct.pack( | ||
lrParams_fmt, *params | ||
) | ||
|
||
|
||
parser = argparse.ArgumentParser(description="Read and write quality estimator files.") | ||
parser.add_argument( | ||
"--to_json", type=argparse.FileType("rb"), help="Read quality estimator file" | ||
) | ||
parser.add_argument( | ||
"--from_json", | ||
type=argparse.FileType("r"), | ||
help="Read json file and generate quality estimator binary", | ||
) | ||
parser.add_argument( | ||
"--out", | ||
type=argparse.FileType("wb"), | ||
help="Output generated data from to_json or from_json option", | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
output = None | ||
|
||
if args.to_json: | ||
output = json.dumps(from_qe_file(args.to_json), indent=3) | ||
elif args.from_json: | ||
output = to_binary(json.loads(args.from_json.read())) | ||
|
||
if output is None: | ||
exit(0) | ||
|
||
if args.out: | ||
args.out.write(output.encode("UTF-8") if type(output) is str else output) | ||
args.out.close() | ||
else: | ||
print(output) |