-
Notifications
You must be signed in to change notification settings - Fork 7
/
config_data.py
73 lines (64 loc) · 2.31 KB
/
config_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Taken from https://github.com/asyml/texar/blob/master/examples/bert/config_data.py and modified
max_seq_length = 96
num_train_data = 1731
tfrecord_data_dir = "data"
train_batch_size = 28
max_train_epoch = 250
display_steps = 10 # Print training loss every display_steps; -1 to disable
eval_steps = 75 # Eval on the dev set every eval_steps; -1 to disable
eval_batch_size = 28
test_batch_size = 28
max_decoding_length = 96
feature_original_types = {
# Reading features from TFRecord data file.
# E.g., Reading feature "src_input_ids" as dtype `tf.int64`;
# "FixedLenFeature" indicates its length is fixed for all data instances;
# and the sequence length is limited by `max_seq_length`.
"src_input_ids": ["tf.int64", "FixedLenFeature", max_seq_length],
"src_segment_ids": ["tf.int64", "FixedLenFeature", max_seq_length],
"tgt_input_ids": ["tf.int64", "FixedLenFeature", max_seq_length],
"tgt_labels": ["tf.int64", "FixedLenFeature", max_seq_length]
}
feature_convert_types = {
# Converting feature dtype after reading. E.g.,
# Converting the dtype of feature "src_input_ids" from `tf.int64` (as above)
# to `tf.int32`
"src_input_ids": "tf.int32",
"src_segment_ids": "tf.int32",
"tgt_input_ids": "tf.int32",
"tgt_labels": "tf.int32"
}
train_hparam = {
"allow_smaller_final_batch": False,
"batch_size": train_batch_size,
"dataset": {
"data_name": "data",
"feature_convert_types": feature_convert_types,
"feature_original_types": feature_original_types,
"files": "{}/train.tf_record".format(tfrecord_data_dir)
},
"shuffle": True,
"shuffle_buffer_size": 100
}
eval_hparam = {
"allow_smaller_final_batch": True,
"batch_size": eval_batch_size,
"dataset": {
"data_name": "data",
"feature_convert_types": feature_convert_types,
"feature_original_types": feature_original_types,
"files": "{}/eval.tf_record".format(tfrecord_data_dir)
},
"shuffle": False
}
test_hparam = {
"allow_smaller_final_batch": True,
"batch_size": test_batch_size,
"dataset": {
"data_name": "data",
"feature_convert_types": feature_convert_types,
"feature_original_types": feature_original_types,
"files": "{}/test.tf_record".format(tfrecord_data_dir)
},
"shuffle": False
}