Perlで学ぶディープラーニング入門

クロスエントロピー誤差を求める - 損失関数

クロスエントロピー誤差を求める関数をPerlで書いてみましょう。クロスエントロピー誤差は、出力結果と期待される出力結果(正しい答え)の誤差を計算するために使用する損失関数の一つです。

use strict;
use warnings;

# クロスエントロピー誤差
sub cross_entropy_cost {
  my ($outputs, $desired_outputs) = @_;
  
  if (@$outputs != @$desired_outputs) {
    die "Outputs length is different from Desired length";
  }
  
  my $cross_entropy_cost = 0;
  
  for (my $i = 0; $i < @$outputs; $i++) {
    $cross_entropy_cost += -$desired_outputs->[$i] * log($outputs->[$i]) - (1 - $desired_outputs->[$i]) * log(1 - $outputs->[$i]);
  }
  
  return $cross_entropy_cost;
}

my $outputs = [0.7, 0.2, 0.1];
my $desired_outputs = [1, 0, 0];
my $cross_entropy_cost = cross_entropy_cost($outputs, $desired_outputs);

print "$cross_entropy_cost\n";

ディープラーニングでは、損失関数で求められた誤差が小さくなるように、重みとバイアスのパラメーターが調整されていきます。

パターン認識の問題における損失関数としては、偏微分の形が難しく計算が複雑になるので二乗和誤差より、クロスエントロピー誤差を使うほうが、望ましいようです。

クロスエントロピー誤差の偏微分関数

クロスエントロピー誤差の偏微分関数をPerlで書いてみましょう。損失関数の偏微分関数は、逆誤伝播法を実装するときに必要になります。

損失関数の偏微分の戻り値は、配列になることに注意してください。戻り値が一つの値である損失関数と異なります。

use strict;
use warnings;

sub cross_entropy_cost_delta {
  my ($outputs, $activate_outputs, $desired_outputs) = @_;

  if (@$activate_outputs != @$desired_outputs) {
    die "Outputs length is different from Desired length";
  }
  
  my $cross_entropy_cost_delta = [];
  for (my $i = 0; $i < @$activate_outputs; $i++) {
    $cross_entropy_cost_delta->[$i] = $activate_outputs->[$i] - $desired_outputs->[$i];
  }
  
  return $cross_entropy_cost_delta;
}

my $activate_outputs = [0.6, 0, 0.2];
my $desired_outputs = [1, 0, 0];
my $cross_entropy_cost = cross_entropy_cost_delta(undef, $activate_outputs, $desired_outputs);

print "@$cross_entropy_cost\n";

ソフトウェアエンジニアにとっての、偏微分のイメージは、個々の入力を少し変化させた場合に対する、出力(損失関数の値)の変化の割合だと考えてください。

最初の入力値を0.01増やしてみください。出力は、0.3増えました。傾きは「0.3 / 0.01」で、30です。

次の入力値を0.01増やしてみください。出力は、0.5増えました。傾きは「0.5 / 0.01」で、50です。

偏微分という難しい言葉に脳みそがやられてしまうかもしれませんが、実は簡単なことなのです。

「入力の変化に対する出力の変化の割合」という意味を、ピッタリと表現する言葉が採用されていれば、数学はもっと簡単なものだったかもしれませんね。

Perlで学ぶディープラーニング入門のご紹介

Side Bar