KNN Heatmap display
This commit is contained in:
parent
4206edd60f
commit
12ad80a092
@ -1,10 +1,12 @@
|
|||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
import csv
|
import csv
|
||||||
from sklearn.preprocessing import MinMaxScaler
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import random
|
import random
|
||||||
|
from sklearn.metrics import confusion_matrix
|
||||||
|
|
||||||
class Tree(Enum):
|
class Tree(Enum):
|
||||||
ACCASIA = 0
|
ACCASIA = 0
|
||||||
@ -27,6 +29,7 @@ print("CSV tags: ", header)
|
|||||||
|
|
||||||
# Get classifier tags
|
# Get classifier tags
|
||||||
tags_int = []
|
tags_int = []
|
||||||
|
|
||||||
for row in data:
|
for row in data:
|
||||||
tree = row.pop(0)
|
tree = row.pop(0)
|
||||||
id = Tree[tree.upper()]
|
id = Tree[tree.upper()]
|
||||||
@ -34,10 +37,8 @@ for row in data:
|
|||||||
# print("Tree name =", tree, " id =", id.value)
|
# print("Tree name =", tree, " id =", id.value)
|
||||||
tags_int.append(id.value)
|
tags_int.append(id.value)
|
||||||
|
|
||||||
# Get a random number for testing
|
|
||||||
validateId = random.randint(0, len(tags_int) - 1)
|
|
||||||
|
|
||||||
# Make into numpy array cus OpenCV is dumb af
|
# Make into numpy array cus OpenCV is dumb af
|
||||||
|
tags_len = len(tags_int)
|
||||||
tags_int = np.array(tags_int, dtype=np.int32)
|
tags_int = np.array(tags_int, dtype=np.int32)
|
||||||
|
|
||||||
# Transform array for normalisation
|
# Transform array for normalisation
|
||||||
@ -63,14 +64,27 @@ for idx, col in enumerate(data[0]):
|
|||||||
# Replace original data array
|
# Replace original data array
|
||||||
data[:, idx] = column
|
data[:, idx] = column
|
||||||
|
|
||||||
|
# # Get a random number for testing
|
||||||
|
# validateId = random.randint(0, tags_len - 1)
|
||||||
|
tag_true = []
|
||||||
|
tag_predict = []
|
||||||
|
|
||||||
|
print(tags_len)
|
||||||
|
|
||||||
|
for validateId in range(0, tags_len - 1):
|
||||||
|
# Remove object from train set
|
||||||
validateTag = tags_int[validateId]
|
validateTag = tags_int[validateId]
|
||||||
validateObj =np.array([data[validateId]])
|
validateObj =np.array([data[validateId]])
|
||||||
np.delete(tags_int, validateId)
|
np.delete(tags_int, validateId)
|
||||||
np.delete(data, validateTag)
|
np.delete(data, validateTag)
|
||||||
|
|
||||||
print(validateTag, validateObj)
|
tag_true.append(validateTag)
|
||||||
|
|
||||||
|
# print(validateTag, validateObj)
|
||||||
|
|
||||||
knn = cv.ml.KNearest_create()
|
knn = cv.ml.KNearest_create()
|
||||||
|
print(tags_int)
|
||||||
|
print(data.dtype, type(data), tags_int.dtype, type(tags_int))
|
||||||
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
||||||
|
|
||||||
# print (data)
|
# print (data)
|
||||||
@ -78,7 +92,14 @@ knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
|||||||
# print (validateObj)
|
# print (validateObj)
|
||||||
|
|
||||||
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
||||||
|
tag_predict.append(results[0][0])
|
||||||
|
|
||||||
print( "result: {}\n".format(results) )
|
# print( "result: {}\n".format(results) )
|
||||||
print( "neighbours: {}\n".format(neighbours) )
|
# print( "neighbours: {}\n".format(neighbours) )
|
||||||
print( "distance: {}\n".format(dist) )
|
# print( "distance: {}\n".format(dist) )
|
||||||
|
|
||||||
|
|
||||||
|
# Create a heatmap
|
||||||
|
sns.heatmap(confusion_matrix(tag_true, tag_predict), annot=True)
|
||||||
|
plt.title( "Confusion Matrix KNN" )
|
||||||
|
plt.show()
|
Loading…
Reference in New Issue
Block a user