Compare commits

...

2 Commits

View File

@ -14,6 +14,7 @@ import json
# OpenCV
import numpy as np
import cv2
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import (
MinMaxScaler,
StandardScaler,
@ -25,6 +26,7 @@ import joblib
# GUI
import pygubu
import matplotlib.pyplot as plt
import seaborn as sns
# Helpers
from helpers.statistics import imgStats
@ -85,6 +87,7 @@ class CVSuite:
# Plots
self.axs = self.createPlot(2, 2)
self.axs_cm = None
# UI Variables
self.canny_thr1 = None
@ -133,17 +136,18 @@ class CVSuite:
print(C_INFO, f"Loading model {model}")
mpath = config_json["models"][model]
if model == "knn":
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
# Tuple with name, class instance and array of guesses for confusion matrix
self.models.append(("KNN", CVSuiteTestKNN(mpath), []))
elif model == "dectree":
self.models.append(
("Decision Tree", CVSuiteTestDecisionTree(mpath))
("Decision Tree", CVSuiteTestDecisionTree(mpath), [])
)
elif model == "randforest":
self.models.append(
("Random Forest", CVSuiteTestRandomForest(mpath))
("Random Forest", CVSuiteTestRandomForest(mpath), [])
)
elif model == "extratree":
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), []))
else:
print(
C_WARN, f"Model {model} does not exist or is not supported!"
@ -246,7 +250,7 @@ class CVSuite:
self.update()
def createPlot(self, columns, rows):
fig, axs = plt.subplots(columns, rows)
fig, axs = plt.subplots(columns, rows, num=100)
return axs
def drawHist(self, image, labels, column, row):
@ -362,7 +366,7 @@ class CVSuite:
data = np.array([data], dtype=np.float32)
for name, ins in self.models:
for name, ins, guesses in self.models:
output.insert("end", f"{name} Result:\n")
# Predict result using model instance
@ -370,9 +374,42 @@ class CVSuite:
# Prediciton result should be an array
for idx, value in enumerate(result):
if idx == 0:
guesses.append([Tree[tag.upper()].value, value])
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
print(C_DBUG, f"Guesses for {name}:", guesses)
output.configure(state="disabled")
def drawConfusionMatrix(self, event=None):
if self.axs_cm is not None:
for ays in self.axs_cm:
for graph in ays:
graph.remove()
fig, axs = plt.subplots(2, 2, num=101)
self.axs_cm = axs
for idx, ays in enumerate(axs):
for idy, graph in enumerate(ays):
# Get guesses for current model
modelnr = (idx * 2) + idy
guesses = self.models[modelnr][2]
# Convert guess array
tag_true = [guess[0] for guess in guesses ]
tag_predict = [guess[1] for guess in guesses ]
labels = [Tree(tag).name for tag in range(0, 7)]
sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g')
graph.set_title(self.models[modelnr][0])
graph.set_xlabel("Predicted")
graph.set_ylabel("Actual")
graph.set_xticklabels(labels, rotation=0)
graph.set_yticklabels(labels, rotation=0)
# exit()
def updatePath(self):
"""
@ -515,6 +552,8 @@ class CVSuite:
# Write results to CSV file
if not part_update:
self.runTest(self.log.data)
self.drawConfusionMatrix()
self.log.update()
else:
self.log.clear() # Prevent partial updates from breaking log