diff --git a/min-char-rnn-nb.ipynb b/min-char-rnn-nb.ipynb index eb3225a..b168f13 100644 --- a/min-char-rnn-nb.ipynb +++ b/min-char-rnn-nb.ipynb @@ -8,12 +8,13 @@ }, "outputs": [], "source": [ - "import numpy as np" + "import numpy as np\n", + "np.random.seed(1337)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -36,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "collapsed": true }, @@ -48,7 +49,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 4, "metadata": { "collapsed": false }, @@ -59,7 +60,7 @@ "86" ] }, - "execution_count": 10, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -70,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 5, "metadata": { "collapsed": false }, @@ -92,7 +93,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 6, "metadata": { "collapsed": false }, @@ -115,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 7, "metadata": { "collapsed": false }, @@ -145,7 +146,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 8, "metadata": { "collapsed": false }, @@ -162,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 9, "metadata": { "collapsed": false }, @@ -171,8 +172,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n", - " 0.01687631 -0.01756024 0.00207365 -0.00083882]\n" + "[ -2.19597294e-03 -1.11958403e-02 -7.96573514e-03 -8.29025315e-03\n", + " -1.69046852e-02 7.16480643e-03 -1.08628224e-05 1.38136549e-03\n", + " 8.91231078e-03 1.22455580e-02]\n" ] } ], @@ -185,7 +187,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 10, "metadata": { "collapsed": false }, @@ -194,30 +196,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "[ 7.60384659e-04 4.52720278e-04 -6.57143068e-04 4.59711552e-04\n", - " 2.16313976e-04 -2.73453571e-04 6.45337766e-04 6.58329123e-06\n", - " 4.52167436e-04 -6.46317115e-04 4.02199638e-04 1.99170349e-04\n", - " 4.25310887e-04 3.50158091e-04 -6.04819130e-04 2.26877413e-04\n", - " -1.94723297e-04 5.50639249e-04 -1.71656881e-04 -2.99860880e-04\n", - " 6.07620521e-04 5.70905998e-06 4.26038734e-04 1.01454914e-04\n", - " -3.65278849e-04 5.90686240e-04 1.59204750e-04 4.08708316e-04\n", - " 2.56014525e-04 1.39087410e-06 -6.90950824e-04 7.18269467e-04\n", - " 2.61719276e-04 4.58528196e-04 3.86335255e-04 2.70955381e-04\n", - " -1.48929417e-04 -1.83764339e-04 -6.36673024e-04 4.00478015e-04\n", - " -8.69493168e-05 3.73431457e-04 9.14721284e-05 4.38316766e-05\n", - " -6.83760980e-05 1.52588617e-05 1.74420973e-04 -1.47075009e-04\n", - " -3.62079011e-04 5.05874063e-05 6.20919471e-04 7.19441792e-05\n", - " -1.88333333e-04 -2.98096664e-04 4.43592043e-04 -3.76377143e-04\n", - " 2.10335682e-04 -6.41359408e-04 -4.54105970e-05 -7.58974605e-05\n", - " 5.47988450e-05 -2.77672190e-04 2.07213272e-05 2.13088016e-04\n", - " -5.07786842e-04 1.27347437e-04 4.24719496e-04 -2.13192324e-04\n", - " -2.05694776e-04 -1.43861034e-04 4.50169861e-04 2.74675798e-04\n", - " 1.31264147e-04 3.29399476e-04 -1.41866715e-04 -7.91514010e-05\n", - " -2.96356884e-05 -3.85666773e-04 2.79322785e-04 6.82910193e-04\n", - " -2.05519751e-04 -4.25825715e-04 -8.47060105e-05 2.66896920e-04\n", - " 1.17914625e-04 -4.00869517e-04 2.30587114e-04 1.24662656e-05\n", - " 2.43743504e-04 4.57417182e-04 -2.78926664e-05 1.23560813e-05\n", - " -3.55199314e-04]\n" + "[ -1.86941995e-04 5.75478992e-05 5.45725372e-05 -4.12423671e-04\n", + " 3.06201649e-04 -5.06931560e-04 3.19806236e-04 -3.51163814e-05\n", + " -3.38972137e-04 -2.39857442e-04 -8.01432457e-05 -1.04391992e-04\n", + " -3.15399258e-04 1.48454913e-05 4.55998798e-05 1.28931375e-04\n", + " 4.04072882e-04 -5.75504844e-04 1.57533734e-04 -4.08258042e-04\n", + " -7.41755605e-05 -7.76690551e-05 1.87837950e-04 -2.79910204e-04\n", + " -7.96190232e-04 -5.79876645e-05 1.26730230e-04 3.95894081e-04\n", + " 2.76955496e-04 2.59637379e-05 -4.61015050e-04 -5.14636280e-04\n", + " -2.56714611e-04 4.94377196e-04 -3.64149466e-04 -4.26364575e-04\n", + " -1.25515821e-04 9.29132184e-06 4.88944584e-05 -6.31681837e-04\n", + " -1.75121523e-05 9.51597096e-06 3.08227238e-04 1.29694758e-04\n", + " 1.72591168e-04 -5.87989566e-05 2.74250548e-04 5.90320270e-05\n", + " 1.64825617e-04 3.11431559e-04 3.48582321e-04 -4.19733486e-05\n", + " -2.77062543e-04 -1.10111861e-05 -1.33642742e-05 -5.99994639e-05\n", + " 6.07084027e-04 3.51784289e-04 -3.89265870e-05 -1.03640935e-05\n", + " 2.94412266e-04 7.75998698e-05 -2.85416619e-04 4.26053337e-04\n", + " -3.47809927e-04 -4.67567746e-04 -6.81980155e-05 -1.09923565e-06\n", + " 6.26504669e-05 -2.76671607e-04 2.96818053e-04 1.87011604e-05\n", + " 2.98832018e-04 -2.65531746e-04 -5.58804832e-04 -3.07801824e-04\n", + " 3.75513314e-04 4.76734817e-04 -2.96633044e-04 -3.06249150e-04\n", + " 1.22093579e-05 -1.43195488e-04 -1.58954120e-04 -1.08600298e-04\n", + " 1.52122727e-04 -6.54969615e-04 1.60857638e-04 -9.11628405e-04\n", + " 3.67740402e-05 2.16278686e-04 -1.98860111e-04 4.08539302e-05\n", + " -1.47124899e-04]\n" ] } ], @@ -229,7 +231,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 11, "metadata": { "collapsed": false }, @@ -238,22 +240,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n", - " 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n", - " 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n", - " 0.01075015 0.01074877 0.01075853 0.01075206 0.01075658 0.01075309\n", - " 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n", - " 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n", - " 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n", - " 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n", - " 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n", - " 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n", - " 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n", - " 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n", - " 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n", - " 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n", - " 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n", - " 0.0107517 0.01075213 0.01074818]\n", + "[ 0.01075121 0.01075383 0.0107538 0.01074878 0.01075651 0.01074777\n", + " 0.01075666 0.01075284 0.01074957 0.01075064 0.01075235 0.01075209\n", + " 0.01074982 0.01075338 0.01075371 0.0107546 0.01075756 0.01074703\n", + " 0.01075491 0.01074883 0.01075242 0.01075238 0.01075524 0.01075021\n", + " 0.01074466 0.01075259 0.01075458 0.01075747 0.01075619 0.01075349\n", + " 0.01074826 0.01074768 0.01075046 0.01075853 0.0107493 0.01074863\n", + " 0.01075187 0.01075332 0.01075374 0.01074643 0.01075303 0.01075332\n", + " 0.01075653 0.01075461 0.01075507 0.01075258 0.01075617 0.01075385\n", + " 0.01075499 0.01075656 0.01075696 0.01075276 0.01075024 0.0107531\n", + " 0.01075307 0.01075257 0.01075975 0.010757 0.0107528 0.0107531\n", + " 0.01075638 0.01075405 0.01075015 0.0107578 0.01074948 0.01074819\n", + " 0.01075248 0.0107532 0.01075389 0.01075024 0.01075641 0.01075342\n", + " 0.01075643 0.01075036 0.01074721 0.01074991 0.01075725 0.01075834\n", + " 0.01075003 0.01074992 0.01075335 0.01075168 0.01075151 0.01075205\n", + " 0.01075485 0.01074617 0.01075495 0.01074342 0.01075361 0.01075554\n", + " 0.01075108 0.01075365 0.01075163]\n", "probabilities sum to 1.0\n" ] } @@ -267,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": 12, "metadata": { "collapsed": false }, @@ -276,7 +278,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "probability assigned to the correct next character is right now: 0.0107565780746\n" + "probability assigned to the correct next character is right now: 0.0107552356218\n" ] } ], @@ -286,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 13, "metadata": { "collapsed": false }, @@ -295,7 +297,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "the cross-entropy (softmax) loss is 4.53223779763\n" + "the cross-entropy (softmax) loss is 4.53236260838\n" ] } ], @@ -306,7 +308,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 14, "metadata": { "collapsed": false }, @@ -315,25 +317,25 @@ "name": "stdout", "output_type": "stream", "text": [ - "[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n", - " 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n", - " 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n", - " 0.01075015 0.01074877 0.01075853 0.01075206 -0.98924342 0.01075309\n", - " 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n", - " 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n", - " 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n", - " 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n", - " 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n", - " 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n", - " 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n", - " 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n", - " 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n", - " 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n", - " 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n", - " 0.0107517 0.01075213 0.01074818]\n", - "sum of dy is 3.26128013484e-16\n", - "the gradient for the correct character (t) is: -0.989243421925\n", - "the gradient for the character (a) is: 0.0107544758895\n" + "[ 0.01075121 0.01075383 0.0107538 0.01074878 0.01075651 0.01074777\n", + " 0.01075666 0.01075284 0.01074957 0.01075064 0.01075235 0.01075209\n", + " 0.01074982 0.01075338 0.01075371 0.0107546 0.01075756 0.01074703\n", + " 0.01075491 0.01074883 0.01075242 0.01075238 -0.98924476 0.01075021\n", + " 0.01074466 0.01075259 0.01075458 0.01075747 0.01075619 0.01075349\n", + " 0.01074826 0.01074768 0.01075046 0.01075853 0.0107493 0.01074863\n", + " 0.01075187 0.01075332 0.01075374 0.01074643 0.01075303 0.01075332\n", + " 0.01075653 0.01075461 0.01075507 0.01075258 0.01075617 0.01075385\n", + " 0.01075499 0.01075656 0.01075696 0.01075276 0.01075024 0.0107531\n", + " 0.01075307 0.01075257 0.01075975 0.010757 0.0107528 0.0107531\n", + " 0.01075638 0.01075405 0.01075015 0.0107578 0.01074948 0.01074819\n", + " 0.01075248 0.0107532 0.01075389 0.01075024 0.01075641 0.01075342\n", + " 0.01075643 0.01075036 0.01074721 0.01074991 0.01075725 0.01075834\n", + " 0.01075003 0.01074992 0.01075335 0.01075168 0.01075151 0.01075205\n", + " 0.01075485 0.01074617 0.01075495 0.01074342 0.01075361 0.01075554\n", + " 0.01075108 0.01075365 0.01075163]\n", + "sum of dy is -4.23272528138e-16\n", + "the gradient for the correct character (t) is: -0.989244764378\n", + "the gradient for the character (a) is: 0.0107549454461\n" ] } ], @@ -349,7 +351,7 @@ }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 15, "metadata": { "collapsed": false }, @@ -359,17 +361,18 @@ "output_type": "stream", "text": [ "the hidden vector activations were:\n", - "[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n", - " 0.01687631 -0.01756024 0.00207365 -0.00083882]\n", + "[ -2.19597294e-03 -1.11958403e-02 -7.96573514e-03 -8.29025315e-03\n", + " -1.69046852e-02 7.16480643e-03 -1.08628224e-05 1.38136549e-03\n", + " 8.91231078e-03 1.22455580e-02]\n", "the gradients are:\n", - "[ 0.01411266 -0.00086765 -0.00501521 0.00024313 0.00070486 0.02431426\n", - " -0.00599582 0.00532303 -0.01018747 -0.00155986]\n", + "[ 0.010636 -0.00556332 0.01572234 -0.01013355 0.01024101 0.00893933\n", + " 0.01156068 -0.0035879 -0.00304811 -0.00761244]\n", "the gradients dWhy have size: (93, 10)\n", "a small sample is:\n", - "[[ 2.12128361e-04 -5.73708201e-05 1.53852052e-05 -1.06090457e-04]\n", - " [ 2.12063107e-04 -5.73531718e-05 1.53804724e-05 -1.06057822e-04]\n", - " [ 2.11827876e-04 -5.72895529e-05 1.53634117e-05 -1.05940178e-04]\n", - " [ 2.12064589e-04 -5.73535728e-05 1.53805799e-05 -1.06058563e-04]]\n" + "[[ -2.36093564e-05 -1.20368781e-04 -8.56412557e-05 -8.91302155e-05]\n", + " [ -2.36151293e-05 -1.20398213e-04 -8.56621967e-05 -8.91520096e-05]\n", + " [ -2.36150591e-05 -1.20397855e-04 -8.56619418e-05 -8.91517444e-05]\n", + " [ -2.36040335e-05 -1.20341643e-04 -8.56219473e-05 -8.91101206e-05]]\n" ] } ], @@ -389,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 16, "metadata": { "collapsed": false }, @@ -399,10 +402,10 @@ "output_type": "stream", "text": [ "small sample of Whh:\n", - "[[ 0.00013373 -0.01329709 0.01799412 0.01353991]\n", - " [-0.00427217 -0.01052258 0.00820703 0.00452433]\n", - " [-0.00511801 -0.0094128 0.00029513 0.0058532 ]\n", - " [-0.00132356 0.00330526 0.00737467 -0.00426905]]\n" + "[[-0.0181644 -0.00730641 0.00068086 0.00043998]\n", + " [-0.00706002 -0.00657655 0.0100175 0.00198405]\n", + " [ 0.0175388 0.01643369 -0.00161146 -0.01413518]\n", + " [ 0.00812143 -0.00105367 0.00944295 0.01148334]]\n" ] } ], @@ -420,7 +423,7 @@ }, { "cell_type": "code", - "execution_count": 95, + "execution_count": 17, "metadata": { "collapsed": true }, @@ -438,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 18, "metadata": { "collapsed": false }, @@ -447,10 +450,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "probability assigned to the correct next character was: 0.0107565780746\n", - "probability assigned to the correct next character is now: 0.0118772830163\n", - "the cross-entropy (softmax) loss was 4.53223779763\n", - "the loss is now 4.43312769353\n" + "probability assigned to the correct next character was: 0.0107552356218\n", + "probability assigned to the correct next character is now: 0.0118750046895\n", + "the cross-entropy (softmax) loss was 4.53236260838\n", + "the loss is now 4.43331953415\n" ] } ], @@ -468,7 +471,7 @@ }, { "cell_type": "code", - "execution_count": 98, + "execution_count": 19, "metadata": { "collapsed": true }, @@ -479,13 +482,13 @@ }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ - "# putting it together\n", + "# putting it together with loops\n", "def lossFun(inputs, targets, hprev):\n", " \"\"\"\n", " inputs,targets are both list of integers.\n", @@ -495,6 +498,7 @@ " xs, hs, ys, ps = {}, {}, {}, {}\n", " hs[-1] = np.copy(hprev)\n", " loss = 0\n", + " \n", " # forward pass\n", " for t in xrange(len(inputs)):\n", " xs[t] = np.zeros((vocab_size,1)) # encode in 1-of-k representation\n", @@ -509,14 +513,27 @@ " dbh, dby = np.zeros_like(bh), np.zeros_like(by)\n", " dhnext = np.zeros_like(hs[0])\n", " for t in reversed(xrange(len(inputs))):\n", - " pass # TODO\n", + " dy = np.copy(ps[t])\n", + " dy[targets[t]] -= 1 # backprop into y\n", + " dWhy += np.dot(dy, hs[t].T)\n", + " dby += dy\n", + " dh = np.dot(Why.T, dy) + dhnext # backprop into h\n", + " dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity\n", + " dbh += dhraw\n", + " dWxh += np.dot(dhraw, xs[t].T)\n", + " dWhh += np.dot(dhraw, hs[t-1].T)\n", + " dhnext = np.dot(Whh.T, dhraw)\n", + " \n", + " # clip to mitigate exploding gradients\n", + " for dparam in [dWxh, dWhh, dWhy, dbh, dby]:\n", + " np.clip(dparam, -5, 5, out=dparam)\n", " \n", " return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]" ] }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 21, "metadata": { "collapsed": false }, @@ -525,7 +542,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "113.313778376\n" + "113.314839864\n" ] } ], @@ -536,12 +553,41 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [ + "# TODO: write the sampling code\n", + "def sample(h, seed_ix, n):\n", + " \"\"\" \n", + " sample a sequence of integers from the model \n", + " h is initial memory state, seed_ix is seed letter for first time step\n", + " n is the number of time steps to sample for\n", + " \"\"\"\n", + " x = np.zeros((vocab_size, 1))\n", + " x[seed_ix] = 1\n", + " ixes = [] # sampled indices\n", + " for t in xrange(n):\n", + " pass # TODO: run the RNN for one time step, sample from distribution\n", + " return ixes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, "metadata": { "collapsed": true }, "outputs": [], - "source": [] + "source": [ + "# TODO: write the optimization loop\n", + "# Loop over the dataset from beginning to end, sampling batches of characters seq_length long\n", + "# Call the loss function and get the gradients\n", + "# Perform a parameter update\n", + "# Sample some examples from the model" + ] } ], "metadata": {