This commit is contained in:
Tom Selier 2023-10-13 14:29:17 +02:00
commit 6feaab7f68

View File

@ -1,10 +1,12 @@
import cv2 as cv import cv2 as cv
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns
import csv import csv
from sklearn.preprocessing import MinMaxScaler from sklearn.preprocessing import MinMaxScaler
from enum import Enum from enum import Enum
import random import random
from sklearn.metrics import confusion_matrix
class Tree(Enum): class Tree(Enum):
ACCASIA = 0 ACCASIA = 0
@ -27,6 +29,7 @@ print("CSV tags: ", header)
# Get classifier tags # Get classifier tags
tags_int = [] tags_int = []
for row in data: for row in data:
tree = row.pop(0) tree = row.pop(0)
id = Tree[tree.upper()] id = Tree[tree.upper()]
@ -34,10 +37,8 @@ for row in data:
# print("Tree name =", tree, " id =", id.value) # print("Tree name =", tree, " id =", id.value)
tags_int.append(id.value) tags_int.append(id.value)
# Get a random number for testing
validateId = random.randint(0, len(tags_int) - 1)
# Make into numpy array cus OpenCV is dumb af # Make into numpy array cus OpenCV is dumb af
tags_len = len(tags_int)
tags_int = np.array(tags_int, dtype=np.int32) tags_int = np.array(tags_int, dtype=np.int32)
# Transform array for normalisation # Transform array for normalisation
@ -63,22 +64,42 @@ for idx, col in enumerate(data[0]):
# Replace original data array # Replace original data array
data[:, idx] = column data[:, idx] = column
validateTag = tags_int[validateId] # # Get a random number for testing
validateObj =np.array([data[validateId]]) # validateId = random.randint(0, tags_len - 1)
np.delete(tags_int, validateId) tag_true = []
np.delete(data, validateTag) tag_predict = []
print(validateTag, validateObj) print(tags_len)
knn = cv.ml.KNearest_create() for validateId in range(0, tags_len - 1):
knn.train(data, cv.ml.ROW_SAMPLE, tags_int) # Remove object from train set
validateTag = tags_int[validateId]
validateObj =np.array([data[validateId]])
np.delete(tags_int, validateId)
np.delete(data, validateTag)
# print (data) tag_true.append(validateTag)
# print('--------------------')
# print (validateObj)
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3) # print(validateTag, validateObj)
print( "result: {}\n".format(results) ) knn = cv.ml.KNearest_create()
print( "neighbours: {}\n".format(neighbours) ) print(tags_int)
print( "distance: {}\n".format(dist) ) print(data.dtype, type(data), tags_int.dtype, type(tags_int))
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
# print (data)
# print('--------------------')
# print (validateObj)
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
tag_predict.append(results[0][0])
# print( "result: {}\n".format(results) )
# print( "neighbours: {}\n".format(neighbours) )
# print( "distance: {}\n".format(dist) )
# Create a heatmap
sns.heatmap(confusion_matrix(tag_true, tag_predict), annot=True)
plt.title( "Confusion Matrix KNN" )
plt.show()