diff --git a/src/suite.py b/src/suite.py index d7be901..e458516 100644 --- a/src/suite.py +++ b/src/suite.py @@ -14,6 +14,7 @@ import json # OpenCV import numpy as np import cv2 +from sklearn.metrics import confusion_matrix from sklearn.preprocessing import ( MinMaxScaler, StandardScaler, @@ -25,6 +26,7 @@ import joblib # GUI import pygubu import matplotlib.pyplot as plt +import seaborn as sns # Helpers from helpers.statistics import imgStats @@ -85,6 +87,7 @@ class CVSuite: # Plots self.axs = self.createPlot(2, 2) + self.axs_cm = None # UI Variables self.canny_thr1 = None @@ -133,17 +136,18 @@ class CVSuite: print(C_INFO, f"Loading model {model}") mpath = config_json["models"][model] if model == "knn": - self.models.append(("KNN", CVSuiteTestKNN(mpath))) + # Tuple with name, class instance and array of guesses for confusion matrix + self.models.append(("KNN", CVSuiteTestKNN(mpath), [])) elif model == "dectree": self.models.append( - ("Decision Tree", CVSuiteTestDecisionTree(mpath)) + ("Decision Tree", CVSuiteTestDecisionTree(mpath), []) ) elif model == "randforest": self.models.append( - ("Random Forest", CVSuiteTestRandomForest(mpath)) + ("Random Forest", CVSuiteTestRandomForest(mpath), []) ) elif model == "extratree": - self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath))) + self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), [])) else: print( C_WARN, f"Model {model} does not exist or is not supported!" @@ -246,7 +250,7 @@ class CVSuite: self.update() def createPlot(self, columns, rows): - fig, axs = plt.subplots(columns, rows) + fig, axs = plt.subplots(columns, rows, num=100) return axs def drawHist(self, image, labels, column, row): @@ -362,7 +366,7 @@ class CVSuite: data = np.array([data], dtype=np.float32) - for name, ins in self.models: + for name, ins, guesses in self.models: output.insert("end", f"{name} Result:\n") # Predict result using model instance @@ -370,9 +374,42 @@ class CVSuite: # Prediciton result should be an array for idx, value in enumerate(result): + if idx == 0: + guesses.append([Tree[tag.upper()].value, value]) output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n") + + print(C_DBUG, f"Guesses for {name}:", guesses) output.configure(state="disabled") + + def drawConfusionMatrix(self, event=None): + if self.axs_cm is not None: + for ays in self.axs_cm: + for graph in ays: + graph.remove() + + fig, axs = plt.subplots(2, 2, num=101) + self.axs_cm = axs + + for idx, ays in enumerate(axs): + for idy, graph in enumerate(ays): + # Get guesses for current model + modelnr = (idx * 2) + idy + guesses = self.models[modelnr][2] + + # Convert guess array + tag_true = [guess[0] for guess in guesses ] + tag_predict = [guess[1] for guess in guesses ] + + labels = [Tree(tag).name for tag in range(0, 7)] + + sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g') + graph.set_title(self.models[modelnr][0]) + graph.set_xlabel("Predicted") + graph.set_ylabel("Actual") + graph.set_xticklabels(labels, rotation=0) + graph.set_yticklabels(labels, rotation=0) + # exit() def updatePath(self): """ @@ -515,6 +552,8 @@ class CVSuite: # Write results to CSV file if not part_update: self.runTest(self.log.data) + self.drawConfusionMatrix() + self.log.update() else: self.log.clear() # Prevent partial updates from breaking log