Compare commits
No commits in common. "aed4d22199b62892c1dc1a935bb4035596682a39" and "39d30708ec5e1f53ade8516e3ef5298f18e9b2d5" have entirely different histories.
aed4d22199
...
39d30708ec
51
src/suite.py
51
src/suite.py
@ -14,7 +14,6 @@ import json
|
|||||||
# OpenCV
|
# OpenCV
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
from sklearn.metrics import confusion_matrix
|
|
||||||
from sklearn.preprocessing import (
|
from sklearn.preprocessing import (
|
||||||
MinMaxScaler,
|
MinMaxScaler,
|
||||||
StandardScaler,
|
StandardScaler,
|
||||||
@ -26,7 +25,6 @@ import joblib
|
|||||||
# GUI
|
# GUI
|
||||||
import pygubu
|
import pygubu
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import seaborn as sns
|
|
||||||
|
|
||||||
# Helpers
|
# Helpers
|
||||||
from helpers.statistics import imgStats
|
from helpers.statistics import imgStats
|
||||||
@ -87,7 +85,6 @@ class CVSuite:
|
|||||||
|
|
||||||
# Plots
|
# Plots
|
||||||
self.axs = self.createPlot(2, 2)
|
self.axs = self.createPlot(2, 2)
|
||||||
self.axs_cm = None
|
|
||||||
|
|
||||||
# UI Variables
|
# UI Variables
|
||||||
self.canny_thr1 = None
|
self.canny_thr1 = None
|
||||||
@ -136,18 +133,17 @@ class CVSuite:
|
|||||||
print(C_INFO, f"Loading model {model}")
|
print(C_INFO, f"Loading model {model}")
|
||||||
mpath = config_json["models"][model]
|
mpath = config_json["models"][model]
|
||||||
if model == "knn":
|
if model == "knn":
|
||||||
# Tuple with name, class instance and array of guesses for confusion matrix
|
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
|
||||||
self.models.append(("KNN", CVSuiteTestKNN(mpath), []))
|
|
||||||
elif model == "dectree":
|
elif model == "dectree":
|
||||||
self.models.append(
|
self.models.append(
|
||||||
("Decision Tree", CVSuiteTestDecisionTree(mpath), [])
|
("Decision Tree", CVSuiteTestDecisionTree(mpath))
|
||||||
)
|
)
|
||||||
elif model == "randforest":
|
elif model == "randforest":
|
||||||
self.models.append(
|
self.models.append(
|
||||||
("Random Forest", CVSuiteTestRandomForest(mpath), [])
|
("Random Forest", CVSuiteTestRandomForest(mpath))
|
||||||
)
|
)
|
||||||
elif model == "extratree":
|
elif model == "extratree":
|
||||||
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath), []))
|
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
C_WARN, f"Model {model} does not exist or is not supported!"
|
C_WARN, f"Model {model} does not exist or is not supported!"
|
||||||
@ -250,7 +246,7 @@ class CVSuite:
|
|||||||
self.update()
|
self.update()
|
||||||
|
|
||||||
def createPlot(self, columns, rows):
|
def createPlot(self, columns, rows):
|
||||||
fig, axs = plt.subplots(columns, rows, num=100)
|
fig, axs = plt.subplots(columns, rows)
|
||||||
return axs
|
return axs
|
||||||
|
|
||||||
def drawHist(self, image, labels, column, row):
|
def drawHist(self, image, labels, column, row):
|
||||||
@ -366,7 +362,7 @@ class CVSuite:
|
|||||||
|
|
||||||
data = np.array([data], dtype=np.float32)
|
data = np.array([data], dtype=np.float32)
|
||||||
|
|
||||||
for name, ins, guesses in self.models:
|
for name, ins in self.models:
|
||||||
output.insert("end", f"{name} Result:\n")
|
output.insert("end", f"{name} Result:\n")
|
||||||
|
|
||||||
# Predict result using model instance
|
# Predict result using model instance
|
||||||
@ -374,42 +370,9 @@ class CVSuite:
|
|||||||
|
|
||||||
# Prediciton result should be an array
|
# Prediciton result should be an array
|
||||||
for idx, value in enumerate(result):
|
for idx, value in enumerate(result):
|
||||||
if idx == 0:
|
|
||||||
guesses.append([Tree[tag.upper()].value, value])
|
|
||||||
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
|
output.insert("end", f" [{idx + 1}]\t{Tree(value).name}\n")
|
||||||
|
|
||||||
print(C_DBUG, f"Guesses for {name}:", guesses)
|
|
||||||
|
|
||||||
output.configure(state="disabled")
|
output.configure(state="disabled")
|
||||||
|
|
||||||
def drawConfusionMatrix(self, event=None):
|
|
||||||
if self.axs_cm is not None:
|
|
||||||
for ays in self.axs_cm:
|
|
||||||
for graph in ays:
|
|
||||||
graph.remove()
|
|
||||||
|
|
||||||
fig, axs = plt.subplots(2, 2, num=101)
|
|
||||||
self.axs_cm = axs
|
|
||||||
|
|
||||||
for idx, ays in enumerate(axs):
|
|
||||||
for idy, graph in enumerate(ays):
|
|
||||||
# Get guesses for current model
|
|
||||||
modelnr = (idx * 2) + idy
|
|
||||||
guesses = self.models[modelnr][2]
|
|
||||||
|
|
||||||
# Convert guess array
|
|
||||||
tag_true = [guess[0] for guess in guesses ]
|
|
||||||
tag_predict = [guess[1] for guess in guesses ]
|
|
||||||
|
|
||||||
labels = [Tree(tag).name for tag in range(0, 7)]
|
|
||||||
|
|
||||||
sns.heatmap(confusion_matrix(tag_true, tag_predict), xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='g')
|
|
||||||
graph.set_title(self.models[modelnr][0])
|
|
||||||
graph.set_xlabel("Predicted")
|
|
||||||
graph.set_ylabel("Actual")
|
|
||||||
graph.set_xticklabels(labels, rotation=0)
|
|
||||||
graph.set_yticklabels(labels, rotation=0)
|
|
||||||
# exit()
|
|
||||||
|
|
||||||
def updatePath(self):
|
def updatePath(self):
|
||||||
"""
|
"""
|
||||||
@ -552,8 +515,6 @@ class CVSuite:
|
|||||||
# Write results to CSV file
|
# Write results to CSV file
|
||||||
if not part_update:
|
if not part_update:
|
||||||
self.runTest(self.log.data)
|
self.runTest(self.log.data)
|
||||||
self.drawConfusionMatrix()
|
|
||||||
|
|
||||||
self.log.update()
|
self.log.update()
|
||||||
else:
|
else:
|
||||||
self.log.clear() # Prevent partial updates from breaking log
|
self.log.clear() # Prevent partial updates from breaking log
|
||||||
|
Loading…
Reference in New Issue
Block a user