Added weighted MCC
This commit is contained in:
parent
17af113484
commit
39a39ba47e
27
src/suite.py
27
src/suite.py
@ -14,7 +14,7 @@ import json
|
||||
# OpenCV
|
||||
import numpy as np
|
||||
import cv2
|
||||
from sklearn.metrics import confusion_matrix
|
||||
from sklearn.metrics import confusion_matrix, matthews_corrcoef
|
||||
from sklearn.preprocessing import (
|
||||
MinMaxScaler,
|
||||
StandardScaler,
|
||||
@ -395,21 +395,40 @@ class CVSuite:
|
||||
modelnr = (idx * 2) + idy
|
||||
guesses = self.models[modelnr][2]
|
||||
|
||||
# Get accuracy
|
||||
guess_total = 0
|
||||
guess_ok = 0
|
||||
for guess in guesses:
|
||||
print(guess)
|
||||
guess_total += 1
|
||||
if guess[0] == guess[1]:
|
||||
guess_ok += 1
|
||||
|
||||
# Convert guess array
|
||||
tag_true = [guess[0] for guess in guesses ]
|
||||
tag_predict = [guess[1] for guess in guesses ]
|
||||
|
||||
# calculate weighted average
|
||||
data_len_class = [tag_true.count(tag) for tag in range(0, 8)]
|
||||
data_len = len(tag_true)
|
||||
data_class_weight = [data_len_class[tag] / data_len for tag in range(0, 8)]
|
||||
|
||||
tag_weighted = []
|
||||
for tag in tag_true:
|
||||
tag_weighted.append(data_class_weight[tag])
|
||||
|
||||
print(C_DBUG, f"Data length per class: {data_len_class}; Total: {data_len}")
|
||||
|
||||
# Get MCC
|
||||
mcc = matthews_corrcoef(tag_true, tag_predict, sample_weight=tag_weighted)
|
||||
|
||||
labels = [Tree(tag).name for tag in range(0, 8)]
|
||||
|
||||
cm = confusion_matrix(tag_true, tag_predict)
|
||||
cmn = cmn = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
||||
sns.heatmap(cmn, xticklabels=labels, yticklabels=labels, ax=graph, annot=True, cbar=False, fmt='.2f')
|
||||
|
||||
graph.set_title(self.models[modelnr][0])
|
||||
graph.set_xlabel("Predicted")
|
||||
graph.set_title(f"{self.models[modelnr][0]}; MCC: {mcc:.2f}; Acc: {((guess_ok / guess_total) * 100):.2f}%" )
|
||||
graph.set_xlabel("Predicted")
|
||||
graph.set_ylabel("Actual")
|
||||
graph.set_xticklabels(labels, rotation=0)
|
||||
graph.set_yticklabels(labels, rotation=0)
|
||||
|
Loading…
Reference in New Issue
Block a user