diff --git a/datasets.py b/datasets.py new file mode 100644 index 0000000..60a815f --- /dev/null +++ b/datasets.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +import hashlib +import os +import pickle +import requests +import tarfile + +def requires_scratch(f): + '''Decorator that ensures the scratch dir exists before calling the decorated function.''' + def _wrapped(*args, **kwargs): + _ensure_scratch_path() + return f(*args, **kwargs) + _wrapped.__name__ = f.__name__ + _wrapped.__doc__ = f.__doc__ + return _wrapped + +# +# Dataset getters +# + +class CIFAR10: + def __init__(self, path): + self.path = path + + @property + def batch1(self): + return self._do_batch(1) + + @property + def batch2(self): + return self._do_batch(2) + + @property + def batch3(self): + return self._do_batch(3) + + @property + def batch4(self): + return self._do_batch(4) + + @property + def batch5(self): + return self._do_batch(5) + + def _do_batch(self, idx): + attr = '__batch{}'.format(idx) + if not getattr(self, attr, None): + path = os.path.join(self.path, 'data_batch_{}'.format(idx)) + setattr(self, attr, self._unpickle(path)) + return getattr(self, attr) + + def _unpickle(self, path): + with open(path, 'rb') as f: + data = pickle.load(f, encoding='bytes') + return data + + +@requires_scratch +def cifar10(): + '''Download, extract, and return the CIFAR-10 archive.''' + url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' + file_md5 = 'c58f30108f718f92721af3b95e74349a' + + archive_path = os.path.join(_scratch_path(), 'cifar10.tar.gz') + + if not os.path.exists(archive_path): + # Download the file to our scratch path. + print('Downloading from {}'.format(url)) + r = requests.get(url) + r.raise_for_status() + with open(archive_path, 'wb') as f: + print('Writing to {}'.format(archive_path)) + f.write(r.content) + else: + print('Archive exists, proceeding: {}'.format(archive_path)) + + # TODO: Validate MD5 sum + + root = os.path.join(_scratch_path(), 'cifar-10-batches-py') + if not os.path.isdir(root): + with tarfile.open(archive_path) as f: + print('Extracting') + f.extractall(_scratch_path()) + + return CIFAR10(root) + +# +# Scratch helpers +# + +def _ensure_scratch_path(): + try: + os.makedirs(_scratch_path()) + except OSError as exc: + if exc.errno != os.errno.EEXIST: + raise + +def _scratch_path(): + return os.path.join(_script_path(), 'scratch') + +def _script_path(): + return os.path.dirname(os.path.abspath(__file__))