Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf export #107

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ transformers
datasets
timm
open_clip_torch
albumentations
albumentations==1.3.1
opencv-python==4.8.0.74
opencv-python-headless==4.8.0.74
git+https://github.com/facebookresearch/segment-anything.git
Expand Down
3 changes: 2 additions & 1 deletion examples/visualize_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def main(rank: int = 0, world_size: int = 1):
output_fmt='NLC',
intermediates_only=True,
aggregation=args.intermediate_aggregation,
norm_alpha_scheme="none",
)
assert args.adaptor_name is None
all_feat = [o[1] for o in outputs]
all_feat = outputs
else:
output = model(p_images)
if args.adaptor_name:
Expand Down
67 changes: 64 additions & 3 deletions hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ def main():
checkpoint = torch.load(args.checkpoint_path, map_location="cpu")
model_args = checkpoint["args"]

# Remove invalid identifier.
if hasattr(model_args, "enable_cudnn_attention"):
print(f'Removing attribute: enable-cudnn-attention!')
delattr(model_args, "enable-cudnn-attention")
if hasattr(model_args, "device"):
print(f'Removing attribute: device!')
delattr(model_args, "device")

# Extract the state dict from the checkpoint.
if "state_dict_ema" in checkpoint:
state_dict = checkpoint["state_dict_ema"]
Expand Down Expand Up @@ -161,11 +169,33 @@ def main():

adaptor_configs[adaptor_name] = adaptor_config


feat_norm_sd = get_prefix_state_dict(state_dict, '_feature_normalizer.')
feature_normalizer_config = None
if feat_norm_sd:
feature_normalizer_config = {
"embed_dim": feat_norm_sd['mean'].shape[0]
}

inter_feat_norm_sd = get_prefix_state_dict(state_dict, '_intermediate_feature_normalizer.')
inter_feature_normalizer_config = None
if inter_feat_norm_sd:
inter_feature_normalizer_config = {
"num_intermediates": inter_feat_norm_sd['means'].shape[0],
"embed_dim": inter_feat_norm_sd['means'].shape[1],
"rot_per_layer": inter_feat_norm_sd['rotation'].ndim == 3,
}

model_vars = vars(model_args)
model_vars.pop('enable-cudnn-attention', None)

radio_config = RADIOConfig(
vars(model_args),
model_vars,
version=args.version,
adaptor_names=adaptor_names,
adaptor_configs=adaptor_configs,
feature_normalizer_config=feature_normalizer_config,
inter_feature_normalizer_config=inter_feature_normalizer_config,
)
radio_model = RADIOModel(radio_config)

Expand Down Expand Up @@ -194,6 +224,12 @@ def main():
get_prefix_state_dict(state_dict, "input_conditioner.")
)

# Restore feature normalizer.
if feat_norm_sd:
radio_model.radio_model.feature_normalizer.load_state_dict(feat_norm_sd)
if inter_feat_norm_sd:
radio_model.radio_model.inter_feature_normalizer.load_state_dict(inter_feat_norm_sd)

radio_model.eval().cuda()

# Sample inference with deterministic values.
Expand All @@ -215,11 +251,27 @@ def main():
hf_summary, hf_features = v.summary, v.features

print(
f"[{k}] Sample inference on tensor shape {x.shape} returned summary ",
f"[{k}] HF inference on tensor shape {x.shape} returned summary ",
f"with shape={hf_summary.shape} and std={hf_summary.std().item():.3}, ",
f"features with shape={hf_features.shape} and std={hf_features.std().item():.3}",
)

intermediates = radio_model.radio_model.forward_intermediates(
x,
indices=[-1],
return_prefix_tokens=True,
norm=False,
stop_early=False,
output_fmt='NLC',
intermediates_only=True,
aggregation="sparse",
)
print(
f"Intermediates inference returned ",
f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}",
)
#assert torch.allclose(intermediates[0].features, hf_output["backbone"].features, atol=1e-4)

# Infer using TorchHub model.
print("Infer using TorchHub model...")
torchhub_model = torch.hub.load(
Expand All @@ -244,6 +296,12 @@ def main():
torchhub_output[k].features,
)

print(
f"[{k}] TorchHub inference on tensor shape {x.shape} returned summary ",
f"with shape={torchhub_summary.shape} and std={torchhub_summary.std().item():.3}, ",
f"features with shape={torchhub_features.shape} and std={torchhub_features.std().item():.3}",
)

# Make sure the shapes are the same.
assert (
hf_summary.shape == torchhub_summary.shape
Expand All @@ -262,6 +320,10 @@ def main():

print(f"{k} outputs matched!")



print("All outputs matched!")

if args.push:
# Push to HuggingFace Hub.
huggingface_repo = args.hf_repo
Expand All @@ -273,7 +335,6 @@ def main():
)
print(f"Pushed to {commit}")


if __name__ == "__main__":
"""Call the main entrypoiny."""
main()
4 changes: 4 additions & 0 deletions mmseg/radio.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
# Standard ViT case.
patch_height, patch_width = self.base_model.model.patch_embed.patch_size
features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous()
else:
B, _, C = features.shape
patch_height = patch_width = 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Don't we have a robust property self.base_model.patch_size?

features = features.reshape(B, math.ceil(H/patch_height), math.ceil(W/patch_width), C).permute(0, 3, 1, 2).contiguous()

# IMPORTANT: prevent gradients from flowing back towards the backbone.
features = features.detach()
Expand Down
2 changes: 1 addition & 1 deletion radio/enable_cpe_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from timm.models import VisionTransformer, checkpoint_seq

from radio.feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer
from .feature_normalizer import IntermediateFeatureNormalizerBase, NullIntermediateFeatureNormalizer

from .extra_models import DinoWrapper
from .vit_patch_generator import ViTPatchGenerator
Expand Down
24 changes: 24 additions & 0 deletions radio/hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
from .adaptor_mlp import create_mlp_from_config
from .adaptor_registry import adaptor_registry
from .cls_token import ClsToken
from .dinov2_arch import dinov2_vitg14_reg
from .enable_cpe_support import enable_cpe
from .enable_spectral_reparam import configure_spectral_reparam_from_args
from .eradio_model import eradio
from .feature_normalizer import FeatureNormalizer, IntermediateFeatureNormalizer
from .forward_intermediates import forward_intermediates
from .radio_model import create_model_from_args
from .radio_model import RADIOModel as RADIOModelBase, Resolution
from .input_conditioner import get_default_conditioner, InputConditioner
Expand All @@ -40,6 +43,7 @@

# Register extra models
from .extra_timm_models import *
from .extra_models import *


class RADIOConfig(PretrainedConfig):
Expand All @@ -55,6 +59,8 @@ def __init__(
adaptor_names: Union[str, List[str]] = None,
adaptor_configs: Dict[str, Dict[str, int]] = None,
vitdet_window_size: Optional[int] = None,
feature_normalizer_config: Optional[dict] = None,
inter_feature_normalizer_config: Optional[dict] = None,
**kwargs,
):
self.args = args
Expand All @@ -74,9 +80,12 @@ def __init__(
self.adaptor_names = adaptor_names
self.adaptor_configs = adaptor_configs
self.vitdet_window_size = vitdet_window_size
self.feature_normalizer_config = feature_normalizer_config
self.inter_feature_normalizer_config = inter_feature_normalizer_config
super().__init__(**kwargs)



class RADIOModel(PreTrainedModel):
"""Pretrained Hugging Face model for RADIO.

Expand Down Expand Up @@ -118,6 +127,19 @@ def __init__(self, config: RADIOConfig):
adaptor.head_idx = mlp_config["head_idx"]
adaptors[adaptor_name] = adaptor

feature_normalizer = None
if config.feature_normalizer_config is not None:
# Actual normalization values will be restored when loading checkpoint weights.
feature_normalizer = FeatureNormalizer(config.feature_normalizer_config["embed_dim"])

inter_feature_normalizer = None
if config.inter_feature_normalizer_config is not None:
inter_feature_normalizer = IntermediateFeatureNormalizer(
config.inter_feature_normalizer_config["num_intermediates"],
config.inter_feature_normalizer_config["embed_dim"],
rot_per_layer=config.inter_feature_normalizer_config["rot_per_layer"],
dtype=dtype)

self.radio_model = RADIOModelBase(
model,
input_conditioner,
Expand All @@ -127,6 +149,8 @@ def __init__(self, config: RADIOConfig):
window_size=config.vitdet_window_size,
preferred_resolution=config.preferred_resolution,
adaptors=adaptors,
feature_normalizer=feature_normalizer,
inter_feature_normalizer=inter_feature_normalizer,
)

@property
Expand Down
27 changes: 24 additions & 3 deletions test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def main():
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio_v2.1_bf16.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --torchhub-repo NVlabs/RADIO:dev/hf
python3 -m test_hf --hf-repo gheinrich/RADIO --torchhub-version ./radio-v2.5-l_half.pth.tar --adaptor-names siglip,sam
python3 -m test_hf --hf-repo gheinrich/RADIO-NORM --torchhub-version /lustre/fs6/portfolios/llmservice/users/mranzinger/output/evfm/hero/n32_8-19-24_vit-h-16_hero-v4_s3/checkpoints/last_norm_release_half.pth.tar --torchhub-repo NVlabs/RADIO:mranzinger/ship_paper
"""
parser = argparse.ArgumentParser()
parser.add_argument("--hf-repo", help="Path to the HuggingFace repo", required=True)
Expand All @@ -53,6 +54,9 @@ def main():
parser.add_argument(
"--torchhub-repo", help="Path to the Torchhub repo", default="NVlabs/RADIO"
)
parser.add_argument(
"--hf-revision", help="HuggingFace revision to checkout", default="main"
)
parser.add_argument(
"--adaptor-names",
default=None,
Expand All @@ -63,13 +67,13 @@ def main():

args = parser.parse_args()

hf_config = AutoConfig.from_pretrained(args.hf_repo, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(args.hf_repo, revision=args.hf_revision, trust_remote_code=True)
if args.adaptor_names is not None:
# Configure adaptors if specified on the command line.
# This needs to happen before we instantiate the model.
hf_config.adaptor_names = args.adaptor_names
hf_model = AutoModel.from_pretrained(
args.hf_repo, trust_remote_code=True, config=hf_config
args.hf_repo, revision=args.hf_revision, trust_remote_code=True, config=hf_config
)
hf_model.eval().cuda()

Expand Down Expand Up @@ -126,10 +130,27 @@ def main():
assert torch.allclose(hf_summary, torchhub_summary, atol=1e-6)
assert torch.allclose(hf_features, torchhub_features, atol=1e-6)

intermediates = hf_model.radio_model.forward_intermediates(
hf_model.input_conditioner(x),
indices=[-1],
return_prefix_tokens=True,
norm=False,
stop_early=False,
output_fmt='NLC',
intermediates_only=True,
aggregation="sparse",
)
print(
f"Intermediates inference returned summary ",
f"with shape={intermediates[0].summary.shape} and std={intermediates[0].summary.std().item():.3}, ",
f"features with shape={intermediates[0].features.shape} and std={intermediates[0].features.std().item():.3}",
)
#assert torch.allclose(intermediates[0].features, torchhub_output["backbone"].features, atol=1e-6)

print("All outputs matched!")

# Infer a sample image.
image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo)
image_processor = CLIPImageProcessor.from_pretrained(args.hf_repo, revision=args.hf_revision)

image = Image.open("./examples/image1.png").convert("RGB")
pixel_values = image_processor(images=image, return_tensors="pt").pixel_values
Expand Down