learn-me-a-machine/datasets.py

114 lines
2.9 KiB
Python

#!/usr/bin/env python3
import hashlib
import os
import pickle
import requests
import tarfile
def requires_scratch(f):
'''Decorator that ensures the scratch dir exists before calling the decorated function.'''
def _wrapped(*args, **kwargs):
_ensure_scratch_path()
return f(*args, **kwargs)
_wrapped.__name__ = f.__name__
_wrapped.__doc__ = f.__doc__
return _wrapped
#
# Dataset getters
#
class CIFAR10:
def __init__(self, path):
self.path = path
@property
def data_batch1(self):
return self._do_data_batch(1)
@property
def data_batch2(self):
return self._do_data_batch(2)
@property
def data_batch3(self):
return self._do_data_batch(3)
@property
def data_batch4(self):
return self._do_data_batch(4)
@property
def data_batch5(self):
return self._do_data_batch(5)
@property
def test_batch(self):
if not getattr(self, '__test_batch', None):
path = os.path.join(self.path, 'test_batch')
self.__test_batch = self._unpickle(path)
return self.__test_batch
@property
def all_data_batches(self):
return [getattr(self, 'data_batch{}'.format(i)) for i in range(1, 6)]
def _do_data_batch(self, idx):
attr = '__data_batch{}'.format(idx)
if not getattr(self, attr, None):
path = os.path.join(self.path, 'data_batch_{}'.format(idx))
setattr(self, attr, self._unpickle(path))
return getattr(self, attr)
def _unpickle(self, path):
with open(path, 'rb') as f:
data = pickle.load(f, encoding='bytes')
return data
@requires_scratch
def cifar10():
'''Download, extract, and return the CIFAR-10 archive.'''
url = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
file_md5 = 'c58f30108f718f92721af3b95e74349a'
archive_path = os.path.join(_scratch_path(), 'cifar10.tar.gz')
if not os.path.exists(archive_path):
# Download the file to our scratch path.
print('Downloading from {}'.format(url))
r = requests.get(url)
r.raise_for_status()
with open(archive_path, 'wb') as f:
print('Writing to {}'.format(archive_path))
f.write(r.content)
else:
print('Archive exists, proceeding: {}'.format(archive_path))
# TODO: Validate MD5 sum
root = os.path.join(_scratch_path(), 'cifar-10-batches-py')
if not os.path.isdir(root):
with tarfile.open(archive_path) as f:
print('Extracting')
f.extractall(_scratch_path())
return CIFAR10(root)
#
# Scratch helpers
#
def _ensure_scratch_path():
try:
os.makedirs(_scratch_path())
except OSError as exc:
if exc.errno != os.errno.EEXIST:
raise
def _scratch_path():
return os.path.join(_script_path(), 'scratch')
def _script_path():
return os.path.dirname(os.path.abspath(__file__))