We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#tested On Windonws
Package Version -------------------- ------------ absl-py 1.4.0 certifi 2022.12.7 charset-normalizer 2.1.1 colorama 0.4.6 coloredlogs 15.0.1 contourpy 1.0.7 cycler 0.11.0 filelock 3.9.0 flatbuffers 23.3.3 fonttools 4.39.3 humanfriendly 10.0 idna 3.4 importlib-resources 5.12.0 Jinja2 3.1.2 kiwisolver 1.4.4 MarkupSafe 2.1.2 matplotlib 3.7.1 mpmath 1.2.1 networkx 3.0 numpy 1.23.4 onnx 1.10.0 onnxruntime 1.14.1 opencv-python packaging 23.1 Pillow 9.3.0 pip 23.0.1 prettytable 3.7.0 protobuf 3.20.3 pycocotools 2.0.6 pyparsing 3.0.9 pyreadline3 3.4.1 python-dateutil 2.8.2 pytorch-quantization 2.1.2 PyYAML 6.0 requests 2.28.1 scipy 1.10.1 setuptools 66.0.0 six 1.16.0 sphinx-glpi-theme 0.3 sympy 1.11.1 tensorrt torch 2.0.0+cu117 torchaudio 2.0.1+cu117 torchvision 0.15.1+cu117 tqdm 4.65.0 typing_extensions 4.4.0 urllib3 1.26.13 wcwidth 0.2.6 wheel 0.38.4 zipp 3.15.0
import torch import torch.nn as nn from torch.nn import functional as F from segment_anything.modeling import Sam import numpy as np from torchvision.transforms.functional import resize, to_pil_image from typing import Tuple from segment_anything import sam_model_registry, SamPredictor import cv2 import matplotlib.pyplot as plt import warnings import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" import onnxruntime from pytorch_quantization import quant_modules from pytorch_quantization import nn as quant_nn from pytorch_quantization import calib def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30/255, 144/255, 255/255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels==1] neg_points = coords[labels==0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: """ Compute the output size given input size and target long side length. """ scale = long_side_length * 1.0 / max(oldh, oldw) newh, neww = oldh * scale, oldw * scale neww = int(neww + 0.5) newh = int(newh + 0.5) return (newh, neww) # @torch.no_grad() def pre_processing(image: np.ndarray, target_length: int, device,pixel_mean,pixel_std,img_size): target_size = get_preprocess_shape(image.shape[0], image.shape[1], target_length) input_image = np.array(resize(to_pil_image(image), target_size)) input_image_torch = torch.as_tensor(input_image, device=device) input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] # Normalize colors input_image_torch = (input_image_torch - pixel_mean) / pixel_std # Pad h, w = input_image_torch.shape[-2:] padh = img_size - h padw = img_size - w input_image_torch = F.pad(input_image_torch, (0, padw, 0, padh)) return input_image_torch def export_embedding_model(): sam_checkpoint = "weights/sam_vit_l_0b3195.pth" model_type = "vit_l" device = "cpu" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) # image_encoder = EmbeddingOnnxModel(sam) image = cv2.imread('notebooks/images/truck.jpg') target_length = sam.image_encoder.img_size pixel_mean = sam.pixel_mean pixel_std = sam.pixel_std img_size = sam.image_encoder.img_size inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,img_size) onnx_model_path = model_type+"_"+"embedding.onnx" dummy_inputs = { "images": inputs } output_names = ["image_embeddings"] image_embeddings = sam.image_encoder(inputs).cpu().numpy() print('image_embeddings', image_embeddings.shape) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) with open(onnx_model_path, "wb") as f: torch.onnx.export( sam.image_encoder, tuple(dummy_inputs.values()), f, export_params=True, verbose=False, opset_version=13, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, # dynamic_axes=dynamic_axes, ) def export_sam_model(): from segment_anything.utils.onnx import SamOnnxModel checkpoint = "weights/sam_vit_l_0b3195.pth" model_type = "vit_l" sam = sam_model_registry[model_type](checkpoint=checkpoint) onnx_model_path = "sam_onnx_example.onnx" onnx_model = SamOnnxModel(sam, return_single_mask=True) dynamic_axes = { "point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}, } embed_dim = sam.prompt_encoder.embed_dim embed_size = sam.prompt_encoder.image_embedding_size mask_input_size = [4 * x for x in embed_size] dummy_inputs = { "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), "has_mask_input": torch.tensor([1], dtype=torch.float), # "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), } # output_names = ["masks", "iou_predictions", "low_res_masks"] output_names = ["masks", "scores"] with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) warnings.filterwarnings("ignore", category=UserWarning) with open(onnx_model_path, "wb") as f: torch.onnx.export( onnx_model, tuple(dummy_inputs.values()), f, export_params=True, verbose=False, opset_version=13, do_constant_folding=True, input_names=list(dummy_inputs.keys()), output_names=output_names, dynamic_axes=dynamic_axes, ) def onnx_model_example(): import os ort_session_embedding = onnxruntime.InferenceSession('vit_l_embedding.onnx',providers=['CPUExecutionProvider']) ort_session_sam = onnxruntime.InferenceSession('sam_onnx_example.onnx',providers=['CPUExecutionProvider']) image = cv2.imread('notebooks/images/truck.jpg') image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image2 = image.copy() prompt_embed_dim = 256 image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size target_length = image_size pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375] pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) device = "cpu" inputs = pre_processing(image, target_length, device,pixel_mean,pixel_std,image_size) ort_inputs = { "images": inputs.cpu().numpy() } image_embeddings = ort_session_embedding.run(None, ort_inputs)[0] input_point = np.array([[500, 375]]) input_label = np.array([1]) onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :] onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32) from segment_anything.utils.transforms import ResizeLongestSide transf = ResizeLongestSide(image_size) onnx_coord = transf.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32) onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) onnx_has_mask_input = np.zeros(1, dtype=np.float32) ort_inputs = { "image_embeddings": image_embeddings, "point_coords": onnx_coord, "point_labels": onnx_label, "mask_input": onnx_mask_input, "has_mask_input": onnx_has_mask_input, # "orig_im_size": np.array(image.shape[:2], dtype=np.float32) } masks, _ = ort_session_sam.run(None, ort_inputs) from segment_anything.utils.onnx import SamOnnxModel checkpoint = "weights/sam_vit_l_0b3195.pth" model_type = "vit_l" sam = sam_model_registry[model_type](checkpoint=checkpoint) onnx_model_path = "sam_onnx_example.onnx" onnx_model = SamOnnxModel(sam, return_single_mask=True) masks = onnx_model.mask_postprocessing(torch.as_tensor(masks), torch.as_tensor(image.shape[:2])) masks = masks > 0.0 plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(masks, plt.gca()) # show_box(input_box, plt.gca()) show_points(input_point, input_label, plt.gca()) plt.axis('off') plt.savefig('demo.png') def export_engine_image_encoder(f='vit_l_embedding.onnx'): import tensorrt as trt from pathlib import Path file = Path(f) f = file.with_suffix('.engine') # TensorRT engine file onnx = file.with_suffix('.onnx') logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) config = builder.create_builder_config() workspace = 6 print("workspace: ", workspace) config.max_workspace_size = workspace * 1 << 30 flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(str(onnx)): raise RuntimeError(f'failed to load ONNX file: {onnx}') inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] for inp in inputs: print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') for out in outputs: print(f'output "{out.name}" with shape{out.shape} {out.dtype}') half = True print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') if builder.platform_has_fast_fp16 and half: config.set_flag(trt.BuilderFlag.FP16) with builder.build_engine(network, config) as engine, open(f, 'wb') as t: t.write(engine.serialize()) def export_engine_prompt_encoder_and_mask_decoder(f='sam_onnx_example.onnx'): import tensorrt as trt from pathlib import Path file = Path(f) f = file.with_suffix('.engine') # TensorRT engine file onnx = file.with_suffix('.onnx') logger = trt.Logger(trt.Logger.INFO) builder = trt.Builder(logger) config = builder.create_builder_config() workspace = 10 print("workspace: ", workspace) config.max_workspace_size = workspace * 1 << 30 flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) network = builder.create_network(flag) parser = trt.OnnxParser(network, logger) if not parser.parse_from_file(str(onnx)): raise RuntimeError(f'failed to load ONNX file: {onnx}') inputs = [network.get_input(i) for i in range(network.num_inputs)] outputs = [network.get_output(i) for i in range(network.num_outputs)] for inp in inputs: print(f'input "{inp.name}" with shape{inp.shape} {inp.dtype}') for out in outputs: print(f'output "{out.name}" with shape{out.shape} {out.dtype}') profile = builder.create_optimization_profile() profile.set_shape('image_embeddings', (1, 256, 64, 64), (1, 256, 64, 64), (1, 256, 64, 64)) profile.set_shape('point_coords', (1, 2,2), (1, 5,2), (1,10,2)) profile.set_shape('point_labels', (1, 2), (1, 5), (1,10)) profile.set_shape('mask_input', (1, 1, 256, 256), (1, 1, 256, 256), (1, 1, 256, 256)) profile.set_shape('has_mask_input', (1,), (1, ), (1, )) # # profile.set_shape_input('orig_im_size', (416,416), (1024,1024), (1500, 2250)) # profile.set_shape_input('orig_im_size', (2,), (2,), (2, )) config.add_optimization_profile(profile) half = True print(f'building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine as {f}') if builder.platform_has_fast_fp16 and half: config.set_flag(trt.BuilderFlag.FP16) with builder.build_engine(network, config) as engine, open(f, 'wb') as t: t.write(engine.serialize()) def collect_stats(model, device, data_loader, num_batches): """Feed data to the network and collect statistic""" # Enable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, (path, im, im0s, vid_cap, s) in tqdm(enumerate(data_loader), total=num_batches): im = torch.from_numpy(im).to(device) im = im.float() im /= 255 # 0 - 255 to 0.0 - 1.0 if len(im.shape) == 3: im = im[None] # expand for batch dim model(im) if i >= num_batches: break # Disable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.enable_quant() module.disable_calib() else: module.enable() def compute_amax(model, **kwargs): # Load calib result for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: if isinstance(module._calibrator, calib.MaxCalibrator): module.load_calib_amax() else: module.load_calib_amax(**kwargs) def Post_training_quantization(): ... if __name__ == '__main__': with torch.no_grad(): # export_embedding_model() # export_sam_model() # onnx_model_example() export_engine_image_encoder() # export_engine_prompt_encoder_and_mask_decoder()
The text was updated successfully, but these errors were encountered:
No branches or pull requests
#tested On Windonws
conda vir envs
the python scripts "demo.py"
The text was updated successfully, but these errors were encountered: