This commit is contained in:
Tom Selier 2023-10-21 19:17:43 +02:00
commit 57906d2a44
3 changed files with 26 additions and 6 deletions

View File

@ -80,7 +80,7 @@ $ python ./src/suite.py
- Based on your system configuration, this might take a while
3. Run the CVSuiteTestKNN CLI tool:
```sh
$ python ./src/helpers/test/knn.py -i ./out/result-(date/time).csv -o ./out/models/model_knn.yaml
$ python ./src/helpers/test/knn.py -i ./out/result-(date/time).csv -o ./out/models/
```
4. Edit your `config.json` to include the newly created model

View File

@ -5,6 +5,9 @@ from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Ma
import argparse
from enum import Enum
import yaml
import joblib
import datetime
import os
parser = argparse.ArgumentParser(prog='KNN Train CLI')
parser.add_argument('-i', '--input', help='Input CSV file', required=True)
@ -34,6 +37,7 @@ class CVSuiteTestKNN:
def trainCSV(self, path, output):
'''
Takes preprocessed data from CVSuite, normalises it and trains the model
Output should be a folder path
Function expects first two columns of the dataset to be tag and photoId, the first row should be the CSV header
'''
file = open(path, mode='r')
@ -85,6 +89,10 @@ class CVSuiteTestKNN:
# Replace original data array
data[:, idx] = column
# Dump the scalers
now = datetime.datetime.now()
joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
# Pass data to train function
self.train(data, tags_int, output)
@ -97,7 +105,10 @@ class CVSuiteTestKNN:
raise EnvironmentError("Model already trained!")
else:
self.knn.train(data, cv.ml.ROW_SAMPLE, tags)
self.knn.save(output)
# Save it
now = datetime.datetime.now()
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
def predict(self, data):
return self.knn.predict(data)

View File

@ -15,6 +15,7 @@ import json
import numpy as np
import cv2
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, MaxAbsScaler
import joblib
# GUI
import pygubu
@ -28,7 +29,7 @@ from helpers.sift import getSiftData
# Tests
from helpers.test.knn import CVSuiteTestKNN
from helpers.test.decision_tree import CVSuiteTestDecisionTree
# from helpers.test.decision_tree import CVSuiteTestDecisionTree
## UI config load
PROJECT_PATH = pathlib.Path(__file__).parent
@ -93,8 +94,15 @@ class CVSuite:
)
builder.connect_callbacks(self)
# Attempt to load scaler
if config_json["scaler"] != "":
self.scaler = joblib.load(config_json["scaler"])
print(self.scaler)
else:
self.scaler = None
# Model tests
if config_json["models"]["knn"] != "":
if self.scaler is not None and config_json["models"]["knn"] != "":
self.test_knn = CVSuiteTestKNN(config_json["models"]["knn"])
else:
self.test_knn = None
@ -299,9 +307,10 @@ class CVSuite:
tag = data.pop(0)
photoId = data.pop(1)
for idx, value in enumerate(data):
data[idx] = self.scaler[idx].transform(np.array(value).reshape(-1, 1))
print(data)
for value in data:
print(value)
if self.test_knn is not None:
# Do knn test