From 12ad80a09245b1137e8fe9add55c97c960bafbce Mon Sep 17 00:00:00 2001 From: Arne van Iterson Date: Thu, 12 Oct 2023 14:01:14 +0200 Subject: [PATCH] KNN Heatmap display --- src/experiments/knn/knn.py | 55 ++++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/src/experiments/knn/knn.py b/src/experiments/knn/knn.py index 4269d30..c5d4617 100644 --- a/src/experiments/knn/knn.py +++ b/src/experiments/knn/knn.py @@ -1,10 +1,12 @@ import cv2 as cv import numpy as np import matplotlib.pyplot as plt +import seaborn as sns import csv from sklearn.preprocessing import MinMaxScaler from enum import Enum import random +from sklearn.metrics import confusion_matrix class Tree(Enum): ACCASIA = 0 @@ -27,17 +29,16 @@ print("CSV tags: ", header) # Get classifier tags tags_int = [] + for row in data: tree = row.pop(0) id = Tree[tree.upper()] # print("Tree name =", tree, " id =", id.value) tags_int.append(id.value) - -# Get a random number for testing -validateId = random.randint(0, len(tags_int) - 1) # Make into numpy array cus OpenCV is dumb af +tags_len = len(tags_int) tags_int = np.array(tags_int, dtype=np.int32) # Transform array for normalisation @@ -63,22 +64,42 @@ for idx, col in enumerate(data[0]): # Replace original data array data[:, idx] = column -validateTag = tags_int[validateId] -validateObj =np.array([data[validateId]]) -np.delete(tags_int, validateId) -np.delete(data, validateTag) +# # Get a random number for testing +# validateId = random.randint(0, tags_len - 1) +tag_true = [] +tag_predict = [] -print(validateTag, validateObj) +print(tags_len) -knn = cv.ml.KNearest_create() -knn.train(data, cv.ml.ROW_SAMPLE, tags_int) +for validateId in range(0, tags_len - 1): + # Remove object from train set + validateTag = tags_int[validateId] + validateObj =np.array([data[validateId]]) + np.delete(tags_int, validateId) + np.delete(data, validateTag) -# print (data) -# print('--------------------') -# print (validateObj) + tag_true.append(validateTag) -ret, results, neighbours ,dist = knn.findNearest(validateObj, 3) + # print(validateTag, validateObj) -print( "result: {}\n".format(results) ) -print( "neighbours: {}\n".format(neighbours) ) -print( "distance: {}\n".format(dist) ) \ No newline at end of file + knn = cv.ml.KNearest_create() + print(tags_int) + print(data.dtype, type(data), tags_int.dtype, type(tags_int)) + knn.train(data, cv.ml.ROW_SAMPLE, tags_int) + + # print (data) + # print('--------------------') + # print (validateObj) + + ret, results, neighbours ,dist = knn.findNearest(validateObj, 3) + tag_predict.append(results[0][0]) + + # print( "result: {}\n".format(results) ) + # print( "neighbours: {}\n".format(neighbours) ) + # print( "distance: {}\n".format(dist) ) + + +# Create a heatmap +sns.heatmap(confusion_matrix(tag_true, tag_predict), annot=True) +plt.title( "Confusion Matrix KNN" ) +plt.show() \ No newline at end of file