This commit is contained in:
Arne van Iterson 2023-10-21 19:32:19 +02:00
commit 8af553236e

View File

@ -0,0 +1,127 @@
from enum import Enum
from sklearn.preprocessing import maxabs_scale, MaxAbsScaler
from sklearn.ensemble import RandomForestClassifier
from joblib import dump, load
from sklearn import tree
import csv
import argparse
import os
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
parser.add_argument('-o', '--output', help='Output model file', required=True)
class Tree(Enum):
ACCASIA = 0
BERK = 1
EIK = 2
ELS = 3
ESDOORN = 4
ES = 5
LINDE = 6
PLATAAN = 7
class CVSuiteTestTree:
def __init__(self, model_path = None):
self.model_path = model_path
if self.model_path is not None:
self.model = load(self.model_path)
self.scaler = None
def trainCSV(self, path, output) -> None:
with open(path, 'r') as file:
reader = csv.reader(file, delimiter= ',')
matrix = list(reader)
i = 0
labels = []
data = [[] for x in range(len(matrix)-1)]
# Load all but the headers
for row in matrix[1:]:
## append data to lists
label = row.pop(0).upper()
labels.append(Tree[label].value)
# remove ID
row.pop(0)
# append all but ID and tree
for element in row:
data[i].append(float(element))
i += 1
# normalize data
if self.scaler is not None:
norm = self.scaler.fit(data)
for row in norm:
print(len(row))
else:
raise EnvironmentError("No scaler found")
# train model
self.train(norm, labels, output)
def addScaler(self, path) -> None:
self.scaler = load(path)
def train(self, data, labels, output) -> None:
print("You called the parent class, doofus")
def save(self, output, name) -> None:
path = os.path.join(output, name)
dump(self.model, path)
pass
def predict(self, data) -> None | int:
if self.model is not None:
return self.model.predict([data])
else:
return None
class CVSuiteTestDecisionTree(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.DecisionTreeClassifier(
class_weight=None,
min_samples_leaf=2,
max_depth=None,
random_state=False,
criterion='gini',
splitter='best',
ccp_alpha=0
)
self.model.fit(data, labels)
self.save(output, 'decisiontree.joblib')
class CVSuiteTestRandomForest(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = RandomForestClassifier(
n_estimators=150,
criterion='gini',
)
self.model.fit(data, labels)
self.save(output, 'randomforest.joblib')
class CVSuiteTestExtraTrees(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.ExtraTreeClassifier()
self.model.fit(data, labels)
self.save(output, 'extratrees.joblib')
if __name__ == "__main__":
args = parser.parse_args()
test = CVSuiteTestRandomForest()
test.trainCSV(args.input, args.output)
test = CVSuiteTestDecisionTree(
"C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\models\\randomforest.joblib"
)
path = "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\dataset\\csv\\result-2023-10-21T09.59.24.csv"
file = open(path, 'r')
reader = csv.reader(file, delimiter=',')
matrix = list(reader)
file.close()
data = [float(x) for x in matrix[2][2:]]
norm = maxabs_scale(data)
print(test.predict(norm))