Merge branch 'main' of https://arnweb.nl/gitea/arne/EV5_Beeldherk_Bomen
This commit is contained in:
commit
4206edd60f
84
src/experiments/knn/knn.py
Normal file
84
src/experiments/knn/knn.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import cv2 as cv
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import csv
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
from enum import Enum
|
||||||
|
import random
|
||||||
|
|
||||||
|
class Tree(Enum):
|
||||||
|
ACCASIA = 0
|
||||||
|
BERK = 1
|
||||||
|
EIK = 2
|
||||||
|
ELS = 3
|
||||||
|
ESDOORN = 4
|
||||||
|
ES = 5
|
||||||
|
LINDE = 6
|
||||||
|
PLATAAN = 7
|
||||||
|
|
||||||
|
# Open file
|
||||||
|
file = open('./out/result-2023-10-10T15.08.36.csv', "r")
|
||||||
|
data = list(csv.reader(file, delimiter=","))
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
# Remove CSV header
|
||||||
|
header = data.pop(0)
|
||||||
|
print("CSV tags: ", header)
|
||||||
|
|
||||||
|
# Get classifier tags
|
||||||
|
tags_int = []
|
||||||
|
for row in data:
|
||||||
|
tree = row.pop(0)
|
||||||
|
id = Tree[tree.upper()]
|
||||||
|
|
||||||
|
# print("Tree name =", tree, " id =", id.value)
|
||||||
|
tags_int.append(id.value)
|
||||||
|
|
||||||
|
# Get a random number for testing
|
||||||
|
validateId = random.randint(0, len(tags_int) - 1)
|
||||||
|
|
||||||
|
# Make into numpy array cus OpenCV is dumb af
|
||||||
|
tags_int = np.array(tags_int, dtype=np.int32)
|
||||||
|
|
||||||
|
# Transform array for normalisation
|
||||||
|
data = np.array(data, dtype=np.float32)
|
||||||
|
|
||||||
|
for idx, col in enumerate(data[0]):
|
||||||
|
# Get column from data
|
||||||
|
column = data[:, idx]
|
||||||
|
|
||||||
|
# Shape it to 2 dimentional
|
||||||
|
column = np.array(column).reshape(-1, 1)
|
||||||
|
|
||||||
|
# Perform Min - Max scaling
|
||||||
|
scaler = MinMaxScaler()
|
||||||
|
column = scaler.fit_transform(column)
|
||||||
|
|
||||||
|
# Reshape it back cus scaler is dumb af
|
||||||
|
column = np.array(column).reshape(len(column))
|
||||||
|
|
||||||
|
# DEBUG Print resulting column
|
||||||
|
# print("NORM", header[idx + 1], "\n", column)
|
||||||
|
|
||||||
|
# Replace original data array
|
||||||
|
data[:, idx] = column
|
||||||
|
|
||||||
|
validateTag = tags_int[validateId]
|
||||||
|
validateObj =np.array([data[validateId]])
|
||||||
|
np.delete(tags_int, validateId)
|
||||||
|
np.delete(data, validateTag)
|
||||||
|
|
||||||
|
print(validateTag, validateObj)
|
||||||
|
|
||||||
|
knn = cv.ml.KNearest_create()
|
||||||
|
knn.train(data, cv.ml.ROW_SAMPLE, tags_int)
|
||||||
|
|
||||||
|
# print (data)
|
||||||
|
# print('--------------------')
|
||||||
|
# print (validateObj)
|
||||||
|
|
||||||
|
ret, results, neighbours ,dist = knn.findNearest(validateObj, 3)
|
||||||
|
|
||||||
|
print( "result: {}\n".format(results) )
|
||||||
|
print( "neighbours: {}\n".format(neighbours) )
|
||||||
|
print( "distance: {}\n".format(dist) )
|
Loading…
Reference in New Issue
Block a user