Add NN IPython notebook from CS231n at Stanford
This commit is contained in:
parent
021b8c556e
commit
a24e341a1c
1 changed files with 154 additions and 0 deletions
154
CS231n-1 Nearest Neighbor.ipynb
Normal file
154
CS231n-1 Nearest Neighbor.ipynb
Normal file
|
@ -0,0 +1,154 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# http://cs231n.github.io/classification/#nn"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import datasets"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Archive exists, proceeding: /Users/eryn/Code/learnmeamachine/scratch/cifar10.tar.gz\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"cifar10 = datasets.cifar10()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create two arrays: one containing images, one containing labels.\n",
|
||||
"# Partition those arrays into a training set, *tr, and a validation set, *val\n",
|
||||
"\n",
|
||||
"# Training set is the first 4 batches. (40k rows)\n",
|
||||
"Xtr = np.concatenate([b[b'data'] for b in cifar10.all_data_batches])\n",
|
||||
"Ytr = np.concatenate([b[b'labels'] for b in cifar10.all_data_batches])\n",
|
||||
"\n",
|
||||
"# Validation set is the last 1000 rows.\n",
|
||||
"Xte = cifar10.test_batch[b'data']\n",
|
||||
"Yte = cifar10.test_batch[b'labels']\n",
|
||||
"\n",
|
||||
"# Reshape the images into 1D arrays\n",
|
||||
"Xtr = Xtr.reshape(Xtr.shape[0], 32*32*3)\n",
|
||||
"Xte = Xval.reshape(Xval.shape[0], 32*32*3)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Create a NearestNeighbor classifier. Copied from the lecture material above.\n",
|
||||
"\n",
|
||||
"class NearestNeighbor(object):\n",
|
||||
" def __init__(self):\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
" def train(self, X, y):\n",
|
||||
" \"\"\" X is N x D where each row is an example. Y is 1-dimension of size N \"\"\"\n",
|
||||
" # the nearest neighbor classifier simply remembers all the training data\n",
|
||||
" self.Xtr = X\n",
|
||||
" self.ytr = y\n",
|
||||
"\n",
|
||||
" def predict(self, X):\n",
|
||||
" \"\"\" X is N x D where each row is an example we wish to predict label for \"\"\"\n",
|
||||
" num_test = X.shape[0]\n",
|
||||
" # lets make sure that the output type matches the input type\n",
|
||||
" Ypred = np.zeros(num_test, dtype = self.ytr.dtype)\n",
|
||||
"\n",
|
||||
" # loop over all test rows\n",
|
||||
" for i in range(num_test):\n",
|
||||
" # find the nearest training image to the i'th test image\n",
|
||||
" # using the L1 distance (sum of absolute value differences)\n",
|
||||
" distances = np.sum(np.abs(self.Xtr - X[i,:]), axis = 1)\n",
|
||||
" min_index = np.argmin(distances) # get the index with smallest distance\n",
|
||||
" Ypred[i] = self.ytr[min_index] # predict the label of the nearest example\n",
|
||||
"\n",
|
||||
" return Ypred"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Train the NN classifier on the data.\n",
|
||||
"nn = NearestNeighbor()\n",
|
||||
"nn.train(Xtr, Ytr)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"accuracy: 0.2492\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"Yte_predict = nn.predict(Xte)\n",
|
||||
"print('accuracy: {}'.format(np.mean(Yte_predict == Yte)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue