diff --git a/src/experiments/knn/knn.py b/src/experiments/knn/knn.py new file mode 100644 index 0000000..4269d30 --- /dev/null +++ b/src/experiments/knn/knn.py @@ -0,0 +1,84 @@ +import cv2 as cv +import numpy as np +import matplotlib.pyplot as plt +import csv +from sklearn.preprocessing import MinMaxScaler +from enum import Enum +import random + +class Tree(Enum): + ACCASIA = 0 + BERK = 1 + EIK = 2 + ELS = 3 + ESDOORN = 4 + ES = 5 + LINDE = 6 + PLATAAN = 7 + +# Open file +file = open('./out/result-2023-10-10T15.08.36.csv', "r") +data = list(csv.reader(file, delimiter=",")) +file.close() + +# Remove CSV header +header = data.pop(0) +print("CSV tags: ", header) + +# Get classifier tags +tags_int = [] +for row in data: + tree = row.pop(0) + id = Tree[tree.upper()] + + # print("Tree name =", tree, " id =", id.value) + tags_int.append(id.value) + +# Get a random number for testing +validateId = random.randint(0, len(tags_int) - 1) + +# Make into numpy array cus OpenCV is dumb af +tags_int = np.array(tags_int, dtype=np.int32) + +# Transform array for normalisation +data = np.array(data, dtype=np.float32) + +for idx, col in enumerate(data[0]): + # Get column from data + column = data[:, idx] + + # Shape it to 2 dimentional + column = np.array(column).reshape(-1, 1) + + # Perform Min - Max scaling + scaler = MinMaxScaler() + column = scaler.fit_transform(column) + + # Reshape it back cus scaler is dumb af + column = np.array(column).reshape(len(column)) + + # DEBUG Print resulting column + # print("NORM", header[idx + 1], "\n", column) + + # Replace original data array + data[:, idx] = column + +validateTag = tags_int[validateId] +validateObj =np.array([data[validateId]]) +np.delete(tags_int, validateId) +np.delete(data, validateTag) + +print(validateTag, validateObj) + +knn = cv.ml.KNearest_create() +knn.train(data, cv.ml.ROW_SAMPLE, tags_int) + +# print (data) +# print('--------------------') +# print (validateObj) + +ret, results, neighbours ,dist = knn.findNearest(validateObj, 3) + +print( "result: {}\n".format(results) ) +print( "neighbours: {}\n".format(neighbours) ) +print( "distance: {}\n".format(dist) ) \ No newline at end of file