From 5c88f1e74246a40f69f720f4551151712158768b Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Mon, 24 May 2021 15:51:15 +0800 Subject: [PATCH 1/9] script for hybrid embedding --- .../wdl_train_eval_with_hybrid_embd.py | 353 ++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py new file mode 100644 index 0000000..c2e451d --- /dev/null +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -0,0 +1,353 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import argparse +import oneflow as flow +import datetime +import os +import glob +from sklearn.metrics import roc_auc_score +import numpy as np +import time + +def str_list(x): + return x.split(',') +parser = argparse.ArgumentParser() +parser.add_argument('--dataset_format', type=str, default='ofrecord', help='ofrecord or onerec') +parser.add_argument('--train_data_dir', type=str, default='') +parser.add_argument('--train_data_part_num', type=int, default=1) +parser.add_argument('--train_part_name_suffix_length', type=int, default=-1) +parser.add_argument('--eval_data_dir', type=str, default='') +parser.add_argument('--eval_data_part_num', type=int, default=1) +parser.add_argument('--eval_part_name_suffix_length', type=int, default=-1) +parser.add_argument('--eval_batchs', type=int, default=20) +parser.add_argument('--eval_interval', type=int, default=1000) +parser.add_argument('--batch_size', type=int, default=16384) +parser.add_argument('--learning_rate', type=float, default=1e-3) +parser.add_argument('--wide_vocab_size', type=int, default=3200000) +parser.add_argument('--deep_vocab_size', type=int, default=3200000) +parser.add_argument('--hf_wide_vocab_size', type=int, default=1600000) +parser.add_argument('--hf_deep_vocab_size', type=int, default=1600000) +parser.add_argument('--deep_embedding_vec_size', type=int, default=16) +parser.add_argument('--deep_dropout_rate', type=float, default=0.5) +parser.add_argument('--num_dense_fields', type=int, default=13) +parser.add_argument('--num_wide_sparse_fields', type=int, default=2) +parser.add_argument('--num_deep_sparse_fields', type=int, default=26) +parser.add_argument('--max_iter', type=int, default=30000) +parser.add_argument('--loss_print_every_n_iter', type=int, default=100) +parser.add_argument('--gpu_num_per_node', type=int, default=8) +parser.add_argument('--num_nodes', type=int, default=1, + help='node/machine number for training') +parser.add_argument('--node_ips', type=str_list, default=['192.168.1.13', '192.168.1.14'], + help='nodes ip list for training, devided by ",", length >= num_nodes') +parser.add_argument("--ctrl_port", type=int, default=50051, help='ctrl_port for multinode job') +parser.add_argument('--hidden_units_num', type=int, default=7) +parser.add_argument('--hidden_size', type=int, default=1024) + +FLAGS = parser.parse_args() + +#DEEP_HIDDEN_UNITS = [1024, 1024]#, 1024, 1024, 1024, 1024, 1024] +DEEP_HIDDEN_UNITS = [FLAGS.hidden_size for i in range(FLAGS.hidden_units_num)] + +def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, shuffle=True): + if FLAGS.dataset_format == 'ofrecord': + return _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length, + shuffle) + elif FLAGS.dataset_format == 'onerec': + return _data_loader_onerec(data_dir, batch_size, shuffle) + elif FLAGS.dataset_format == 'synthetic': + return _data_loader_synthetic(batch_size) + else: + assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`." + + + +def _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, + shuffle=True): + assert data_dir + print('load ofrecord data form', data_dir) + ofrecord = flow.data.ofrecord_reader(data_dir, + batch_size=batch_size, + data_part_num=data_part_num, + part_name_suffix_length=part_name_suffix_length, + random_shuffle=shuffle, + shuffle_after_epoch=shuffle) + def _blob_decoder(bn, shape, dtype=flow.int32): + return flow.data.OFRecordRawDecoder(ofrecord, bn, shape=shape, dtype=dtype) + labels = _blob_decoder("labels", (1,)) + dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float) + wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,)) + deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,)) + return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) + + +def _data_loader_synthetic(batch_size): + devices = ['{}:0-{}'.format(i, FLAGS.gpu_num_per_node - 1) for i in range(FLAGS.num_nodes)] + with flow.scope.placement("cpu", devices): + def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)): + return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size, + initializer=initializer) + labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32)) + dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float, + initializer=flow.random_uniform_initializer()) + wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,)) + deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,)) + print('use synthetic data') + return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) + + +def _data_loader_onerec(data_dir, batch_size, shuffle): + assert data_dir + print('load onerec data form', data_dir) + files = glob.glob(os.path.join(data_dir, '*.onerec')) + readdata = flow.data.onerec_reader(files=files, batch_size=batch_size, random_shuffle=shuffle, + verify_example=False, + shuffle_buffer_size=64, + shuffle_after_epoch=shuffle) + + def _blob_decoder(bn, shape, dtype=flow.int32): + return flow.data.onerec_decoder(readdata, key=bn, shape=shape, dtype=dtype) + + labels = _blob_decoder('labels', shape=(1,)) + dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float) + wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,)) + deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,)) + return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) + + +def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size): + b, s = ids.shape + ids = flow.flatten(ids) + unique_ids, unique_ids_idx, _, _ = flow.experimental.unique_with_counts(ids) + hf_vocab_size_constant = flow.constant(hf_vocab_size, dtype=flow.int32) + hf_indices = flow.argwhere(flow.math.less(unique_ids, hf_vocab_size_constant)) + lf_indices = flow.argwhere(flow.math.greater_equal(unique_ids, hf_vocab_size_constant)) + hf_ids = flow.gather_nd(params=unique_ids, indices=hf_indices) + lf_ids = flow.gather_nd(params=unique_ids, indices=lf_indices) + hf_embedding_table = flow.get_variable( + name=f'hf_{name}', + shape=(hf_vocab_size, embedding_size), + dtype=flow.float, + initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), + ) + hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)#, no_duplicates_in_indices=True) + lf_ids = lf_ids - hf_vocab_size_constant + with flow.scope.placement('cpu', '0:0'): + lf_embedding_table = flow.get_variable( + name=f'lf_{name}', + shape=(vocab_size - hf_vocab_size, embedding_size), + #shape=(vocab_size, embedding_size), + dtype=flow.float, + initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), + ) + lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)#, no_duplicates_in_indices=True) + unique_embedding = flow.reshape(flow.zeros_like(unique_ids, dtype=flow.float), (-1, 1)) * flow.constant(0.0, dtype=flow.float, shape=(1,embedding_size)) + # unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size)) + unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=hf_embedding, indices=hf_indices) + unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=lf_embedding, indices=lf_indices) + unique_embedding = flow.gather(params=unique_embedding, indices=unique_ids_idx) + unique_embedding = flow.cast_to_static_shape(unique_embedding) + unique_embedding = flow.reshape(unique_embedding, shape=(b, s*embedding_size)) + return unique_embedding + + +def _embedding(name, ids, embedding_size, vocab_size, split_axis=0): + ids = flow.parallel_cast(ids, distribute=flow.distribute.broadcast()) + params = flow.get_variable( + name=name, + shape=(vocab_size, embedding_size), + initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), + distribute=flow.distribute.split(split_axis), + ) + embedding = flow.gather(params=params, indices=ids) + embedding = flow.reshape(embedding, shape=(-1, embedding.shape[-1] * embedding.shape[-2])) + return embedding + + +# def _wide_embedding(wide_sparse_fields): +# wide_sparse_fields = flow.parallel_cast(wide_sparse_fields, distribute=flow.distribute.broadcast()) +# wide_embedding_table = flow.get_variable( +# name='wide_embedding', +# shape=(FLAGS.wide_vocab_size, 1), +# initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), +# distribute=flow.distribute.split(0), +# ) +# wide_embedding = flow.gather(params=wide_embedding_table, indices=wide_sparse_fields) +# wide_embedding = flow.reshape(wide_embedding, shape=(-1, wide_embedding.shape[-1] * wide_embedding.shape[-2])) +# return wide_embedding + +# def _deep_embedding(deep_sparse_fields): +# deep_sparse_fields = flow.parallel_cast(deep_sparse_fields, distribute=flow.distribute.broadcast()) +# deep_embedding_table = flow.get_variable( +# name='deep_embedding', +# shape=(FLAGS.deep_vocab_size, FLAGS.deep_embedding_vec_size), +# initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), +# distribute=flow.distribute.split(1), +# ) +# deep_embedding = flow.gather(params=deep_embedding_table, indices=deep_sparse_fields) +# deep_embedding = flow.parallel_cast(deep_embedding, distribute=flow.distribute.split(0), +# gradient_distribute=flow.distribute.split(2)) +# deep_embedding = flow.reshape(deep_embedding, shape=(-1, deep_embedding.shape[-1] * deep_embedding.shape[-2])) +# return deep_embedding + + +def _model(dense_fields, wide_sparse_fields, deep_sparse_fields): + # wide_embedding = _wide_embedding(wide_sparse_fields) + wide_embedding = _embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size) + # wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size, + # FLAGS.hf_wide_vocab_size) + wide_scores = flow.math.reduce_sum(wide_embedding, axis=[1], keepdims=True) + wide_scores = flow.parallel_cast(wide_scores, distribute=flow.distribute.split(0), + gradient_distribute=flow.distribute.broadcast()) + + # deep_embedding = _deep_embedding(deep_sparse_fields) + # deep_embedding = _embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size, + # FLAGS.deep_vocab_size, split_axis=1) + deep_embedding = _hybrid_embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size, + FLAGS.deep_vocab_size, FLAGS.hf_deep_vocab_size) + deep_features = flow.concat([deep_embedding, dense_fields], axis=1) + for idx, units in enumerate(DEEP_HIDDEN_UNITS): + deep_features = flow.layers.dense( + deep_features, + units=units, + kernel_initializer=flow.glorot_uniform_initializer(), + bias_initializer=flow.constant_initializer(0.0), + activation=flow.math.relu, + name='fc' + str(idx + 1) + ) + deep_features = flow.nn.dropout(deep_features, rate=FLAGS.deep_dropout_rate) + deep_scores = flow.layers.dense( + deep_features, + units=1, + kernel_initializer=flow.glorot_uniform_initializer(), + bias_initializer=flow.constant_initializer(0.0), + name='fc' + str(len(DEEP_HIDDEN_UNITS) + 1) + ) + + scores = wide_scores + deep_scores + return scores + + +global_loss = 0.0 +def _create_train_callback(step): + def nop(loss): + global global_loss + global_loss += loss.mean() + pass + + def print_loss(loss): + global global_loss + global_loss += loss.mean() + print(step+1, 'time', time.time(), 'loss', global_loss/FLAGS.loss_print_every_n_iter) + global_loss = 0.0 + + if (step + 1) % FLAGS.loss_print_every_n_iter == 0: + return print_loss + else: + return nop + + +def CreateOptimizer(args): + lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [args.learning_rate]) + return flow.optimizer.LazyAdam(lr_scheduler) + + +def _get_train_conf(): + train_conf = flow.FunctionConfig() + train_conf.default_data_type(flow.float) + train_conf.indexed_slices_optimizer_conf(dict(include_op_names=dict(op_name=['wide_embedding', 'deep_embedding']))) + return train_conf + + +@flow.global_function('train', _get_train_conf()) +def train_job(): + labels, dense_fields, wide_sparse_fields, deep_sparse_fields = \ + _data_loader(data_dir=FLAGS.train_data_dir, data_part_num=FLAGS.train_data_part_num, + batch_size=FLAGS.batch_size, + part_name_suffix_length=FLAGS.train_part_name_suffix_length, shuffle=True) + logits = _model(dense_fields, wide_sparse_fields, deep_sparse_fields) + loss = flow.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + opt = CreateOptimizer(FLAGS) + opt.minimize(loss) + loss = flow.math.reduce_mean(loss) + return loss + + +@flow.global_function() +def eval_job(): + labels, dense_fields, wide_sparse_fields, deep_sparse_fields = \ + _data_loader(data_dir=FLAGS.eval_data_dir, data_part_num=FLAGS.eval_data_part_num, + batch_size=FLAGS.batch_size, + part_name_suffix_length=FLAGS.eval_part_name_suffix_length, shuffle=False) + logits = _model(dense_fields, wide_sparse_fields, deep_sparse_fields) + loss = flow.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits) + predict = flow.math.sigmoid(logits) + return loss, predict, labels + + +def InitNodes(args): + if args.num_nodes > 1: + assert args.num_nodes <= len(args.node_ips) + flow.env.ctrl_port(args.ctrl_port) + nodes = [] + for ip in args.node_ips[:args.num_nodes]: + addr_dict = {} + addr_dict["addr"] = ip + nodes.append(addr_dict) + + flow.env.machine(nodes) + +def print_args(args): + print("=".ljust(66, "=")) + print("Running {}: num_gpu_per_node = {}, num_nodes = {}.".format( + 'OneFlow-WDL', args.gpu_num_per_node, args.num_nodes)) + print("=".ljust(66, "=")) + for arg in vars(args): + print("{} = {}".format(arg, getattr(args, arg))) + print("-".ljust(66, "-")) + #print("Time stamp: {}".format( + # str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S")))) + + +def main(): + print_args(FLAGS) + InitNodes(FLAGS) + flow.config.gpu_device_num(FLAGS.gpu_num_per_node) + flow.config.enable_model_io_v2(True) + flow.config.enable_debug_mode(True) + flow.config.collective_boxing.nccl_enable_all_to_all(True) + #flow.config.enable_numa_aware_cuda_malloc_host(True) + #flow.config.collective_boxing.enable_fusion(False) + check_point = flow.train.CheckPoint() + check_point.init() + for i in range(FLAGS.max_iter): + train_job().async_get(_create_train_callback(i)) + if (i + 1 ) % FLAGS.eval_interval == 0: + labels = np.array([[0]]) + preds = np.array([[0]]) + cur_time = time.time() + eval_loss = 0.0 + for j in range(FLAGS.eval_batchs): + loss, pred, ref = eval_job().get() + label_ = ref.numpy().astype(np.float32) + labels = np.concatenate((labels, label_), axis=0) + preds = np.concatenate((preds, pred.numpy()), axis=0) + eval_loss += loss.mean() + auc = roc_auc_score(labels[1:], preds[1:]) + print(i+1, "eval_loss", eval_loss/FLAGS.eval_batchs, "eval_auc", auc) + + +if __name__ == '__main__': + main() From cada466df3015e1d588cd368083003d60a54d298 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Mon, 31 May 2021 12:53:13 +0800 Subject: [PATCH 2/9] rm usless lines --- .../wdl_train_eval_with_hybrid_embd.py | 39 +++++-------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py index c2e451d..c3b265c 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -176,35 +176,7 @@ def _embedding(name, ids, embedding_size, vocab_size, split_axis=0): return embedding -# def _wide_embedding(wide_sparse_fields): -# wide_sparse_fields = flow.parallel_cast(wide_sparse_fields, distribute=flow.distribute.broadcast()) -# wide_embedding_table = flow.get_variable( -# name='wide_embedding', -# shape=(FLAGS.wide_vocab_size, 1), -# initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), -# distribute=flow.distribute.split(0), -# ) -# wide_embedding = flow.gather(params=wide_embedding_table, indices=wide_sparse_fields) -# wide_embedding = flow.reshape(wide_embedding, shape=(-1, wide_embedding.shape[-1] * wide_embedding.shape[-2])) -# return wide_embedding - -# def _deep_embedding(deep_sparse_fields): -# deep_sparse_fields = flow.parallel_cast(deep_sparse_fields, distribute=flow.distribute.broadcast()) -# deep_embedding_table = flow.get_variable( -# name='deep_embedding', -# shape=(FLAGS.deep_vocab_size, FLAGS.deep_embedding_vec_size), -# initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), -# distribute=flow.distribute.split(1), -# ) -# deep_embedding = flow.gather(params=deep_embedding_table, indices=deep_sparse_fields) -# deep_embedding = flow.parallel_cast(deep_embedding, distribute=flow.distribute.split(0), -# gradient_distribute=flow.distribute.split(2)) -# deep_embedding = flow.reshape(deep_embedding, shape=(-1, deep_embedding.shape[-1] * deep_embedding.shape[-2])) -# return deep_embedding - - def _model(dense_fields, wide_sparse_fields, deep_sparse_fields): - # wide_embedding = _wide_embedding(wide_sparse_fields) wide_embedding = _embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size) # wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size, # FLAGS.hf_wide_vocab_size) @@ -212,7 +184,6 @@ def _model(dense_fields, wide_sparse_fields, deep_sparse_fields): wide_scores = flow.parallel_cast(wide_scores, distribute=flow.distribute.split(0), gradient_distribute=flow.distribute.broadcast()) - # deep_embedding = _deep_embedding(deep_sparse_fields) # deep_embedding = _embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size, # FLAGS.deep_vocab_size, split_axis=1) deep_embedding = _hybrid_embedding('deep_embedding', deep_sparse_fields, FLAGS.deep_embedding_vec_size, @@ -267,7 +238,15 @@ def CreateOptimizer(args): def _get_train_conf(): train_conf = flow.FunctionConfig() train_conf.default_data_type(flow.float) - train_conf.indexed_slices_optimizer_conf(dict(include_op_names=dict(op_name=['wide_embedding', 'deep_embedding']))) + indexed_slices_ops = [ + 'wide_embedding', + 'deep_embedding', + 'hf_wide_embedding', + 'hf_deep_embedding', + 'lf_wide_embedding', + 'lf_deep_embedding', + ] + train_conf.indexed_slices_optimizer_conf(dict(include_op_names=dict(op_name=indexed_slices_ops))) return train_conf From d06761bd51c349e7bbf077db9830a5d8867be2bb Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Mon, 31 May 2021 16:26:24 +0800 Subject: [PATCH 3/9] how to make hf dataset --- .../how_to_make_hf_dataset.md | 293 ++++++++++++++++++ 1 file changed, 293 insertions(+) create mode 100644 ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md diff --git a/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md new file mode 100644 index 0000000..a153a59 --- /dev/null +++ b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md @@ -0,0 +1,293 @@ +# How to Make High Frequency Dataset for OneFlow-WDL + +[**how_to_make_ofrecord_for_wdl**](https://github.com/Oneflow-Inc/OneFlow-Benchmark/blob/master/ClickThroughRate/WideDeepLearning/how_to_make_ofrecord_for_wdl.md)一文中介绍了如何利用spark制作OneFlow-WDL使用的ofrecord数据集,GPU&CPU混合embedding的实践中,这个数据集就不好用了,主要原因是没有按照词频排序,所以需要制作新的数据集。本文将持续上文中的套路,介绍一下如何制作按照词频排序的数据集。 + +## 数据集及预处理 + +数据由[CriteoLabs](http://labs.criteo.com/2014/02/kaggle-display-advertising-challenge-dataset/)提供。原始数据包括三个部分:一个标签列`labels`、13个整型特征`I列`、26个分类特征`C列`。数据处理后: + +- `I列`转换为`dense_fields`; +- `C列`转换为`deep_sparse_fields`; +- `C列`中的`C1 C2`、`C3 C4`构成了交叉特征,形成了`wide_sparse_fields`。 + +数据经过处理后保存成`ofrecord`格式,结构如下: + +```bash +root + |-- deep_sparse_fields: array (nullable = true) + | |-- element: integer (containsNull = true) + |-- dense_fields: array (nullable = true) + | |-- element: float (containsNull = true) + |-- labels: integer (nullable = true) + |-- wide_sparse_fields: array (nullable = true) + | |-- element: integer (containsNull = true) + +``` + +## step0 准备工作 + +这一步主要是导入相关的库,并且准备一个临时目录。后面的很多步骤中都主动把中间结果保存到临时目录中,这样能够节省内存,而且方便中断恢复操作。 + +```scala +import org.apache.spark.sql.types._ +import org.apache.spark.sql.functions.{when, _} +import org.apache.spark.sql.DataFrame +import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, MinMaxScaler} +import org.apache.spark.ml.linalg._ + +import java.nio.file.{Files, Paths} +val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +Files.createDirectories(Paths.get(tmp_dir)) +``` + +## step1 导入数据 + +这一步中读入原始数据集,并根据需求做了如下操作: + +1. 给读入的每一列命名[label, I1,...,I13, C1,...,C26] +2. 给每一条数据加上`id`,后面所有的表的合并操作都基于这个`id` +3. 将`I列`转换成整型 +4. `I列`和`C列`空白处补`NaN` +5. `features`是后面经常用到的DataFrame + +```scala +// load input file +var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.bak") +// var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.txt") + +// rename columns [label, I1,...,I13, C1,...,C26] +val NUM_INTEGER_COLUMNS = 13 +val NUM_CATEGORICAL_COLUMNS = 26 + +// val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>s"I$id"} +val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>$"I$id"} // note +val categorical_cols = (1 to NUM_CATEGORICAL_COLUMNS).map{id=>s"C$id"} +val feature_cols = integer_cols.map{c=>c.toString} ++ categorical_cols +val all_cols = (Seq(s"labels") ++ feature_cols) +input = input.toDF(all_cols: _*).withColumn("id", monotonically_increasing_id()) + +input = input.withColumn("labels", col("labels").cast(IntegerType)) +// cast integer columns to int +for(i <- 1 to NUM_INTEGER_COLUMNS) { + val col_name = s"I$i" + input = input.withColumn(col_name, col(col_name).cast(IntegerType)) +} + +// replace `null` with `NaN` +val features = input.na.fill(Int.MinValue, integer_cols.map{c=>c.toString}).na.fill("80000000", categorical_cols) + +// dump features as parquet format +val features_dir = tmp_dir ++ "/filled_features" +features.write.mode("overwrite").parquet(features_dir) +``` + +duration: 1 min + +## step2 处理整型特征生成`dense_fields` + +需要两个步骤: + +1. 循环处理每一个`I列`,编码映射后保存到临时文件夹; +2. 从临时文件夹中读取后转换成`dense_fields`并保存。 + +### `I列`编码映射 + +对于每一个整型特征: + +- 计算每个特征值的频次 +- 频次小于6的特征值修改为NaN +- 特征编码 +- 进行normalize操作,或仅+1操作 +- 保存该列到临时文件夹 + +```scala +val features_dir = tmp_dir ++ "/filled_features" +val features = spark.read.parquet(features_dir) + +// integer features +println("create integer feature cols") +val normalize_dense = 1 +val nanValue = Int.MinValue +val getItem = udf((v: Vector, i: Int) => v(i).toFloat) +for(column_name <- integer_cols) { + val col_name = column_name.toString + println(col_name) + val col_index = col_name ++ "_index" + val uniqueValueCounts = features.groupBy(col_name).count() + val df = features.join(uniqueValueCounts, Seq(col_name)) + .withColumn(col_name, when(col("count") >= 6, col(col_name)).otherwise(nanValue)) + .select("id", col_name) + val indexedDf = new StringIndexer().setInputCol(col_name) + .setOutputCol(col_index) + .fit(df).transform(df) + .drop(col_name) // trick: drop col_name here and will be reused later + + var scaledDf = spark.emptyDataFrame + if (normalize_dense > 0) { + val assembler = new VectorAssembler().setInputCols(Array(col_index)).setOutputCol("vVec") + val df= assembler.transform(indexedDf) + scaledDf = new MinMaxScaler().setInputCol("vVec") + .setOutputCol(col_name) + .fit(df).transform(df) + .select("id", col_name) + } else { + scaledDf = indexedDf.withColumn(col_name, col(col_index) + lit(1)) // trick: reuse col_name + .select("id", col_name) + //.withColumn(col_name, col(col_index).cast(IntegerType)) + } + val col_dir = tmp_dir ++ "/" ++ col_name + scaledDf = scaledDf.withColumn(col_name, getItem(column_name, lit(0))) + scaledDf.write.mode("overwrite").parquet(col_dir) + scaledDf.printSchema +} +``` + +duration: 3*13 ~= 40 min + +### 合并所有`I列`形成`dense_fields` + +- 从临时文件夹里分别读取各列,并合并到一个dataframe `df`里; +- 将`df`里的`I列`合并成`dense_fields`; +- 将`dense_fields`保存到临时文件夹。 + +```scala +val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>s"I$id"} +var df = features.select("id") +for(col_name <- integer_cols) { + println(col_name) + val df_col = spark.read.parquet(tmp_dir ++ "/" ++ col_name) + df = df.join(df_col, Seq("id")) +} +df = df.select($"id", array(integer_cols map col: _*).as("dense_fields")) +val parquet_dir = tmp_dir ++ "/parquet_dense" +df.write.mode("overwrite").parquet(parquet_dir) +``` + +Duration: 3mins + +## step3 处理分类特征并生成`deep_sparse_fields` + +在这一步中,我们首先处理`deep_sparse_fields`,也就是`C*`列数据。因为在原来的处理中,所有的id都需要加上一个offset,这样就保证了所有的C列id不会重复,但也导致了后面列高频词的id要比前面列低频词的id还要大,所以无法满足需求。为了解决这个问题,需要拿到所有的C列,得到不重复的词的列表,并且按照词频排序,为了做到列之间即使出现相同的词也能有不同的id,我们会在每列词的前面加上列名作为前缀,然后再计算词频并排序。下面介绍具体过程: + +#### 处理分类特征 + +- 创建`new_categorical_cols`用于给所有的分类特征值加上列的名称 +- 选择新的分类特征列,并保存到spark session中,表名为`f` +- 获得表`f`中的所有列的所有不重复的值,并且按照频次从高到低排序,结果存到uniqueValueCounts里面,忽略频次小于6的值 +- 按照频次的高低给每一个值分配一个`fid`,即频次最高的为0,存到`hf`表中 +- 再重新遍历所有的分类表,用新的`fid`替换原来的特征,并保存到文件系统 + +```scala +val new_categorical_cols = (1 to NUM_CATEGORICAL_COLUMNS).map{id=>concat(lit(s"C$id"), col(s"C$id")) as s"C$id"} +features.select(new_categorical_cols:_*).createOrReplaceTempView("f") +val orderedValues = spark.sql("select cid, count(*) as cnt from (select explode( array(" + categorical_cols.mkString(",") + ") ) as cid from f) group by cid ").filter("cnt>=6").orderBy($"cnt".desc) + +val hf = orderedValues.select("cid").as[(String)].rdd.zipWithIndex().toDF().select(col("_1").as("cid"), col("_2").as("fid")) + +for(col_name <- categorical_cols) { + println(col_name) + val col_feature = features.select(col("id"), concat(lit(col_name), col(col_name)) as col_name) + val scaledDf = col_feature.join(hf, col_feature(col_name)=== hf("cid")).select(col("id"), col("fid").as(col_name)) + val col_dir = tmp_dir ++ "/" ++ col_name + scaledDf.write.mode("overwrite").parquet(col_dir) +} +``` + +Mem: 110G +Time: 10 mins + +### 生成`deep_sparse_fields` + +这段操作和形成`dense_fields`的方式相似,代码冗余。 +这一段要处理26个列,内存消耗极大(170G),速度到不是最慢的。如果数据集更大,或可采用每次合一列的方式。前面的`dense_fields`也可以采用这种方式,列为`TODO`吧。 + +```scala +val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +val features_dir = tmp_dir ++ "/filled_features" +val features = spark.read.parquet(features_dir) + +val NUM_CATEGORICAL_COLUMNS = 26 +val categorical_cols = (1 to NUM_CATEGORICAL_COLUMNS).map{id=>s"C$id"} + +var df = features.select("id") +for(col_name <- categorical_cols) { + println(col_name) + val df_col = spark.read.parquet(tmp_dir ++ "/" ++ col_name) + df = df.join(df_col, Seq("id")) +} +df = df.select($"id", array(categorical_cols map col: _*).as("deep_sparse_fields")) +val parquet_dir = tmp_dir ++ "/parquet_deep_sparse" +df.write.mode("overwrite").parquet(parquet_dir) +``` + +Duration: 5min + +## Step4 生成交叉特征并生成wide_sparse_fields + +在OneFlow-WDL里,交叉特征被用来生成`wide_sparse_fields`,也是有可能需要按照高低频排序的。在之前交叉特征的id被排在了后面,存在将交叉特征和分类特征一起使用的可能,即使用同一个embedding表。如果这里单独按照高低频排序,就不能这么做了,不过不影响当前的WDL网络。 + +```scala +val cross_pairs = Array("C1_C2", "C3_C4") +var df = features.select("id") +for(cross_pair <- cross_pairs) { + val df_col = spark.read.parquet(tmp_dir ++ "/" ++ cross_pair) + df = df.join(df_col, Seq("id")) +} +// df.select("C1_C2", "C3_C4").createOrReplaceTempView("f") +// df.select(cross_pairs.map{id=>col(id)}:_*).createOrReplaceTempView("f") +df.select(cross_pairs map col: _*).createOrReplaceTempView("f") + +val orderedValues = spark.sql("select cid, count(*) as cnt from (select explode( array(" + cross_pairs.mkString(",") + ") ) as cid from f) group by cid ").filter("cnt>=6").orderBy($"cnt".desc) + +val hf = orderedValues.select("cid").as[(String)].rdd.zipWithIndex().toDF().select(col("_1").as("cid"), col("_2").as("fid")) + +for(cross_pair <- cross_pairs) { + df = df.join(hf, df(cross_pair)=== hf("cid")).drop(cross_pair, "cid").withColumnRenamed("fid", cross_pair) +} + +df = df.select($"id", array(cross_pairs map col: _*).as("wide_sparse_fields")) +val parquet_dir = tmp_dir ++ "/parquet_wide_sparse" +df.write.mode("overwrite").parquet(parquet_dir) +``` + +Duration: 2min + +## step5 合并所有字段 + +```scala +val fields = Array("dense", "deep_sparse", "wide_sparse") +var df = features.select("id", "labels") +for(field <- fields) { + val df_col = spark.read.parquet(tmp_dir ++ "/parquet_" ++ field) + df = df.join(df_col, Seq("id")) +} +val parquet_dir = tmp_dir ++ "/parquet_all" +df.write.mode("overwrite").parquet(parquet_dir) +``` + +## Step6 写入ofrecord + +```scala +val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +import org.oneflow.spark.functions._ +val parquet_dir = tmp_dir ++ "/parquet_all" +val df = spark.read.parquet(parquet_dir) + +val dfs = df.drop("id").randomSplit(Array(0.8, 0.1, 0.1)) + +val ofrecord_dir = tmp_dir ++ "/ofrecord/train" +dfs(0).repartition(256).write.mode("overwrite").ofrecord(ofrecord_dir) +dfs(0).count +sc.formatFilenameAsOneflowStyle(ofrecord_dir) + +val ofrecord_dir = tmp_dir ++ "/ofrecord/val" +dfs(1).repartition(256).write.mode("overwrite").ofrecord(ofrecord_dir) +dfs(1).count +sc.formatFilenameAsOneflowStyle(ofrecord_dir) + +val ofrecord_dir = tmp_dir ++ "/ofrecord/test" +dfs(2).repartition(256).write.mode("overwrite").ofrecord(ofrecord_dir) +dfs(2).count +sc.formatFilenameAsOneflowStyle(ofrecord_dir) +``` + From 8964faed2e082d05c29e3fa8db365ffe76902f77 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 1 Jun 2021 11:15:12 +0800 Subject: [PATCH 4/9] hybrid wide branch --- .../WideDeepLearning/wdl_train_eval_with_hybrid_embd.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py index c3b265c..ea7f5f5 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -177,9 +177,9 @@ def _embedding(name, ids, embedding_size, vocab_size, split_axis=0): def _model(dense_fields, wide_sparse_fields, deep_sparse_fields): - wide_embedding = _embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size) - # wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size, - # FLAGS.hf_wide_vocab_size) + # wide_embedding = _embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size) + wide_embedding = _hybrid_embedding('wide_embedding', wide_sparse_fields, 1, FLAGS.wide_vocab_size, + FLAGS.hf_wide_vocab_size) wide_scores = flow.math.reduce_sum(wide_embedding, axis=[1], keepdims=True) wide_scores = flow.parallel_cast(wide_scores, distribute=flow.distribute.split(0), gradient_distribute=flow.distribute.broadcast()) From e4bf9a93c7b3a50b1875a32106c9588529386f57 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 9 Jun 2021 16:44:57 +0800 Subject: [PATCH 5/9] update dataloader --- .../wdl_train_eval_with_hybrid_embd.py | 53 +++++++++++-------- 1 file changed, 30 insertions(+), 23 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py index ea7f5f5..48cc619 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -62,15 +62,23 @@ def str_list(x): DEEP_HIDDEN_UNITS = [FLAGS.hidden_size for i in range(FLAGS.hidden_units_num)] def _data_loader(data_dir, data_part_num, batch_size, part_name_suffix_length=-1, shuffle=True): - if FLAGS.dataset_format == 'ofrecord': - return _data_loader_ofrecord(data_dir, data_part_num, batch_size, part_name_suffix_length, - shuffle) - elif FLAGS.dataset_format == 'onerec': - return _data_loader_onerec(data_dir, batch_size, shuffle) - elif FLAGS.dataset_format == 'synthetic': - return _data_loader_synthetic(batch_size) + assert FLAGS.num_dataloader_thread_per_gpu >= 1 + if FLAGS.use_single_dataloader_thread: + devices = ['{}:0'.format(i) for i in range(FLAGS.num_nodes)] else: - assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`." + num_dataloader_thread = FLAGS.num_dataloader_thread_per_gpu * FLAGS.gpu_num_per_node + devices = ['{}:0-{}'.format(i, num_dataloader_thread - 1) for i in range(FLAGS.num_nodes)] + with flow.scope.placement("cpu", devices): + if FLAGS.dataset_format == 'ofrecord': + data = _data_loader_ofrecord(data_dir, data_part_num, batch_size, + part_name_suffix_length, shuffle) + elif FLAGS.dataset_format == 'onerec': + data = _data_loader_onerec(data_dir, batch_size, shuffle) + elif FLAGS.dataset_format == 'synthetic': + data = _data_loader_synthetic(batch_size) + else: + assert 0, "Please specify dataset_type as `ofrecord`, `onerec` or `synthetic`." + return flow.identity_n(data) @@ -90,22 +98,20 @@ def _blob_decoder(bn, shape, dtype=flow.int32): dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float) wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,)) deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,)) - return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) + return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields] def _data_loader_synthetic(batch_size): - devices = ['{}:0-{}'.format(i, FLAGS.gpu_num_per_node - 1) for i in range(FLAGS.num_nodes)] - with flow.scope.placement("cpu", devices): - def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)): - return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size, - initializer=initializer) - labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32)) - dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float, - initializer=flow.random_uniform_initializer()) - wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,)) - deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,)) - print('use synthetic data') - return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) + def _blob_random(shape, dtype=flow.int32, initializer=flow.zeros_initializer(flow.int32)): + return flow.data.decode_random(shape=shape, dtype=dtype, batch_size=batch_size, + initializer=initializer) + labels = _blob_random((1,), initializer=flow.random_uniform_initializer(dtype=flow.int32)) + dense_fields = _blob_random((FLAGS.num_dense_fields,), dtype=flow.float, + initializer=flow.random_uniform_initializer()) + wide_sparse_fields = _blob_random((FLAGS.num_wide_sparse_fields,)) + deep_sparse_fields = _blob_random((FLAGS.num_deep_sparse_fields,)) + print('use synthetic data') + return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields] def _data_loader_onerec(data_dir, batch_size, shuffle): @@ -114,6 +120,7 @@ def _data_loader_onerec(data_dir, batch_size, shuffle): files = glob.glob(os.path.join(data_dir, '*.onerec')) readdata = flow.data.onerec_reader(files=files, batch_size=batch_size, random_shuffle=shuffle, verify_example=False, + shuffle_mode="batch", shuffle_buffer_size=64, shuffle_after_epoch=shuffle) @@ -124,9 +131,9 @@ def _blob_decoder(bn, shape, dtype=flow.int32): dense_fields = _blob_decoder("dense_fields", (FLAGS.num_dense_fields,), flow.float) wide_sparse_fields = _blob_decoder("wide_sparse_fields", (FLAGS.num_wide_sparse_fields,)) deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,)) - return flow.identity_n([labels, dense_fields, wide_sparse_fields, deep_sparse_fields]) - + return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields] + def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size): b, s = ids.shape ids = flow.flatten(ids) From 3aec0ec16ddbd2e50ad67e02cb1e51f1d79dcf6a Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Wed, 9 Jun 2021 16:45:58 +0800 Subject: [PATCH 6/9] update args --- .../WideDeepLearning/wdl_train_eval_with_hybrid_embd.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py index 48cc619..d938962 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -26,6 +26,12 @@ def str_list(x): return x.split(',') parser = argparse.ArgumentParser() parser.add_argument('--dataset_format', type=str, default='ofrecord', help='ofrecord or onerec') +parser.add_argument( + "--use_single_dataloader_thread", + action="store_true", + help="use single dataloader threads per node or not." +) +parser.add_argument('--num_dataloader_thread_per_gpu', type=int, default=2) parser.add_argument('--train_data_dir', type=str, default='') parser.add_argument('--train_data_part_num', type=int, default=1) parser.add_argument('--train_part_name_suffix_length', type=int, default=-1) @@ -133,7 +139,7 @@ def _blob_decoder(bn, shape, dtype=flow.int32): deep_sparse_fields = _blob_decoder("deep_sparse_fields", (FLAGS.num_deep_sparse_fields,)) return [labels, dense_fields, wide_sparse_fields, deep_sparse_fields] - + def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size): b, s = ids.shape ids = flow.flatten(ids) From d1275e633115aa867f0bfedaca51649a741e757f Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 15 Jun 2021 11:17:57 +0800 Subject: [PATCH 7/9] rm usless comments --- .../wdl_train_eval_with_hybrid_embd.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py index d938962..b7790c0 100644 --- a/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py +++ b/ClickThroughRate/WideDeepLearning/wdl_train_eval_with_hybrid_embd.py @@ -155,19 +155,17 @@ def _hybrid_embedding(name, ids, embedding_size, vocab_size, hf_vocab_size): dtype=flow.float, initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), ) - hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids)#, no_duplicates_in_indices=True) + hf_embedding = flow.gather(params=hf_embedding_table, indices=hf_ids) lf_ids = lf_ids - hf_vocab_size_constant with flow.scope.placement('cpu', '0:0'): lf_embedding_table = flow.get_variable( name=f'lf_{name}', shape=(vocab_size - hf_vocab_size, embedding_size), - #shape=(vocab_size, embedding_size), dtype=flow.float, initializer=flow.random_uniform_initializer(minval=-0.05, maxval=0.05), ) - lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids)#, no_duplicates_in_indices=True) + lf_embedding = flow.gather(params=lf_embedding_table, indices=lf_ids) unique_embedding = flow.reshape(flow.zeros_like(unique_ids, dtype=flow.float), (-1, 1)) * flow.constant(0.0, dtype=flow.float, shape=(1,embedding_size)) - # unique_embedding = flow.constant(0.0, dtype=flow.float, shape=(b*s, embedding_size)) unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=hf_embedding, indices=hf_indices) unique_embedding = flow.tensor_scatter_nd_update(params=unique_embedding, updates=lf_embedding, indices=lf_indices) unique_embedding = flow.gather(params=unique_embedding, indices=unique_ids_idx) @@ -309,8 +307,6 @@ def print_args(args): for arg in vars(args): print("{} = {}".format(arg, getattr(args, arg))) print("-".ljust(66, "-")) - #print("Time stamp: {}".format( - # str(datetime.now().strftime("%Y-%m-%d-%H:%M:%S")))) def main(): @@ -320,8 +316,6 @@ def main(): flow.config.enable_model_io_v2(True) flow.config.enable_debug_mode(True) flow.config.collective_boxing.nccl_enable_all_to_all(True) - #flow.config.enable_numa_aware_cuda_malloc_host(True) - #flow.config.collective_boxing.enable_fusion(False) check_point = flow.train.CheckPoint() check_point.init() for i in range(FLAGS.max_iter): From 395c28f5fd8198909e6ff3deaf1bae12a14f09c4 Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 15 Jun 2021 11:50:25 +0800 Subject: [PATCH 8/9] update how_to_make_hf_dataset.md --- .../WideDeepLearning/how_to_make_hf_dataset.md | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md index a153a59..60926d1 100644 --- a/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md +++ b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md @@ -1,6 +1,6 @@ # How to Make High Frequency Dataset for OneFlow-WDL -[**how_to_make_ofrecord_for_wdl**](https://github.com/Oneflow-Inc/OneFlow-Benchmark/blob/master/ClickThroughRate/WideDeepLearning/how_to_make_ofrecord_for_wdl.md)一文中介绍了如何利用spark制作OneFlow-WDL使用的ofrecord数据集,GPU&CPU混合embedding的实践中,这个数据集就不好用了,主要原因是没有按照词频排序,所以需要制作新的数据集。本文将持续上文中的套路,介绍一下如何制作按照词频排序的数据集。 +[**how_to_make_ofrecord_for_wdl**](https://github.com/Oneflow-Inc/OneFlow-Benchmark/blob/master/ClickThroughRate/WideDeepLearning/how_to_make_ofrecord_for_wdl.md)一文中介绍了如何利用spark制作OneFlow-WDL使用的ofrecord数据集,GPU&CPU混合embedding的实践中需要把特征根据词频从大到小排序,本文将持续上文中的套路,介绍一下如何制作按照词频排序的数据集。 ## 数据集及预处理 @@ -53,14 +53,12 @@ Files.createDirectories(Paths.get(tmp_dir)) ```scala // load input file var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.bak") -// var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.txt") // rename columns [label, I1,...,I13, C1,...,C26] val NUM_INTEGER_COLUMNS = 13 val NUM_CATEGORICAL_COLUMNS = 26 -// val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>s"I$id"} -val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>$"I$id"} // note +val integer_cols = (1 to NUM_INTEGER_COLUMNS).map{id=>$"I$id"} val categorical_cols = (1 to NUM_CATEGORICAL_COLUMNS).map{id=>s"C$id"} val feature_cols = integer_cols.map{c=>c.toString} ++ categorical_cols val all_cols = (Seq(s"labels") ++ feature_cols) @@ -133,7 +131,6 @@ for(column_name <- integer_cols) { } else { scaledDf = indexedDf.withColumn(col_name, col(col_index) + lit(1)) // trick: reuse col_name .select("id", col_name) - //.withColumn(col_name, col(col_index).cast(IntegerType)) } val col_dir = tmp_dir ++ "/" ++ col_name scaledDf = scaledDf.withColumn(col_name, getItem(column_name, lit(0))) @@ -233,8 +230,6 @@ for(cross_pair <- cross_pairs) { val df_col = spark.read.parquet(tmp_dir ++ "/" ++ cross_pair) df = df.join(df_col, Seq("id")) } -// df.select("C1_C2", "C3_C4").createOrReplaceTempView("f") -// df.select(cross_pairs.map{id=>col(id)}:_*).createOrReplaceTempView("f") df.select(cross_pairs map col: _*).createOrReplaceTempView("f") val orderedValues = spark.sql("select cid, count(*) as cnt from (select explode( array(" + cross_pairs.mkString(",") + ") ) as cid from f) group by cid ").filter("cnt>=6").orderBy($"cnt".desc) From 308fb8e843a1f6f4ade30dc32fcf17e251a565eb Mon Sep 17 00:00:00 2001 From: ShawnXuan Date: Tue, 15 Jun 2021 11:54:42 +0800 Subject: [PATCH 9/9] modify default path name --- .../WideDeepLearning/how_to_make_hf_dataset.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md index 60926d1..02a8a94 100644 --- a/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md +++ b/ClickThroughRate/WideDeepLearning/how_to_make_hf_dataset.md @@ -36,7 +36,7 @@ import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, MinMaxScaler import org.apache.spark.ml.linalg._ import java.nio.file.{Files, Paths} -val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +val tmp_dir = "/path/to/wdl_tmp" Files.createDirectories(Paths.get(tmp_dir)) ``` @@ -52,7 +52,7 @@ Files.createDirectories(Paths.get(tmp_dir)) ```scala // load input file -var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///DATA/disk1/xuan/train.shuf.bak") +var input = spark.read.options(Map("delimiter"->"\t")).csv("file:///path/to/train.shuf") // rename columns [label, I1,...,I13, C1,...,C26] val NUM_INTEGER_COLUMNS = 13 @@ -199,7 +199,7 @@ Time: 10 mins 这一段要处理26个列,内存消耗极大(170G),速度到不是最慢的。如果数据集更大,或可采用每次合一列的方式。前面的`dense_fields`也可以采用这种方式,列为`TODO`吧。 ```scala -val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +val tmp_dir = "/path/to/wdl_tmp" val features_dir = tmp_dir ++ "/filled_features" val features = spark.read.parquet(features_dir) @@ -263,7 +263,7 @@ df.write.mode("overwrite").parquet(parquet_dir) ## Step6 写入ofrecord ```scala -val tmp_dir = "/DATA/disk1/xuan/wdl_tmp" +val tmp_dir = "/path/to/wdl_tmp" import org.oneflow.spark.functions._ val parquet_dir = tmp_dir ++ "/parquet_all" val df = spark.read.parquet(parquet_dir)