dataset splitser
This commit is contained in:
parent
6582fa01d6
commit
a94720346d
103
src/experiments/dataset.py
Normal file
103
src/experiments/dataset.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
from enum import Enum
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
class Tree(Enum):
|
||||||
|
ACCASIA = 0
|
||||||
|
BERK = 1
|
||||||
|
EIK = 2
|
||||||
|
ELS = 3
|
||||||
|
ESDOORN = 4
|
||||||
|
ES = 5
|
||||||
|
LINDE = 6
|
||||||
|
PLATAAN = 7
|
||||||
|
|
||||||
|
DIR = "dataset\\input"
|
||||||
|
SEED = 10
|
||||||
|
|
||||||
|
TRAIN_TARGET = 0.7
|
||||||
|
VAL_TARGET = 0.2
|
||||||
|
TEST_TARGET = 0.1
|
||||||
|
|
||||||
|
tree_count = [0 for x in Tree]
|
||||||
|
training = [0 for x in Tree]
|
||||||
|
validation = [0 for x in Tree]
|
||||||
|
testing = [0 for x in Tree]
|
||||||
|
|
||||||
|
files = [[] for x in Tree]
|
||||||
|
|
||||||
|
for file in os.listdir(DIR):
|
||||||
|
if not file.endswith(".JPG"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# get path
|
||||||
|
file_path = os.path.join(DIR, file)
|
||||||
|
name = file.split('_')[0].upper()
|
||||||
|
|
||||||
|
# store file names
|
||||||
|
files[Tree[name].value].append(file)
|
||||||
|
|
||||||
|
# count tree types
|
||||||
|
tree_count[Tree[name].value] += 1
|
||||||
|
|
||||||
|
# distribute
|
||||||
|
for i in range(len(Tree)):
|
||||||
|
training[i] = round(tree_count[i] * 0.7)
|
||||||
|
validation[i] = round(tree_count[i] * 0.2)
|
||||||
|
testing[i] = round(tree_count[i] * 0.1)
|
||||||
|
|
||||||
|
# display distribution
|
||||||
|
print("Total", tree_count, "=", sum(tree_count))
|
||||||
|
print("Training", training, "=", sum(training))
|
||||||
|
print("Validation", validation, "=", sum(validation))
|
||||||
|
print("Testing", testing, "=", sum(testing))
|
||||||
|
print("Total splits =", (sum(training) + sum(validation) + sum(testing)))
|
||||||
|
|
||||||
|
# create output dirs
|
||||||
|
if not os.path.exists(os.path.join(DIR, "training")):
|
||||||
|
os.mkdir(os.path.join(DIR, "training"))
|
||||||
|
if not os.path.exists(os.path.join(DIR, "validation")):
|
||||||
|
os.mkdir(os.path.join(DIR, "validation"))
|
||||||
|
if not os.path.exists(os.path.join(DIR, "testing")):
|
||||||
|
os.mkdir(os.path.join(DIR, "testing"))
|
||||||
|
|
||||||
|
# create output lists
|
||||||
|
training_names = []
|
||||||
|
validation_names = []
|
||||||
|
testing_names = []
|
||||||
|
|
||||||
|
# seed random
|
||||||
|
random.seed(SEED)
|
||||||
|
|
||||||
|
# fill output lists psuedorandomly
|
||||||
|
for tree in files:
|
||||||
|
name = tree[0].split('_')[0].upper()
|
||||||
|
|
||||||
|
# training
|
||||||
|
for i in range(training[Tree[name].value]):
|
||||||
|
idx = random.randint(0, len(tree)-1)
|
||||||
|
temp = tree.pop(idx)
|
||||||
|
training_names.append(temp)
|
||||||
|
|
||||||
|
# validation
|
||||||
|
for i in range(validation[Tree[name].value]):
|
||||||
|
idx = random.randint(0, len(tree)-1)
|
||||||
|
temp = tree.pop(idx)
|
||||||
|
validation_names.append(temp)
|
||||||
|
|
||||||
|
# testing
|
||||||
|
for i in range(testing[Tree[name].value]):
|
||||||
|
idx = random.randint(0, len(tree)-1)
|
||||||
|
temp = tree.pop(idx)
|
||||||
|
testing_names.append(temp)
|
||||||
|
|
||||||
|
# copy files to output dirs
|
||||||
|
for file in training_names:
|
||||||
|
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "training"))
|
||||||
|
|
||||||
|
for file in validation_names:
|
||||||
|
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "validation"))
|
||||||
|
|
||||||
|
for file in testing_names:
|
||||||
|
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "testing"))
|
Loading…
Reference in New Issue
Block a user