draft 2
Esse commit está contido em:
+152
-106
@@ -8,12 +8,13 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np"
|
||||
"import numpy as np\n",
|
||||
"np.random.seed(1337)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -36,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -48,7 +49,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -59,7 +60,7 @@
|
||||
"86"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
@@ -70,7 +71,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 57,
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -92,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -115,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"execution_count": 7,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -145,7 +146,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -162,7 +163,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -171,8 +172,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n",
|
||||
" 0.01687631 -0.01756024 0.00207365 -0.00083882]\n"
|
||||
"[ -2.19597294e-03 -1.11958403e-02 -7.96573514e-03 -8.29025315e-03\n",
|
||||
" -1.69046852e-02 7.16480643e-03 -1.08628224e-05 1.38136549e-03\n",
|
||||
" 8.91231078e-03 1.22455580e-02]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -185,7 +187,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"execution_count": 10,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -194,30 +196,30 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 7.60384659e-04 4.52720278e-04 -6.57143068e-04 4.59711552e-04\n",
|
||||
" 2.16313976e-04 -2.73453571e-04 6.45337766e-04 6.58329123e-06\n",
|
||||
" 4.52167436e-04 -6.46317115e-04 4.02199638e-04 1.99170349e-04\n",
|
||||
" 4.25310887e-04 3.50158091e-04 -6.04819130e-04 2.26877413e-04\n",
|
||||
" -1.94723297e-04 5.50639249e-04 -1.71656881e-04 -2.99860880e-04\n",
|
||||
" 6.07620521e-04 5.70905998e-06 4.26038734e-04 1.01454914e-04\n",
|
||||
" -3.65278849e-04 5.90686240e-04 1.59204750e-04 4.08708316e-04\n",
|
||||
" 2.56014525e-04 1.39087410e-06 -6.90950824e-04 7.18269467e-04\n",
|
||||
" 2.61719276e-04 4.58528196e-04 3.86335255e-04 2.70955381e-04\n",
|
||||
" -1.48929417e-04 -1.83764339e-04 -6.36673024e-04 4.00478015e-04\n",
|
||||
" -8.69493168e-05 3.73431457e-04 9.14721284e-05 4.38316766e-05\n",
|
||||
" -6.83760980e-05 1.52588617e-05 1.74420973e-04 -1.47075009e-04\n",
|
||||
" -3.62079011e-04 5.05874063e-05 6.20919471e-04 7.19441792e-05\n",
|
||||
" -1.88333333e-04 -2.98096664e-04 4.43592043e-04 -3.76377143e-04\n",
|
||||
" 2.10335682e-04 -6.41359408e-04 -4.54105970e-05 -7.58974605e-05\n",
|
||||
" 5.47988450e-05 -2.77672190e-04 2.07213272e-05 2.13088016e-04\n",
|
||||
" -5.07786842e-04 1.27347437e-04 4.24719496e-04 -2.13192324e-04\n",
|
||||
" -2.05694776e-04 -1.43861034e-04 4.50169861e-04 2.74675798e-04\n",
|
||||
" 1.31264147e-04 3.29399476e-04 -1.41866715e-04 -7.91514010e-05\n",
|
||||
" -2.96356884e-05 -3.85666773e-04 2.79322785e-04 6.82910193e-04\n",
|
||||
" -2.05519751e-04 -4.25825715e-04 -8.47060105e-05 2.66896920e-04\n",
|
||||
" 1.17914625e-04 -4.00869517e-04 2.30587114e-04 1.24662656e-05\n",
|
||||
" 2.43743504e-04 4.57417182e-04 -2.78926664e-05 1.23560813e-05\n",
|
||||
" -3.55199314e-04]\n"
|
||||
"[ -1.86941995e-04 5.75478992e-05 5.45725372e-05 -4.12423671e-04\n",
|
||||
" 3.06201649e-04 -5.06931560e-04 3.19806236e-04 -3.51163814e-05\n",
|
||||
" -3.38972137e-04 -2.39857442e-04 -8.01432457e-05 -1.04391992e-04\n",
|
||||
" -3.15399258e-04 1.48454913e-05 4.55998798e-05 1.28931375e-04\n",
|
||||
" 4.04072882e-04 -5.75504844e-04 1.57533734e-04 -4.08258042e-04\n",
|
||||
" -7.41755605e-05 -7.76690551e-05 1.87837950e-04 -2.79910204e-04\n",
|
||||
" -7.96190232e-04 -5.79876645e-05 1.26730230e-04 3.95894081e-04\n",
|
||||
" 2.76955496e-04 2.59637379e-05 -4.61015050e-04 -5.14636280e-04\n",
|
||||
" -2.56714611e-04 4.94377196e-04 -3.64149466e-04 -4.26364575e-04\n",
|
||||
" -1.25515821e-04 9.29132184e-06 4.88944584e-05 -6.31681837e-04\n",
|
||||
" -1.75121523e-05 9.51597096e-06 3.08227238e-04 1.29694758e-04\n",
|
||||
" 1.72591168e-04 -5.87989566e-05 2.74250548e-04 5.90320270e-05\n",
|
||||
" 1.64825617e-04 3.11431559e-04 3.48582321e-04 -4.19733486e-05\n",
|
||||
" -2.77062543e-04 -1.10111861e-05 -1.33642742e-05 -5.99994639e-05\n",
|
||||
" 6.07084027e-04 3.51784289e-04 -3.89265870e-05 -1.03640935e-05\n",
|
||||
" 2.94412266e-04 7.75998698e-05 -2.85416619e-04 4.26053337e-04\n",
|
||||
" -3.47809927e-04 -4.67567746e-04 -6.81980155e-05 -1.09923565e-06\n",
|
||||
" 6.26504669e-05 -2.76671607e-04 2.96818053e-04 1.87011604e-05\n",
|
||||
" 2.98832018e-04 -2.65531746e-04 -5.58804832e-04 -3.07801824e-04\n",
|
||||
" 3.75513314e-04 4.76734817e-04 -2.96633044e-04 -3.06249150e-04\n",
|
||||
" 1.22093579e-05 -1.43195488e-04 -1.58954120e-04 -1.08600298e-04\n",
|
||||
" 1.52122727e-04 -6.54969615e-04 1.60857638e-04 -9.11628405e-04\n",
|
||||
" 3.67740402e-05 2.16278686e-04 -1.98860111e-04 4.08539302e-05\n",
|
||||
" -1.47124899e-04]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -229,7 +231,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 63,
|
||||
"execution_count": 11,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -238,22 +240,22 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n",
|
||||
" 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n",
|
||||
" 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n",
|
||||
" 0.01075015 0.01074877 0.01075853 0.01075206 0.01075658 0.01075309\n",
|
||||
" 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n",
|
||||
" 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n",
|
||||
" 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n",
|
||||
" 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n",
|
||||
" 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n",
|
||||
" 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n",
|
||||
" 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n",
|
||||
" 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n",
|
||||
" 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n",
|
||||
" 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n",
|
||||
" 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n",
|
||||
" 0.0107517 0.01075213 0.01074818]\n",
|
||||
"[ 0.01075121 0.01075383 0.0107538 0.01074878 0.01075651 0.01074777\n",
|
||||
" 0.01075666 0.01075284 0.01074957 0.01075064 0.01075235 0.01075209\n",
|
||||
" 0.01074982 0.01075338 0.01075371 0.0107546 0.01075756 0.01074703\n",
|
||||
" 0.01075491 0.01074883 0.01075242 0.01075238 0.01075524 0.01075021\n",
|
||||
" 0.01074466 0.01075259 0.01075458 0.01075747 0.01075619 0.01075349\n",
|
||||
" 0.01074826 0.01074768 0.01075046 0.01075853 0.0107493 0.01074863\n",
|
||||
" 0.01075187 0.01075332 0.01075374 0.01074643 0.01075303 0.01075332\n",
|
||||
" 0.01075653 0.01075461 0.01075507 0.01075258 0.01075617 0.01075385\n",
|
||||
" 0.01075499 0.01075656 0.01075696 0.01075276 0.01075024 0.0107531\n",
|
||||
" 0.01075307 0.01075257 0.01075975 0.010757 0.0107528 0.0107531\n",
|
||||
" 0.01075638 0.01075405 0.01075015 0.0107578 0.01074948 0.01074819\n",
|
||||
" 0.01075248 0.0107532 0.01075389 0.01075024 0.01075641 0.01075342\n",
|
||||
" 0.01075643 0.01075036 0.01074721 0.01074991 0.01075725 0.01075834\n",
|
||||
" 0.01075003 0.01074992 0.01075335 0.01075168 0.01075151 0.01075205\n",
|
||||
" 0.01075485 0.01074617 0.01075495 0.01074342 0.01075361 0.01075554\n",
|
||||
" 0.01075108 0.01075365 0.01075163]\n",
|
||||
"probabilities sum to 1.0\n"
|
||||
]
|
||||
}
|
||||
@@ -267,7 +269,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 64,
|
||||
"execution_count": 12,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -276,7 +278,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"probability assigned to the correct next character is right now: 0.0107565780746\n"
|
||||
"probability assigned to the correct next character is right now: 0.0107552356218\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -286,7 +288,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 72,
|
||||
"execution_count": 13,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -295,7 +297,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"the cross-entropy (softmax) loss is 4.53223779763\n"
|
||||
"the cross-entropy (softmax) loss is 4.53236260838\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -306,7 +308,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 74,
|
||||
"execution_count": 14,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -315,25 +317,25 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[ 0.01076018 0.01075687 0.01074493 0.01075694 0.01075432 0.01074906\n",
|
||||
" 0.01075894 0.01075207 0.01075686 0.01074505 0.01075632 0.01075414\n",
|
||||
" 0.01075657 0.01075576 0.0107455 0.01075444 0.0107499 0.01075792\n",
|
||||
" 0.01075015 0.01074877 0.01075853 0.01075206 -0.98924342 0.01075309\n",
|
||||
" 0.01074807 0.01075835 0.01075371 0.01075639 0.01075475 0.01075201\n",
|
||||
" 0.01074457 0.01075972 0.01075481 0.01075693 0.01075615 0.01075491\n",
|
||||
" 0.0107504 0.01075002 0.01074515 0.0107563 0.01075106 0.01075601\n",
|
||||
" 0.01075298 0.01075247 0.01075126 0.01075216 0.01075387 0.01075042\n",
|
||||
" 0.0107481 0.01075254 0.01075867 0.01075277 0.01074997 0.01074879\n",
|
||||
" 0.01075677 0.01074795 0.01075426 0.0107451 0.01075151 0.01075118\n",
|
||||
" 0.01075259 0.01074901 0.01075222 0.01075429 0.01074654 0.01075337\n",
|
||||
" 0.01075656 0.0107497 0.01074978 0.01075045 0.01075684 0.01075495\n",
|
||||
" 0.01075341 0.01075554 0.01075047 0.01075115 0.01075168 0.01074785\n",
|
||||
" 0.010755 0.01075934 0.01074979 0.01074742 0.01075109 0.01075487\n",
|
||||
" 0.01075326 0.01074769 0.01075448 0.01075213 0.01075462 0.01075692\n",
|
||||
" 0.0107517 0.01075213 0.01074818]\n",
|
||||
"sum of dy is 3.26128013484e-16\n",
|
||||
"the gradient for the correct character (t) is: -0.989243421925\n",
|
||||
"the gradient for the character (a) is: 0.0107544758895\n"
|
||||
"[ 0.01075121 0.01075383 0.0107538 0.01074878 0.01075651 0.01074777\n",
|
||||
" 0.01075666 0.01075284 0.01074957 0.01075064 0.01075235 0.01075209\n",
|
||||
" 0.01074982 0.01075338 0.01075371 0.0107546 0.01075756 0.01074703\n",
|
||||
" 0.01075491 0.01074883 0.01075242 0.01075238 -0.98924476 0.01075021\n",
|
||||
" 0.01074466 0.01075259 0.01075458 0.01075747 0.01075619 0.01075349\n",
|
||||
" 0.01074826 0.01074768 0.01075046 0.01075853 0.0107493 0.01074863\n",
|
||||
" 0.01075187 0.01075332 0.01075374 0.01074643 0.01075303 0.01075332\n",
|
||||
" 0.01075653 0.01075461 0.01075507 0.01075258 0.01075617 0.01075385\n",
|
||||
" 0.01075499 0.01075656 0.01075696 0.01075276 0.01075024 0.0107531\n",
|
||||
" 0.01075307 0.01075257 0.01075975 0.010757 0.0107528 0.0107531\n",
|
||||
" 0.01075638 0.01075405 0.01075015 0.0107578 0.01074948 0.01074819\n",
|
||||
" 0.01075248 0.0107532 0.01075389 0.01075024 0.01075641 0.01075342\n",
|
||||
" 0.01075643 0.01075036 0.01074721 0.01074991 0.01075725 0.01075834\n",
|
||||
" 0.01075003 0.01074992 0.01075335 0.01075168 0.01075151 0.01075205\n",
|
||||
" 0.01075485 0.01074617 0.01075495 0.01074342 0.01075361 0.01075554\n",
|
||||
" 0.01075108 0.01075365 0.01075163]\n",
|
||||
"sum of dy is -4.23272528138e-16\n",
|
||||
"the gradient for the correct character (t) is: -0.989244764378\n",
|
||||
"the gradient for the character (a) is: 0.0107549454461\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -349,7 +351,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 88,
|
||||
"execution_count": 15,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -359,17 +361,18 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"the hidden vector activations were:\n",
|
||||
"[ 0.01971421 -0.00533177 0.00142983 -0.00985955 0.00248728 -0.01736378\n",
|
||||
" 0.01687631 -0.01756024 0.00207365 -0.00083882]\n",
|
||||
"[ -2.19597294e-03 -1.11958403e-02 -7.96573514e-03 -8.29025315e-03\n",
|
||||
" -1.69046852e-02 7.16480643e-03 -1.08628224e-05 1.38136549e-03\n",
|
||||
" 8.91231078e-03 1.22455580e-02]\n",
|
||||
"the gradients are:\n",
|
||||
"[ 0.01411266 -0.00086765 -0.00501521 0.00024313 0.00070486 0.02431426\n",
|
||||
" -0.00599582 0.00532303 -0.01018747 -0.00155986]\n",
|
||||
"[ 0.010636 -0.00556332 0.01572234 -0.01013355 0.01024101 0.00893933\n",
|
||||
" 0.01156068 -0.0035879 -0.00304811 -0.00761244]\n",
|
||||
"the gradients dWhy have size: (93, 10)\n",
|
||||
"a small sample is:\n",
|
||||
"[[ 2.12128361e-04 -5.73708201e-05 1.53852052e-05 -1.06090457e-04]\n",
|
||||
" [ 2.12063107e-04 -5.73531718e-05 1.53804724e-05 -1.06057822e-04]\n",
|
||||
" [ 2.11827876e-04 -5.72895529e-05 1.53634117e-05 -1.05940178e-04]\n",
|
||||
" [ 2.12064589e-04 -5.73535728e-05 1.53805799e-05 -1.06058563e-04]]\n"
|
||||
"[[ -2.36093564e-05 -1.20368781e-04 -8.56412557e-05 -8.91302155e-05]\n",
|
||||
" [ -2.36151293e-05 -1.20398213e-04 -8.56621967e-05 -8.91520096e-05]\n",
|
||||
" [ -2.36150591e-05 -1.20397855e-04 -8.56619418e-05 -8.91517444e-05]\n",
|
||||
" [ -2.36040335e-05 -1.20341643e-04 -8.56219473e-05 -8.91101206e-05]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -389,7 +392,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 90,
|
||||
"execution_count": 16,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -399,10 +402,10 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"small sample of Whh:\n",
|
||||
"[[ 0.00013373 -0.01329709 0.01799412 0.01353991]\n",
|
||||
" [-0.00427217 -0.01052258 0.00820703 0.00452433]\n",
|
||||
" [-0.00511801 -0.0094128 0.00029513 0.0058532 ]\n",
|
||||
" [-0.00132356 0.00330526 0.00737467 -0.00426905]]\n"
|
||||
"[[-0.0181644 -0.00730641 0.00068086 0.00043998]\n",
|
||||
" [-0.00706002 -0.00657655 0.0100175 0.00198405]\n",
|
||||
" [ 0.0175388 0.01643369 -0.00161146 -0.01413518]\n",
|
||||
" [ 0.00812143 -0.00105367 0.00944295 0.01148334]]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -420,7 +423,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 95,
|
||||
"execution_count": 17,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -438,7 +441,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 96,
|
||||
"execution_count": 18,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -447,10 +450,10 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"probability assigned to the correct next character was: 0.0107565780746\n",
|
||||
"probability assigned to the correct next character is now: 0.0118772830163\n",
|
||||
"the cross-entropy (softmax) loss was 4.53223779763\n",
|
||||
"the loss is now 4.43312769353\n"
|
||||
"probability assigned to the correct next character was: 0.0107552356218\n",
|
||||
"probability assigned to the correct next character is now: 0.0118750046895\n",
|
||||
"the cross-entropy (softmax) loss was 4.53236260838\n",
|
||||
"the loss is now 4.43331953415\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -468,7 +471,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 98,
|
||||
"execution_count": 19,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
@@ -479,13 +482,13 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 101,
|
||||
"execution_count": 20,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# putting it together\n",
|
||||
"# putting it together with loops\n",
|
||||
"def lossFun(inputs, targets, hprev):\n",
|
||||
" \"\"\"\n",
|
||||
" inputs,targets are both list of integers.\n",
|
||||
@@ -495,6 +498,7 @@
|
||||
" xs, hs, ys, ps = {}, {}, {}, {}\n",
|
||||
" hs[-1] = np.copy(hprev)\n",
|
||||
" loss = 0\n",
|
||||
" \n",
|
||||
" # forward pass\n",
|
||||
" for t in xrange(len(inputs)):\n",
|
||||
" xs[t] = np.zeros((vocab_size,1)) # encode in 1-of-k representation\n",
|
||||
@@ -509,14 +513,27 @@
|
||||
" dbh, dby = np.zeros_like(bh), np.zeros_like(by)\n",
|
||||
" dhnext = np.zeros_like(hs[0])\n",
|
||||
" for t in reversed(xrange(len(inputs))):\n",
|
||||
" pass # TODO\n",
|
||||
" dy = np.copy(ps[t])\n",
|
||||
" dy[targets[t]] -= 1 # backprop into y\n",
|
||||
" dWhy += np.dot(dy, hs[t].T)\n",
|
||||
" dby += dy\n",
|
||||
" dh = np.dot(Why.T, dy) + dhnext # backprop into h\n",
|
||||
" dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity\n",
|
||||
" dbh += dhraw\n",
|
||||
" dWxh += np.dot(dhraw, xs[t].T)\n",
|
||||
" dWhh += np.dot(dhraw, hs[t-1].T)\n",
|
||||
" dhnext = np.dot(Whh.T, dhraw)\n",
|
||||
" \n",
|
||||
" # clip to mitigate exploding gradients\n",
|
||||
" for dparam in [dWxh, dWhh, dWhy, dbh, dby]:\n",
|
||||
" np.clip(dparam, -5, 5, out=dparam)\n",
|
||||
" \n",
|
||||
" return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 104,
|
||||
"execution_count": 21,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
@@ -525,7 +542,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"113.313778376\n"
|
||||
"113.314839864\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -536,12 +553,41 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 22,
|
||||
"metadata": {
|
||||
"collapsed": false
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# TODO: write the sampling code\n",
|
||||
"def sample(h, seed_ix, n):\n",
|
||||
" \"\"\" \n",
|
||||
" sample a sequence of integers from the model \n",
|
||||
" h is initial memory state, seed_ix is seed letter for first time step\n",
|
||||
" n is the number of time steps to sample for\n",
|
||||
" \"\"\"\n",
|
||||
" x = np.zeros((vocab_size, 1))\n",
|
||||
" x[seed_ix] = 1\n",
|
||||
" ixes = [] # sampled indices\n",
|
||||
" for t in xrange(n):\n",
|
||||
" pass # TODO: run the RNN for one time step, sample from distribution\n",
|
||||
" return ixes\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"metadata": {
|
||||
"collapsed": true
|
||||
},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
"source": [
|
||||
"# TODO: write the optimization loop\n",
|
||||
"# Loop over the dataset from beginning to end, sampling batches of characters seq_length long\n",
|
||||
"# Call the loss function and get the gradients\n",
|
||||
"# Perform a parameter update\n",
|
||||
"# Sample some examples from the model"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário