diff --git a/min-char-rnn-nb.ipynb b/min-char-rnn-nb.ipynb new file mode 100644 index 0000000..eb3225a --- /dev/null +++ b/min-char-rnn-nb.ipynb @@ -0,0 +1,568 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import numpy as np" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "data has 3291648 characters, 93 unique.\n" + ] + } + ], + "source": [ + "# data I/O\n", + "data = open('warpeace.txt', 'r').read() # should be simple plain text file\n", + "chars = list(set(data))\n", + "data_size, vocab_size = len(data), len(chars)\n", + "print 'data has %d characters, %d unique.' % (data_size, vocab_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "char_to_ix = { ch:i for i,ch in enumerate(chars) }\n", + "ix_to_char = { i:ch for i,ch in enumerate(chars) }" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "86" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "char_to_ix['a']" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "at the twitching cheeks o\n" + ] + } + ], + "source": [ + "# lets sample a batch of data\n", + "seq_length = 25 # number of characters in the batch\n", + "p = 220000 # point in the book to sample from\n", + "print data[p:p+seq_length] # print a chunk of data" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[86, 22, 0, 22, 18, 87, 0, 22, 45, 88, 22, 40, 18, 88, 64, 41, 0, 40, 18, 87, 87, 42, 44, 0, 43]\n", + "[22, 0, 22, 18, 87, 0, 22, 45, 88, 22, 40, 18, 88, 64, 41, 0, 40, 18, 87, 87, 42, 44, 0, 43, 62]\n" + ] + } + ], + "source": [ + "inputs = [char_to_ix[ch] for ch in data[p:p+seq_length]]\n", + "targets = [char_to_ix[ch] for ch in data[p+1:p+seq_length+1]]\n", + "print inputs\n", + "print targets" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", + " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", + " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", + " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.\n", + " 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.\n", + " 0. 0. 0.]\n" + ] + } + ], + "source": [ + "# lets plug the first character into the RNN\n", + "ix_input = inputs[0]\n", + "ix_target = targets[0]\n", + "# encode the input character with a 1-hot representation\n", + "x = np.zeros((vocab_size,1))\n", + "x[ix_input] = 1\n", + "print x.ravel()" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# create random starting parameters\n", + "hidden_size = 10\n", + "Wxh = np.random.randn(hidden_size, vocab_size)*0.01 # input to hidden\n", + "Whh = np.random.randn(hidden_size, hidden_size)*0.01 # hidden to hidden\n", + "Why = np.random.randn(vocab_size, hidden_size)*0.01 # hidden to output\n", + "bh = np.zeros((hidden_size, 1)) # hidden bias\n", + "by = np.zeros((vocab_size, 1)) # output bias" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n", + " 0.01687631 -0.01756024 0.00207365 -0.00083882]\n" + ] + } + ], + "source": [ + "# compute the hidden state\n", + "h_prev = np.zeros((hidden_size, 1))\n", + "h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h_prev + bh))\n", + "print h.ravel()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 7.60384659e-04 4.52720278e-04 -6.57143068e-04 4.59711552e-04\n", + " 2.16313976e-04 -2.73453571e-04 6.45337766e-04 6.58329123e-06\n", + " 4.52167436e-04 -6.46317115e-04 4.02199638e-04 1.99170349e-04\n", + " 4.25310887e-04 3.50158091e-04 -6.04819130e-04 2.26877413e-04\n", + " -1.94723297e-04 5.50639249e-04 -1.71656881e-04 -2.99860880e-04\n", + " 6.07620521e-04 5.70905998e-06 4.26038734e-04 1.01454914e-04\n", + " -3.65278849e-04 5.90686240e-04 1.59204750e-04 4.08708316e-04\n", + " 2.56014525e-04 1.39087410e-06 -6.90950824e-04 7.18269467e-04\n", + " 2.61719276e-04 4.58528196e-04 3.86335255e-04 2.70955381e-04\n", + " -1.48929417e-04 -1.83764339e-04 -6.36673024e-04 4.00478015e-04\n", + " -8.69493168e-05 3.73431457e-04 9.14721284e-05 4.38316766e-05\n", + " -6.83760980e-05 1.52588617e-05 1.74420973e-04 -1.47075009e-04\n", + " -3.62079011e-04 5.05874063e-05 6.20919471e-04 7.19441792e-05\n", + " -1.88333333e-04 -2.98096664e-04 4.43592043e-04 -3.76377143e-04\n", + " 2.10335682e-04 -6.41359408e-04 -4.54105970e-05 -7.58974605e-05\n", + " 5.47988450e-05 -2.77672190e-04 2.07213272e-05 2.13088016e-04\n", + " -5.07786842e-04 1.27347437e-04 4.24719496e-04 -2.13192324e-04\n", + " -2.05694776e-04 -1.43861034e-04 4.50169861e-04 2.74675798e-04\n", + " 1.31264147e-04 3.29399476e-04 -1.41866715e-04 -7.91514010e-05\n", + " -2.96356884e-05 -3.85666773e-04 2.79322785e-04 6.82910193e-04\n", + " -2.05519751e-04 -4.25825715e-04 -8.47060105e-05 2.66896920e-04\n", + " 1.17914625e-04 -4.00869517e-04 2.30587114e-04 1.24662656e-05\n", + " 2.43743504e-04 4.57417182e-04 -2.78926664e-05 1.23560813e-05\n", + " -3.55199314e-04]\n" + ] + } + ], + "source": [ + "# compute the scores for next character\n", + "y = np.dot(Why, h) + by\n", + "print y.ravel()" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n", + " 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n", + " 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n", + " 0.01075015 0.01074877 0.01075853 0.01075206 0.01075658 0.01075309\n", + " 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n", + " 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n", + " 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n", + " 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n", + " 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n", + " 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n", + " 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n", + " 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n", + " 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n", + " 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n", + " 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n", + " 0.0107517 0.01075213 0.01074818]\n", + "probabilities sum to 1.0\n" + ] + } + ], + "source": [ + "# the scores are unnormalized log probabilities. compute the probabilities\n", + "p = np.exp(y) / np.sum(np.exp(y))\n", + "print p.ravel()\n", + "print 'probabilities sum to ', p.sum()" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "probability assigned to the correct next character is right now: 0.0107565780746\n" + ] + } + ], + "source": [ + "print 'probability assigned to the correct next character is right now: ', p[ix_target,0]" + ] + }, + { + "cell_type": "code", + "execution_count": 72, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the cross-entropy (softmax) loss is 4.53223779763\n" + ] + } + ], + "source": [ + "loss = -np.log(p[ix_target,0])\n", + "print 'the cross-entropy (softmax) loss is ', loss" + ] + }, + { + "cell_type": "code", + "execution_count": 74, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n", + " 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n", + " 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n", + " 0.01075015 0.01074877 0.01075853 0.01075206 -0.98924342 0.01075309\n", + " 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n", + " 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n", + " 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n", + " 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n", + " 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n", + " 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n", + " 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n", + " 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n", + " 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n", + " 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n", + " 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n", + " 0.0107517 0.01075213 0.01074818]\n", + "sum of dy is 3.26128013484e-16\n", + "the gradient for the correct character (t) is: -0.989243421925\n", + "the gradient for the character (a) is: 0.0107544758895\n" + ] + } + ], + "source": [ + "# compute the gradient on y\n", + "dy = np.copy(p)\n", + "dy[ix_target] -= 1\n", + "print dy.ravel()\n", + "print 'sum of dy is ', dy.sum()\n", + "print 'the gradient for the correct character (%s) is: %s' % (ix_to_char[ix_target], dy[ix_target,0])\n", + "print 'the gradient for the character (a) is: ', dy[char_to_ix['a'],0]" + ] + }, + { + "cell_type": "code", + "execution_count": 88, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the hidden vector activations were:\n", + "[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n", + " 0.01687631 -0.01756024 0.00207365 -0.00083882]\n", + "the gradients are:\n", + "[ 0.01411266 -0.00086765 -0.00501521 0.00024313 0.00070486 0.02431426\n", + " -0.00599582 0.00532303 -0.01018747 -0.00155986]\n", + "the gradients dWhy have size: (93, 10)\n", + "a small sample is:\n", + "[[ 2.12128361e-04 -5.73708201e-05 1.53852052e-05 -1.06090457e-04]\n", + " [ 2.12063107e-04 -5.73531718e-05 1.53804724e-05 -1.06057822e-04]\n", + " [ 2.11827876e-04 -5.72895529e-05 1.53634117e-05 -1.05940178e-04]\n", + " [ 2.12064589e-04 -5.73535728e-05 1.53805799e-05 -1.06058563e-04]]\n" + ] + } + ], + "source": [ + "# we computed [y = np.dot(Why, h) + by]; Backpropagate to Why, h, and by\n", + "dWhy = np.dot(dy, h.T)\n", + "dh = np.dot(Why.T, dy)\n", + "dby = np.copy(dy)\n", + "print 'the hidden vector activations were:'\n", + "print h.ravel()\n", + "print 'the gradients are:'\n", + "print dh.ravel()\n", + "print 'the gradients dWhy have size: ', dWhy.shape\n", + "print 'a small sample is:'\n", + "print dWhy[:4,:4]" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "small sample of Whh:\n", + "[[ 0.00013373 -0.01329709 0.01799412 0.01353991]\n", + " [-0.00427217 -0.01052258 0.00820703 0.00452433]\n", + " [-0.00511801 -0.0094128 0.00029513 0.0058532 ]\n", + " [-0.00132356 0.00330526 0.00737467 -0.00426905]]\n" + ] + } + ], + "source": [ + "# we computed [h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h_prev + bh))]; \n", + "# Backprop into Wxh, x, Whh, h_prev, bh:\n", + "dh_before_tanh = (1-h**2)*dh\n", + "dbh = np.copy(dh_before_tanh)\n", + "dWxh = np.dot(dh_before_tanh, x.T)\n", + "dWhh = np.dot(dh_before_tanh, h.T)\n", + "dh_prev = np.dot(Whh.T, dh_before_tanh)\n", + "print 'small sample of Whh:'\n", + "print Whh[:4,:4]" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# we now have the gradients for all parameters! (Wxh, Whh, Why, bh, by)\n", + "# lets do a parameter update\n", + "learning_rate = 0.1\n", + "Wxh2 = Wxh - learning_rate * dWxh\n", + "Whh2 = Whh - learning_rate * dWhh\n", + "Why2 = Why - learning_rate * dWhy\n", + "bh2 = bh - learning_rate * dbh\n", + "by2 = by - learning_rate * dby" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "probability assigned to the correct next character was: 0.0107565780746\n", + "probability assigned to the correct next character is now: 0.0118772830163\n", + "the cross-entropy (softmax) loss was 4.53223779763\n", + "the loss is now 4.43312769353\n" + ] + } + ], + "source": [ + "# these parameters should be much better! lets try it out:\n", + "h2 = np.tanh(np.dot(Wxh2, x) + np.dot(Whh2, h_prev + bh2))\n", + "y2 = np.dot(Why2, h2) + by2\n", + "p2 = np.exp(y2) / np.sum(np.exp(y2))\n", + "print 'probability assigned to the correct next character was: ', p[ix_target,0]\n", + "print 'probability assigned to the correct next character is now: ', p2[ix_target,0]\n", + "loss2 = -np.log(p2[ix_target,0])\n", + "print 'the cross-entropy (softmax) loss was ', loss\n", + "print 'the loss is now ', loss2" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# note: the probability for the correct character went up! (and the loss went down)" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# putting it together\n", + "def lossFun(inputs, targets, hprev):\n", + " \"\"\"\n", + " inputs,targets are both list of integers.\n", + " hprev is Hx1 array of initial hidden state\n", + " returns the loss, gradients on model parameters, and last hidden state\n", + " \"\"\"\n", + " xs, hs, ys, ps = {}, {}, {}, {}\n", + " hs[-1] = np.copy(hprev)\n", + " loss = 0\n", + " # forward pass\n", + " for t in xrange(len(inputs)):\n", + " xs[t] = np.zeros((vocab_size,1)) # encode in 1-of-k representation\n", + " xs[t][inputs[t]] = 1\n", + " hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t-1]) + bh) # hidden state\n", + " ys[t] = np.dot(Why, hs[t]) + by # unnormalized log probabilities for next chars\n", + " ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t])) # probabilities for next chars\n", + " loss += -np.log(ps[t][targets[t],0]) # softmax (cross-entropy loss)\n", + " \n", + " # backward pass: compute gradients going backwards\n", + " dWxh, dWhh, dWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)\n", + " dbh, dby = np.zeros_like(bh), np.zeros_like(by)\n", + " dhnext = np.zeros_like(hs[0])\n", + " for t in reversed(xrange(len(inputs))):\n", + " pass # TODO\n", + " \n", + " return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "metadata": { + "collapsed": false + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "113.313778376\n" + ] + } + ], + "source": [ + "loss, dWxh, dWhh, dWhy, dbh, dby, hnew = lossFun(inputs, targets, h_prev)\n", + "print loss" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +}