Added other trees to suite
This commit is contained in:
parent
4dcd021490
commit
58e521cb6e
@ -11,34 +11,40 @@ import sys
|
|||||||
try:
|
try:
|
||||||
# Perform relative import if included from CVSuite
|
# Perform relative import if included from CVSuite
|
||||||
from ..tags import Tree
|
from ..tags import Tree
|
||||||
from ..logger import C_DONE
|
from ..logger import C_DONE, C_ERR
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# This solution is hot garbage but I refuse to spend any more time on it
|
# This solution is hot garbage but I refuse to spend any more time on it
|
||||||
directory = os.path.dirname(os.path.realpath(__file__))
|
directory = os.path.dirname(os.path.realpath(__file__))
|
||||||
sys.path.append(os.path.join(directory, '..'))
|
sys.path.append(os.path.join(directory, ".."))
|
||||||
from tags import Tree
|
from tags import Tree
|
||||||
from logger import C_DONE
|
from logger import C_DONE
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog='DecisionTree CLI')
|
parser = argparse.ArgumentParser(prog="DecisionTree CLI")
|
||||||
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
|
parser.add_argument("-i", "--input", help="Input CSV file", required=True)
|
||||||
parser.add_argument('-o', '--output', help='Output model folder', required=True)
|
parser.add_argument("-o", "--output", help="Output model folder", required=True)
|
||||||
parser.add_argument('-m', '--model', help='Chosen model (\'dectree\', \'randforest\' or \'extratree\')', required=True)
|
parser.add_argument(
|
||||||
parser.add_argument('-s', '--scaler', help='Scaler preprocessor', required=True)
|
"-m",
|
||||||
|
"--model",
|
||||||
|
help="Chosen model ('dectree', 'randforest' or 'extratree')",
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
parser.add_argument("-s", "--scaler", help="Scaler preprocessor", required=True)
|
||||||
|
|
||||||
|
|
||||||
class CVSuiteTestTree:
|
class CVSuiteTestTree:
|
||||||
def __init__(self, model_path = None):
|
def __init__(self, model_path=None):
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
if self.model_path is not None:
|
if self.model_path is not None:
|
||||||
self.model = load(self.model_path)
|
self.model = load(self.model_path)
|
||||||
self.scaler = None
|
self.scaler = None
|
||||||
|
|
||||||
def trainCSV(self, path, output) -> None:
|
def trainCSV(self, path, output) -> None:
|
||||||
with open(path, 'r') as file:
|
with open(path, "r") as file:
|
||||||
reader = csv.reader(file, delimiter= ',')
|
reader = csv.reader(file, delimiter=",")
|
||||||
matrix = list(reader)
|
matrix = list(reader)
|
||||||
|
|
||||||
labels = []
|
labels = []
|
||||||
data = [[] for x in range(len(matrix)-1)]
|
data = [[] for x in range(len(matrix) - 1)]
|
||||||
|
|
||||||
# Load all but the headers
|
# Load all but the headers
|
||||||
for ridx, row in enumerate(matrix[1:]):
|
for ridx, row in enumerate(matrix[1:]):
|
||||||
@ -55,7 +61,7 @@ class CVSuiteTestTree:
|
|||||||
data[ridx].append(value[0][0])
|
data[ridx].append(value[0][0])
|
||||||
|
|
||||||
# normalize data has been included in code above :D
|
# normalize data has been included in code above :D
|
||||||
#TODO: Check if data is normalized correctly
|
# TODO: Check if data is normalized correctly
|
||||||
|
|
||||||
# train model
|
# train model
|
||||||
self.train(data, labels, output)
|
self.train(data, labels, output)
|
||||||
@ -64,11 +70,11 @@ class CVSuiteTestTree:
|
|||||||
self.scaler = load(path)
|
self.scaler = load(path)
|
||||||
|
|
||||||
if self.scaler is None:
|
if self.scaler is None:
|
||||||
print("Scaler failed to load!")
|
print(C_ERR, "Scaler failed to load!")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
def train(self, data, labels, output) -> None:
|
def train(self, data, labels, output) -> None:
|
||||||
print("You called the parent class, doofus")
|
print(C_ERR, "You called the parent class, doofus")
|
||||||
|
|
||||||
def save(self, output, name) -> None:
|
def save(self, output, name) -> None:
|
||||||
path = os.path.join(output, name)
|
path = os.path.join(output, name)
|
||||||
@ -81,48 +87,52 @@ class CVSuiteTestTree:
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
class CVSuiteTestDecisionTree(CVSuiteTestTree):
|
class CVSuiteTestDecisionTree(CVSuiteTestTree):
|
||||||
def train(self, data, labels, output) -> None:
|
def train(self, data, labels, output) -> None:
|
||||||
self.model = tree.DecisionTreeClassifier(
|
self.model = tree.DecisionTreeClassifier(
|
||||||
class_weight=None,
|
class_weight=None,
|
||||||
min_samples_leaf=2,
|
min_samples_leaf=2,
|
||||||
max_depth=None,
|
max_depth=None,
|
||||||
random_state=False,
|
random_state=False,
|
||||||
criterion='gini',
|
criterion="gini",
|
||||||
splitter='best',
|
splitter="best",
|
||||||
ccp_alpha=0
|
ccp_alpha=0,
|
||||||
)
|
)
|
||||||
self.model.fit(data, labels)
|
self.model.fit(data, labels)
|
||||||
self.save(output, 'decisiontree.pkl')
|
self.save(output, "decisiontree.pkl")
|
||||||
|
|
||||||
|
|
||||||
class CVSuiteTestRandomForest(CVSuiteTestTree):
|
class CVSuiteTestRandomForest(CVSuiteTestTree):
|
||||||
def train(self, data, labels, output) -> None:
|
def train(self, data, labels, output) -> None:
|
||||||
self.model = RandomForestClassifier(
|
self.model = RandomForestClassifier(
|
||||||
n_estimators=150,
|
n_estimators=150,
|
||||||
criterion='gini',
|
criterion="gini",
|
||||||
)
|
)
|
||||||
self.model.fit(data, labels)
|
self.model.fit(data, labels)
|
||||||
self.save(output, 'randomforest.pkl')
|
self.save(output, "randomforest.pkl")
|
||||||
|
|
||||||
|
|
||||||
class CVSuiteTestExtraTrees(CVSuiteTestTree):
|
class CVSuiteTestExtraTrees(CVSuiteTestTree):
|
||||||
def train(self, data, labels, output) -> None:
|
def train(self, data, labels, output) -> None:
|
||||||
self.model = tree.ExtraTreeClassifier()
|
self.model = tree.ExtraTreeClassifier()
|
||||||
self.model.fit(data, labels)
|
self.model.fit(data, labels)
|
||||||
self.save(output, 'extratrees.pkl')
|
self.save(output, "extratrees.pkl")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model == 'dectree':
|
if args.model == "dectree":
|
||||||
test = CVSuiteTestDecisionTree()
|
test = CVSuiteTestDecisionTree()
|
||||||
elif args.model == 'randforest':
|
elif args.model == "randforest":
|
||||||
test = CVSuiteTestRandomForest()
|
test = CVSuiteTestRandomForest()
|
||||||
elif args.model == 'extratree':
|
elif args.model == "extratree":
|
||||||
test = CVSuiteTestExtraTrees()
|
test = CVSuiteTestExtraTrees()
|
||||||
else:
|
else:
|
||||||
print("Model not found!")
|
print(C_ERR, "Model not found!")
|
||||||
exit()
|
exit()
|
||||||
|
|
||||||
test.addScaler(args.scaler)
|
test.addScaler(args.scaler)
|
||||||
test.trainCSV(args.input, args.output)
|
test.trainCSV(args.input, args.output)
|
||||||
print(C_DONE + "Model trained successfully!")
|
print(C_DONE, "Model trained successfully!")
|
||||||
|
@ -93,7 +93,8 @@ class CVSuiteTestKNN:
|
|||||||
|
|
||||||
# Dump the scalers
|
# Dump the scalers
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
|
joblib.dump(self.scale, os.path.join(output, F"scaler.pkl"))
|
||||||
|
# joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
|
||||||
|
|
||||||
# Pass data to train function
|
# Pass data to train function
|
||||||
self.train(data, tags_int, output)
|
self.train(data, tags_int, output)
|
||||||
@ -112,7 +113,8 @@ class CVSuiteTestKNN:
|
|||||||
|
|
||||||
# Save it
|
# Save it
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
|
self.knn.save(os.path.join(output, F"model_knn.yaml"))
|
||||||
|
# self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
|
||||||
|
|
||||||
def predict(self, data, nr = 3):
|
def predict(self, data, nr = 3):
|
||||||
ret, results, neighbours ,dist = self.knn.findNearest(data, nr)
|
ret, results, neighbours ,dist = self.knn.findNearest(data, nr)
|
||||||
|
34
src/suite.py
34
src/suite.py
@ -35,7 +35,11 @@ from helpers.tags import Tree
|
|||||||
|
|
||||||
# Tests
|
# Tests
|
||||||
from helpers.test.knn import CVSuiteTestKNN
|
from helpers.test.knn import CVSuiteTestKNN
|
||||||
from helpers.test.decision_tree import CVSuiteTestDecisionTree
|
from helpers.test.decision_tree import (
|
||||||
|
CVSuiteTestDecisionTree,
|
||||||
|
CVSuiteTestRandomForest,
|
||||||
|
CVSuiteTestExtraTrees,
|
||||||
|
)
|
||||||
|
|
||||||
## UI config load
|
## UI config load
|
||||||
PROJECT_PATH = pathlib.Path(__file__).parent
|
PROJECT_PATH = pathlib.Path(__file__).parent
|
||||||
@ -127,23 +131,25 @@ class CVSuite:
|
|||||||
for model in config_json["models"]:
|
for model in config_json["models"]:
|
||||||
if config_json["models"][model] != "":
|
if config_json["models"][model] != "":
|
||||||
print(C_INFO, f"Loading model {model}")
|
print(C_INFO, f"Loading model {model}")
|
||||||
|
mpath = config_json["models"][model]
|
||||||
if model == "knn":
|
if model == "knn":
|
||||||
self.models.append(
|
self.models.append(("KNN", CVSuiteTestKNN(mpath)))
|
||||||
("KNN", CVSuiteTestKNN(config_json["models"]["knn"]))
|
|
||||||
)
|
|
||||||
elif model == "dectree":
|
elif model == "dectree":
|
||||||
self.models.append(
|
self.models.append(
|
||||||
(
|
("Decision Tree", CVSuiteTestDecisionTree(mpath))
|
||||||
"Decision Tree",
|
|
||||||
CVSuiteTestDecisionTree(
|
|
||||||
config_json["models"]["dectree"]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
elif model == "randforest":
|
||||||
|
self.models.append(
|
||||||
|
("Random Forest", CVSuiteTestRandomForest(mpath))
|
||||||
|
)
|
||||||
|
elif model == "extratree":
|
||||||
|
self.models.append(("Extra tree", CVSuiteTestExtraTrees(mpath)))
|
||||||
else:
|
else:
|
||||||
print(C_WARN, f"Model {model} does not exist!")
|
print(
|
||||||
|
C_WARN, f"Model {model} does not exist or is not supported!"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print(C_WARN, f"Model {model} not configured!")
|
print(C_WARN, f"Model {model} path not configured!")
|
||||||
print(C_DONE, f"{len(self.models)} models loaded!")
|
print(C_DONE, f"{len(self.models)} models loaded!")
|
||||||
|
|
||||||
print(C_DONE, "CVSuite initialised!\n")
|
print(C_DONE, "CVSuite initialised!\n")
|
||||||
@ -350,7 +356,9 @@ class CVSuite:
|
|||||||
# Normalise data using loaded scalers
|
# Normalise data using loaded scalers
|
||||||
for idx, value in enumerate(data):
|
for idx, value in enumerate(data):
|
||||||
d = np.array(value)
|
d = np.array(value)
|
||||||
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[0][0]
|
data[idx] = self.scaler[idx].transform(d.astype(np.float32).reshape(1, -1))[
|
||||||
|
0
|
||||||
|
][0]
|
||||||
|
|
||||||
data = np.array([data], dtype=np.float32)
|
data = np.array([data], dtype=np.float32)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user