Split out CIFAR10 data batches and test batch
This commit is contained in:
		
							parent
							
								
									9a4fccb5a6
								
							
						
					
					
						commit
						021b8c556e
					
				
					 1 changed files with 23 additions and 12 deletions
				
			
		
							
								
								
									
										35
									
								
								datasets.py
									
										
									
									
									
								
							
							
						
						
									
										35
									
								
								datasets.py
									
										
									
									
									
								
							|  | @ -24,27 +24,38 @@ class CIFAR10: | |||
|         self.path = path | ||||
| 
 | ||||
|     @property | ||||
|     def batch1(self): | ||||
|         return self._do_batch(1) | ||||
|     def data_batch1(self): | ||||
|         return self._do_data_batch(1) | ||||
| 
 | ||||
|     @property | ||||
|     def batch2(self): | ||||
|         return self._do_batch(2) | ||||
|     def data_batch2(self): | ||||
|         return self._do_data_batch(2) | ||||
| 
 | ||||
|     @property | ||||
|     def batch3(self): | ||||
|         return self._do_batch(3) | ||||
|     def data_batch3(self): | ||||
|         return self._do_data_batch(3) | ||||
| 
 | ||||
|     @property | ||||
|     def batch4(self): | ||||
|         return self._do_batch(4) | ||||
|     def data_batch4(self): | ||||
|         return self._do_data_batch(4) | ||||
| 
 | ||||
|     @property | ||||
|     def batch5(self): | ||||
|         return self._do_batch(5) | ||||
|     def data_batch5(self): | ||||
|         return self._do_data_batch(5) | ||||
| 
 | ||||
|     def _do_batch(self, idx): | ||||
|         attr = '__batch{}'.format(idx) | ||||
|     @property | ||||
|     def test_batch(self): | ||||
|         if not getattr(self, '__test_batch', None): | ||||
|             path = os.path.join(self.path, 'test_batch') | ||||
|             self.__test_batch = self._unpickle(path) | ||||
|         return self.__test_batch | ||||
| 
 | ||||
|     @property | ||||
|     def all_data_batches(self): | ||||
|         return [getattr(self, 'data_batch{}'.format(i)) for i in range(1, 6)] | ||||
| 
 | ||||
|     def _do_data_batch(self, idx): | ||||
|         attr = '__data_batch{}'.format(idx) | ||||
|         if not getattr(self, attr, None): | ||||
|             path = os.path.join(self.path, 'data_batch_{}'.format(idx)) | ||||
|             setattr(self, attr, self._unpickle(path)) | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue