KNN Heatmap display
This commit is contained in:
parent
4206edd60f
commit
12ad80a092
@ -1,10 +1,12 @@
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import csv
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from enum import Enum
|
||||
import random
|
||||
from sklearn.metrics import confusion_matrix
|
||||
|
||||
class Tree(Enum):
|
||||
ACCASIA = 0
|
||||
@ -27,17 +29,16 @@ print("CSV tags: ", header)
|
||||
|
||||
# Get classifier tags
|
||||
tags_int = []
|
||||
|
||||
for row in data:
|
||||
tree = row.pop(0)
|
||||
id = Tree[tree.upper()]
|
||||
|
||||
# print("Tree name =", tree, " id =", id.value)
|
||||
tags_int.append(id.value)
|
||||
|
||||
# Get a random number for testing
|
||||
validateId = random.randint(0, len(tags_int) - 1)
|
||||
|
||||
# Make into numpy array cus OpenCV is dumb af
|
||||
tags_len = len(tags_int)
|
||||
tags_int = np.array(tags_int, dtype=np.int32)
|
||||
|
||||
# Transform array for normalisation
|
||||
@ -63,22 +64,42 @@ for idx, col in enumerate(data[0]):
|
||||
# Replace original data array
|
||||
data[:, idx] = column
|
||||
|
||||
validateTag = tags_int[validateId]
|
||||
validateObj =np.array([data[validateId]])
|
||||
np.delete(tags_int, validateId)
|
||||
np.delete(data, validateTag)
|
||||
# # Get a random number for testing
|
||||
# validateId = random.randint(0, tags_len - 1)
|
||||
tag_true = []
|
||||
tag_predict = []
|
||||
|
||||
print(validateTag, validateObj)
|
||||
print(tags_len)
|
||||
|
||||
knn = cv.ml.KNearest_create()
|
||||
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
||||
for validateId in range(0, tags_len - 1):
|
||||
# Remove object from train set
|
||||
validateTag = tags_int[validateId]
|
||||
validateObj =np.array([data[validateId]])
|
||||
np.delete(tags_int, validateId)
|
||||
np.delete(data, validateTag)
|
||||
|
||||
# print (data)
|
||||
# print('--------------------')
|
||||
# print (validateObj)
|
||||
tag_true.append(validateTag)
|
||||
|
||||
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
||||
# print(validateTag, validateObj)
|
||||
|
||||
print( "result: {}\n".format(results) )
|
||||
print( "neighbours: {}\n".format(neighbours) )
|
||||
print( "distance: {}\n".format(dist) )
|
||||
knn = cv.ml.KNearest_create()
|
||||
print(tags_int)
|
||||
print(data.dtype, type(data), tags_int.dtype, type(tags_int))
|
||||
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
||||
|
||||
# print (data)
|
||||
# print('--------------------')
|
||||
# print (validateObj)
|
||||
|
||||
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
||||
tag_predict.append(results[0][0])
|
||||
|
||||
# print( "result: {}\n".format(results) )
|
||||
# print( "neighbours: {}\n".format(neighbours) )
|
||||
# print( "distance: {}\n".format(dist) )
|
||||
|
||||
|
||||
# Create a heatmap
|
||||
sns.heatmap(confusion_matrix(tag_true, tag_predict), annot=True)
|
||||
plt.title( "Confusion Matrix KNN" )
|
||||
plt.show()
|
Loading…
Reference in New Issue
Block a user