IT WORKS
This commit is contained in:
parent
621d976d89
commit
568c293fac
53
src/experiments/decision_tree/decision_tree.py
Normal file
53
src/experiments/decision_tree/decision_tree.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
from sklearn import tree
|
||||||
|
# from ...helpers.treenum import Tree
|
||||||
|
from enum import Enum
|
||||||
|
import csv
|
||||||
|
import random
|
||||||
|
|
||||||
|
SIFT_PATH = "..\\algorithms\\data\\sift.csv"
|
||||||
|
|
||||||
|
class Tree(Enum):
|
||||||
|
ACCASIA = 0
|
||||||
|
BERK = 1
|
||||||
|
EIK = 2
|
||||||
|
ELS = 3
|
||||||
|
ESDOORN = 4
|
||||||
|
ES = 5
|
||||||
|
LINDE = 6
|
||||||
|
PLATAAN = 7
|
||||||
|
|
||||||
|
# [[tree1_data],[tree2_data]]
|
||||||
|
# [tree1_label, tree2_label]
|
||||||
|
|
||||||
|
labels = []
|
||||||
|
dialect = csv.Dialect
|
||||||
|
i = 0
|
||||||
|
done = False
|
||||||
|
test_index = random.randint(0, 102)
|
||||||
|
print(test_index)
|
||||||
|
|
||||||
|
with open(SIFT_PATH, 'r') as file:
|
||||||
|
reader = csv.reader(file, delimiter= ',')
|
||||||
|
matrix = list(reader)
|
||||||
|
data = [[] for x in range(len(matrix)-1)]
|
||||||
|
for row in matrix[1:]:
|
||||||
|
## Remove test case
|
||||||
|
if i == test_index and done == False:
|
||||||
|
done = True
|
||||||
|
data.pop(i)
|
||||||
|
continue
|
||||||
|
|
||||||
|
## append data to lists
|
||||||
|
labels.append(Tree[row[0].upper()].value)
|
||||||
|
for element in row[1:]:
|
||||||
|
data[i].append(element)
|
||||||
|
|
||||||
|
## iterator
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
clf = tree.DecisionTreeClassifier()
|
||||||
|
clf = clf.fit(data, labels)
|
||||||
|
# tree.plot_tree(clf)
|
||||||
|
print(Tree[matrix[test_index][0].upper()])
|
||||||
|
result = clf.predict([matrix[test_index][1:]])
|
||||||
|
print(Tree(result[0]).name)
|
Loading…
Reference in New Issue
Block a user