Added other trees to suite

This commit is contained in:
Arne van Iterson 2023-10-22 15:12:03 +02:00
parent 4dcd021490
commit 58e521cb6e
3 changed files with 72 additions and 52 deletions

View File

@ -11,19 +11,25 @@ import sys
try: try:
# Perform relative import if included from CVSuite # Perform relative import if included from CVSuite
from ..tags import Tree from ..tags import Tree
from ..logger import C_DONE from ..logger import C_DONE, C_ERR
except ImportError: except ImportError:
# This solution is hot garbage but I refuse to spend any more time on it # This solution is hot garbage but I refuse to spend any more time on it
directory = os.path.dirname(os.path.realpath(__file__)) directory = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(directory, '..')) sys.path.append(os.path.join(directory, ".."))
from tags import Tree from tags import Tree
from logger import C_DONE from logger import C_DONE
parser = argparse.ArgumentParser(prog='DecisionTree CLI') parser = argparse.ArgumentParser(prog="DecisionTree CLI")
parser.add_argument('-i', '--input', help='Input CSV file', required=True) parser.add_argument("-i", "--input", help="Input CSV file", required=True)
parser.add_argument('-o', '--output', help='Output model folder', required=True) parser.add_argument("-o", "--output", help="Output model folder", required=True)
parser.add_argument('-m', '--model', help='Chosen model (\'dectree\', \'randforest\' or \'extratree\')', required=True) parser.add_argument(
parser.add_argument('-s', '--scaler', help='Scaler preprocessor', required=True) "-m",
"--model",
help="Chosen model ('dectree', 'randforest' or 'extratree')",
required=True,
)
parser.add_argument("-s", "--scaler", help="Scaler preprocessor", required=True)
class CVSuiteTestTree: class CVSuiteTestTree:
def __init__(self, model_path=None): def __init__(self, model_path=None):
@ -33,8 +39,8 @@ class CVSuiteTestTree:
self.scaler = None self.scaler = None
def trainCSV(self, path, output) -> None: def trainCSV(self, path, output) -> None:
with open(path, 'r') as file: with open(path, "r") as file:
reader = csv.reader(file, delimiter= ',') reader = csv.reader(file, delimiter=",")
matrix = list(reader) matrix = list(reader)
labels = [] labels = []
@ -64,11 +70,11 @@ class CVSuiteTestTree:
self.scaler = load(path) self.scaler = load(path)
if self.scaler is None: if self.scaler is None:
print("Scaler failed to load!") print(C_ERR, "Scaler failed to load!")
exit() exit()
def train(self, data, labels, output) -> None: def train(self, data, labels, output) -> None:
print("You called the parent class, doofus") print(C_ERR, "You called the parent class, doofus")
def save(self, output, name) -> None: def save(self, output, name) -> None:
path = os.path.join(output, name) path = os.path.join(output, name)
@ -81,6 +87,7 @@ class CVSuiteTestTree:
else: else:
return None return None
class CVSuiteTestDecisionTree(CVSuiteTestTree): class CVSuiteTestDecisionTree(CVSuiteTestTree):
def train(self, data, labels, output) -> None: def train(self, data, labels, output) -> None:
self.model = tree.DecisionTreeClassifier( self.model = tree.DecisionTreeClassifier(
@ -88,41 +95,44 @@ class CVSuiteTestDecisionTree(CVSuiteTestTree):
min_samples_leaf=2, min_samples_leaf=2,
max_depth=None, max_depth=None,
random_state=False, random_state=False,
criterion='gini', criterion="gini",
splitter='best', splitter="best",
ccp_alpha=0 ccp_alpha=0,
) )
self.model.fit(data, labels) self.model.fit(data, labels)
self.save(output, 'decisiontree.pkl') self.save(output, "decisiontree.pkl")
class CVSuiteTestRandomForest(CVSuiteTestTree): class CVSuiteTestRandomForest(CVSuiteTestTree):
def train(self, data, labels, output) -> None: def train(self, data, labels, output) -> None:
self.model = RandomForestClassifier( self.model = RandomForestClassifier(
n_estimators=150, n_estimators=150,
criterion='gini', criterion="gini",
) )
self.model.fit(data, labels) self.model.fit(data, labels)
self.save(output, 'randomforest.pkl') self.save(output, "randomforest.pkl")
class CVSuiteTestExtraTrees(CVSuiteTestTree): class CVSuiteTestExtraTrees(CVSuiteTestTree):
def train(self, data, labels, output) -> None: def train(self, data, labels, output) -> None:
self.model = tree.ExtraTreeClassifier() self.model = tree.ExtraTreeClassifier()
self.model.fit(data, labels) self.model.fit(data, labels)
self.save(output, 'extratrees.pkl') self.save(output, "extratrees.pkl")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
if args.model == 'dectree': if args.model == "dectree":
test = CVSuiteTestDecisionTree() test = CVSuiteTestDecisionTree()
elif args.model == 'randforest': elif args.model == "randforest":
test = CVSuiteTestRandomForest() test = CVSuiteTestRandomForest()
elif args.model == 'extratree': elif args.model == "extratree":
test = CVSuiteTestExtraTrees() test = CVSuiteTestExtraTrees()
else: else:
print("Model not found!") print(C_ERR, "Model not found!")
exit() exit()
test.addScaler(args.scaler) test.addScaler(args.scaler)
test.trainCSV(args.input, args.output) test.trainCSV(args.input, args.output)
print(C_DONE + "Model trained successfully!") print(C_DONE, "Model trained successfully!")

View File

@ -93,7 +93,8 @@ class CVSuiteTestKNN:
# Dump the scalers # Dump the scalers
now = datetime.datetime.now() now = datetime.datetime.now()
joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl")) joblib.dump(self.scale, os.path.join(output, F"scaler.pkl"))
# joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
# Pass data to train function # Pass data to train function
self.train(data, tags_int, output) self.train(data, tags_int, output)
@ -112,7 +113,8 @@ class CVSuiteTestKNN:
# Save it # Save it
now = datetime.datetime.now() now = datetime.datetime.now()
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml")) self.knn.save(os.path.join(output, F"model_knn.yaml"))
# self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
def predict(self, data, nr = 3): def predict(self, data, nr = 3):
ret, results, neighbours ,dist = self.knn.findNearest(data, nr) ret, results, neighbours ,dist = self.knn.findNearest(data, nr)

View File

@ -35,7 +35,11 @@ from helpers.tags import Tree
# Tests # Tests
from helpers.test.knn import CVSuiteTestKNN from helpers.test.knn import CVSuiteTestKNN
from helpers.test.decision_tree import CVSuiteTestDecisionTree from helpers.test.decision_tree import (
CVSuiteTestDecisionTree,
CVSuiteTestRandomForest,
CVSuiteTestExtraTrees,
)
## UI config load ## UI config load
PROJECT_PATH = pathlib.Path(__file__).parent PROJECT_PATH = pathlib.Path(__file__).parent
@ -127,23 +131,25 @@ class CVSuite:
for model in config_json["models"]: for model in config_json["models"]:
if config_json["models"][model] != "": if config_json["models"][model] != "":
print(C_INFO, f"Loading model {model}") print(C_INFO, f"Loading model {model}")
mpath = config_json["models"][model]
if model == "knn": if model == "knn":
self.models.append( self.models.append(("KNN", CVSuiteTestKNN(mpath)))
("KNN", CVSuiteTestKNN(config_json["models"]["knn"]))
)
elif model == "dectree": elif model == "dectree":
self.models.append( self.models.append(
( ("Decision Tree", CVSuiteTestDecisionTree(mpath))
"Decision Tree",
CVSuiteTestDecisionTree(
config_json["models"]["dectree"]
),
) )
elif model == "randforest":
self.models.append(
("Random Forest", CVSuiteTestRandomForest(mpath))
)
elif model == "extratree":
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
else:
print(
C_WARN, f"Model {model} does not exist or is not supported!"
) )
else: else:
print(C_WARN, f"Model {model} does not exist!") print(C_WARN, f"Model {model} path not configured!")
else:
print(C_WARN, f"Model {model} not configured!")
print(C_DONE, f"{len(self.models)} models loaded!") print(C_DONE, f"{len(self.models)} models loaded!")
print(C_DONE, "CVSuite initialised!\n") print(C_DONE, "CVSuite initialised!\n")
@ -350,7 +356,9 @@ class CVSuite:
# Normalise data using loaded scalers # Normalise data using loaded scalers
for idx, value in enumerate(data): for idx, value in enumerate(data):
d = np.array(value) d = np.array(value)
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[0][0] data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[
0
][0]
data = np.array([data], dtype=np.float32) data = np.array([data], dtype=np.float32)