{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Naive Bayes Classification"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-01-24T06:22:13.130775Z",
"start_time": "2019-01-24T06:22:11.553295Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "1"
}
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"from matplotlib import pyplot as plt\n",
"from IPython import display\n",
"display.set_matplotlib_formats('svg')\n",
"import mxnet as mx\n",
"from mxnet import nd\n",
"import numpy as np\n",
"\n",
"# we go over one observation at a time (speed doesn't matter here)\n",
"def transform(data, label):\n",
" return (nd.floor(data/128)).astype(np.float32), label.astype(np.float32)\n",
"mnist_train = mx.gluon.data.vision.MNIST(train=True, transform=transform)\n",
"mnist_test = mx.gluon.data.vision.MNIST(train=False, transform=transform)"
]
},
{
"cell_type": "markdown",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"## Naive Bayes Classification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-01-24T06:22:56.094594Z",
"start_time": "2019-01-24T06:22:13.133949Z"
}
},
"outputs": [],
"source": [
"# initialize the counters\n",
"xcount = nd.ones((784,10))\n",
"ycount = nd.ones((10))\n",
"\n",
"for data, label in mnist_train:\n",
" y = int(label)\n",
" ycount[y] += 1\n",
" xcount[:,y] += data.reshape((784))\n",
"\n",
"# using broadcast again for division\n",
"py = ycount / ycount.sum()\n",
"px = (xcount / ycount.reshape(1,10))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-01-24T06:22:56.444530Z",
"start_time": "2019-01-24T06:22:56.096568Z"
},
"attributes": {
"classes": [],
"id": "",
"n": "2"
},
"slideshow": {
"slide_type": "slide"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
"