diff --git a/src/experiments/decision_tree/decision_tree.py b/src/experiments/decision_tree/decision_tree.py new file mode 100644 index 0000000..2133454 --- /dev/null +++ b/src/experiments/decision_tree/decision_tree.py @@ -0,0 +1,53 @@ +from sklearn import tree +# from ...helpers.treenum import Tree +from enum import Enum +import csv +import random + +SIFT_PATH = "..\\algorithms\\data\\sift.csv" + +class Tree(Enum): + ACCASIA = 0 + BERK = 1 + EIK = 2 + ELS = 3 + ESDOORN = 4 + ES = 5 + LINDE = 6 + PLATAAN = 7 + +# [[tree1_data],[tree2_data]] +# [tree1_label, tree2_label] + +labels = [] +dialect = csv.Dialect +i = 0 +done = False +test_index = random.randint(0, 102) +print(test_index) + +with open(SIFT_PATH, 'r') as file: + reader = csv.reader(file, delimiter= ',') + matrix = list(reader) + data = [[] for x in range(len(matrix)-1)] + for row in matrix[1:]: + ## Remove test case + if i == test_index and done == False: + done = True + data.pop(i) + continue + + ## append data to lists + labels.append(Tree[row[0].upper()].value) + for element in row[1:]: + data[i].append(element) + + ## iterator + i += 1 + +clf = tree.DecisionTreeClassifier() +clf = clf.fit(data, labels) +# tree.plot_tree(clf) +print(Tree[matrix[test_index][0].upper()]) +result = clf.predict([matrix[test_index][1:]]) +print(Tree(result[0]).name) \ No newline at end of file