dataset splitser
This commit is contained in:
parent
6582fa01d6
commit
a94720346d
103
src/experiments/dataset.py
Normal file
103
src/experiments/dataset.py
Normal file
@ -0,0 +1,103 @@
|
||||
from enum import Enum
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
class Tree(Enum):
|
||||
ACCASIA = 0
|
||||
BERK = 1
|
||||
EIK = 2
|
||||
ELS = 3
|
||||
ESDOORN = 4
|
||||
ES = 5
|
||||
LINDE = 6
|
||||
PLATAAN = 7
|
||||
|
||||
DIR = "dataset\\input"
|
||||
SEED = 10
|
||||
|
||||
TRAIN_TARGET = 0.7
|
||||
VAL_TARGET = 0.2
|
||||
TEST_TARGET = 0.1
|
||||
|
||||
tree_count = [0 for x in Tree]
|
||||
training = [0 for x in Tree]
|
||||
validation = [0 for x in Tree]
|
||||
testing = [0 for x in Tree]
|
||||
|
||||
files = [[] for x in Tree]
|
||||
|
||||
for file in os.listdir(DIR):
|
||||
if not file.endswith(".JPG"):
|
||||
continue
|
||||
|
||||
# get path
|
||||
file_path = os.path.join(DIR, file)
|
||||
name = file.split('_')[0].upper()
|
||||
|
||||
# store file names
|
||||
files[Tree[name].value].append(file)
|
||||
|
||||
# count tree types
|
||||
tree_count[Tree[name].value] += 1
|
||||
|
||||
# distribute
|
||||
for i in range(len(Tree)):
|
||||
training[i] = round(tree_count[i] * 0.7)
|
||||
validation[i] = round(tree_count[i] * 0.2)
|
||||
testing[i] = round(tree_count[i] * 0.1)
|
||||
|
||||
# display distribution
|
||||
print("Total", tree_count, "=", sum(tree_count))
|
||||
print("Training", training, "=", sum(training))
|
||||
print("Validation", validation, "=", sum(validation))
|
||||
print("Testing", testing, "=", sum(testing))
|
||||
print("Total splits =", (sum(training) + sum(validation) + sum(testing)))
|
||||
|
||||
# create output dirs
|
||||
if not os.path.exists(os.path.join(DIR, "training")):
|
||||
os.mkdir(os.path.join(DIR, "training"))
|
||||
if not os.path.exists(os.path.join(DIR, "validation")):
|
||||
os.mkdir(os.path.join(DIR, "validation"))
|
||||
if not os.path.exists(os.path.join(DIR, "testing")):
|
||||
os.mkdir(os.path.join(DIR, "testing"))
|
||||
|
||||
# create output lists
|
||||
training_names = []
|
||||
validation_names = []
|
||||
testing_names = []
|
||||
|
||||
# seed random
|
||||
random.seed(SEED)
|
||||
|
||||
# fill output lists psuedorandomly
|
||||
for tree in files:
|
||||
name = tree[0].split('_')[0].upper()
|
||||
|
||||
# training
|
||||
for i in range(training[Tree[name].value]):
|
||||
idx = random.randint(0, len(tree)-1)
|
||||
temp = tree.pop(idx)
|
||||
training_names.append(temp)
|
||||
|
||||
# validation
|
||||
for i in range(validation[Tree[name].value]):
|
||||
idx = random.randint(0, len(tree)-1)
|
||||
temp = tree.pop(idx)
|
||||
validation_names.append(temp)
|
||||
|
||||
# testing
|
||||
for i in range(testing[Tree[name].value]):
|
||||
idx = random.randint(0, len(tree)-1)
|
||||
temp = tree.pop(idx)
|
||||
testing_names.append(temp)
|
||||
|
||||
# copy files to output dirs
|
||||
for file in training_names:
|
||||
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "training"))
|
||||
|
||||
for file in validation_names:
|
||||
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "validation"))
|
||||
|
||||
for file in testing_names:
|
||||
shutil.copy(os.path.join(DIR, file), os.path.join(DIR, "testing"))
|
Loading…
Reference in New Issue
Block a user