From a24e341a1c421d0b3ae3b70b594e111dc87d0c1d Mon Sep 17 00:00:00 2001 From: Eryn Wells Date: Sat, 27 Jan 2018 07:56:52 -0800 Subject: [PATCH] Add NN IPython notebook from CS231n at Stanford --- CS231n-1 Nearest Neighbor.ipynb | 154 ++++++++++++++++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 CS231n-1 Nearest Neighbor.ipynb diff --git a/CS231n-1 Nearest Neighbor.ipynb b/CS231n-1 Nearest Neighbor.ipynb new file mode 100644 index 0000000..62da387 --- /dev/null +++ b/CS231n-1 Nearest Neighbor.ipynb @@ -0,0 +1,154 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# http://cs231n.github.io/classification/#nn" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Archive exists, proceeding: /Users/eryn/Code/learnmeamachine/scratch/cifar10.tar.gz\n" + ] + } + ], + "source": [ + "cifar10 = datasets.cifar10()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# Create two arrays: one containing images, one containing labels.\n", + "# Partition those arrays into a training set, *tr, and a validation set, *val\n", + "\n", + "# Training set is the first 4 batches. (40k rows)\n", + "Xtr = np.concatenate([b[b'data'] for b in cifar10.all_data_batches])\n", + "Ytr = np.concatenate([b[b'labels'] for b in cifar10.all_data_batches])\n", + "\n", + "# Validation set is the last 1000 rows.\n", + "Xte = cifar10.test_batch[b'data']\n", + "Yte = cifar10.test_batch[b'labels']\n", + "\n", + "# Reshape the images into 1D arrays\n", + "Xtr = Xtr.reshape(Xtr.shape[0], 32*32*3)\n", + "Xte = Xval.reshape(Xval.shape[0], 32*32*3)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a NearestNeighbor classifier. Copied from the lecture material above.\n", + "\n", + "class NearestNeighbor(object):\n", + " def __init__(self):\n", + " pass\n", + "\n", + " def train(self, X, y):\n", + " \"\"\" X is N x D where each row is an example. Y is 1-dimension of size N \"\"\"\n", + " # the nearest neighbor classifier simply remembers all the training data\n", + " self.Xtr = X\n", + " self.ytr = y\n", + "\n", + " def predict(self, X):\n", + " \"\"\" X is N x D where each row is an example we wish to predict label for \"\"\"\n", + " num_test = X.shape[0]\n", + " # lets make sure that the output type matches the input type\n", + " Ypred = np.zeros(num_test, dtype = self.ytr.dtype)\n", + "\n", + " # loop over all test rows\n", + " for i in range(num_test):\n", + " # find the nearest training image to the i'th test image\n", + " # using the L1 distance (sum of absolute value differences)\n", + " distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)\n", + " min_index = np.argmin(distances) # get the index with smallest distance\n", + " Ypred[i] = self.ytr[min_index] # predict the label of the nearest example\n", + "\n", + " return Ypred" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# Train the NN classifier on the data.\n", + "nn = NearestNeighbor()\n", + "nn.train(Xtr, Ytr)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "accuracy: 0.2492\n" + ] + } + ], + "source": [ + "Yte_predict = nn.predict(Xte)\n", + "print('accuracy: {}'.format(np.mean(Yte_predict == Yte)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}