| | import os |
| | import argparse |
| | import random |
| | import onnxruntime |
| | import numpy as np |
| |
|
| | import torch |
| | from torch.nn import functional as F |
| | from torch.utils import data |
| |
|
| | import cv2 |
| | from PIL import Image |
| | from tqdm import tqdm |
| |
|
| | from utils import input_transform, pad_image, resize_image, preprocess, get_confusion_matrix |
| |
|
| | parser = argparse.ArgumentParser(description='HRNet') |
| | parser.add_argument('-m', '--onnx-model', default='', |
| | type=str, help='Path to onnx model.') |
| | parser.add_argument('-r', '--root', default='', |
| | type=str, help='Path to dataset root.') |
| | parser.add_argument('-l', '--list_path', default='', |
| | type=str, help='Path to dataset list.') |
| | parser.add_argument("--ipu", action="store_true", help="Use IPU for inference.") |
| | parser.add_argument("--provider_config", type=str, |
| | default="vaip_config.json", help="Path of the config file for seting provider_options.") |
| | args = parser.parse_args() |
| |
|
| | INPUT_SIZE = [512, 1024] |
| | NUM_CLASSES = 19 |
| | IGNORE_LABEL = 255 |
| |
|
| |
|
| | class Cityscapes(data.Dataset): |
| | def __init__(self, |
| | root, |
| | list_path, |
| | num_classes=19, |
| | downsample_rate=8, |
| | ignore_label=-1): |
| |
|
| | self.root = root |
| | self.list_path = list_path |
| | self.num_classes = num_classes |
| | self.downsample_rate = downsample_rate |
| | |
| | self.img_list = [line.strip().split() for line in open(root+list_path)] |
| |
|
| | self.files = self.read_files() |
| |
|
| | self.label_mapping = {-1: ignore_label, 0: ignore_label, |
| | 1: ignore_label, 2: ignore_label, |
| | 3: ignore_label, 4: ignore_label, |
| | 5: ignore_label, 6: ignore_label, |
| | 7: 0, 8: 1, 9: ignore_label, |
| | 10: ignore_label, 11: 2, 12: 3, |
| | 13: 4, 14: ignore_label, 15: ignore_label, |
| | 16: ignore_label, 17: 5, 18: ignore_label, |
| | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, |
| | 25: 12, 26: 13, 27: 14, 28: 15, |
| | 29: ignore_label, 30: ignore_label, |
| | 31: 16, 32: 17, 33: 18} |
| | |
| | def read_files(self): |
| | files = [] |
| | for item in self.img_list: |
| | image_path, label_path = item |
| | name = os.path.splitext(os.path.basename(label_path))[0] |
| | files.append({ |
| | "img": image_path, |
| | "label": label_path, |
| | "name": name, |
| | }) |
| | return files |
| |
|
| | def __len__(self): |
| | return len(self.files) |
| | |
| | def convert_label(self, label, inverse=False): |
| | temp = label.copy() |
| | if inverse: |
| | for v, k in self.label_mapping.items(): |
| | label[temp == k] = v |
| | else: |
| | for k, v in self.label_mapping.items(): |
| | label[temp == k] = v |
| | return label |
| |
|
| | def __getitem__(self, index): |
| | item = self.files[index] |
| | image = cv2.imread(os.path.join(self.root, item["img"]), |
| | cv2.IMREAD_COLOR) |
| | label = cv2.imread(os.path.join(self.root, item["label"]), |
| | cv2.IMREAD_GRAYSCALE) |
| | label = self.convert_label(label) |
| | image, label = self.gen_sample(image, label) |
| |
|
| | return image.copy(), label.copy() |
| |
|
| | def gen_sample(self, image, label): |
| | label = self.label_transform(label) |
| | |
| |
|
| | if self.downsample_rate != 1: |
| | label = cv2.resize( |
| | label, |
| | None, |
| | fx=self.downsample_rate, |
| | fy=self.downsample_rate, |
| | interpolation=cv2.INTER_NEAREST |
| | ) |
| |
|
| | return image, label |
| |
|
| | def label_transform(self, label): |
| | return np.array(label).astype('int32') |
| |
|
| |
|
| | def run_onnx_inference(ort_session, img): |
| | """Infer an image with onnx seession |
| | |
| | Args: |
| | ort_session: Onnx session |
| | img (ndarray): Image to be infered. |
| | |
| | Returns: |
| | ndarray: Model inference result. |
| | """ |
| | pre_img, pad_h, pad_w = preprocess(img) |
| | |
| |
|
| | img = np.expand_dims(pre_img, 0) |
| | img = np.transpose(img, (0,2,3,1)) |
| |
|
| | ort_inputs = {ort_session.get_inputs()[0].name: img} |
| | o1 = ort_session.run(None, ort_inputs)[0] |
| | h, w = o1.shape[-2:] |
| | h_cut = int(h / INPUT_SIZE[0] * pad_h) |
| | w_cut = int(w / INPUT_SIZE[1] * pad_w) |
| | o1 = o1[..., :h - h_cut, :w - w_cut] |
| | return o1 |
| |
|
| |
|
| | def testval(ort_session, root, list_path): |
| |
|
| | test_dataset = Cityscapes( |
| | root=root, |
| | list_path=list_path, |
| | num_classes=NUM_CLASSES, |
| | ignore_label=IGNORE_LABEL, |
| | downsample_rate=1) |
| |
|
| | testloader = torch.utils.data.DataLoader( |
| | test_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | num_workers=4, |
| | pin_memory=True) |
| |
|
| | confusion_matrix = np.zeros( |
| | (NUM_CLASSES, NUM_CLASSES)) |
| | for index, batch in enumerate(tqdm(testloader)): |
| | image, label = batch |
| | image = image.numpy()[0] |
| | out = run_onnx_inference(ort_session, image) |
| | size = label.size() |
| | |
| | out = out.transpose(0, 3, 1, 2) |
| | if out.shape[2] != size[1] or out.shape[3] != size[2]: |
| | out = torch.from_numpy(out).cpu() |
| | pred = F.interpolate( |
| | out, size=size[1:], |
| | mode='bilinear' |
| | ) |
| |
|
| | confusion_matrix += get_confusion_matrix( |
| | label, |
| | pred, |
| | size, |
| | NUM_CLASSES, |
| | IGNORE_LABEL) |
| |
|
| | pos = confusion_matrix.sum(1) |
| | res = confusion_matrix.sum(0) |
| | tp = np.diag(confusion_matrix) |
| | pixel_acc = tp.sum()/pos.sum() |
| | mean_acc = (tp/np.maximum(1.0, pos)).mean() |
| | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) |
| | mean_IoU = IoU_array.mean() |
| |
|
| | return mean_IoU, IoU_array, pixel_acc, mean_acc |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | onnx_path = args.onnx_model |
| | root = args.root |
| | list_path = args.list_path |
| | if args.ipu: |
| | providers = ["VitisAIExecutionProvider"] |
| | provider_options = [{"config_file": args.provider_config}] |
| | else: |
| | providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] |
| | provider_options = None |
| | |
| | ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options) |
| | |
| | mean_IoU, IoU_array, pixel_acc, mean_acc = testval(ort_session, root, list_path) |
| |
|
| | msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, Mean_Acc: {: 4.4f}'.format(mean_IoU, \ |
| | pixel_acc, mean_acc) |
| | print(msg) |
| |
|