Skip to content

Commit

Permalink
Discrete glyph sequence prediction from BNC embeddings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720885107
  • Loading branch information
agutkin committed Jan 29, 2025
1 parent c4a1058 commit 0d70fa9
Show file tree
Hide file tree
Showing 7 changed files with 319 additions and 0 deletions.
1 change: 1 addition & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Discrete glyph-only models with concepts.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Multimodal PMMX-One flaxformer architecture.

from __gin__ import dynamic_registration

import seqio
from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import pmmx_architecture

from flaxformer.architectures.t5 import t5_architecture
from flaxformer.components import embedding

include "protoscribe/pmmx/configs/architectures/p1_t5_1_1_flaxformer.gin"

# Architecture (Flax Module).
ARCHITECTURE = @pmmx_architecture.MultimodalEncoderDecoder()

# Vocabulary for the encoder.
inputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
inputs/seqio.PassThroughVocabulary.size = 0

# Output vocabulary for the decoder. The `GLYPH_TOKEN_VOCAB_SIZE` corresponds
# to the real glyph token vocabulary size (all the glyphs + the special
# symbols, 314 elements) + non-administrative concepts (468 elements) + some
# provision for extra glyphs, rounded to 128 for TPU efficiency.

GLYPH_TOKEN_VOCAB_SIZE = 1024
END_OF_SKETCH = 2 # glyph_vocab.GLYPH_EOS
NUM_EMBEDDINGS = %GLYPH_TOKEN_VOCAB_SIZE

outputs/PASSTHROUGH_VOCABULARY = @seqio.PassThroughVocabulary()
outputs/seqio.PassThroughVocabulary.size = 0
outputs/seqio.PassThroughVocabulary.eos_id = %END_OF_SKETCH

# Actual multimodal encoder-decoder architecture.
pmmx_architecture.MultimodalEncoderDecoder:
encoder_factory = @pmmx_architecture.MultimodalEncoder
decoder_factory = @t5_architecture.Decoder
shared_token_embedder_factory = @token_embedder/embedding.Embed
dtype = %ACTIVATION_DTYPE

feature_converters.MultimodalEncDecFeatureConverterFactory:
task_feature_lengths = %TASK_FEATURE_LENGTHS
feature_specs = (
("text.concept_embedding", "float32", 2),
)

# Encoder
pmmx_architecture.MultimodalEncoder:
feature_spec = [
("text.concept_embedding", "text.concept_embedding"),
]
modality_spec = ["text.concept_embedding"]
modality_embedders_spec = {
"text.concept_embedding": [
("text.concept_embedding", @pmmx_architecture.DenseEmbed)
],
}
81 changes: 81 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/dataset.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Settings for Protoscribe dataset reader using discrete sketch tokens.

from __gin__ import dynamic_registration

from t5x import utils

from protoscribe.corpus.reader import tasks

DATA_DIR = %gin.REQUIRED
TRAIN_DATA_DIR = %DATA_DIR
EVAL_DATA_DIR = %DATA_DIR
INFER_EVAL_DATA_DIR = %DATA_DIR

MAX_GLYPH_SEQUENCE_LENGTH = 20

tasks.register:
concept_embedding_type = "bnc"
glyph_only_targets = True
max_glyph_sequence_length = %MAX_GLYPH_SEQUENCE_LENGTH

train_task/tasks.register:
task_name = "bnc_glyphs.train"
dataset_dir = %TRAIN_DATA_DIR
is_training = True
noisify_embeddings = True
noisify_neftune_alphas = {
%tasks.EMBEDDING_SEMANTICS: 0.01,
}

eval_task/tasks.register:
task_name = "bnc_glyphs.eval"
dataset_dir = %EVAL_DATA_DIR
is_training = False

infer_eval_task/tasks.register:
task_name = "bnc_glyphs.infer_eval"
dataset_dir = %INFER_EVAL_DATA_DIR
is_training = False

TRAIN_TASK = @train_task/tasks.register()
EVAL_TASK = @eval_task/tasks.register()
INFER_EVAL_TASK = @infer_eval_task/tasks.register()
MIXTURE_OR_TASK_NAME = %TRAIN_TASK
MIXTURE_OR_TASK_MODULE = "protoscribe.corpus.reader.tasks"
USE_CACHED_TASKS = False

TASK_FEATURE_LENGTHS = {
"text.concept_embedding": 2,
"targets": %MAX_GLYPH_SEQUENCE_LENGTH
}

train/utils.DatasetConfig:
mixture_or_task_name = %MIXTURE_OR_TASK_NAME
pack = False
use_custom_packing_ops = False

train_eval/utils.DatasetConfig:
mixture_or_task_name = %EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False

infer_eval/utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK
pack = False
shuffle = False
use_custom_packing_ops = False
39 changes: 39 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/infer.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __gin__ import dynamic_registration

import __main__ as infer_script
from protoscribe.sketches.inference import json_utils
from t5x import utils

include "protoscribe/pmmx/configs/runs/infer.gin"

utils.DatasetConfig:
mixture_or_task_name = %INFER_EVAL_TASK

infer_script.infer:
mode = "predict_with_aux"
write_fn = @json_utils.write_inferences_to_file
merge_fn = @infer_script.merge_chunks_to_file

json_utils.write_inferences_to_file:
include_all_inputs = False
input_fields_to_include = [
"doc.id",
"concept.name",
"number.name",
"text.sampa",
"text.words",
]
27 changes: 27 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/model_base.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_concepts/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
NUM_HEADS = 4
HEAD_DIM = 32
EMBED_DIM = 96
MLP_DIM = 512
73 changes: 73 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/model_common.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Base configuration of the model.

from __gin__ import dynamic_registration

from protoscribe.pmmx import feature_converters
from protoscribe.pmmx import models
from t5x import adafactor
from t5x import utils

ARCHITECTURE = %gin.REQUIRED

include "protoscribe/models/pmmx/configs/glyph_concepts/arch_p1_t5_1_1_flaxformer.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = %gin.REQUIRED
NUM_DECODER_LAYERS = %gin.REQUIRED
NUM_HEADS = %gin.REQUIRED
HEAD_DIM = %gin.REQUIRED
EMBED_DIM = %gin.REQUIRED
MLP_DIM = %gin.REQUIRED

# Optimizer
# `learning_rate` is set by `Trainer.learning_rate_fn`.
OPTIMIZER = @adafactor.Adafactor()
adafactor.Adafactor:
decay_rate = 0.8
step_offset = 0

# Loss defaults.
Z_LOSS = 0.0001
LABEL_SMOOTHING = 0.0
LOSS_NORMALIZING_FACTOR = None

# Model
MODEL = @models.MultimodalEncoderDecoderModel()
models.MultimodalEncoderDecoderModel:
feature_converter_cls = @feature_converters.MultimodalEncDecFeatureConverterFactory()
module = %ARCHITECTURE # provided by t5_flaxformer
input_vocabulary = %inputs/PASSTHROUGH_VOCABULARY
output_vocabulary = %outputs/PASSTHROUGH_VOCABULARY
optimizer_def = %OPTIMIZER
z_loss = %Z_LOSS
label_smoothing = %LABEL_SMOOTHING
loss_normalizing_factor = %LOSS_NORMALIZING_FACTOR

# Decoding
NUM_DECODES = 5
RETURN_ALL_DECODES = True
models.MultimodalEncoderDecoderModel.predict_batch_with_aux:
num_decodes = %NUM_DECODES
return_all_decodes = %RETURN_ALL_DECODES

# Checkpoints
CHECKPOINT_PERIOD = 10_000
EVAL_PERIOD = %CHECKPOINT_PERIOD
utils.SaveCheckpointConfig:
period = %CHECKPOINT_PERIOD
keep = None # Keep all checkpoints.
save_dataset = False
27 changes: 27 additions & 0 deletions protoscribe/models/pmmx/configs/glyph_concepts/model_tiny.gin
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2024 The Protoscribe Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Tiny configuration.

from __gin__ import dynamic_registration

include "protoscribe/models/pmmx/configs/glyph_concepts/model_common.gin"

# Architecture overrides.
NUM_ENCODER_LAYERS = 2
NUM_DECODER_LAYERS = 2
NUM_HEADS = 6
HEAD_DIM = 16
EMBED_DIM = 16
MLP_DIM = 16

0 comments on commit 0d70fa9

Please sign in to comment.