This commit is contained in:
Tom Selier 2023-10-11 16:14:08 +02:00
parent 621d976d89
commit 568c293fac

View File

@ -0,0 +1,53 @@
from sklearn import tree
# from ...helpers.treenum import Tree
from enum import Enum
import csv
import random
SIFT_PATH = "..\\algorithms\\data\\sift.csv"
class Tree(Enum):
ACCASIA = 0
BERK = 1
EIK = 2
ELS = 3
ESDOORN = 4
ES = 5
LINDE = 6
PLATAAN = 7
# [[tree1_data],[tree2_data]]
# [tree1_label, tree2_label]
labels = []
dialect = csv.Dialect
i = 0
done = False
test_index = random.randint(0, 102)
print(test_index)
with open(SIFT_PATH, 'r') as file:
reader = csv.reader(file, delimiter= ',')
matrix = list(reader)
data = [[] for x in range(len(matrix)-1)]
for row in matrix[1:]:
## Remove test case
if i == test_index and done == False:
done = True
data.pop(i)
continue
## append data to lists
labels.append(Tree[row[0].upper()].value)
for element in row[1:]:
data[i].append(element)
## iterator
i += 1
clf = tree.DecisionTreeClassifier()
clf = clf.fit(data, labels)
# tree.plot_tree(clf)
print(Tree[matrix[test_index][0].upper()])
result = clf.predict([matrix[test_index][1:]])
print(Tree(result[0]).name)