diff --git a/src/experiments/decision_tree/decision_tree.py b/src/experiments/decision_tree/decision_tree.py index fa928ea..969dd4b 100644 --- a/src/experiments/decision_tree/decision_tree.py +++ b/src/experiments/decision_tree/decision_tree.py @@ -1,15 +1,16 @@ +from enum import Enum from sklearn import tree from sklearn import metrics from sklearn import preprocessing -from sklearn.ensemble import RandomForestClassifier -# from ...helpers.treenum import Tree -from enum import Enum -import csv -import random +import sklearn from matplotlib import pyplot as plt +import pandas as pd import numpy as np +import random +import csv SIFT_PATH = "..\\algorithms\\data\\sift.csv" +# SIFT_PATH = "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\datacsv\\result-2023-10-13T14.46.23.csv" class Tree(Enum): ACCASIA = 0 @@ -24,6 +25,27 @@ class Tree(Enum): # [[tree1_data],[tree2_data]] # [tree1_label, tree2_label] +def roc_auc_score_multiclass(actual_class, pred_class, average = "macro"): + + #creating a set of all the unique classes using the actual class list + unique_class = set(actual_class) + roc_auc_dict = {} + for per_class in unique_class: + + #creating a list of all the classes except the current class + other_class = [x for x in unique_class if x != per_class] + + #marking the current class as 1 and all other classes as 0 + new_actual_class = [0 if x in other_class else 1 for x in actual_class] + new_pred_class = [0 if x in other_class else 1 for x in pred_class] + + #using the sklearn metrics method to calculate the roc_auc_score + roc_auc = metrics.roc_auc_score(new_actual_class, new_pred_class, average = average) + roc_auc_dict[per_class] = roc_auc + + return roc_auc_dict + + labels = [] i = 0 done = False @@ -44,30 +66,65 @@ with open(SIFT_PATH, 'r') as file: normalized = preprocessing.normalize(data, axis=0, norm='max') norm = list(normalized.tolist()) -actual = [] -predicted = [] -for i in range(75): - test_index = random.randint(1, 101) - temp_data = data.pop(test_index) - temp_label = labels.pop(test_index) +steps = np.linspace(2, 20, 10, dtype=np.int64) +accuracy = [] +precision = [] +recall = [] +roc = [] - # dec_tree = tree.DecisionTreeClassifier( - # criterion='entropy', - # splitter='best') - dec_tree = RandomForestClassifier(max_depth=None) - dec_tree = dec_tree.fit(data, labels) - result = dec_tree.predict([matrix[test_index][1:]]) +for step in steps: + actual = [] + predicted = [] - # normalized_list.append(temp_data) - data.append(temp_data) - labels.append(temp_label) + for i in range(100): + test_index = random.randint(1, 101) + temp_data = data.pop(test_index) + temp_label = labels.pop(test_index) + del dec_tree - actual.append(temp_label) - predicted.append(result[0]) + dec_tree = tree.DecisionTreeClassifier( + min_samples_leaf=2, + max_depth=None, + random_state=False, + criterion='gini', + splitter='best') + dec_tree = dec_tree.fit(data, labels) + result = dec_tree.predict([matrix[test_index][1:]]) + # normalized_list.append(temp_data) + data.append(temp_data) + labels.append(temp_label) + + actual.append(temp_label) + predicted.append(result[0]) + + accuracy.append(metrics.accuracy_score(actual, predicted)) + precision.append(metrics.precision_score(actual, predicted, average='macro')) + recall.append(metrics.recall_score(actual, predicted, average='macro')) + roc.append(roc_auc_score_multiclass(actual, predicted)) + + print(step) + +# Scores +# https://www.evidentlyai.com/classification-metrics/multi-class-metrics +plt.plot(accuracy) +plt.title("Accuracy") +plt.show() +plt.plot(precision) +plt.title("Precision") +plt.show() +plt.plot(recall) +plt.title("Recall") +plt.show() +df = pd.DataFrame(roc) +plt.figure() +for i in range(7): + plt.plot(df[i], label=Tree(i).name) +plt.legend() +plt.show() + +# Confusion matrix c_matrix = metrics.confusion_matrix(actual, predicted) cm_display = metrics.ConfusionMatrixDisplay(confusion_matrix=c_matrix) cm_display.plot() -plt.show(block=False) -# print("Testdata: \t" + Tree[matrix[test_index][0].upper()].name) -# print("Predicted: \t" + Tree(result[0]).name) \ No newline at end of file +plt.show(block=False) \ No newline at end of file