Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding a pretrained tf 2.0 BigGAN #15

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions pytorch_pretrained_biggan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@
import copy
import json

PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}

CONFIG_NAME = 'config.json'

class BigGANConfig(object):
""" Configuration class to store the configuration of a `BigGAN`.
Defaults are for the 128x128 model.
Expand Down Expand Up @@ -42,6 +50,28 @@ def __init__(self,
self.eps = eps
self.n_stats = n_stats

def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
if pretrained_model_name_or_path in PRETRAINED_CONFIG_ARCHIVE_MAP:
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

try:
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong config name, should be a valid path to a folder containing "
"a {} file or a config name in {}".format(
CONFIG_NAME, PRETRAINED_CONFIG_ARCHIVE_MAP.keys()))
raise

logger.info("loading config {} from cache at {}".format(pretrained_model_name_or_path, resolved_config_file))

# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))

return config

@classmethod
def from_dict(cls, json_object):
"""Constructs a `BigGANConfig` from a Python dictionary of parameters."""
Expand Down
20 changes: 4 additions & 16 deletions pytorch_pretrained_biggan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,7 @@
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
}

PRETRAINED_CONFIG_ARCHIVE_MAP = {
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
}

WEIGHTS_NAME = 'pytorch_model.bin'
CONFIG_NAME = 'config.json'


def snconv2d(eps=1e-12, **kwargs):
Expand Down Expand Up @@ -252,28 +245,23 @@ class BigGAN(nn.Module):

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
config = BigGANConfig.from_pretrained(pretrained_model_name_or_path, cache_dir=None)

if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)

try:
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
except EnvironmentError:
logger.error("Wrong model name, should be a valid path to a folder containing "
"a {} file and a {} file or a model name in {}".format(
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
"a {} file or a model name in {}".format(
WEIGHTS_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
raise

logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))

# Load config
config = BigGANConfig.from_json_file(resolved_config_file)
logger.info("Model config {}".format(config))

# Instantiate model.
model = cls(config, *inputs, **kwargs)
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
Expand Down
Loading