Add confusion matrix for all model predictions
This commit is contained in:
parent
58e521cb6e
commit
ad1892112f
51
src/suite.py
51
src/suite.py
@ -14,6 +14,7 @@ import json
|
||||
# OpenCV
|
||||
import numpy as np
|
||||
import cv2
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.preprocessing import (
|
||||
MinMaxScaler,
|
||||
StandardScaler,
|
||||
@ -25,6 +26,7 @@ import joblib
|
||||
# GUI
|
||||
import pygubu
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
# Helpers
|
||||
from helpers.statistics import imgStats
|
||||
@ -85,6 +87,7 @@ class CVSuite:
|
||||
|
||||
# Plots
|
||||
self.axs = self.createPlot(2, 2)
|
||||
self.axs_cm = None
|
||||
|
||||
# UI Variables
|
||||
self.canny_thr1 = None
|
||||
@ -133,17 +136,18 @@ class CVSuite:
|
||||
print(C_INFO, f"Loading model {model}")
|
||||
mpath = config_json["models"][model]
|
||||
if model == "knn":
|
||||
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
|
||||
# Tuple with name, class instance and array of guesses for confusion matrix
|
||||
self.models.append(("KNN", CVSuiteTestKNN(mpath), []))
|
||||
elif model == "dectree":
|
||||
self.models.append(
|
||||
("Decision Tree", CVSuiteTestDecisionTree(mpath))
|
||||
("Decision Tree", CVSuiteTestDecisionTree(mpath), [])
|
||||
)
|
||||
elif model == "randforest":
|
||||
self.models.append(
|
||||
("Random Forest", CVSuiteTestRandomForest(mpath))
|
||||
("Random Forest", CVSuiteTestRandomForest(mpath), [])
|
||||
)
|
||||
elif model == "extratree":
|
||||
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
|
||||
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), []))
|
||||
else:
|
||||
print(
|
||||
C_WARN, f"Model {model} does not exist or is not supported!"
|
||||
@ -246,7 +250,7 @@ class CVSuite:
|
||||
self.update()
|
||||
|
||||
def createPlot(self, columns, rows):
|
||||
fig, axs = plt.subplots(columns, rows)
|
||||
fig, axs = plt.subplots(columns, rows, num=100)
|
||||
return axs
|
||||
|
||||
def drawHist(self, image, labels, column, row):
|
||||
@ -362,7 +366,7 @@ class CVSuite:
|
||||
|
||||
data = np.array([data], dtype=np.float32)
|
||||
|
||||
for name, ins in self.models:
|
||||
for name, ins, guesses in self.models:
|
||||
output.insert("end", f"{name} Result:\n")
|
||||
|
||||
# Predict result using model instance
|
||||
@ -370,10 +374,43 @@ class CVSuite:
|
||||
|
||||
# Prediciton result should be an array
|
||||
for idx, value in enumerate(result):
|
||||
if idx == 0:
|
||||
guesses.append([Tree[tag.upper()].value, value])
|
||||
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
|
||||
|
||||
print(C_DBUG, f"Guesses for {name}:", guesses)
|
||||
|
||||
output.configure(state="disabled")
|
||||
|
||||
def drawConfusionMatrix(self, event=None):
|
||||
if self.axs_cm is not None:
|
||||
for ays in self.axs_cm:
|
||||
for graph in ays:
|
||||
graph.remove()
|
||||
|
||||
fig, axs = plt.subplots(2, 2, num=101)
|
||||
self.axs_cm = axs
|
||||
|
||||
for idx, ays in enumerate(axs):
|
||||
for idy, graph in enumerate(ays):
|
||||
# Get guesses for current model
|
||||
modelnr = (idx * 2) + idy
|
||||
guesses = self.models[modelnr][2]
|
||||
|
||||
# Convert guess array
|
||||
tag_true = [guess[0] for guess in guesses ]
|
||||
tag_predict = [guess[1] for guess in guesses ]
|
||||
|
||||
labels = [Tree(tag).name for tag in range(0, 7)]
|
||||
|
||||
sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g')
|
||||
graph.set_title(self.models[modelnr][0])
|
||||
graph.set_xlabel("Predicted")
|
||||
graph.set_ylabel("Actual")
|
||||
graph.set_xticklabels(labels, rotation=0)
|
||||
graph.set_yticklabels(labels, rotation=0)
|
||||
# exit()
|
||||
|
||||
def updatePath(self):
|
||||
"""
|
||||
Only update image name and path
|
||||
@ -515,6 +552,8 @@ class CVSuite:
|
||||
# Write results to CSV file
|
||||
if not part_update:
|
||||
self.runTest(self.log.data)
|
||||
self.drawConfusionMatrix()
|
||||
|
||||
self.log.update()
|
||||
else:
|
||||
self.log.clear() # Prevent partial updates from breaking log
|
||||
|
Loading…
Reference in New Issue
Block a user