decision treees

This commit is contained in:
Tom Selier 2023-10-13 14:29:04 +02:00
parent 4206edd60f
commit f9374d02fa

View File

@ -1,8 +1,13 @@
from sklearn import tree
from sklearn import metrics
from sklearn import preprocessing
from sklearn.ensemble import RandomForestClassifier
# from ...helpers.treenum import Tree
from enum import Enum
import csv
import random
from matplotlib import pyplot as plt
import numpy as np
SIFT_PATH = "..\\algorithms\\data\\sift.csv"
@ -20,34 +25,49 @@ class Tree(Enum):
# [tree1_label, tree2_label]
labels = []
dialect = csv.Dialect
i = 0
done = False
test_index = random.randint(0, 102)
print(test_index)
with open(SIFT_PATH, 'r') as file:
reader = csv.reader(file, delimiter= ',')
matrix = list(reader)
data = [[] for x in range(len(matrix)-1)]
for row in matrix[1:]:
## Remove test case
if i == test_index and done == False:
done = True
data.pop(i)
continue
## append data to lists
labels.append(Tree[row[0].upper()].value)
for element in row[1:]:
data[i].append(element)
## iterator
data[i].append(float(element))
i += 1
clf = tree.DecisionTreeClassifier()
clf = clf.fit(data, labels)
# tree.plot_tree(clf)
print(Tree[matrix[test_index][0].upper()])
result = clf.predict([matrix[test_index][1:]])
print(Tree(result[0]).name)
# Werkt niet met genormaliseerde data
normalized = preprocessing.normalize(data, axis=0, norm='max')
norm = list(normalized.tolist())
actual = []
predicted = []
for i in range(75):
test_index = random.randint(1, 101)
temp_data = data.pop(test_index)
temp_label = labels.pop(test_index)
# dec_tree = tree.DecisionTreeClassifier(
# criterion='entropy',
# splitter='best')
dec_tree = RandomForestClassifier(max_depth=None)
dec_tree = dec_tree.fit(data, labels)
result = dec_tree.predict([matrix[test_index][1:]])
# normalized_list.append(temp_data)
data.append(temp_data)
labels.append(temp_label)
actual.append(temp_label)
predicted.append(result[0])
c_matrix = metrics.confusion_matrix(actual, predicted)
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=c_matrix)
cm_display.plot()
plt.show(block=False)
# print("Testdata: \t" + Tree[matrix[test_index][0].upper()].name)
# print("Predicted: \t" + Tree(result[0]).name)