Skip to content

Commit

Permalink
Sketch generation using phonetics.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723838136
  • Loading branch information
agutkin committed Feb 6, 2025
1 parent 38ee313 commit c00f87b
Show file tree
Hide file tree
Showing 8 changed files with 331 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Sketch generation using discrete tokens given the phonetics.
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multimodal PMMX-One flaxformer architecture.

from __gin__ import dynamic_registration

import seqio
from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import pmmx_architecture

from flaxformer.architectures.t5 import t5_architecture
from flaxformer.components import embedding

include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin"

# Architecture (Flax Module).
ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder()

# Vocabulary for the encoder.
inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
inputs/seqio.PassThroughVocabulary.size = 0

# Output vocabulary for the decoder. The `SKETCH_TOKEN_VOCAB_SIZE` corresponds
# to the real sketch token vocabulary size (N+3), where 3 is the number of special
# symbols rounded by the batch size B (16): In other words, N + B. This is
#
# - 2064: for N=2048
# - 4112: for N=4096

SKETCH_TOKEN_VOCAB_SIZE = 2064
END_OF_SKETCH = 3 # sketch_tokenizer.END_OF_SKETCH
NUM_EMBEDDINGS = %SKETCH_TOKEN_VOCAB_SIZE

outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
outputs/seqio.PassThroughVocabulary.size = 0
outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH

# Actual multimodal encoder-decoder architecture.
pmmx_architecture.MultimodalEncoderDecoder:
encoder_factory = @pmmx_architecture.MultimodalEncoder
decoder_factory = @t5_architecture.Decoder
shared_token_embedder_factory = @token_embedder/embedding.Embed
dtype = %ACTIVATION_DTYPE

feature_converters.MultimodalEncDecFeatureConverterFactory:
task_feature_lengths = %TASK_FEATURE_LENGTHS
feature_specs = (
("text.phonetic_embedding", "float32", 2),
)

# Encoder
pmmx_architecture.MultimodalEncoder:
feature_spec = [
("text.phonetic_embedding", "text.phonetic_embedding"),
]
modality_spec = [
"text.phonetic_embedding",
]
modality_embedders_spec = {
"text.phonetic_embedding": [
("text.phonetic_embedding", @pmmx_architecture.DenseEmbed)
],
}
86 changes: 86 additions & 0 deletions protoscribe/models/pmmx/configs/sketch-token_phonemes/dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Settings for Protoscribe dataset reader using discrete sketch tokens.

from __gin__ import dynamic_registration

from t5x import utils

from protoscribe.corpus.reader import tasks

DATA_DIR = %gin.REQUIRED
TRAIN_DATA_DIR = %DATA_DIR
EVAL_DATA_DIR = %DATA_DIR
INFER_EVAL_DATA_DIR = %DATA_DIR

MAX_STROKE_SEQUENCE_LENGTH = 250
MAX_PHONETIC_SEQUENCE_LENGTH = 10

tasks.register:
concept_embedding_type = "bnc"
max_stroke_sequence_length = %MAX_STROKE_SEQUENCE_LENGTH
max_glyph_sequence_length = 20
max_phonetic_sequence_length = %MAX_PHONETIC_SEQUENCE_LENGTH
stroke_random_scale_factor = 0.0
stroke_normalization_type = "sketch-rnn"
stroke_token_vocab_filename = "vocab2048_normalized_sketchrnn.npy"

train_task/tasks.register:
task_name = "bnc_tokens.train"
dataset_dir = %TRAIN_DATA_DIR
is_training = True
noisify_embeddings = True
noisify_neftune_alphas = {
%tasks.EMBEDDING_PHONETICS: 0.01,
}

eval_task/tasks.register:
task_name = "bnc_tokens.eval"
dataset_dir = %EVAL_DATA_DIR
is_training = False

infer_eval_task/tasks.register:
task_name = "bnc_tokens.infer_eval"
dataset_dir = %INFER_EVAL_DATA_DIR
is_training = False

TRAIN_TASK = @train_task/tasks.register()
EVAL_TASK = @eval_task/tasks.register()
INFER_EVAL_TASK = @infer_eval_task/tasks.register()
MIXTURE_OR_TASK_NAME = %TRAIN_TASK
MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks"
USE_CACHED_TASKS = False

TASK_FEATURE_LENGTHS = {
"text.phonetic_embedding": %MAX_PHONETIC_SEQUENCE_LENGTH,
"targets": %MAX_STROKE_SEQUENCE_LENGTH
}

train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
pack = False
use_custom_packing_ops = False

train_eval/utils.DatasetConfig:
mixture_or_task_name = %EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False

infer_eval/utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False
40 changes: 40 additions & 0 deletions protoscribe/models/pmmx/configs/sketch-token_phonemes/infer.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __gin__ import dynamic_registration

import __main__ as infer_script
from protoscribe.sketches.inference import json_utils
from t5x import utils

include "protoscribe/pmmx/configs/runs/infer.gin"

utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK

infer_script.infer:
mode = "predict_with_aux"
write_fn = @json_utils.write_inferences_to_file
merge_fn = @infer_script.merge_chunks_to_file

json_utils.write_inferences_to_file:
include_all_inputs = False
# We need the following fields to annotate the results.
input_fields_to_include = [
"doc.id",
"concept.name",
"number.name",
"text.sampa",
"text.words",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 12
NUM_DECODER_LAYERS = 12
NUM_HEADS = 12
HEAD_DIM = 64
EMBED_DIM = 768
MLP_DIM = 2048
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration of the model.

from __gin__ import dynamic_registration

from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import models
from t5x import adafactor
from t5x import utils

ARCHITECTURE = %gin.REQUIRED

include "protoscribe/models/pmmx/configs/sketch-token_phonemes/arch_p1_t5_1_1_flaxformer.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = %gin.REQUIRED
NUM_DECODER_LAYERS = %gin.REQUIRED
NUM_HEADS = %gin.REQUIRED
HEAD_DIM = %gin.REQUIRED
EMBED_DIM = %gin.REQUIRED
MLP_DIM = %gin.REQUIRED

# Optimizer
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0

# Loss defaults.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None

# Model
MODEL = @models.MultimodalEncoderDecoderModel()
models.MultimodalEncoderDecoderModel:
feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory()
module = %ARCHITECTURE # provided by t5_flaxformer
input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY
output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY
optimizer_def = %OPTIMIZER
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Decoding
NUM_DECODES = 8
RETURN_ALL_DECODES = True
models.MultimodalEncoderDecoderModel.predict_batch_with_aux:
num_decodes = %NUM_DECODES
return_all_decodes = %RETURN_ALL_DECODES

# Checkpoints
CHECKPOINT_PERIOD = 20_000
EVAL_PERIOD = %CHECKPOINT_PERIOD
utils.SaveCheckpointConfig:
period = %CHECKPOINT_PERIOD
keep = None # Keep all checkpoints.
save_dataset = False
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/sketch-token_phonemes/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
NUM_HEADS = 6
HEAD_DIM = 16
EMBED_DIM = 16
MLP_DIM = 16
2 changes: 2 additions & 0 deletions protoscribe/models/pmmx/model_config_gin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def tearDown(self):
"glyph_logmel-spectrum",
"glyph_phonemes",
"sketch-token_concepts",
"sketch-token_phonemes",
)
def test_model_train(self, model_dir: str) -> None:
"""Tests tiny model configuration for training."""
Expand All @@ -80,6 +81,7 @@ def test_model_train(self, model_dir: str) -> None:
"glyph_logmel-spectrum",
"glyph_phonemes",
"sketch-token_concepts",
"sketch-token_phonemes",
)
def test_model_infer(self, model_dir: str) -> None:
"""Tests tiny model configuration in inference mode."""
Expand Down

0 comments on commit c00f87b

Please sign in to comment.