-
Notifications
You must be signed in to change notification settings - Fork 0
/
data.py
53 lines (39 loc) · 2.08 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Generates datasets."""
from typing import *
import numpy as np
class Dataset():
def __init__(self, function, dataset_size: int, batch_size: int, input_size: int, input_range: Tuple[int, int]) -> None:
"""
Inputs:
`function`: The function used to generate the dataset.
`dataset_size`: The number of data to generate.
`batch_size`: Number of data to train on at once.
`input_size`: The number of input arguments to the function.
`input_range`: The minimum and maximum input values to input into the function.
"""
assert dataset_size % batch_size == 0, f'The dataset size must divide evenly into the batch size.'
self.x = np.random.rand(dataset_size, input_size, 1) * (input_range[1] - input_range[0]) + input_range[0]
self.y = function(*[self.x[:, i:i+1, :] for i in range(input_size)])
self.dataset_size = dataset_size
self.batch_size = batch_size
self.training_ratio = 0.8
def training(self) -> Generator:
"""Return a generator of training data."""
count = round(self.training_ratio * self.dataset_size)
return ((self.x[i:i+self.batch_size, ...], self.y[i:i+self.batch_size, ...]) for i in range(0, count, self.batch_size))
def testing(self) -> Iterable:
"""Return a generator of testing data."""
count = round((1 - self.training_ratio) * self.dataset_size)
return ((self.x[i:i+self.batch_size, ...], self.y[i:i+self.batch_size, ...]) for i in range(self.dataset_size - count, self.dataset_size, self.batch_size))
def shuffle(self) -> None:
count = round(self.training_ratio * self.dataset_size)
# Shuffle the training data.
indices = np.arange(count)
np.random.shuffle(indices)
self.x[:count] = self.x[:count][indices]
self.y[:count] = self.y[:count][indices]
# Shuffle the testing data.
indices = np.arange(self.dataset_size - count)
np.random.shuffle(indices)
self.x[count:] = self.x[count:][indices]
self.y[count:] = self.y[count:][indices]