Split out CIFAR10 data batches and test batch

This commit is contained in:
Eryn Wells 2018-01-26 23:26:15 -08:00
parent 9a4fccb5a6
commit 021b8c556e

View file

@ -24,27 +24,38 @@ class CIFAR10:
self.path = path
@property
def batch1(self):
return self._do_batch(1)
def data_batch1(self):
return self._do_data_batch(1)
@property
def batch2(self):
return self._do_batch(2)
def data_batch2(self):
return self._do_data_batch(2)
@property
def batch3(self):
return self._do_batch(3)
def data_batch3(self):
return self._do_data_batch(3)
@property
def batch4(self):
return self._do_batch(4)
def data_batch4(self):
return self._do_data_batch(4)
@property
def batch5(self):
return self._do_batch(5)
def data_batch5(self):
return self._do_data_batch(5)
def _do_batch(self, idx):
attr = '__batch{}'.format(idx)
@property
def test_batch(self):
if not getattr(self, '__test_batch', None):
path = os.path.join(self.path, 'test_batch')
self.__test_batch = self._unpickle(path)
return self.__test_batch
@property
def all_data_batches(self):
return [getattr(self, 'data_batch{}'.format(i)) for i in range(1, 6)]
def _do_data_batch(self, idx):
attr = '__data_batch{}'.format(idx)
if not getattr(self, attr, None):
path = os.path.join(self.path, 'data_batch_{}'.format(idx))
setattr(self, attr, self._unpickle(path))