dataset splitser

This commit is contained in:
Tom Selier 2023-10-21 12:31:24 +02:00
parent 6582fa01d6
commit a94720346d
1 changed files with 103 additions and 0 deletions

103
src/experiments/dataset.py Normal file
View File

@ -0,0 +1,103 @@
from enum import Enum
import os
import random
import shutil
class Tree(Enum):
ACCASIA = 0
BERK = 1
EIK = 2
ELS = 3
ESDOORN = 4
ES = 5
LINDE = 6
PLATAAN = 7
DIR = "dataset\\input"
SEED = 10
TRAIN_TARGET = 0.7
VAL_TARGET = 0.2
TEST_TARGET = 0.1
tree_count = [0 for x in Tree]
training = [0 for x in Tree]
validation = [0 for x in Tree]
testing = [0 for x in Tree]
files = [[] for x in Tree]
for file in os.listdir(DIR):
if not file.endswith(".JPG"):
continue
# get path
file_path = os.path.join(DIR, file)
name = file.split('_')[0].upper()
# store file names
files[Tree[name].value].append(file)
# count tree types
tree_count[Tree[name].value] += 1
# distribute
for i in range(len(Tree)):
training[i] = round(tree_count[i] * 0.7)
validation[i] = round(tree_count[i] * 0.2)
testing[i] = round(tree_count[i] * 0.1)
# display distribution
print("Total", tree_count, "=", sum(tree_count))
print("Training", training, "=", sum(training))
print("Validation", validation, "=", sum(validation))
print("Testing", testing, "=", sum(testing))
print("Total splits =", (sum(training) + sum(validation) + sum(testing)))
# create output dirs
if not os.path.exists(os.path.join(DIR, "training")):
os.mkdir(os.path.join(DIR, "training"))
if not os.path.exists(os.path.join(DIR, "validation")):
os.mkdir(os.path.join(DIR, "validation"))
if not os.path.exists(os.path.join(DIR, "testing")):
os.mkdir(os.path.join(DIR, "testing"))
# create output lists
training_names = []
validation_names = []
testing_names = []
# seed random
random.seed(SEED)
# fill output lists psuedorandomly
for tree in files:
name = tree[0].split('_')[0].upper()
# training
for i in range(training[Tree[name].value]):
idx = random.randint(0, len(tree)-1)
temp = tree.pop(idx)
training_names.append(temp)
# validation
for i in range(validation[Tree[name].value]):
idx = random.randint(0, len(tree)-1)
temp = tree.pop(idx)
validation_names.append(temp)
# testing
for i in range(testing[Tree[name].value]):
idx = random.randint(0, len(tree)-1)
temp = tree.pop(idx)
testing_names.append(temp)
# copy files to output dirs
for file in training_names:
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "training"))
for file in validation_names:
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "validation"))
for file in testing_names:
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "testing"))