added weighted std calcs

This commit is contained in:
Tom Selier 2023-09-29 17:36:56 +02:00
parent 4b510dc4ef
commit 14247b7d18
2 changed files with 84 additions and 0 deletions

View File

@ -0,0 +1,84 @@
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import math
def imgStats(img):
mean = np.zeros(3)
std = np.zeros(3)
for i in range(img.shape[2]):
mean[i] = np.mean(img[:, :, i])
std[i] = np.std(img[:, :, i])
return mean, std
def isFloat(num):
try:
float(num)
return True
except ValueError:
return False
def calcWeightedStd(values, weights):
N = len(values)
x_bar = np.average(values, weights=weights, axis=0)
numerator = 0
for i in range(N):
numerator += weights[i] * ((values[i] - x_bar) ** 2)
denominator = np.sum(weights)
return math.sqrt(numerator / denominator)
def calcNormalFunc(mean, sd, len):
f = np.zeros(len, dtype=np.longdouble)
# calculate PDF
for x in range(len):
exp = (-(x - mean) ** 2)/(2 * sd ** 2)
f[x] = 1 / math.sqrt(2 * np.pi * sd** 20 ) * (math.exp(exp))
# normalize PDF
max = np.amax(f)
min = np.amin(f)
for x in range(len):
f[x] = (f[x] - min) / (max - min)
return f
DATASET_PATH = "dataset\\"
BARK_TYPES = 8
averages = [[] for x in range(BARK_TYPES)]
weights = [[] for x in range(BARK_TYPES)]
variances = [[] for x in range(BARK_TYPES)]
i = -1
last_name = ""
for file in os.listdir(DATASET_PATH):
name = file.split('_')[0]
if(name != last_name):
last_name = name
i += 1
image = cv2.imread(os.path.join(DATASET_PATH + file), 1)
assert image is not None, "Something went wrong"
# Weighted average
averages[i].append(np.mean(image[:, :, 0]))
weights[i].append(len(image[:, :, 0]))
# print()
## Voor gewogen std:
## https://stackoverflow.com/questions/2413522/weighted-standard-deviation-in-numpy
weighted_avg = np.zeros(BARK_TYPES)
weighted_std = np.zeros(BARK_TYPES)
for i in range(BARK_TYPES):
weighted_avg[i] = np.average(averages[i], weights=weights[i], axis=0)
weighted_std[i] = calcWeightedStd(averages[i], weights[i])
print("Weighted averages: ")
print(weighted_avg)
print("Weighted standard deviation: ")
print(weighted_std)