確率的勾配降下法(SGD) - 重みとバイアスのパラメータの更新

確率的勾配降下法(SGD)は、重みとバイアスのパラメータを更新するための一つの手法です。それぞれの重みとバイアスのパラメータの微小変化に対する損失関数の微小変化が求めた後に、この値を、学習率を考慮して、現在の重みとバイアスのパラメータから減算します。訓練データをランダムで並び変えることが特徴です。

ディープラーニングでは、それぞれの重みとバイアスのパラメータの微小変化に対する損失関数の微小変化は、逆誤伝播法によって求めますが、勾配降下法は、逆誤伝播法を実行した後に行われる処理になります。

勾配というのは、それぞれの重みとバイアスのパラメータの微小変化に対する損失関数の微小変化のことだとイメージしてください。

降下というのは、減算することだとイメージしてください。

下の図で、重みとパラメータが、どの位置にあるかを図示しました。入出力と損失関数についても、図示しています。

● ● ● ● ●

5個の入力から7個の出力へ変換する関数(重みパラメータは7行5列の行列とバイアスパラメーターは5要素の列ベクトル)

● ● ● ● ● ● ●

7個の入力から4個の出力へ変換する関数(重みパラメータは4行7列の行列とバイアスパラメーターは4要素の列ベクトル)

● ● ● ●

損失関数を適用(4つの入力は1つの誤差を表す指標へ)

●

ひとつの重みやバイアスのパラメータを少しだけ増やしたときに、損失関数の値が少し増えたと考えてください。この値を使って、パラメータの微小変化に対する損失関数の微小変化「損失関数の微小変化 / パラメータの微小変化」が求まります。これは、手動で求めましたが、ディープラーニングでは、逆誤伝播法というアルゴリズムを使って求めます。

勾配降下法は、更新されたパラメータを、いくつまとめて更新するか(バッチサイズ)ということも考慮します。勾配降下法では、訓練データは、ランダムにシャッフルすると、学習が止まりにくくなります。

勾配降下法のソースコード

勾配降下法の部分だけを取り出したPerlのソースコードです。ミニバッチサイズで、それぞれの重みとバイアスのパラメータの傾きの合計を求めて、学習率とミニバッチサイズを考慮して、重みとバイアスを更新します。

# エポックの回数だけ訓練セットを実行
for (my $epoch_index = 0; $epoch_index < $epoch_count; $epoch_index++) {
  
  # 訓練データのインデックスをシャッフル(ランダムに学習させた方が汎用化するらしい)
  my @training_data_indexes_shuffle = shuffle @training_data_indexes;
  
  my $count = 0;
  
  # ミニバッチサイズ単位で学習
  my $backprop_count = 0;

  while (my @indexed_for_mini_batch = splice(@training_data_indexes_shuffle, 0, $mini_batch_size)) {
    
    # ミニバッチにおける各変換関数のバイアスの傾きの合計とミニバッチにおける各変換関数の重みの傾きの合計を0で初期化
    for (my $m_to_n_func_index = 0; $m_to_n_func_index < @$m_to_n_func_mini_batch_infos; $m_to_n_func_index++) {
      my $m_to_n_func_info = $m_to_n_func_infos->[$m_to_n_func_index];
      my $biases = $m_to_n_func_info->{biases};
      my $weights_mat = $m_to_n_func_info->{weights_mat};
      
      # ミニバッチにおける各変換関数のバイアスの傾きの合計を0で初期化して作成
      $m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{biase_grad_totals} = array_new_zero(scalar @{$m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{biase_grad_totals}});

      # ミニバッチにおける各変換関数の重みの傾きの合計を0で初期化して作成
      $m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{weight_grad_totals_mat}{values} = array_new_zero(scalar @{$m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{weight_grad_totals_mat}{values}});
    }
    
    for my $training_data_index (@indexed_for_mini_batch) {
      # 逆誤伝播法を使って重みとバイアスの損失関数に関する傾きを取得
      my $m_to_n_func_grad_infos = backprop($m_to_n_func_infos, $mnist_train_image_info, $mnist_train_label_info, $training_data_index);
      
      # バイアスの損失関数に関する傾き
      my $biase_grads = $m_to_n_func_grad_infos->{biases};
      
      # 重みの損失関数に関する傾き
      my $weight_grads_mat = $m_to_n_func_grad_infos->{weights_mat};

      # ミニバッチにおける各変換関数のバイアスの傾きの合計とミニバッチにおける各変換関数の重みの傾きを加算
      for (my $m_to_n_func_index = 0; $m_to_n_func_index < @$m_to_n_func_mini_batch_infos; $m_to_n_func_index++) {
        my $m_to_n_func_info = $m_to_n_func_infos->[$m_to_n_func_index];
        
        # ミニバッチにおける各変換関数のバイアスの傾きを加算
        array_add_inplace($m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{biase_grad_totals}, $biase_grads->[$m_to_n_func_index]);

        # ミニバッチにおける各変換関数の重みの傾きを加算
        array_add_inplace($m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{weight_grad_totals_mat}{values}, $weight_grads_mat->[$m_to_n_func_index]{values});
      }
    }

    # 各変換関数のバイアスと重みをミニバッチの傾きの合計を使って更新
    for (my $m_to_n_func_index = 0; $m_to_n_func_index < @$m_to_n_func_infos; $m_to_n_func_index++) {
      
      # 各変換関数のバイアスを更新(学習率を考慮し、ミニバッチ数で割る)
      update_params($m_to_n_func_infos->[$m_to_n_func_index]{biases}, $m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{biase_grad_totals}, $learning_rate, $mini_batch_size);
      
      # 各変換関数の重みを更新(学習率を考慮し、傾きの合計をミニバッチ数で、ミニバッチ数で割る)
      update_params($m_to_n_func_infos->[$m_to_n_func_index]{weights_mat}{values}, $m_to_n_func_mini_batch_infos->[$m_to_n_func_index]{weight_grad_totals_mat}{values}, $learning_rate, $mini_batch_size);
    }
  }
}

# 配列の各要素の和を最初の引数に足す
sub array_add_inplace {
  my ($nums1, $nums2) = @_;
  
  if (@$nums1 != @$nums2) {
    die "Array length is diffent";
  }
  
  for (my $i = 0; $i < @$nums1; $i++) {
    $nums1->[$i] += $nums2->[$i];
  }
}

# 学習率とミニバッチ数を考慮してパラメーターを更新
sub update_params {
  my ($params, $param_grads, $learning_rate, $mini_batch_size) = @_;
  
  for (my $param_index = 0; $param_index < @$params; $param_index++) {
    my $update_diff = ($learning_rate / $mini_batch_size) * $param_grads->[$param_index];
    $params->[$param_index] -= $update_diff;
  }
}

# 配列を0で初期化して作成
sub array_new_zero {
  my ($length) = @_;
  
  my $nums = [(0) x $length];
  
  return $nums;
}
Perlテキスト処理のエッセンス
  • 初級者向け・テキスト処理と正規表現の基本をマスター
業務に役立つPerl
  • 実務者向け・ログ解析など日本語を含むテキスト処理の実践!
Perlクラブ
  • 仲間と出会い
    ゆとりあるエンジニアライフを送る