| | import time |
| | import numpy as np |
| |
|
| | import Classify |
| | |
| |
|
| |
|
| | |
| | |
| | class Results: |
| | def __init__(self): |
| | self.accuracy = None |
| | self.precision = None |
| | self.recall = None |
| | self.F1 = None |
| |
|
| | |
| | def calculate_accuracy(self, correct_pos, correct_neg, total): |
| | return (correct_pos + correct_neg) / total if total > 0 else 0 |
| |
|
| | |
| | |
| | def calculate_precision(self, correct_pos, false_pos): |
| | return correct_pos / (correct_pos + false_pos) if (correct_pos + false_pos) > 0 else 0 |
| |
|
| | |
| | def calculate_recall(self, correct_pos, false_neg): |
| | return correct_pos / (correct_pos + false_neg) if (correct_pos + false_neg) > 0 else 0 |
| |
|
| | |
| | def calculateF1(self, precision, recall): |
| | return 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
| |
|
| |
|
| | |
| | class Evaluate: |
| | def __init__(self): |
| | self.model = None |
| | self.dataLoader=None |
| | self.threshold=0.5 |
| | pass |
| |
|
| | |
| | |
| | |
| | def run_test(self, verbose=True, visual=False): |
| | if self.model is None or self.dataLoader is None: |
| | raise AttributeError("Please choose a model to test before running the test") |
| | |
| | self.model.eval() |
| |
|
| | total, correct_pos, correct_neg, false_pos, false_neg = 0, 0, 0, 0, 0 |
| | running_average_time = 0.0 |
| | collated_results = Results() |
| | incorrect = [] |
| |
|
| | for image, gt in self.dataLoader: |
| | current_start_time = time.time() |
| | prediction = Classify.infer(image, self.model) |
| | running_average_time += time.time() - current_start_time |
| | positive, negative = prediction[prediction[:, 0] > self.threshold], prediction[prediction[:, 0] <= self.threshold] |
| | positive_gt, negative_gt = gt[prediction[:, 0] > self.threshold], gt[prediction[:, 0] <= self.threshold] |
| |
|
| | correct_pos += len(positive[positive_gt[:, 0]==1]) |
| | correct_neg += len(negative[negative_gt[:, 0]==0]) |
| | false_pos += len(positive[positive_gt[:, 0]==0]) |
| | false_neg += len(negative[negative_gt[:, 0]==1]) |
| | total += min(self.dataLoader.batch_size, len(image)) |
| |
|
| | false_pos_mask = (prediction[:, 0] > self.threshold) & (gt[:, 0].detach().numpy() == 0) |
| | false_neg_mask = (prediction[:, 0] < self.threshold) & (gt[:, 0].detach().numpy() == 1) |
| |
|
| | if len(false_pos_mask) > 0: |
| | incorrect.append((image[false_pos_mask], gt[false_pos_mask])) |
| | if len(false_neg_mask) > 0: |
| | incorrect.append((image[false_neg_mask], gt[false_neg_mask])) |
| |
|
| | if verbose: |
| | print(f"Total Images Processed: [{total}]," |
| | f" \nAccuracy: [{((correct_pos+correct_neg)/total)*100:.2f}%]," |
| | f" \nCorrect Positives: [{correct_pos}], Correct Negatives: [{correct_neg}]," |
| | f" \nFalse Positives: [{false_pos}], False Negatives [{false_neg}]," |
| | f" \nAverage Running Time (s) per image: [{running_average_time / total}]") |
| |
|
| | if visual and incorrect: |
| | for (img_set, lab_set) in incorrect: |
| | for (img, lab) in zip(img_set, lab_set): |
| | if len(img) > 0: |
| | Classify.infer_and_display(img, 0.5, lab) |
| |
|
| | return (correct_pos, correct_neg, false_pos, false_neg, total) |
| |
|
| | |
| | def test_MobileNet3_default(self, model_state_dict, test_num=1, verbose=True, visual=False) -> Results: |
| | import MobileNetV3 as mn3 |
| |
|
| | |
| | if test_num > len(mn3.dataset) * 0.05: |
| | test_num = int((len(mn3.dataset) - 1) * 0.05) |
| |
|
| | test_loader = mn3.DataLoader( |
| | mn3.Subset(mn3.dataset, mn3.random.sample(list(range(int(len(mn3.dataset) * 0.95), len(mn3.dataset))), test_num)), |
| | batch_size=mn3.batch_size, shuffle=False) |
| | |
| | test_model = Classify.load_mobileNet_classifier(model_state_dict) |
| |
|
| | self.model = test_model |
| | self.dataLoader = test_loader |
| |
|
| | correct_pos, correct_neg, false_pos, false_neg, total = self.run_test(verbose=verbose, visual=visual) |
| |
|
| | self.model=None |
| | self.dataLoader=None |
| | test_results = Results() |
| |
|
| | test_results.accuracy = test_results.calculate_accuracy(correct_pos, correct_neg, total) |
| | test_results.precision = test_results.calculate_precision(correct_pos, false_pos) |
| | test_results.recall = test_results.calculate_recall(correct_pos, false_neg) |
| | test_results.F1 = test_results.calculateF1(test_results.precision, test_results.recall) |
| |
|
| | return test_results |
| | |
| |
|
| | eval = Evaluate() |
| | if __name__ == "__main__": |
| | mn3_test_results = eval.test_MobileNet3_default("MobileNetV3_state_dict_big_train.pth", test_num=10000, visual=True) |
| |
|