Added other trees to suite

This commit is contained in:
Arne van Iterson 2023-10-22 15:12:03 +02:00
parent 4dcd021490
commit 58e521cb6e
3 changed files with 72 additions and 52 deletions

View File

@ -11,19 +11,25 @@ import sys
try:
# Perform relative import if included from CVSuite
from ..tags import Tree
from ..logger import C_DONE
from ..logger import C_DONE, C_ERR
except ImportError:
# This solution is hot garbage but I refuse to spend any more time on it
directory = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(directory, '..'))
sys.path.append(os.path.join(directory, ".."))
from tags import Tree
from logger import C_DONE
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
parser.add_argument('-o', '--output', help='Output model folder', required=True)
parser.add_argument('-m', '--model', help='Chosen model (\'dectree\', \'randforest\' or \'extratree\')', required=True)
parser.add_argument('-s', '--scaler', help='Scaler preprocessor', required=True)
parser = argparse.ArgumentParser(prog="DecisionTree CLI")
parser.add_argument("-i", "--input", help="Input CSV file", required=True)
parser.add_argument("-o", "--output", help="Output model folder", required=True)
parser.add_argument(
"-m",
"--model",
help="Chosen model ('dectree', 'randforest' or 'extratree')",
required=True,
)
parser.add_argument("-s", "--scaler", help="Scaler preprocessor", required=True)
class CVSuiteTestTree:
def __init__(self, model_path=None):
@ -33,8 +39,8 @@ class CVSuiteTestTree:
self.scaler = None
def trainCSV(self, path, output) -> None:
with open(path, 'r') as file:
reader = csv.reader(file, delimiter= ',')
with open(path, "r") as file:
reader = csv.reader(file, delimiter=",")
matrix = list(reader)
labels = []
@ -64,11 +70,11 @@ class CVSuiteTestTree:
self.scaler = load(path)
if self.scaler is None:
print("Scaler failed to load!")
print(C_ERR, "Scaler failed to load!")
exit()
def train(self, data, labels, output) -> None:
print("You called the parent class, doofus")
print(C_ERR, "You called the parent class, doofus")
def save(self, output, name) -> None:
path = os.path.join(output, name)
@ -81,6 +87,7 @@ class CVSuiteTestTree:
else:
return None
class CVSuiteTestDecisionTree(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.DecisionTreeClassifier(
@ -88,41 +95,44 @@ class CVSuiteTestDecisionTree(CVSuiteTestTree):
min_samples_leaf=2,
max_depth=None,
random_state=False,
criterion='gini',
splitter='best',
ccp_alpha=0
criterion="gini",
splitter="best",
ccp_alpha=0,
)
self.model.fit(data, labels)
self.save(output, 'decisiontree.pkl')
self.save(output, "decisiontree.pkl")
class CVSuiteTestRandomForest(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = RandomForestClassifier(
n_estimators=150,
criterion='gini',
criterion="gini",
)
self.model.fit(data, labels)
self.save(output, 'randomforest.pkl')
self.save(output, "randomforest.pkl")
class CVSuiteTestExtraTrees(CVSuiteTestTree):
def train(self, data, labels, output) -> None:
self.model = tree.ExtraTreeClassifier()
self.model.fit(data, labels)
self.save(output, 'extratrees.pkl')
self.save(output, "extratrees.pkl")
if __name__ == "__main__":
args = parser.parse_args()
if args.model == 'dectree':
if args.model == "dectree":
test = CVSuiteTestDecisionTree()
elif args.model == 'randforest':
elif args.model == "randforest":
test = CVSuiteTestRandomForest()
elif args.model == 'extratree':
elif args.model == "extratree":
test = CVSuiteTestExtraTrees()
else:
print("Model not found!")
print(C_ERR, "Model not found!")
exit()
test.addScaler(args.scaler)
test.trainCSV(args.input, args.output)
print(C_DONE + "Model trained successfully!")
print(C_DONE, "Model trained successfully!")

View File

@ -93,7 +93,8 @@ class CVSuiteTestKNN:
# Dump the scalers
now = datetime.datetime.now()
joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
joblib.dump(self.scale, os.path.join(output, F"scaler.pkl"))
# joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
# Pass data to train function
self.train(data, tags_int, output)
@ -112,7 +113,8 @@ class CVSuiteTestKNN:
# Save it
now = datetime.datetime.now()
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
self.knn.save(os.path.join(output, F"model_knn.yaml"))
# self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
def predict(self, data, nr = 3):
ret, results, neighbours ,dist = self.knn.findNearest(data, nr)

View File

@ -35,7 +35,11 @@ from helpers.tags import Tree
# Tests
from helpers.test.knn import CVSuiteTestKNN
from helpers.test.decision_tree import CVSuiteTestDecisionTree
from helpers.test.decision_tree import (
CVSuiteTestDecisionTree,
CVSuiteTestRandomForest,
CVSuiteTestExtraTrees,
)
## UI config load
PROJECT_PATH = pathlib.Path(__file__).parent
@ -127,23 +131,25 @@ class CVSuite:
for model in config_json["models"]:
if config_json["models"][model] != "":
print(C_INFO, f"Loading model {model}")
mpath = config_json["models"][model]
if model == "knn":
self.models.append(
("KNN", CVSuiteTestKNN(config_json["models"]["knn"]))
)
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
elif model == "dectree":
self.models.append(
(
"Decision Tree",
CVSuiteTestDecisionTree(
config_json["models"]["dectree"]
),
("Decision Tree", CVSuiteTestDecisionTree(mpath))
)
elif model == "randforest":
self.models.append(
("Random Forest", CVSuiteTestRandomForest(mpath))
)
elif model == "extratree":
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
else:
print(
C_WARN, f"Model {model} does not exist or is not supported!"
)
else:
print(C_WARN, f"Model {model} does not exist!")
else:
print(C_WARN, f"Model {model} not configured!")
print(C_WARN, f"Model {model} path not configured!")
print(C_DONE, f"{len(self.models)} models loaded!")
print(C_DONE, "CVSuite initialised!\n")
@ -350,7 +356,9 @@ class CVSuite:
# Normalise data using loaded scalers
for idx, value in enumerate(data):
d = np.array(value)
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[0][0]
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[
0
][0]
data = np.array([data], dtype=np.float32)