From c00f87bc41423e33bb6160a0533da8bc33cca51d Mon Sep 17 00:00:00 2001 From: Alexander Gutkin Date: Thu, 6 Feb 2025 09:39:25 +0000 Subject: [PATCH] Sketch generation using phonetics. PiperOrigin-RevId: 723838136 --- .../configs/sketch-token_phonemes/README.md | 1 + .../arch_p1_t5_1_1_flaxformer.gin | 75 ++++++++++++++++ .../configs/sketch-token_phonemes/dataset.gin | 86 +++++++++++++++++++ .../configs/sketch-token_phonemes/infer.gin | 40 +++++++++ .../sketch-token_phonemes/model_base.gin | 27 ++++++ .../sketch-token_phonemes/model_common.gin | 73 ++++++++++++++++ .../sketch-token_phonemes/model_tiny.gin | 27 ++++++ .../models/pmmx/model_config_gin_test.py | 2 + 8 files changed, 331 insertions(+) create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/README.md create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/dataset.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/infer.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/model_base.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin create mode 100644 protoscribe/models/pmmx/configs/sketch-token_phonemes/model_tiny.gin diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/README.md b/protoscribe/models/pmmx/configs/sketch-token_phonemes/README.md new file mode 100644 index 0000000..d51111b --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/README.md @@ -0,0 +1 @@ +# Sketch generation using discrete tokens given the phonetics. diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin new file mode 100644 index 0000000..dc0679b --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin @@ -0,0 +1,75 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Multimodal PMMX-One flaxformer architecture. + +from __gin__ import dynamic_registration + +import seqio +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import pmmx_architecture + +from flaxformer.architectures.t5 import t5_architecture +from flaxformer.components import embedding + +include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin" + +# Architecture (Flax Module). +ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder() + +# Vocabulary for the encoder. +inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +inputs/seqio.PassThroughVocabulary.size = 0 + +# Output vocabulary for the decoder. The `SKETCH_TOKEN_VOCAB_SIZE` corresponds +# to the real sketch token vocabulary size (N+3), where 3 is the number of special +# symbols rounded by the batch size B (16): In other words, N + B. This is +# +# - 2064: for N=2048 +# - 4112: for N=4096 + +SKETCH_TOKEN_VOCAB_SIZE = 2064 +END_OF_SKETCH = 3 # sketch_tokenizer.END_OF_SKETCH +NUM_EMBEDDINGS = %SKETCH_TOKEN_VOCAB_SIZE + +outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary() +outputs/seqio.PassThroughVocabulary.size = 0 +outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH + +# Actual multimodal encoder-decoder architecture. +pmmx_architecture.MultimodalEncoderDecoder: + encoder_factory = @pmmx_architecture.MultimodalEncoder + decoder_factory = @t5_architecture.Decoder + shared_token_embedder_factory = @token_embedder/embedding.Embed + dtype = %ACTIVATION_DTYPE + +feature_converters.MultimodalEncDecFeatureConverterFactory: + task_feature_lengths = %TASK_FEATURE_LENGTHS + feature_specs = ( + ("text.phonetic_embedding", "float32", 2), + ) + +# Encoder +pmmx_architecture.MultimodalEncoder: + feature_spec = [ + ("text.phonetic_embedding", "text.phonetic_embedding"), + ] + modality_spec = [ + "text.phonetic_embedding", + ] + modality_embedders_spec = { + "text.phonetic_embedding": [ + ("text.phonetic_embedding", @pmmx_architecture.DenseEmbed) + ], + } diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/dataset.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/dataset.gin new file mode 100644 index 0000000..533a86f --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/dataset.gin @@ -0,0 +1,86 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Settings for Protoscribe dataset reader using discrete sketch tokens. + +from __gin__ import dynamic_registration + +from t5x import utils + +from protoscribe.corpus.reader import tasks + +DATA_DIR = %gin.REQUIRED +TRAIN_DATA_DIR = %DATA_DIR +EVAL_DATA_DIR = %DATA_DIR +INFER_EVAL_DATA_DIR = %DATA_DIR + +MAX_STROKE_SEQUENCE_LENGTH = 250 +MAX_PHONETIC_SEQUENCE_LENGTH = 10 + +tasks.register: + concept_embedding_type = "bnc" + max_stroke_sequence_length = %MAX_STROKE_SEQUENCE_LENGTH + max_glyph_sequence_length = 20 + max_phonetic_sequence_length = %MAX_PHONETIC_SEQUENCE_LENGTH + stroke_random_scale_factor = 0.0 + stroke_normalization_type = "sketch-rnn" + stroke_token_vocab_filename = "vocab2048_normalized_sketchrnn.npy" + +train_task/tasks.register: + task_name = "bnc_tokens.train" + dataset_dir = %TRAIN_DATA_DIR + is_training = True + noisify_embeddings = True + noisify_neftune_alphas = { + %tasks.EMBEDDING_PHONETICS: 0.01, + } + +eval_task/tasks.register: + task_name = "bnc_tokens.eval" + dataset_dir = %EVAL_DATA_DIR + is_training = False + +infer_eval_task/tasks.register: + task_name = "bnc_tokens.infer_eval" + dataset_dir = %INFER_EVAL_DATA_DIR + is_training = False + +TRAIN_TASK = @train_task/tasks.register() +EVAL_TASK = @eval_task/tasks.register() +INFER_EVAL_TASK = @infer_eval_task/tasks.register() +MIXTURE_OR_TASK_NAME = %TRAIN_TASK +MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks" +USE_CACHED_TASKS = False + +TASK_FEATURE_LENGTHS = { + "text.phonetic_embedding": %MAX_PHONETIC_SEQUENCE_LENGTH, + "targets": %MAX_STROKE_SEQUENCE_LENGTH +} + +train/utils.DatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + pack = False + use_custom_packing_ops = False + +train_eval/utils.DatasetConfig: + mixture_or_task_name = %EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False + +infer_eval/utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + pack = False + shuffle = False + use_custom_packing_ops = False \ No newline at end of file diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/infer.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/infer.gin new file mode 100644 index 0000000..97bab39 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/infer.gin @@ -0,0 +1,40 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __gin__ import dynamic_registration + +import __main__ as infer_script +from protoscribe.sketches.inference import json_utils +from t5x import utils + +include "protoscribe/pmmx/configs/runs/infer.gin" + +utils.DatasetConfig: + mixture_or_task_name = %INFER_EVAL_TASK + +infer_script.infer: + mode = "predict_with_aux" + write_fn = @json_utils.write_inferences_to_file + merge_fn = @infer_script.merge_chunks_to_file + +json_utils.write_inferences_to_file: + include_all_inputs = False + # We need the following fields to annotate the results. + input_fields_to_include = [ + "doc.id", + "concept.name", + "number.name", + "text.sampa", + "text.words", + ] diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_base.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_base.gin new file mode 100644 index 0000000..5d7b973 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_base.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 12 +NUM_DECODER_LAYERS = 12 +NUM_HEADS = 12 +HEAD_DIM = 64 +EMBED_DIM = 768 +MLP_DIM = 2048 diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin new file mode 100644 index 0000000..91b7ed9 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin @@ -0,0 +1,73 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Base configuration of the model. + +from __gin__ import dynamic_registration + +from protoscribe.pmmx import feature_converters +from protoscribe.pmmx import models +from t5x import adafactor +from t5x import utils + +ARCHITECTURE = %gin.REQUIRED + +include "protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = %gin.REQUIRED +NUM_DECODER_LAYERS = %gin.REQUIRED +NUM_HEADS = %gin.REQUIRED +HEAD_DIM = %gin.REQUIRED +EMBED_DIM = %gin.REQUIRED +MLP_DIM = %gin.REQUIRED + +# Optimizer +# `learning_rate` is set by `Trainer.learning_rate_fn`. +OPTIMIZER = @adafactor.Adafactor() +adafactor.Adafactor: + decay_rate = 0.8 + step_offset = 0 + +# Loss defaults. +Z_LOSS = 0.0001 +LABEL_SMOOTHING = 0.0 +LOSS_NORMALIZING_FACTOR = None + +# Model +MODEL = @models.MultimodalEncoderDecoderModel() +models.MultimodalEncoderDecoderModel: + feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory() + module = %ARCHITECTURE # provided by t5_flaxformer + input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY + output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY + optimizer_def = %OPTIMIZER + z_loss = %Z_LOSS + label_smoothing = %LABEL_SMOOTHING + loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR + +# Decoding +NUM_DECODES = 8 +RETURN_ALL_DECODES = True +models.MultimodalEncoderDecoderModel.predict_batch_with_aux: + num_decodes = %NUM_DECODES + return_all_decodes = %RETURN_ALL_DECODES + +# Checkpoints +CHECKPOINT_PERIOD = 20_000 +EVAL_PERIOD = %CHECKPOINT_PERIOD +utils.SaveCheckpointConfig: + period = %CHECKPOINT_PERIOD + keep = None # Keep all checkpoints. + save_dataset = False diff --git a/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_tiny.gin b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_tiny.gin new file mode 100644 index 0000000..30822a7 --- /dev/null +++ b/protoscribe/models/pmmx/configs/sketch-token_phonemes/model_tiny.gin @@ -0,0 +1,27 @@ +# Copyright 2024 The Protoscribe Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Tiny configuration. + +from __gin__ import dynamic_registration + +include "protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin" + +# Architecture overrides. +NUM_ENCODER_LAYERS = 2 +NUM_DECODER_LAYERS = 2 +NUM_HEADS = 6 +HEAD_DIM = 16 +EMBED_DIM = 16 +MLP_DIM = 16 diff --git a/protoscribe/models/pmmx/model_config_gin_test.py b/protoscribe/models/pmmx/model_config_gin_test.py index b8de0f2..2999167 100644 --- a/protoscribe/models/pmmx/model_config_gin_test.py +++ b/protoscribe/models/pmmx/model_config_gin_test.py @@ -54,6 +54,7 @@ def tearDown(self): "glyph_logmel-spectrum", "glyph_phonemes", "sketch-token_concepts", + "sketch-token_phonemes", ) def test_model_train(self, model_dir: str) -> None: """Tests tiny model configuration for training.""" @@ -80,6 +81,7 @@ def test_model_train(self, model_dir: str) -> None: "glyph_logmel-spectrum", "glyph_phonemes", "sketch-token_concepts", + "sketch-token_phonemes", ) def test_model_infer(self, model_dir: str) -> None: """Tests tiny model configuration in inference mode."""