117 lines
3.8 KiB
Python
117 lines
3.8 KiB
Python
from enum import Enum
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
from joblib import dump, load
|
|
from sklearn import tree
|
|
import numpy as np
|
|
import csv
|
|
import argparse
|
|
import os
|
|
|
|
from ..tags import Tree
|
|
|
|
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
|
|
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
|
|
parser.add_argument('-o', '--output', help='Output model folder', required=True)
|
|
parser.add_argument('-m', '--model', help='Chosen model (\'dectree\', \'randforest\' or \'extratree\')', required=True)
|
|
parser.add_argument('-s', '--scaler', help='Scaler preprocesser', required=True)
|
|
|
|
class CVSuiteTestTree:
|
|
def __init__(self, model_path = None):
|
|
self.model_path = model_path
|
|
if self.model_path is not None:
|
|
self.model = load(self.model_path)
|
|
self.scaler = None
|
|
|
|
def trainCSV(self, path, output) -> None:
|
|
with open(path, 'r') as file:
|
|
reader = csv.reader(file, delimiter= ',')
|
|
matrix = list(reader)
|
|
|
|
labels = []
|
|
data = [[] for x in range(len(matrix)-1)]
|
|
|
|
# Load all but the headers
|
|
for ridx, row in enumerate(matrix[1:]):
|
|
## append data to lists
|
|
label = row.pop(0).upper()
|
|
labels.append(Tree[label].value)
|
|
|
|
# remove ID
|
|
row.pop(0)
|
|
|
|
# append all but ID and tree
|
|
for idx, element in enumerate(row):
|
|
value = self.scaler[idx].transform([[float(element)]])
|
|
data[ridx].append(value[0][0])
|
|
|
|
# normalize data has been included in code above :D
|
|
#TODO: Check if data is normalized correctly
|
|
|
|
# train model
|
|
self.train(data, labels, output)
|
|
|
|
def addScaler(self, path) -> None:
|
|
self.scaler = load(path)
|
|
|
|
if self.scaler is None:
|
|
print("Scaler failed to load!")
|
|
exit()
|
|
|
|
def train(self, data, labels, output) -> None:
|
|
print("You called the parent class, doofus")
|
|
|
|
def save(self, output, name) -> None:
|
|
path = os.path.join(output, name)
|
|
dump(self.model, path)
|
|
pass
|
|
|
|
def predict(self, data) -> None | int:
|
|
if self.model is not None:
|
|
return self.model.predict(data)
|
|
else:
|
|
return None
|
|
|
|
class CVSuiteTestDecisionTree(CVSuiteTestTree):
|
|
def train(self, data, labels, output) -> None:
|
|
self.model = tree.DecisionTreeClassifier(
|
|
class_weight=None,
|
|
min_samples_leaf=2,
|
|
max_depth=None,
|
|
random_state=False,
|
|
criterion='gini',
|
|
splitter='best',
|
|
ccp_alpha=0
|
|
)
|
|
self.model.fit(data, labels)
|
|
self.save(output, 'decisiontree.joblib')
|
|
|
|
class CVSuiteTestRandomForest(CVSuiteTestTree):
|
|
def train(self, data, labels, output) -> None:
|
|
self.model = RandomForestClassifier(
|
|
n_estimators=150,
|
|
criterion='gini',
|
|
)
|
|
self.model.fit(data, labels)
|
|
self.save(output, 'randomforest.joblib')
|
|
|
|
class CVSuiteTestExtraTrees(CVSuiteTestTree):
|
|
def train(self, data, labels, output) -> None:
|
|
self.model = tree.ExtraTreeClassifier()
|
|
self.model.fit(data, labels)
|
|
self.save(output, 'extratrees.joblib')
|
|
|
|
if __name__ == "__main__":
|
|
args = parser.parse_args()
|
|
|
|
if args.model == 'dectree':
|
|
test = CVSuiteTestDecisionTree()
|
|
elif args.model == 'randforest':
|
|
test = CVSuiteTestRandomForest()
|
|
elif args.model == 'extratree':
|
|
test = CVSuiteTestExtraTrees()
|
|
else:
|
|
print("Model not found!")
|
|
exit()
|
|
|
|
test.addScaler(args.scaler)
|
|
test.trainCSV(args.input, args.output) |