Compare commits

...

2 Commits

2 changed files with 43 additions and 20 deletions

View File

@ -52,7 +52,7 @@ class CVSuiteTestKNN:
for row in data:
tree = row.pop(0)
# photoId = row.pop(1)
photoId = row.pop(1)
id = Tree[tree.upper()]
# print("Tree name =", tree, " id =", id.value)
@ -104,14 +104,16 @@ class CVSuiteTestKNN:
if self.trained:
raise EnvironmentError("Model already trained!")
else:
print(data)
print(data.shape)
self.knn.train(data, cv.ml.ROW_SAMPLE, tags)
# Save it
now = datetime.datetime.now()
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
def predict(self, data):
return self.knn.predict(data)
def predict(self, data, nr = 3):
return self.knn.findNearest(data, nr)
if __name__ == "__main__":
args = parser.parse_args()

View File

@ -14,7 +14,12 @@ import json
# OpenCV
import numpy as np
import cv2
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, MaxAbsScaler
from sklearn.preprocessing import (
MinMaxScaler,
StandardScaler,
RobustScaler,
MaxAbsScaler,
)
import joblib
# GUI
@ -26,9 +31,11 @@ from helpers.statistics import imgStats
from helpers.logger import CVSuiteLogger, C_DBUG, C_WARN
from helpers.canvas import CVSuiteCanvas
from helpers.sift import getSiftData
from helpers.tags import Tree
# Tests
from helpers.test.knn import CVSuiteTestKNN
# from helpers.test.decision_tree import CVSuiteTestDecisionTree
## UI config load
@ -41,6 +48,7 @@ CONFIG_PATH = "./src/config/config.json"
config_file = open(CONFIG_PATH, encoding="utf-8")
config_json = json.load(config_file)
## UI class setup
class CVSuite:
def __init__(self, master=None):
@ -54,7 +62,7 @@ class CVSuite:
# Canvas for output images
self.canvas = CVSuiteCanvas(builder.get_object("output_canvas"))
# Log file
self.log = CVSuiteLogger(config_json["out"]["log"])
@ -97,7 +105,6 @@ class CVSuite:
# Attempt to load scaler
if config_json["scaler"] != "":
self.scaler = joblib.load(config_json["scaler"])
print(self.scaler)
else:
self.scaler = None
@ -108,7 +115,9 @@ class CVSuite:
self.test_knn = None
if config_json["models"]["dectree"] != "":
self.test_dectree = CVSuiteTestDecisionTree(config_json["models"]["dectree"])
self.test_dectree = CVSuiteTestDecisionTree(
config_json["models"]["dectree"]
)
else:
self.test_dectree = None
@ -302,27 +311,40 @@ class CVSuite:
output = self.builder.get_object("testdata")
output.configure(state="normal")
output.delete(1.0, "end")
# Normalise data
# Remove tag and photoId
tag = data.pop(0)
photoId = data.pop(1)
# Add actual name
output.insert("end", f"Actual:\n\t{tag.upper()}\n")
# Normalise data using loaded scalers
for idx, value in enumerate(data):
data[idx] = self.scaler[idx].transform(np.array(value).reshape(-1, 1))
d = np.array(value)
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[0][0]
print(data)
data = np.array([data], dtype=np.float32)
if self.test_knn is not None:
# Do knn test
output.insert("end", "KNN Result:\n")
pass
ret, results, neighbours ,dist = self.test_knn.predict(data)
for idx, res_id in enumerate(neighbours[0]):
output.insert("end", f" {idx}:\t{Tree(res_id).name}\n")
print(C_DBUG, "KNN Result:")
print("\t\tresult: \t{}".format(results) )
print("\t\tneighbours:\t{}".format(neighbours) )
print("\t\tdistance:\t{}".format(dist) )
else:
print(C_WARN, "KNN Model not configured!")
if self.test_dectree is not None:
print(self.test_dectree.predict(data))
output.insert("end", "Decision Tree Result:\n")
pass
else:
print(C_WARN, "Decison Tree Model not configured!")
@ -367,11 +389,11 @@ class CVSuite:
print("Full update forced!")
if self.updatePath():
print(C_DBUG, F"Processing {self.img_name}")
self.mainwindow.title(F"{TITLE} - {self.img_name}")
print(C_DBUG, f"Processing {self.img_name}")
self.mainwindow.title(f"{TITLE} - {self.img_name}")
self.log.add("Tree", self.img_name.split("_")[0])
self.log.add("ID", self.img_name.split("_")[1].split('.')[0])
self.log.add("ID", self.img_name.split("_")[1].split(".")[0])
# Get all user vars
ct1 = self.canny_thr1.get()
@ -467,11 +489,9 @@ class CVSuite:
self.log.add("SIFT total response", siftData[5])
self.log.add("SIFT average response", siftData[6])
# Run tests
self.runTest(self.log.data)
# Write results to CSV file
if not part_update:
self.runTest(self.log.data)
self.log.update()
else:
self.log.clear() # Prevent partial updates from breaking log
@ -480,6 +500,7 @@ class CVSuite:
plt.show(block=False) ## Graphs
self.canvas.draw(size) ## Images
if __name__ == "__main__":
app = CVSuite()
app.run()