Skip to content

Commit

Permalink
Fix TPUEmbedding layer to use named parameters
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 384784083
  • Loading branch information
Bruce Fontaine authored and TensorFlow Recommenders Team committed Jul 14, 2021
1 parent 9303469 commit b0fa907
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions tensorflow_recommenders/layers/embedding/tpu_embedding_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,17 @@ def _clone_and_prepare_features(feature_config):
table_configs[config.table] = table_configs.get(
config.table,
tf.tpu.experimental.embedding.TableConfig(
config.table.vocabulary_size, config.table.dim,
config.table.initializer, config.table.optimizer,
config.table.combiner, config.table.name))
vocabulary_size=config.table.vocabulary_size,
dim=config.table.dim,
initializer=config.table.initializer,
optimizer=config.table.optimizer,
combiner=config.table.combiner,
name=config.table.name))

output_objects.append(tf.tpu.experimental.embedding.FeatureConfig(
table_configs[config.table], config.max_sequence_length, config.name))
table=table_configs[config.table],
max_sequence_length=config.max_sequence_length,
name=config.name))

# Fix up the optimizers.
for _, new_table in table_configs.items():
Expand Down Expand Up @@ -206,8 +211,9 @@ def _update_table_configs(feature_config, table_config_map):
raise ValueError("TableConfig %s does not match any of the TableConfigs "
"used to configure this layer." % config.table)
output_objects.append(tf.tpu.experimental.embedding.FeatureConfig(
table_config_dict[config.table], config.max_sequence_length,
config.name))
table=table_config_dict[config.table],
max_sequence_length=config.max_sequence_length,
name=config.name))

return tf.nest.pack_sequence_as(feature_config, output_objects)

Expand Down

0 comments on commit b0fa907

Please sign in to comment.