Compare commits
2 Commits
39d30708ec
...
aed4d22199
Author | SHA1 | Date | |
---|---|---|---|
aed4d22199 | |||
ad1892112f |
51
src/suite.py
51
src/suite.py
@ -14,6 +14,7 @@ import json
|
|||||||
# OpenCV
|
# OpenCV
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
from sklearn.preprocessing import (
|
from sklearn.preprocessing import (
|
||||||
MinMaxScaler,
|
MinMaxScaler,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
@ -25,6 +26,7 @@ import joblib
|
|||||||
# GUI
|
# GUI
|
||||||
import pygubu
|
import pygubu
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
from helpers.statistics import imgStats
|
from helpers.statistics import imgStats
|
||||||
@ -85,6 +87,7 @@ class CVSuite:
|
|||||||
|
|
||||||
# Plots
|
# Plots
|
||||||
self.axs = self.createPlot(2, 2)
|
self.axs = self.createPlot(2, 2)
|
||||||
|
self.axs_cm = None
|
||||||
|
|
||||||
# UI Variables
|
# UI Variables
|
||||||
self.canny_thr1 = None
|
self.canny_thr1 = None
|
||||||
@ -133,17 +136,18 @@ class CVSuite:
|
|||||||
print(C_INFO, f"Loading model {model}")
|
print(C_INFO, f"Loading model {model}")
|
||||||
mpath = config_json["models"][model]
|
mpath = config_json["models"][model]
|
||||||
if model == "knn":
|
if model == "knn":
|
||||||
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
|
# Tuple with name, class instance and array of guesses for confusion matrix
|
||||||
|
self.models.append(("KNN", CVSuiteTestKNN(mpath), []))
|
||||||
elif model == "dectree":
|
elif model == "dectree":
|
||||||
self.models.append(
|
self.models.append(
|
||||||
("Decision Tree", CVSuiteTestDecisionTree(mpath))
|
("Decision Tree", CVSuiteTestDecisionTree(mpath), [])
|
||||||
)
|
)
|
||||||
elif model == "randforest":
|
elif model == "randforest":
|
||||||
self.models.append(
|
self.models.append(
|
||||||
("Random Forest", CVSuiteTestRandomForest(mpath))
|
("Random Forest", CVSuiteTestRandomForest(mpath), [])
|
||||||
)
|
)
|
||||||
elif model == "extratree":
|
elif model == "extratree":
|
||||||
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
|
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), []))
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
C_WARN, f"Model {model} does not exist or is not supported!"
|
C_WARN, f"Model {model} does not exist or is not supported!"
|
||||||
@ -246,7 +250,7 @@ class CVSuite:
|
|||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def createPlot(self, columns, rows):
|
def createPlot(self, columns, rows):
|
||||||
fig, axs = plt.subplots(columns, rows)
|
fig, axs = plt.subplots(columns, rows, num=100)
|
||||||
return axs
|
return axs
|
||||||
|
|
||||||
def drawHist(self, image, labels, column, row):
|
def drawHist(self, image, labels, column, row):
|
||||||
@ -362,7 +366,7 @@ class CVSuite:
|
|||||||
|
|
||||||
data = np.array([data], dtype=np.float32)
|
data = np.array([data], dtype=np.float32)
|
||||||
|
|
||||||
for name, ins in self.models:
|
for name, ins, guesses in self.models:
|
||||||
output.insert("end", f"{name} Result:\n")
|
output.insert("end", f"{name} Result:\n")
|
||||||
|
|
||||||
# Predict result using model instance
|
# Predict result using model instance
|
||||||
@ -370,9 +374,42 @@ class CVSuite:
|
|||||||
|
|
||||||
# Prediciton result should be an array
|
# Prediciton result should be an array
|
||||||
for idx, value in enumerate(result):
|
for idx, value in enumerate(result):
|
||||||
|
if idx == 0:
|
||||||
|
guesses.append([Tree[tag.upper()].value, value])
|
||||||
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
|
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
|
||||||
|
|
||||||
|
print(C_DBUG, f"Guesses for {name}:", guesses)
|
||||||
|
|
||||||
output.configure(state="disabled")
|
output.configure(state="disabled")
|
||||||
|
|
||||||
|
def drawConfusionMatrix(self, event=None):
|
||||||
|
if self.axs_cm is not None:
|
||||||
|
for ays in self.axs_cm:
|
||||||
|
for graph in ays:
|
||||||
|
graph.remove()
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(2, 2, num=101)
|
||||||
|
self.axs_cm = axs
|
||||||
|
|
||||||
|
for idx, ays in enumerate(axs):
|
||||||
|
for idy, graph in enumerate(ays):
|
||||||
|
# Get guesses for current model
|
||||||
|
modelnr = (idx * 2) + idy
|
||||||
|
guesses = self.models[modelnr][2]
|
||||||
|
|
||||||
|
# Convert guess array
|
||||||
|
tag_true = [guess[0] for guess in guesses ]
|
||||||
|
tag_predict = [guess[1] for guess in guesses ]
|
||||||
|
|
||||||
|
labels = [Tree(tag).name for tag in range(0, 7)]
|
||||||
|
|
||||||
|
sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g')
|
||||||
|
graph.set_title(self.models[modelnr][0])
|
||||||
|
graph.set_xlabel("Predicted")
|
||||||
|
graph.set_ylabel("Actual")
|
||||||
|
graph.set_xticklabels(labels, rotation=0)
|
||||||
|
graph.set_yticklabels(labels, rotation=0)
|
||||||
|
# exit()
|
||||||
|
|
||||||
def updatePath(self):
|
def updatePath(self):
|
||||||
"""
|
"""
|
||||||
@ -515,6 +552,8 @@ class CVSuite:
|
|||||||
# Write results to CSV file
|
# Write results to CSV file
|
||||||
if not part_update:
|
if not part_update:
|
||||||
self.runTest(self.log.data)
|
self.runTest(self.log.data)
|
||||||
|
self.drawConfusionMatrix()
|
||||||
|
|
||||||
self.log.update()
|
self.log.update()
|
||||||
else:
|
else:
|
||||||
self.log.clear() # Prevent partial updates from breaking log
|
self.log.clear() # Prevent partial updates from breaking log
|
||||||
|
Loading…
Reference in New Issue
Block a user