各層の重みとバイアスのパラメーターの初期値の設定方法
各層の重みとバイアスのパラメーターの初期値の設定方法方法です。
中間層における計算 - m個の入力をn個の出力に変換するでは、各層のパラメータに適当な値を使いました。
実際は、学習が進みやすいパラメータの初期値の設定方法が発見されています。
学習が進みやすいパラメータの初期値の設定方法
定数項であるバイアスの初期値
定数項であるバイアス「$b」の初期値は0にします。
my $b = [ 0, 0, 0, ];
重みの初期値 - Heの初期値
重みの初期値について解説します。
まず重みの初期値は、ランダムに決めるということをイメージしておいてください。残るは、どのようなランダム値を使うかです。
まず入力の個数が、m個であるとします。この場合に、平均0、標準偏差「√(2/m)」とする正規分布に従ったランダム値を使用します。これはHeの初期値と呼ばれています。
数学の言葉が出てきても、驚かないでください。ソフトウェアエンジニアにとっては、必要なことは、正規分布に従ったランダム値を返す関数を知ることだけです。
正規分布が何かということは、数学者や統計学者にゆだねればよいのです。
# 正規分布に従う乱数を求める関数 # $sigma は標準偏差、$m は平均 sub randn { my ($m, $sigma) = @_; my ($r1, $r2) = (rand(), rand()); while ($r1 == 0) { $r1 = rand(); } return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m; } # 入力数 my $m = 3; # 重みの初期値 my $w_init_value = randn(0, sqrt(2/$m));
上記の関数を繰り返して100回出力してみましょう。どんな値がでてくるでしょうか?
use strict; use warnings; # 正規分布に従う乱数を求める関数 # $sigma は標準偏差、$m は平均 sub randn { my ($m, $sigma) = @_; my ($r1, $r2) = (rand(), rand()); while ($r1 == 0) { $r1 = rand(); } return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m; } # 入力数 my $m = 3; for (my $i = 0; $i < 100; $i++) { # 重みの初期値 my $w_init_value = randn(0, sqrt(2/$m)); print "$w_init_value\n"; }
出力結果のサンプル。
0.374333645619917 -0.375115681029835 -1.17350573904798 0.895734411096277 -2.45164088502897 0.21248128329348 -0.0612350673715995 -1.58015005611323 0.505326889380243 -0.599354648099927 0.0587914154389874 1.04458353028099 0.479197565615491 0.612832356529786 0.576332381567135 0.264180405541669 0.481818440413769 -0.0832424817326749 0.428427943276402 0.160565847218012 -0.737467690504583 -0.239595285587741 0.710885587813834 -0.726817014379143 -0.55686611444891 1.05190934231497 -0.2010231522538 0.827332432871366 0.295783340949512 0.0930341726064953 -0.133498161233327 -0.115434247541999 -0.0747196705878676 0.126032374885604 2.25242259680805 -0.402090471436117 -0.29409432069849 0.0108104319532479 1.94316372084671 0.353921285897365 -0.135311049425417 1.28256925366589 0.934265904171614 -0.353899801206181 -0.0177130218202863 0.876570751158322 0.93130446050329 0.565140594046581 0.0568720224046188 -1.3106498157376 0.294591391101043 0.302177509705135 -0.418728932729381 -0.890311118269986 -0.169201690565162 -0.726949054345823 -2.15992967386058 -0.0121016476938005 1.13903710052864 -0.334434247842771 -0.293389954684819 0.251851723770236 -0.65162175207737 -0.0542221278450117 -1.11968362212673 -0.756353974237116 -0.92565540502834 0.145613214400026 0.29465043488826 -0.451796034009891 0.815640908027956 0.148295326632814 -0.0099329408191555 0.78619620747882 0.599917749077822 0.62396415920206 -0.784434692735128 0.737015631093392 -1.27206848071826 1.024780494607 0.600957083530113 0.426820457119949 -0.475642547197279 -0.342178324138676 1.35634621998836 -0.168375479156423 -0.895146617883268 -0.275289833219868 -0.00169255909802365 -0.0556849290838152 0.670938086243496 0.231285743569269 0.560481150229232 0.122633939226874 0.190196194044245 -1.18023990429147 1.49262860806553 -1.13516745715629 0.612636896807322 0.99369257431466
感じ取ってください。-1を少し超えるくらいから、1くらいまでの数が出力されています。
次は、0.1刻みで、個数を数えてみましょう。
use strict; use warnings; # 正規分布に従う乱数を求める関数 # $sigma は標準偏差、$m は平均 sub randn { my ($m, $sigma) = @_; my ($r1, $r2) = (rand(), rand()); while ($r1 == 0) { $r1 = rand(); } return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m; } # 入力数 my $m = 3; my $count = {}; for (my $i = 0; $i < 100; $i++) { # 重みの初期値 my $w_init_value = randn(0, sqrt(2/$m)); my $range = int($w_init_value * 10) / 10; $count->{$range}++; } for my $range (sort { $a <=> $b } keys %$count) { print "Range: $range, Count:$count->{$range}\n"; }
出力結果。
Range: -2, Count:3 Range: -1.6, Count:2 Range: -1.4, Count:1 Range: -1.3, Count:1 Range: -1.2, Count:3 Range: -1.1, Count:2 Range: -1, Count:2 Range: -0.9, Count:3 Range: -0.8, Count:1 Range: -0.7, Count:2 Range: -0.6, Count:6 Range: -0.5, Count:3 Range: -0.4, Count:3 Range: -0.3, Count:1 Range: -0.2, Count:7 Range: -0.1, Count:3 Range: 0, Count:8 Range: 0.1, Count:4 Range: 0.2, Count:3 Range: 0.3, Count:8 Range: 0.4, Count:6 Range: 0.5, Count:5 Range: 0.6, Count:2 Range: 0.7, Count:3 Range: 0.8, Count:2 Range: 0.9, Count:4 Range: 1, Count:1 Range: 1.1, Count:3 Range: 1.3, Count:3 Range: 1.4, Count:3 Range: 1.5, Count:1 Range: 1.8, Count:1
0付近の個数がなんとなく多い感じですね。値の絶対値が大きくなるにしたがって、件数が0に近づいていく感じですね。これが正規分布のイメージです。
正規分布の理論はわからなくっても、分布のばらつき具合は、イメージできましたね。
重みを設定するサンプル
重みを設定して、複数の入力から複数の出力を得るサンプルを書いてみます。
中間層における計算 - m個の入力をn個の出力に変換するのループ化のサンプルを、適切な初期値を使って書いてみます。Data::Dumperで、標準エラー出力に、重みの初期値も出力しています。
use strict; use warnings; use Data::Dumper; my $x = [0.5, 0.8]; my $y = [0, 0, 0]; my $x_len = @$x; my $y_len = @$y; my $w = []; for (my $i = 0; $i < $x_len * $y_len; $i++) { my $w_init_value = randn(0, sqrt(2/$x_len)); push @$w, $w_init_value; } print STDERR Dumper($w); my $b = [ 0, 0, 0 ]; for (my $y_index = 0; $y_index < $y_len; $y_index++) { my $total = 0; for (my $x_index = 0; $x_index < $x_len; $x_index++) { $total += ($w->[$x_len * $y_index + $x_index] * $x->[$x_index]); } $total += $b->[$y_index]; $y->[$y_index] = $total > 0 ? $total : 0; } print "($y->[0], $y->[1], $y->[2])\n"; sub randn { my ($m, $sigma) = @_; my ($r1, $r2) = (rand(), rand()); while ($r1 == 0) { $r1 = rand(); } return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m; }
出力結果の例。
$VAR1 = [ '0.152110884289137', '2.40437412125725', '1.47474871999698', '0.30283258213298', '-0.274498796676187', '-1.27516026508991' ]; (1.99955473915037, 0.979640425704874, 0)
重みの初期値 - Xivierの初期値
Heの初期値以外に、Xivierの初期値と呼ばれる、平均0、標準偏差「1/sqrt(n)」である正規分布から重みをランダムに生成する方法があります。
活性化関数がReLU以外の場合に、利用されるようです。