Merge branch 'main' of https://arnweb.nl/gitea/arne/EV5_Beeldherk_Bomen
This commit is contained in:
commit
8af553236e
127
src/helpers/test/decision_tree.py
Normal file
127
src/helpers/test/decision_tree.py
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from sklearn.preprocessing import maxabs_scale, MaxAbsScaler
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from joblib import dump, load
|
||||||
|
from sklearn import tree
|
||||||
|
import csv
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
|
||||||
|
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
|
||||||
|
parser.add_argument('-o', '--output', help='Output model file', required=True)
|
||||||
|
|
||||||
|
class Tree(Enum):
|
||||||
|
ACCASIA = 0
|
||||||
|
BERK = 1
|
||||||
|
EIK = 2
|
||||||
|
ELS = 3
|
||||||
|
ESDOORN = 4
|
||||||
|
ES = 5
|
||||||
|
LINDE = 6
|
||||||
|
PLATAAN = 7
|
||||||
|
|
||||||
|
class CVSuiteTestTree:
|
||||||
|
def __init__(self, model_path = None):
|
||||||
|
self.model_path = model_path
|
||||||
|
if self.model_path is not None:
|
||||||
|
self.model = load(self.model_path)
|
||||||
|
self.scaler = None
|
||||||
|
|
||||||
|
def trainCSV(self, path, output) -> None:
|
||||||
|
with open(path, 'r') as file:
|
||||||
|
reader = csv.reader(file, delimiter= ',')
|
||||||
|
matrix = list(reader)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
labels = []
|
||||||
|
data = [[] for x in range(len(matrix)-1)]
|
||||||
|
|
||||||
|
# Load all but the headers
|
||||||
|
for row in matrix[1:]:
|
||||||
|
## append data to lists
|
||||||
|
label = row.pop(0).upper()
|
||||||
|
labels.append(Tree[label].value)
|
||||||
|
|
||||||
|
# remove ID
|
||||||
|
row.pop(0)
|
||||||
|
|
||||||
|
# append all but ID and tree
|
||||||
|
for element in row:
|
||||||
|
data[i].append(float(element))
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# normalize data
|
||||||
|
if self.scaler is not None:
|
||||||
|
norm = self.scaler.fit(data)
|
||||||
|
for row in norm:
|
||||||
|
print(len(row))
|
||||||
|
else:
|
||||||
|
raise EnvironmentError("No scaler found")
|
||||||
|
|
||||||
|
# train model
|
||||||
|
self.train(norm, labels, output)
|
||||||
|
|
||||||
|
def addScaler(self, path) -> None:
|
||||||
|
self.scaler = load(path)
|
||||||
|
|
||||||
|
def train(self, data, labels, output) -> None:
|
||||||
|
print("You called the parent class, doofus")
|
||||||
|
|
||||||
|
def save(self, output, name) -> None:
|
||||||
|
path = os.path.join(output, name)
|
||||||
|
dump(self.model, path)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict(self, data) -> None | int:
|
||||||
|
if self.model is not None:
|
||||||
|
return self.model.predict([data])
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class CVSuiteTestDecisionTree(CVSuiteTestTree):
|
||||||
|
def train(self, data, labels, output) -> None:
|
||||||
|
self.model = tree.DecisionTreeClassifier(
|
||||||
|
class_weight=None,
|
||||||
|
min_samples_leaf=2,
|
||||||
|
max_depth=None,
|
||||||
|
random_state=False,
|
||||||
|
criterion='gini',
|
||||||
|
splitter='best',
|
||||||
|
ccp_alpha=0
|
||||||
|
)
|
||||||
|
self.model.fit(data, labels)
|
||||||
|
self.save(output, 'decisiontree.joblib')
|
||||||
|
|
||||||
|
class CVSuiteTestRandomForest(CVSuiteTestTree):
|
||||||
|
def train(self, data, labels, output) -> None:
|
||||||
|
self.model = RandomForestClassifier(
|
||||||
|
n_estimators=150,
|
||||||
|
criterion='gini',
|
||||||
|
)
|
||||||
|
self.model.fit(data, labels)
|
||||||
|
self.save(output, 'randomforest.joblib')
|
||||||
|
|
||||||
|
class CVSuiteTestExtraTrees(CVSuiteTestTree):
|
||||||
|
def train(self, data, labels, output) -> None:
|
||||||
|
self.model = tree.ExtraTreeClassifier()
|
||||||
|
self.model.fit(data, labels)
|
||||||
|
self.save(output, 'extratrees.joblib')
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
test = CVSuiteTestRandomForest()
|
||||||
|
test.trainCSV(args.input, args.output)
|
||||||
|
test = CVSuiteTestDecisionTree(
|
||||||
|
"C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\models\\randomforest.joblib"
|
||||||
|
)
|
||||||
|
path = "C:\\Users\\Tom\\Desktop\\Files\\Repositories\\EV5_Beeldherk_Bomen\\dataset\\csv\\result-2023-10-21T09.59.24.csv"
|
||||||
|
file = open(path, 'r')
|
||||||
|
reader = csv.reader(file, delimiter=',')
|
||||||
|
matrix = list(reader)
|
||||||
|
file.close()
|
||||||
|
|
||||||
|
data = [float(x) for x in matrix[2][2:]]
|
||||||
|
norm = maxabs_scale(data)
|
||||||
|
|
||||||
|
print(test.predict(norm))
|
Loading…
Reference in New Issue
Block a user