Add confusion matrix for all model predictions

This commit is contained in:
Arne van Iterson 2023-10-22 20:13:30 +02:00
parent 58e521cb6e
commit ad1892112f

View File

@ -14,6 +14,7 @@ import json
# OpenCV # OpenCV
import numpy as np import numpy as np
import cv2 import cv2
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import ( from sklearn.preprocessing import (
MinMaxScaler, MinMaxScaler,
StandardScaler, StandardScaler,
@ -25,6 +26,7 @@ import joblib
# GUI # GUI
import pygubu import pygubu
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
# Helpers # Helpers
from helpers.statistics import imgStats from helpers.statistics import imgStats
@ -85,6 +87,7 @@ class CVSuite:
# Plots # Plots
self.axs = self.createPlot(2, 2) self.axs = self.createPlot(2, 2)
self.axs_cm = None
# UI Variables # UI Variables
self.canny_thr1 = None self.canny_thr1 = None
@ -133,17 +136,18 @@ class CVSuite:
print(C_INFO, f"Loading model {model}") print(C_INFO, f"Loading model {model}")
mpath = config_json["models"][model] mpath = config_json["models"][model]
if model == "knn": if model == "knn":
self.models.append(("KNN", CVSuiteTestKNN(mpath))) # Tuple with name, class instance and array of guesses for confusion matrix
self.models.append(("KNN", CVSuiteTestKNN(mpath), []))
elif model == "dectree": elif model == "dectree":
self.models.append( self.models.append(
("Decision Tree", CVSuiteTestDecisionTree(mpath)) ("Decision Tree", CVSuiteTestDecisionTree(mpath), [])
) )
elif model == "randforest": elif model == "randforest":
self.models.append( self.models.append(
("Random Forest", CVSuiteTestRandomForest(mpath)) ("Random Forest", CVSuiteTestRandomForest(mpath), [])
) )
elif model == "extratree": elif model == "extratree":
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath))) self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), []))
else: else:
print( print(
C_WARN, f"Model {model} does not exist or is not supported!" C_WARN, f"Model {model} does not exist or is not supported!"
@ -246,7 +250,7 @@ class CVSuite:
self.update() self.update()
def createPlot(self, columns, rows): def createPlot(self, columns, rows):
fig, axs = plt.subplots(columns, rows) fig, axs = plt.subplots(columns, rows, num=100)
return axs return axs
def drawHist(self, image, labels, column, row): def drawHist(self, image, labels, column, row):
@ -362,7 +366,7 @@ class CVSuite:
data = np.array([data], dtype=np.float32) data = np.array([data], dtype=np.float32)
for name, ins in self.models: for name, ins, guesses in self.models:
output.insert("end", f"{name} Result:\n") output.insert("end", f"{name} Result:\n")
# Predict result using model instance # Predict result using model instance
@ -370,9 +374,42 @@ class CVSuite:
# Prediciton result should be an array # Prediciton result should be an array
for idx, value in enumerate(result): for idx, value in enumerate(result):
if idx == 0:
guesses.append([Tree[tag.upper()].value, value])
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n") output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
print(C_DBUG, f"Guesses for {name}:", guesses)
output.configure(state="disabled") output.configure(state="disabled")
def drawConfusionMatrix(self, event=None):
if self.axs_cm is not None:
for ays in self.axs_cm:
for graph in ays:
graph.remove()
fig, axs = plt.subplots(2, 2, num=101)
self.axs_cm = axs
for idx, ays in enumerate(axs):
for idy, graph in enumerate(ays):
# Get guesses for current model
modelnr = (idx * 2) + idy
guesses = self.models[modelnr][2]
# Convert guess array
tag_true = [guess[0] for guess in guesses ]
tag_predict = [guess[1] for guess in guesses ]
labels = [Tree(tag).name for tag in range(0, 7)]
sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g')
graph.set_title(self.models[modelnr][0])
graph.set_xlabel("Predicted")
graph.set_ylabel("Actual")
graph.set_xticklabels(labels, rotation=0)
graph.set_yticklabels(labels, rotation=0)
# exit()
def updatePath(self): def updatePath(self):
""" """
@ -515,6 +552,8 @@ class CVSuite:
# Write results to CSV file # Write results to CSV file
if not part_update: if not part_update:
self.runTest(self.log.data) self.runTest(self.log.data)
self.drawConfusionMatrix()
self.log.update() self.log.update()
else: else:
self.log.clear() # Prevent partial updates from breaking log self.log.clear() # Prevent partial updates from breaking log