KNN Heatmap display

This commit is contained in:
Arne van Iterson 2023-10-12 14:01:14 +02:00
parent 4206edd60f
commit 12ad80a092

View File

@ -1,10 +1,12 @@
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import csv
from sklearn.preprocessing import MinMaxScaler
from enum import Enum
import random
from sklearn.metrics import confusion_matrix
class Tree(Enum):
ACCASIA = 0
@ -27,6 +29,7 @@ print("CSV tags: ", header)
# Get classifier tags
tags_int = []
for row in data:
tree = row.pop(0)
id = Tree[tree.upper()]
@ -34,10 +37,8 @@ for row in data:
# print("Tree name =", tree, " id =", id.value)
tags_int.append(id.value)
# Get a random number for testing
validateId = random.randint(0, len(tags_int) - 1)
# Make into numpy array cus OpenCV is dumb af
tags_len = len(tags_int)
tags_int = np.array(tags_int, dtype=np.int32)
# Transform array for normalisation
@ -63,14 +64,27 @@ for idx, col in enumerate(data[0]):
# Replace original data array
data[:, idx] = column
# # Get a random number for testing
# validateId = random.randint(0, tags_len - 1)
tag_true = []
tag_predict = []
print(tags_len)
for validateId in range(0, tags_len - 1):
# Remove object from train set
validateTag = tags_int[validateId]
validateObj =np.array([data[validateId]])
np.delete(tags_int, validateId)
np.delete(data, validateTag)
print(validateTag, validateObj)
tag_true.append(validateTag)
# print(validateTag, validateObj)
knn = cv.ml.KNearest_create()
print(tags_int)
print(data.dtype, type(data), tags_int.dtype, type(tags_int))
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
# print (data)
@ -78,7 +92,14 @@ knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
# print (validateObj)
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
tag_predict.append(results[0][0])
print( "result: {}\n".format(results) )
print( "neighbours: {}\n".format(neighbours) )
print( "distance: {}\n".format(dist) )
# print( "result: {}\n".format(results) )
# print( "neighbours: {}\n".format(neighbours) )
# print( "distance: {}\n".format(dist) )
# Create a heatmap
sns.heatmap(confusion_matrix(tag_true, tag_predict), annot=True)
plt.title( "Confusion Matrix KNN" )
plt.show()