it's still broken, don't sue me
This commit is contained in:
parent
5115d4c8a2
commit
9ab9837b25
127
src/helpers/test/decision_tree.py
Normal file
127
src/helpers/test/decision_tree.py
Normal file
@ -0,0 +1,127 @@
|
||||
from enum import Enum
|
||||
from sklearn.preprocessing import maxabs_scale, MaxAbsScaler
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from joblib import dump, load
|
||||
from sklearn import tree
|
||||
import csv
|
||||
import argparse
|
||||
import os
|
||||
|
||||
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
|
||||
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
|
||||
parser.add_argument('-o', '--output', help='Output model file', required=True)
|
||||
|
||||
class Tree(Enum):
|
||||
ACCASIA = 0
|
||||
BERK = 1
|
||||
EIK = 2
|
||||
ELS = 3
|
||||
ESDOORN = 4
|
||||
ES = 5
|
||||
LINDE = 6
|
||||
PLATAAN = 7
|
||||
|
||||
class CVSuiteTestTree:
|
||||
def __init__(self, model_path = None):
|
||||
self.model_path = model_path
|
||||
if self.model_path is not None:
|
||||
self.model = load(self.model_path)
|
||||
self.scaler = None
|
||||
|
||||
def trainCSV(self, path, output) -> None:
|
||||
with open(path, 'r') as file:
|
||||
reader = csv.reader(file, delimiter= ',')
|
||||
matrix = list(reader)
|
||||
|
||||
i = 0
|
||||
labels = []
|
||||
data = [[] for x in range(len(matrix)-1)]
|
||||
|
||||
# Load all but the headers
|
||||
for row in matrix[1:]:
|
||||
## append data to lists
|
||||
label = row.pop(0).upper()
|
||||
labels.append(Tree[label].value)
|
||||
|
||||
# remove ID
|
||||
row.pop(0)
|
||||
|
||||
# append all but ID and tree
|
||||
for element in row:
|
||||
data[i].append(float(element))
|
||||
i += 1
|
||||
|
||||
# normalize data
|
||||
if self.scaler is not None:
|
||||
norm = self.scaler.fit(data)
|
||||
for row in norm:
|
||||
print(len(row))
|
||||
else:
|
||||
raise EnvironmentError("No scaler found")
|
||||
|
||||
# train model
|
||||
self.train(norm, labels, output)
|
||||
|
||||
def addScaler(self, path) -> None:
|
||||
self.scaler = load(path)
|
||||
|
||||
def train(self, data, labels, output) -> None:
|
||||
print("You called the parent class, doofus")
|
||||
|
||||
def save(self, output, name) -> None:
|
||||
path = os.path.join(output, name)
|
||||
dump(self.model, path)
|
||||
pass
|
||||
|
||||
def predict(self, data) -> None | int:
|
||||
if self.model is not None:
|
||||
return self.model.predict([data])
|
||||
else:
|
||||
return None
|
||||
|
||||
class CVSuiteTestDecisionTree(CVSuiteTestTree):
|
||||
def train(self, data, labels, output) -> None:
|
||||
self.model = tree.DecisionTreeClassifier(
|
||||
class_weight=None,
|
||||
min_samples_leaf=2,
|
||||
max_depth=None,
|
||||
random_state=False,
|
||||
criterion='gini',
|
||||
splitter='best',
|
||||
ccp_alpha=0
|
||||
)
|
||||
self.model.fit(data, labels)
|
||||
self.save(output, 'decisiontree.joblib')
|
||||
|
||||
class CVSuiteTestRandomForest(CVSuiteTestTree):
|
||||
def train(self, data, labels, output) -> None:
|
||||
self.model = RandomForestClassifier(
|
||||
n_estimators=150,
|
||||
criterion='gini',
|
||||
)
|
||||
self.model.fit(data, labels)
|
||||
self.save(output, 'randomforest.joblib')
|
||||
|
||||
class CVSuiteTestExtraTrees(CVSuiteTestTree):
|
||||
def train(self, data, labels, output) -> None:
|
||||
self.model = tree.ExtraTreeClassifier()
|
||||
self.model.fit(data, labels)
|
||||
self.save(output, 'extratrees.joblib')
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
test = CVSuiteTestRandomForest()
|
||||
test.trainCSV(args.input, args.output)
|
||||
test = CVSuiteTestDecisionTree(
|
||||
"C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\models\\randomforest.joblib"
|
||||
)
|
||||
path = "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\dataset\\csv\\result-2023-10-21T09.59.24.csv"
|
||||
file = open(path, 'r')
|
||||
reader = csv.reader(file, delimiter=',')
|
||||
matrix = list(reader)
|
||||
file.close()
|
||||
|
||||
data = [float(x) for x in matrix[2][2:]]
|
||||
norm = maxabs_scale(data)
|
||||
|
||||
print(test.predict(norm))
|
Loading…
Reference in New Issue
Block a user