Shortcuts

Source code for nets.datasets.mnist

r"""
Defines and pre-process the MNIST dataset. The data will be converted into ``Tensor`` objects.
"""

import gzip
import os
import numpy as np
import nets
from nets.data.dataset import Dataset


[docs]class MNIST(Dataset): """ Loads training, validation, and test partitions of the mnist dataset (http://yann.lecun.com/exdb/mnist/). If the data is not already contained in data_dir, it will try to download it. This dataset contains 60000 training examples, and 10000 test examples of handwritten digits in {0, ..., 9} and corresponding labels. Each handwritten image has an "original" dimension of 28x28x1, and is stored row-wise as a string of 784x1 bytes. Pixel values are in range 0 to 255 (inclusive). Args: data_dir: String. Relative or absolute path of the dataset. devel_size: Integer. Size of the development (validation) dataset partition. Returns: X_train: float64 numpy array with shape [784, 60000-devel_size] with values in [0, 1]. Y_train: uint8 numpy array with shape [60000-devel_size]. Labels. X_devel: float64 numpy array with shape [784, devel_size] with values in [0, 1]. Y_devel: uint8 numpy array with shape [devel_size]. Labels. X_test: float64 numpy array with shape [784, 10000] with values in [0, 1]. Y_test: uint8 numpy array with shape [10000]. Labels. """ urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz', 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',] name = 'mnist-data-py' dirname = 'mnist' def __init__(self, path_data, path_label, transform=None): data = self._load_mnist(path_data, header_size=16).reshape((-1, 28, 28)) if transform is not None: data = transform(data) self.data = nets.Tensor(data) self.labels = nets.Tensor(self._load_mnist(path_label, header_size=8)) def _load_mnist(self, path, header_size): with gzip.open(path, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=header_size) return np.asarray(data, dtype=np.uint8)
[docs] @classmethod def splits(cls, root='.data', train_data='train-images-idx3-ubyte.gz', train_label='train-labels-idx1-ubyte.gz', test_data='t10k-images-idx3-ubyte.gz', test_label='t10k-labels-idx1-ubyte.gz', **kwargs): r""" Loads training and test partitions of the [mnist dataset](https://www.cs.toronto.edu/~kriz/cifar.html). If the data is not already contained in the ``root`` folder, it will download it. Args: root (str): relative or absolute path of the dataset. Returns: tuple(Dataset): training and testing datasets """ path = os.path.join(root, cls.dirname, cls.name) if not os.path.isdir(path): path = cls.download(root) train_data = os.path.join(path, train_data) train_label = os.path.join(path, train_label) test_data = os.path.join(path, test_data) test_label = os.path.join(path, test_label) return MNIST(train_data, train_label, **kwargs), MNIST(test_data, test_label, **kwargs)
def __getitem__(self, item): return self.data[item], self.labels[item] def __setitem__(self, key, value): self.data[key], self.labels[key] = value def __len__(self): return len(self.data)

Docs

Access comprehensive developer documentation for Nets

View Docs

Tutorials

Get beginners tutorials and create state-of-the-art models

View Tutorials

Resources

Check the GitHub page and contribute to the project

View GitHub