diff --git a/src/helpers/test/decision_tree.py b/src/helpers/test/decision_tree.py new file mode 100644 index 0000000..9cbd08d --- /dev/null +++ b/src/helpers/test/decision_tree.py @@ -0,0 +1,127 @@ +from enum import Enum +from sklearn.preprocessing import maxabs_scale, MaxAbsScaler +from sklearn.ensemble import RandomForestClassifier +from joblib import dump, load +from sklearn import tree +import csv +import argparse +import os + +parser = argparse.ArgumentParser(prog='DecisionTree CLI') +parser.add_argument('-i', '--input', help='Input CSV file', required=True) +parser.add_argument('-o', '--output', help='Output model file', required=True) + +class Tree(Enum): + ACCASIA = 0 + BERK = 1 + EIK = 2 + ELS = 3 + ESDOORN = 4 + ES = 5 + LINDE = 6 + PLATAAN = 7 + +class CVSuiteTestTree: + def __init__(self, model_path = None): + self.model_path = model_path + if self.model_path is not None: + self.model = load(self.model_path) + self.scaler = None + + def trainCSV(self, path, output) -> None: + with open(path, 'r') as file: + reader = csv.reader(file, delimiter= ',') + matrix = list(reader) + + i = 0 + labels = [] + data = [[] for x in range(len(matrix)-1)] + + # Load all but the headers + for row in matrix[1:]: + ## append data to lists + label = row.pop(0).upper() + labels.append(Tree[label].value) + + # remove ID + row.pop(0) + + # append all but ID and tree + for element in row: + data[i].append(float(element)) + i += 1 + + # normalize data + if self.scaler is not None: + norm = self.scaler.fit(data) + for row in norm: + print(len(row)) + else: + raise EnvironmentError("No scaler found") + + # train model + self.train(norm, labels, output) + + def addScaler(self, path) -> None: + self.scaler = load(path) + + def train(self, data, labels, output) -> None: + print("You called the parent class, doofus") + + def save(self, output, name) -> None: + path = os.path.join(output, name) + dump(self.model, path) + pass + + def predict(self, data) -> None | int: + if self.model is not None: + return self.model.predict([data]) + else: + return None + +class CVSuiteTestDecisionTree(CVSuiteTestTree): + def train(self, data, labels, output) -> None: + self.model = tree.DecisionTreeClassifier( + class_weight=None, + min_samples_leaf=2, + max_depth=None, + random_state=False, + criterion='gini', + splitter='best', + ccp_alpha=0 + ) + self.model.fit(data, labels) + self.save(output, 'decisiontree.joblib') + +class CVSuiteTestRandomForest(CVSuiteTestTree): + def train(self, data, labels, output) -> None: + self.model = RandomForestClassifier( + n_estimators=150, + criterion='gini', + ) + self.model.fit(data, labels) + self.save(output, 'randomforest.joblib') + +class CVSuiteTestExtraTrees(CVSuiteTestTree): + def train(self, data, labels, output) -> None: + self.model = tree.ExtraTreeClassifier() + self.model.fit(data, labels) + self.save(output, 'extratrees.joblib') + +if __name__ == "__main__": + args = parser.parse_args() + test = CVSuiteTestRandomForest() + test.trainCSV(args.input, args.output) + test = CVSuiteTestDecisionTree( + "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\models\\randomforest.joblib" + ) + path = "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\dataset\\csv\\result-2023-10-21T09.59.24.csv" + file = open(path, 'r') + reader = csv.reader(file, delimiter=',') + matrix = list(reader) + file.close() + + data = [float(x) for x in matrix[2][2:]] + norm = maxabs_scale(data) + + print(test.predict(norm)) \ No newline at end of file