PerlとSPVMで学ぶディープラーニング入門 - 国産AI 日本の期待の星プロジェクト

各層のパラメーターの初期値の設定方法

各層のパラメーターの初期値の設定方法方法です。

中間層における計算 - m個の入力をn個の出力に変換するでは、各層のパラメータに適当な値を使いました。

実際は、学習が進みやすいパラメータの初期値の設定方法が発見されています。

学習が進みやすいパラメータの初期値の設定方法

定数項であるバイアスの初期値

定数項であるバイアス「$b」の初期値は0にします。

my $b = [
  0,
  0,
  0,
];

重みの初期値 - Heの初期値

重みの初期値について解説します。

まず重みの初期値は、ランダムに決めるということをイメージしておいてください。残るは、どのようなランダム値を使うかです。

まず入力の個数が、m個であるとします。この場合に、平均0、標準偏差「√(2/m)」とする正規分布に従ったランダム値を使用します。これはHeの初期値と呼ばれています。

数学の言葉が出てきても、驚かないでください。ソフトウェアエンジニアにとっては、必要なことは、正規分布に従ったランダム値を返す関数を知ることだけです。

正規分布が何かということは、数学者や統計学者にゆだねればよいのです。

# 正規分布に従う乱数を求める関数
# $sigma は標準偏差、$m は平均
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) { $r1 = rand(); }
  return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

# 入力数
my $m = 3;

# 重みの初期値
my $w_init_value = randn(0, sqrt(2/$m));

上記の関数を繰り返して100回出力してみましょう。どんな値がでてくるでしょうか?

use strict;
use warnings;

# 正規分布に従う乱数を求める関数
# $sigma は標準偏差、$m は平均
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) { $r1 = rand(); }
  return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

# 入力数
my $m = 3;

for (my $i = 0; $i < 100; $i++) {
  # 重みの初期値
  my $w_init_value = randn(0, sqrt(2/$m));
  print "$w_init_value\n";
}

出力結果のサンプル。

0.374333645619917
-0.375115681029835
-1.17350573904798
0.895734411096277
-2.45164088502897
0.21248128329348
-0.0612350673715995
-1.58015005611323
0.505326889380243
-0.599354648099927
0.0587914154389874
1.04458353028099
0.479197565615491
0.612832356529786
0.576332381567135
0.264180405541669
0.481818440413769
-0.0832424817326749
0.428427943276402
0.160565847218012
-0.737467690504583
-0.239595285587741
0.710885587813834
-0.726817014379143
-0.55686611444891
1.05190934231497
-0.2010231522538
0.827332432871366
0.295783340949512
0.0930341726064953
-0.133498161233327
-0.115434247541999
-0.0747196705878676
0.126032374885604
2.25242259680805
-0.402090471436117
-0.29409432069849
0.0108104319532479
1.94316372084671
0.353921285897365
-0.135311049425417
1.28256925366589
0.934265904171614
-0.353899801206181
-0.0177130218202863
0.876570751158322
0.93130446050329
0.565140594046581
0.0568720224046188
-1.3106498157376
0.294591391101043
0.302177509705135
-0.418728932729381
-0.890311118269986
-0.169201690565162
-0.726949054345823
-2.15992967386058
-0.0121016476938005
1.13903710052864
-0.334434247842771
-0.293389954684819
0.251851723770236
-0.65162175207737
-0.0542221278450117
-1.11968362212673
-0.756353974237116
-0.92565540502834
0.145613214400026
0.29465043488826
-0.451796034009891
0.815640908027956
0.148295326632814
-0.0099329408191555
0.78619620747882
0.599917749077822
0.62396415920206
-0.784434692735128
0.737015631093392
-1.27206848071826
1.024780494607
0.600957083530113
0.426820457119949
-0.475642547197279
-0.342178324138676
1.35634621998836
-0.168375479156423
-0.895146617883268
-0.275289833219868
-0.00169255909802365
-0.0556849290838152
0.670938086243496
0.231285743569269
0.560481150229232
0.122633939226874
0.190196194044245
-1.18023990429147
1.49262860806553
-1.13516745715629
0.612636896807322
0.99369257431466

感じ取ってください。-1を少し超えるくらいから、1くらいまでの数が出力されています。

次は、0.1刻みで、個数を数えてみましょう。

use strict;
use warnings;

# 正規分布に従う乱数を求める関数
# $sigma は標準偏差、$m は平均
sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) { $r1 = rand(); }
  return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

# 入力数
my $m = 3;

my $count = {};
for (my $i = 0; $i < 100; $i++) {
  # 重みの初期値
  my $w_init_value = randn(0, sqrt(2/$m));
  
  my $range = int($w_init_value * 10) / 10;
  
  $count->{$range}++;
}

for my $range (sort { $a <=> $b } keys %$count) {
  print "Range: $range, Count:$count->{$range}\n";
}

出力結果。

Range: -2, Count:3
Range: -1.6, Count:2
Range: -1.4, Count:1
Range: -1.3, Count:1
Range: -1.2, Count:3
Range: -1.1, Count:2
Range: -1, Count:2
Range: -0.9, Count:3
Range: -0.8, Count:1
Range: -0.7, Count:2
Range: -0.6, Count:6
Range: -0.5, Count:3
Range: -0.4, Count:3
Range: -0.3, Count:1
Range: -0.2, Count:7
Range: -0.1, Count:3
Range: 0, Count:8
Range: 0.1, Count:4
Range: 0.2, Count:3
Range: 0.3, Count:8
Range: 0.4, Count:6
Range: 0.5, Count:5
Range: 0.6, Count:2
Range: 0.7, Count:3
Range: 0.8, Count:2
Range: 0.9, Count:4
Range: 1, Count:1
Range: 1.1, Count:3
Range: 1.3, Count:3
Range: 1.4, Count:3
Range: 1.5, Count:1
Range: 1.8, Count:1

0付近の個数がなんとなく多い感じですね。値の絶対値が大きくなるにしたがって、件数が0に近づいていく感じですね。これが正規分布のイメージです。

正規分布の理論はわからなくっても、分布のばらつき具合は、イメージできましたね。

重みを設定するサンプル

重みを設定して、複数の入力から複数の出力を得るサンプルを書いてみます。

中間層における計算 - m個の入力をn個の出力に変換するのループ化のサンプルを、適切な初期値を使って書いてみます。Data::Dumperで、標準エラー出力に、重みの初期値も出力しています。

use strict;
use warnings;
use Data::Dumper;

my $x = [0.5, 0.8];
my $y = [0, 0, 0];
my $x_len = @$x;
my $y_len = @$y;

my $w = [];
for (my $i = 0; $i < $x_len * $y_len; $i++) {
  my $w_init_value = randn(0, sqrt(2/$x_len));
  push @$w, $w_init_value;
}

print STDERR Dumper($w);

my $b = [
  0,
  0,
  0
];

for (my $y_index = 0; $y_index < $y_len; $y_index++) {
  my $total = 0;
  for (my $x_index = 0; $x_index < $x_len; $x_index++) {
    $total += ($w->[$x_len * $y_index + $x_index] * $x->[$x_index]);
  }
  $total +=  $b->[$y_index];
  $y->[$y_index] = $total > 0 ? $total : 0;
}

print "($y->[0], $y->[1], $y->[2])\n";

sub randn {
  my ($m, $sigma) = @_;
  my ($r1, $r2) = (rand(), rand());
  while ($r1 == 0) { $r1 = rand(); }
  return ($sigma * sqrt(-2 * log($r1)) * sin(2 * 3.14159265359 * $r2)) + $m;
}

出力結果の例。

$VAR1 = [
          '0.152110884289137',
          '2.40437412125725',
          '1.47474871999698',
          '0.30283258213298',
          '-0.274498796676187',
          '-1.27516026508991'
        ];
(1.99955473915037, 0.979640425704874, 0)
Side Bar