Split out CIFAR10 data batches and test batch
This commit is contained in:
parent
9a4fccb5a6
commit
021b8c556e
1 changed files with 23 additions and 12 deletions
35
datasets.py
35
datasets.py
|
@ -24,27 +24,38 @@ class CIFAR10:
|
|||
self.path = path
|
||||
|
||||
@property
|
||||
def batch1(self):
|
||||
return self._do_batch(1)
|
||||
def data_batch1(self):
|
||||
return self._do_data_batch(1)
|
||||
|
||||
@property
|
||||
def batch2(self):
|
||||
return self._do_batch(2)
|
||||
def data_batch2(self):
|
||||
return self._do_data_batch(2)
|
||||
|
||||
@property
|
||||
def batch3(self):
|
||||
return self._do_batch(3)
|
||||
def data_batch3(self):
|
||||
return self._do_data_batch(3)
|
||||
|
||||
@property
|
||||
def batch4(self):
|
||||
return self._do_batch(4)
|
||||
def data_batch4(self):
|
||||
return self._do_data_batch(4)
|
||||
|
||||
@property
|
||||
def batch5(self):
|
||||
return self._do_batch(5)
|
||||
def data_batch5(self):
|
||||
return self._do_data_batch(5)
|
||||
|
||||
def _do_batch(self, idx):
|
||||
attr = '__batch{}'.format(idx)
|
||||
@property
|
||||
def test_batch(self):
|
||||
if not getattr(self, '__test_batch', None):
|
||||
path = os.path.join(self.path, 'test_batch')
|
||||
self.__test_batch = self._unpickle(path)
|
||||
return self.__test_batch
|
||||
|
||||
@property
|
||||
def all_data_batches(self):
|
||||
return [getattr(self, 'data_batch{}'.format(i)) for i in range(1, 6)]
|
||||
|
||||
def _do_data_batch(self, idx):
|
||||
attr = '__data_batch{}'.format(idx)
|
||||
if not getattr(self, attr, None):
|
||||
path = os.path.join(self.path, 'data_batch_{}'.format(idx))
|
||||
setattr(self, attr, self._unpickle(path))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue