This commit is contained in:
Tom Selier 2023-10-21 19:17:43 +02:00
commit 57906d2a44
3 changed files with 26 additions and 6 deletions

View File

@ -80,7 +80,7 @@ $ python ./src/suite.py
- Based on your system configuration, this might take a while - Based on your system configuration, this might take a while
3. Run the CVSuiteTestKNN CLI tool: 3. Run the CVSuiteTestKNN CLI tool:
```sh ```sh
$ python ./src/helpers/test/knn.py -i ./out/result-(date/time).csv -o ./out/models/model_knn.yaml $ python ./src/helpers/test/knn.py -i ./out/result-(date/time).csv -o ./out/models/
``` ```
4. Edit your `config.json` to include the newly created model 4. Edit your `config.json` to include the newly created model

View File

@ -5,6 +5,9 @@ from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, Ma
import argparse import argparse
from enum import Enum from enum import Enum
import yaml import yaml
import joblib
import datetime
import os
parser = argparse.ArgumentParser(prog='KNN Train CLI') parser = argparse.ArgumentParser(prog='KNN Train CLI')
parser.add_argument('-i', '--input', help='Input CSV file', required=True) parser.add_argument('-i', '--input', help='Input CSV file', required=True)
@ -34,6 +37,7 @@ class CVSuiteTestKNN:
def trainCSV(self, path, output): def trainCSV(self, path, output):
''' '''
Takes preprocessed data from CVSuite, normalises it and trains the model Takes preprocessed data from CVSuite, normalises it and trains the model
Output should be a folder path
Function expects first two columns of the dataset to be tag and photoId, the first row should be the CSV header Function expects first two columns of the dataset to be tag and photoId, the first row should be the CSV header
''' '''
file = open(path, mode='r') file = open(path, mode='r')
@ -85,6 +89,10 @@ class CVSuiteTestKNN:
# Replace original data array # Replace original data array
data[:, idx] = column data[:, idx] = column
# Dump the scalers
now = datetime.datetime.now()
joblib.dump(self.scale, os.path.join(output, F"scale_{now.strftime('%Y-%m-%dT%H.%M.%S')}.pkl"))
# Pass data to train function # Pass data to train function
self.train(data, tags_int, output) self.train(data, tags_int, output)
@ -97,7 +105,10 @@ class CVSuiteTestKNN:
raise EnvironmentError("Model already trained!") raise EnvironmentError("Model already trained!")
else: else:
self.knn.train(data, cv.ml.ROW_SAMPLE, tags) self.knn.train(data, cv.ml.ROW_SAMPLE, tags)
self.knn.save(output)
# Save it
now = datetime.datetime.now()
self.knn.save(os.path.join(output, F"model_knn_{now.strftime('%Y-%m-%dT%H.%M.%S')}.yaml"))
def predict(self, data): def predict(self, data):
return self.knn.predict(data) return self.knn.predict(data)

View File

@ -15,6 +15,7 @@ import json
import numpy as np import numpy as np
import cv2 import cv2
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, MaxAbsScaler from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler, MaxAbsScaler
import joblib
# GUI # GUI
import pygubu import pygubu
@ -28,7 +29,7 @@ from helpers.sift import getSiftData
# Tests # Tests
from helpers.test.knn import CVSuiteTestKNN from helpers.test.knn import CVSuiteTestKNN
from helpers.test.decision_tree import CVSuiteTestDecisionTree # from helpers.test.decision_tree import CVSuiteTestDecisionTree
## UI config load ## UI config load
PROJECT_PATH = pathlib.Path(__file__).parent PROJECT_PATH = pathlib.Path(__file__).parent
@ -93,8 +94,15 @@ class CVSuite:
) )
builder.connect_callbacks(self) builder.connect_callbacks(self)
# Attempt to load scaler
if config_json["scaler"] != "":
self.scaler = joblib.load(config_json["scaler"])
print(self.scaler)
else:
self.scaler = None
# Model tests # Model tests
if config_json["models"]["knn"] != "": if self.scaler is not None and config_json["models"]["knn"] != "":
self.test_knn = CVSuiteTestKNN(config_json["models"]["knn"]) self.test_knn = CVSuiteTestKNN(config_json["models"]["knn"])
else: else:
self.test_knn = None self.test_knn = None
@ -299,9 +307,10 @@ class CVSuite:
tag = data.pop(0) tag = data.pop(0)
photoId = data.pop(1) photoId = data.pop(1)
for idx, value in enumerate(data):
data[idx] = self.scaler[idx].transform(np.array(value).reshape(-1, 1))
print(data) print(data)
for value in data:
print(value)
if self.test_knn is not None: if self.test_knn is not None:
# Do knn test # Do knn test