Merge branch 'main' of https://arnweb.nl/gitea/arne/EV5_Beeldherk_Bomen
This commit is contained in:
commit
4206edd60f
84
src/experiments/knn/knn.py
Normal file
84
src/experiments/knn/knn.py
Normal file
@ -0,0 +1,84 @@
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import csv
|
||||
from sklearn.preprocessing import MinMaxScaler
|
||||
from enum import Enum
|
||||
import random
|
||||
|
||||
class Tree(Enum):
|
||||
ACCASIA = 0
|
||||
BERK = 1
|
||||
EIK = 2
|
||||
ELS = 3
|
||||
ESDOORN = 4
|
||||
ES = 5
|
||||
LINDE = 6
|
||||
PLATAAN = 7
|
||||
|
||||
# Open file
|
||||
file = open('./out/result-2023-10-10T15.08.36.csv', "r")
|
||||
data = list(csv.reader(file, delimiter=","))
|
||||
file.close()
|
||||
|
||||
# Remove CSV header
|
||||
header = data.pop(0)
|
||||
print("CSV tags: ", header)
|
||||
|
||||
# Get classifier tags
|
||||
tags_int = []
|
||||
for row in data:
|
||||
tree = row.pop(0)
|
||||
id = Tree[tree.upper()]
|
||||
|
||||
# print("Tree name =", tree, " id =", id.value)
|
||||
tags_int.append(id.value)
|
||||
|
||||
# Get a random number for testing
|
||||
validateId = random.randint(0, len(tags_int) - 1)
|
||||
|
||||
# Make into numpy array cus OpenCV is dumb af
|
||||
tags_int = np.array(tags_int, dtype=np.int32)
|
||||
|
||||
# Transform array for normalisation
|
||||
data = np.array(data, dtype=np.float32)
|
||||
|
||||
for idx, col in enumerate(data[0]):
|
||||
# Get column from data
|
||||
column = data[:, idx]
|
||||
|
||||
# Shape it to 2 dimentional
|
||||
column = np.array(column).reshape(-1, 1)
|
||||
|
||||
# Perform Min - Max scaling
|
||||
scaler = MinMaxScaler()
|
||||
column = scaler.fit_transform(column)
|
||||
|
||||
# Reshape it back cus scaler is dumb af
|
||||
column = np.array(column).reshape(len(column))
|
||||
|
||||
# DEBUG Print resulting column
|
||||
# print("NORM", header[idx + 1], "\n", column)
|
||||
|
||||
# Replace original data array
|
||||
data[:, idx] = column
|
||||
|
||||
validateTag = tags_int[validateId]
|
||||
validateObj =np.array([data[validateId]])
|
||||
np.delete(tags_int, validateId)
|
||||
np.delete(data, validateTag)
|
||||
|
||||
print(validateTag, validateObj)
|
||||
|
||||
knn = cv.ml.KNearest_create()
|
||||
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
||||
|
||||
# print (data)
|
||||
# print('--------------------')
|
||||
# print (validateObj)
|
||||
|
||||
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
||||
|
||||
print( "result: {}\n".format(results) )
|
||||
print( "neighbours: {}\n".format(neighbours) )
|
||||
print( "distance: {}\n".format(dist) )
|
Loading…
Reference in New Issue
Block a user