Compare commits
2 Commits
12ad80a092
...
6feaab7f68
Author | SHA1 | Date | |
---|---|---|---|
6feaab7f68 | |||
f9374d02fa |
@ -1,8 +1,13 @@
|
||||
from sklearn import tree
|
||||
from sklearn import metrics
|
||||
from sklearn import preprocessing
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
# from ...helpers.treenum import Tree
|
||||
from enum import Enum
|
||||
import csv
|
||||
import random
|
||||
from matplotlib import pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
SIFT_PATH = "..\\algorithms\\data\\sift.csv"
|
||||
|
||||
@ -20,34 +25,49 @@ class Tree(Enum):
|
||||
# [tree1_label, tree2_label]
|
||||
|
||||
labels = []
|
||||
dialect = csv.Dialect
|
||||
i = 0
|
||||
done = False
|
||||
test_index = random.randint(0, 102)
|
||||
print(test_index)
|
||||
|
||||
with open(SIFT_PATH, 'r') as file:
|
||||
reader = csv.reader(file, delimiter= ',')
|
||||
matrix = list(reader)
|
||||
|
||||
data = [[] for x in range(len(matrix)-1)]
|
||||
for row in matrix[1:]:
|
||||
## Remove test case
|
||||
if i == test_index and done == False:
|
||||
done = True
|
||||
data.pop(i)
|
||||
continue
|
||||
|
||||
## append data to lists
|
||||
labels.append(Tree[row[0].upper()].value)
|
||||
for element in row[1:]:
|
||||
data[i].append(element)
|
||||
|
||||
## iterator
|
||||
data[i].append(float(element))
|
||||
i += 1
|
||||
|
||||
clf = tree.DecisionTreeClassifier()
|
||||
clf = clf.fit(data, labels)
|
||||
# tree.plot_tree(clf)
|
||||
print(Tree[matrix[test_index][0].upper()])
|
||||
result = clf.predict([matrix[test_index][1:]])
|
||||
print(Tree(result[0]).name)
|
||||
# Werkt niet met genormaliseerde data
|
||||
normalized = preprocessing.normalize(data, axis=0, norm='max')
|
||||
norm = list(normalized.tolist())
|
||||
|
||||
actual = []
|
||||
predicted = []
|
||||
for i in range(75):
|
||||
test_index = random.randint(1, 101)
|
||||
temp_data = data.pop(test_index)
|
||||
temp_label = labels.pop(test_index)
|
||||
|
||||
# dec_tree = tree.DecisionTreeClassifier(
|
||||
# criterion='entropy',
|
||||
# splitter='best')
|
||||
dec_tree = RandomForestClassifier(max_depth=None)
|
||||
dec_tree = dec_tree.fit(data, labels)
|
||||
result = dec_tree.predict([matrix[test_index][1:]])
|
||||
|
||||
# normalized_list.append(temp_data)
|
||||
data.append(temp_data)
|
||||
labels.append(temp_label)
|
||||
|
||||
actual.append(temp_label)
|
||||
predicted.append(result[0])
|
||||
|
||||
c_matrix = metrics.confusion_matrix(actual, predicted)
|
||||
cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=c_matrix)
|
||||
cm_display.plot()
|
||||
plt.show(block=False)
|
||||
# print("Testdata: \t" + Tree[matrix[test_index][0].upper()].name)
|
||||
# print("Predicted: \t" + Tree(result[0]).name)
|
Loading…
Reference in New Issue
Block a user