KNNNNNNNNNNNNNNN

This commit is contained in:
Arne van Iterson 2023-10-11 13:19:29 +02:00
parent 621d976d89
commit 642450e40c
1 changed files with 84 additions and 0 deletions

View File

@ -0,0 +1,84 @@
import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import csv
from sklearn.preprocessing import MinMaxScaler
from enum import Enum
import random
class Tree(Enum):
ACCASIA = 0
BERK = 1
EIK = 2
ELS = 3
ESDOORN = 4
ES = 5
LINDE = 6
PLATAAN = 7
# Open file
file = open('./out/result-2023-10-10T15.08.36.csv', "r")
data = list(csv.reader(file, delimiter=","))
file.close()
# Remove CSV header
header = data.pop(0)
print("CSV tags: ", header)
# Get classifier tags
tags_int = []
for row in data:
tree = row.pop(0)
id = Tree[tree.upper()]
# print("Tree name =", tree, " id =", id.value)
tags_int.append(id.value)
# Get a random number for testing
validateId = random.randint(0, len(tags_int) - 1)
# Make into numpy array cus OpenCV is dumb af
tags_int = np.array(tags_int, dtype=np.int32)
# Transform array for normalisation
data = np.array(data, dtype=np.float32)
for idx, col in enumerate(data[0]):
# Get column from data
column = data[:, idx]
# Shape it to 2 dimentional
column = np.array(column).reshape(-1, 1)
# Perform Min - Max scaling
scaler = MinMaxScaler()
column = scaler.fit_transform(column)
# Reshape it back cus scaler is dumb af
column = np.array(column).reshape(len(column))
# DEBUG Print resulting column
# print("NORM", header[idx + 1], "\n", column)
# Replace original data array
data[:, idx] = column
validateTag = tags_int[validateId]
validateObj =np.array([data[validateId]])
np.delete(tags_int, validateId)
np.delete(data, validateTag)
print(validateTag, validateObj)
knn = cv.ml.KNearest_create()
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
# print (data)
# print('--------------------')
# print (validateObj)
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
print( "result: {}\n".format(results) )
print( "neighbours: {}\n".format(neighbours) )
print( "distance: {}\n".format(dist) )