From a94720346d35984c3ba8ad00731dd31497165bb0 Mon Sep 17 00:00:00 2001 From: Tom Selier Date: Sat, 21 Oct 2023 12:31:24 +0200 Subject: [PATCH] dataset splitser --- src/experiments/dataset.py | 103 +++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 src/experiments/dataset.py diff --git a/src/experiments/dataset.py b/src/experiments/dataset.py new file mode 100644 index 0000000..f251d85 --- /dev/null +++ b/src/experiments/dataset.py @@ -0,0 +1,103 @@ +from enum import Enum +import os +import random +import shutil + +class Tree(Enum): + ACCASIA = 0 + BERK = 1 + EIK = 2 + ELS = 3 + ESDOORN = 4 + ES = 5 + LINDE = 6 + PLATAAN = 7 + +DIR = "dataset\\input" +SEED = 10 + +TRAIN_TARGET = 0.7 +VAL_TARGET = 0.2 +TEST_TARGET = 0.1 + +tree_count = [0 for x in Tree] +training = [0 for x in Tree] +validation = [0 for x in Tree] +testing = [0 for x in Tree] + +files = [[] for x in Tree] + +for file in os.listdir(DIR): + if not file.endswith(".JPG"): + continue + + # get path + file_path = os.path.join(DIR, file) + name = file.split('_')[0].upper() + + # store file names + files[Tree[name].value].append(file) + + # count tree types + tree_count[Tree[name].value] += 1 + +# distribute +for i in range(len(Tree)): + training[i] = round(tree_count[i] * 0.7) + validation[i] = round(tree_count[i] * 0.2) + testing[i] = round(tree_count[i] * 0.1) + +# display distribution +print("Total", tree_count, "=", sum(tree_count)) +print("Training", training, "=", sum(training)) +print("Validation", validation, "=", sum(validation)) +print("Testing", testing, "=", sum(testing)) +print("Total splits =", (sum(training) + sum(validation) + sum(testing))) + +# create output dirs +if not os.path.exists(os.path.join(DIR, "training")): + os.mkdir(os.path.join(DIR, "training")) +if not os.path.exists(os.path.join(DIR, "validation")): + os.mkdir(os.path.join(DIR, "validation")) +if not os.path.exists(os.path.join(DIR, "testing")): + os.mkdir(os.path.join(DIR, "testing")) + +# create output lists +training_names = [] +validation_names = [] +testing_names = [] + +# seed random +random.seed(SEED) + +# fill output lists psuedorandomly +for tree in files: + name = tree[0].split('_')[0].upper() + + # training + for i in range(training[Tree[name].value]): + idx = random.randint(0, len(tree)-1) + temp = tree.pop(idx) + training_names.append(temp) + + # validation + for i in range(validation[Tree[name].value]): + idx = random.randint(0, len(tree)-1) + temp = tree.pop(idx) + validation_names.append(temp) + + # testing + for i in range(testing[Tree[name].value]): + idx = random.randint(0, len(tree)-1) + temp = tree.pop(idx) + testing_names.append(temp) + +# copy files to output dirs +for file in training_names: + shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "training")) + +for file in validation_names: + shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "validation")) + +for file in testing_names: + shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "testing")) \ No newline at end of file