Skip to content

Commit

Permalink
Small bug fixes for running models without tokenizers (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Aug 28, 2024
1 parent 2a6fff4 commit c1f953f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
4 changes: 2 additions & 2 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self,
latent_batch = {}
tokenized_t5 = t5_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
max_length=t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
Expand All @@ -108,7 +108,7 @@ def __init__(self,

tokenized_clip = clip_tokenizer(batch,
padding='max_length',
max_length=t5_tokenizer.model.max_length,
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
Expand Down
9 changes: 6 additions & 3 deletions diffusion/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,20 @@ def train(config: DictConfig) -> None:

model: ComposerModel = hydra.utils.instantiate(config.model)

# If the model has a tokenizer, we'll need it for the dataset
if hasattr(model, 'tokenizer'):
tokenizer = model.tokenizer
else:
tokenizer = None

if hasattr(model, 'autoencoder_loss'):
# Check if this is training an autoencoder. If so, the optimizer needs different param groups
optimizer = make_autoencoder_optimizer(config, model)
tokenizer = None
elif isinstance(model, ComposerTextToImageMMDiT):
# Check if this is training a transformer. If so, the optimizer needs different param groups
optimizer = make_transformer_optimizer(config, model)
tokenizer = model.tokenizer
else:
optimizer = hydra.utils.instantiate(config.optimizer, params=model.parameters())
tokenizer = model.tokenizer

# Load train dataset. Currently this expects to load according to the datasetHparam method.
# This means adding external datasets is currently not super easy. Will refactor or check for
Expand Down

0 comments on commit c1f953f

Please sign in to comment.