オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

mnistを重み付きラベルで学習してみる

mnistを学習させるとき、通常、
この画像のラベルは下記のようにします

f:id:changlikesdesktop:20190317055927j:plain

Number Label
0 0
1 0
2
3 0
4 0
5 0
6 0
7 0
8 0
9 0

これを敢えて、重み付きのラベルで学習させてみます

Number Label
0 0.2
1 0.2
2 0.2
3 0.8
4 0.2
5 0.2
6 0.2
7 0.2
8 0.2
9 0.2

傷位置検出で取り入れようとしている*1この手法が正しいか、
評価するためです

ソースをgithubにアップします*2

損失関数を重み付きのラベルでクロスエントリピーを計算しているところがポイントです

def loss(output, weight):
    soft = tf.nn.softmax(output)
    xentropy = - tf.reduce_sum(weight * tf.log(soft), 1)
    loss = tf.reduce_mean(xentropy)
    return loss

100サイクルで学習させてみました
活性化にsigmoidを使っていることもあって少し精度が低かったんですが、
ちゃんと学習できました

f:id:changlikesdesktop:20190401054641p:plain

通常の01ラベルと、出力を比較してみます

f:id:changlikesdesktop:20190401055647p:plain

01ラベルでは、正解の3が一番大きな値になるものの、
その他の値でもそこそこの確率が出力されていました
一方、重み付きラベルでは、3以外の出力は等しく低く、
横ばいになっています
狙い通りの動きですね