Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

why not export .engine model or encountered the issues "Plugin not found...... " #7

Open
mingj2021 opened this issue Jun 8, 2023 · 0 comments

Comments

@mingj2021
Copy link
Owner

mingj2021 commented Jun 8, 2023

#tested On Windonws

conda vir envs

Package              Version
-------------------- ------------
absl-py              1.4.0
certifi              2022.12.7
charset-normalizer   2.1.1
colorama             0.4.6
coloredlogs          15.0.1
contourpy            1.0.7
cycler               0.11.0
filelock             3.9.0
flatbuffers          23.3.3
fonttools            4.39.3
humanfriendly        10.0
idna                 3.4
importlib-resources  5.12.0
Jinja2               3.1.2
kiwisolver           1.4.4
MarkupSafe           2.1.2
matplotlib           3.7.1
mpmath               1.2.1
networkx             3.0
numpy                1.23.4
onnx                 1.10.0
onnxruntime          1.14.1
opencv-python        4.7.0.72
packaging            23.1
Pillow               9.3.0
pip                  23.0.1
prettytable          3.7.0
protobuf             3.20.3
pycocotools          2.0.6
pyparsing            3.0.9
pyreadline3          3.4.1
python-dateutil      2.8.2
pytorch-quantization 2.1.2
PyYAML               6.0
requests             2.28.1
scipy                1.10.1
setuptools           66.0.0
six                  1.16.0
sphinx-glpi-theme    0.3
sympy                1.11.1
tensorrt             8.5.1.7
torch                2.0.0+cu117
torchaudio           2.0.1+cu117
torchvision          0.15.1+cu117
tqdm                 4.65.0
typing_extensions    4.4.0
urllib3              1.26.13
wcwidth              0.2.6
wheel                0.38.4
zipp                 3.15.0

the python scripts "demo.py"

import torch
import torch.nn as nn
from torch.nn import functional as F
from segment_anything.modeling import Sam
import numpy as np
from torchvision.transforms.functional import resize, to_pil_image
from typing import Tuple
from segment_anything import sam_model_registry, SamPredictor
import cv2
import matplotlib.pyplot as plt
import warnings
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
import onnxruntime

from pytorch_quantization import quant_modules
from pytorch_quantization import nn as quant_nn
from pytorch_quantization import calib

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  


def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
    """
    Compute the output size given input size and target long side length.
    """
    scale = long_side_length * 1.0 / max(oldh, oldw)
    newh, neww = oldh * scale, oldw * scale
    neww = int(neww + 0.5)
    newh = int(newh + 0.5)
    return (newh, neww)

# @torch.no_grad()
def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size):
    target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length)
    input_image = np.array(resize(to_pil_image(image), target_size))
    input_image_torch = torch.as_tensor(input_image, device=device)
    input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]

    # Normalize colors
    input_image_torch = (input_image_torch - pixel_mean) / pixel_std

    # Pad
    h, w = input_image_torch.shape[-2:]
    padh = img_size - h
    padw = img_size - w
    input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh))
    return input_image_torch

def export_embedding_model():
    sam_checkpoint = "weights/sam_vit_l_0b3195.pth"
    model_type = "vit_l"

    device = "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)

    # image_encoder = EmbeddingOnnxModel(sam)
    image = cv2.imread('notebooks/images/truck.jpg')
    target_length = sam.image_encoder.img_size
    pixel_mean = sam.pixel_mean 
    pixel_std = sam.pixel_std
    img_size = sam.image_encoder.img_size
    inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size)
    onnx_model_path = model_type+"_"+"embedding.onnx"
    dummy_inputs = {
    "images": inputs
}
    output_names = ["image_embeddings"]
    image_embeddings = sam.image_encoder(inputs).cpu().numpy()
    print('image_embeddings', image_embeddings.shape)
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        with open(onnx_model_path, "wb") as f:
            torch.onnx.export(
                sam.image_encoder,
                tuple(dummy_inputs.values()),
                f,
                export_params=True,
                verbose=False,
                opset_version=13,
                do_constant_folding=True,
                input_names=list(dummy_inputs.keys()),
                output_names=output_names,
                # dynamic_axes=dynamic_axes,
            )    

def export_sam_model():
    from segment_anything.utils.onnx import SamOnnxModel
    checkpoint = "weights/sam_vit_l_0b3195.pth"
    model_type = "vit_l"
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    onnx_model_path = "sam_onnx_example.onnx"

    onnx_model = SamOnnxModel(sam, return_single_mask=True)

    dynamic_axes = {
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"},
    }

    embed_dim = sam.prompt_encoder.embed_dim
    embed_size = sam.prompt_encoder.image_embedding_size
    mask_input_size = [4 * x for x in embed_size]
    dummy_inputs = {
        "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
        "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
        "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
        "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
        "has_mask_input": torch.tensor([1], dtype=torch.float),
        # "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
    }
    # output_names = ["masks", "iou_predictions", "low_res_masks"]
    output_names = ["masks", "scores"]

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)
        with open(onnx_model_path, "wb") as f:
            torch.onnx.export(
                onnx_model,
                tuple(dummy_inputs.values()),
                f,
                export_params=True,
                verbose=False,
                opset_version=13,
                do_constant_folding=True,
                input_names=list(dummy_inputs.keys()),
                output_names=output_names,
                dynamic_axes=dynamic_axes,
            )  

def onnx_model_example():
    import os
    ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider'])
    ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider'])

    image = cv2.imread('notebooks/images/truck.jpg')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image2 = image.copy()
    prompt_embed_dim = 256
    image_size = 1024
    vit_patch_size = 16
    image_embedding_size = image_size // vit_patch_size
    target_length = image_size
    pixel_mean=[123.675, 116.28, 103.53],
    pixel_std=[58.395, 57.12, 57.375]
    pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
    pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
    device = "cpu"
    inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size)
    ort_inputs = {
    "images": inputs.cpu().numpy()
    }
    image_embeddings = ort_session_embedding.run(None, ort_inputs)[0]

    input_point = np.array([[500, 375]])
    input_label = np.array([1])

    onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
    onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)
    from segment_anything.utils.transforms import ResizeLongestSide
    transf = ResizeLongestSide(image_size)
    onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
    onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    onnx_has_mask_input = np.zeros(1, dtype=np.float32)

    ort_inputs = {
        "image_embeddings": image_embeddings,
        "point_coords": onnx_coord,
        "point_labels": onnx_label,
        "mask_input": onnx_mask_input,
        "has_mask_input": onnx_has_mask_input,
        # "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
    }

    masks, _ = ort_session_sam.run(None, ort_inputs)

    from segment_anything.utils.onnx import SamOnnxModel
    checkpoint = "weights/sam_vit_l_0b3195.pth"
    model_type = "vit_l"
    sam = sam_model_registry[model_type](checkpoint=checkpoint)
    onnx_model_path = "sam_onnx_example.onnx"

    onnx_model = SamOnnxModel(sam, return_single_mask=True)
    masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2]))
    masks = masks > 0.0
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(masks, plt.gca())
    # show_box(input_box, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.axis('off')
    plt.savefig('demo.png')

def export_engine_image_encoder(f='vit_l_embedding.onnx'):
    import tensorrt as trt
    from pathlib import Path
    file = Path(f)
    f = file.with_suffix('.engine')  # TensorRT engine file
    onnx = file.with_suffix('.onnx')
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    workspace = 6
    print("workspace: ", workspace)
    config.max_workspace_size = workspace * 1 << 30
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, logger)
    if not parser.parse_from_file(str(onnx)):
        raise RuntimeError(f'failed to load ONNX file: {onnx}')

    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    outputs = [network.get_output(i) for i in range(network.num_outputs)]
    for inp in inputs:
        print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
    for out in outputs:
        print(f'output "{out.name}" with shape{out.shape} {out.dtype}')

    half = True
    print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
    if builder.platform_has_fast_fp16 and half:
        config.set_flag(trt.BuilderFlag.FP16)
    with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
        t.write(engine.serialize())


def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'):
    import tensorrt as trt
    from pathlib import Path
    file = Path(f)
    f = file.with_suffix('.engine')  # TensorRT engine file
    onnx = file.with_suffix('.onnx')
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    config = builder.create_builder_config()
    workspace = 10
    print("workspace: ", workspace)
    config.max_workspace_size = workspace * 1 << 30
    flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
    network = builder.create_network(flag)
    parser = trt.OnnxParser(network, logger)
    if not parser.parse_from_file(str(onnx)):
        raise RuntimeError(f'failed to load ONNX file: {onnx}')

    inputs = [network.get_input(i) for i in range(network.num_inputs)]
    outputs = [network.get_output(i) for i in range(network.num_outputs)]
    for inp in inputs:
        print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}')
    for out in outputs:
        print(f'output "{out.name}" with shape{out.shape} {out.dtype}')

    profile = builder.create_optimization_profile()
    profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64))
    profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2))
    profile.set_shape('point_labels', (1, 2), (1, 5), (1,10))
    profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256))
    profile.set_shape('has_mask_input', (1,), (1, ), (1, ))
    # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250))
    # profile.set_shape_input('orig_im_size', (2,), (2,), (2, ))
    config.add_optimization_profile(profile)

    half = True
    print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}')
    if builder.platform_has_fast_fp16 and half:
        config.set_flag(trt.BuilderFlag.FP16)
    with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
        t.write(engine.serialize())


def collect_stats(model, device, data_loader, num_batches):
    """Feed data to the network and collect statistic"""

    # Enable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.disable_quant()
                module.enable_calib()
            else:
                module.disable()

    for i, (path, im, im0s, vid_cap, s) in tqdm(enumerate(data_loader), total=num_batches):
        im = torch.from_numpy(im).to(device)
        im = im.float()
        im /= 255  # 0 - 255 to 0.0 - 1.0
        if len(im.shape) == 3:
            im = im[None]  # expand for batch dim
        model(im)
        if i >= num_batches:
            break

    # Disable calibrators
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                module.enable_quant()
                module.disable_calib()
            else:
                module.enable()

def compute_amax(model, **kwargs):
    # Load calib result
    for name, module in model.named_modules():
        if isinstance(module, quant_nn.TensorQuantizer):
            if module._calibrator is not None:
                if isinstance(module._calibrator, calib.MaxCalibrator):
                    module.load_calib_amax()
                else:
                    module.load_calib_amax(**kwargs)


def Post_training_quantization():
    ...
    

if __name__ == '__main__':
    with torch.no_grad():
        # export_embedding_model()
        # export_sam_model()
        # onnx_model_example()
        export_engine_image_encoder()
        # export_engine_prompt_encoder_and_mask_decoder()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant