EV5_Beeldherk_Bomen/src/helpers/test/decision_tree.py

117 lines
3.8 KiB
Python

from enum import Enum
from sklearn.ensemble import RandomForestClassifier
from joblib import dump, load
from sklearn import tree
import numpy as np
import csv
import argparse
import os
from ..tags import Tree
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
parser.add_argument('-o', '--output', help='Output model folder', required=True)
parser.add_argument('-m', '--model', help='Chosen model (\'dectree\', \'randforest\' or \'extratree\')', required=True)
parser.add_argument('-s', '--scaler', help='Scaler preprocesser', required=True)
class CVSuiteTestTree:
def __init__(self, model_path = None):
self.model_path = model_path
if self.model_path is not None:
self.model = load(self.model_path)
self.scaler = None
def trainCSV(self, path, output) -> None:
with open(path, 'r') as file:
reader = csv.reader(file, delimiter= ',')
matrix = list(reader)
labels = []
data = [[] for x in range(len(matrix)-1)]
# Load all but the headers
for ridx, row in enumerate(matrix[1:]):
## append data to lists
label = row.pop(0).upper()
labels.append(Tree[label].value)
# remove ID
row.pop(0)
# append all but ID and tree
for idx, element in enumerate(row):
value = self.scaler[idx].transform([[float(element)]])
data[ridx].append(value[0][0])
# normalize data has been included in code above :D
#TODO: Check if data is normalized correctly
# train model
self.train(data, labels, output)
def addScaler(self, path) -> None:
self.scaler = load(path)
if self.scaler is None:
print("Scaler failed to load!")
exit()
def train(self, data, labels, output) -> None:
print("You called the parent class, doofus")
def save(self, output, name) -> None:
path = os.path.join(output, name)
dump(self.model, path)
pass
def predict(self, data) -> None | int:
if self.model is not None:
return self.model.predict(data)
else:
return None
class CVSuiteTestDecisionTree(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.DecisionTreeClassifier(
class_weight=None,
min_samples_leaf=2,
max_depth=None,
random_state=False,
criterion='gini',
splitter='best',
ccp_alpha=0
)
self.model.fit(data, labels)
self.save(output, 'decisiontree.joblib')
class CVSuiteTestRandomForest(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = RandomForestClassifier(
n_estimators=150,
criterion='gini',
)
self.model.fit(data, labels)
self.save(output, 'randomforest.joblib')
class CVSuiteTestExtraTrees(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.ExtraTreeClassifier()
self.model.fit(data, labels)
self.save(output, 'extratrees.joblib')
if __name__ == "__main__":
args = parser.parse_args()
if args.model == 'dectree':
test = CVSuiteTestDecisionTree()
elif args.model == 'randforest':
test = CVSuiteTestRandomForest()
elif args.model == 'extratree':
test = CVSuiteTestExtraTrees()
else:
print("Model not found!")
exit()
test.addScaler(args.scaler)
test.trainCSV(args.input, args.output)