From 021b8c556e2892ef194420b87397d3252efac18e Mon Sep 17 00:00:00 2001 From: Eryn Wells Date: Fri, 26 Jan 2018 23:26:15 -0800 Subject: [PATCH] Split out CIFAR10 data batches and test batch --- datasets.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/datasets.py b/datasets.py index 60a815f..f004f8b 100644 --- a/datasets.py +++ b/datasets.py @@ -24,27 +24,38 @@ class CIFAR10: self.path = path @property - def batch1(self): - return self._do_batch(1) + def data_batch1(self): + return self._do_data_batch(1) @property - def batch2(self): - return self._do_batch(2) + def data_batch2(self): + return self._do_data_batch(2) @property - def batch3(self): - return self._do_batch(3) + def data_batch3(self): + return self._do_data_batch(3) @property - def batch4(self): - return self._do_batch(4) + def data_batch4(self): + return self._do_data_batch(4) @property - def batch5(self): - return self._do_batch(5) + def data_batch5(self): + return self._do_data_batch(5) - def _do_batch(self, idx): - attr = '__batch{}'.format(idx) + @property + def test_batch(self): + if not getattr(self, '__test_batch', None): + path = os.path.join(self.path, 'test_batch') + self.__test_batch = self._unpickle(path) + return self.__test_batch + + @property + def all_data_batches(self): + return [getattr(self, 'data_batch{}'.format(i)) for i in range(1, 6)] + + def _do_data_batch(self, idx): + attr = '__data_batch{}'.format(idx) if not getattr(self, attr, None): path = os.path.join(self.path, 'data_batch_{}'.format(idx)) setattr(self, attr, self._unpickle(path))