オッサンはDesktopが好き

自作PCや機械学習、自転車のことを脈絡無く書きます

Deep Learningによる傷位置検出、損失関数

こういうの↓を作ってみようと思います

f:id:changlikesdesktop:20190304045701p:plain:w200
*1

mnist問題をやっと動かせるようになったレベルの僕には
敷居が高そうですが、、、

画像のダウンロード

上記の画像は10年位前に行われたコンペティションで使われたもので、
ここ*2からダウンロード出来ます
コンペは種類の異なる6クラスの画像全てを使って行われた様ですが、
今回は上記で示されているClass 1のみを使うことにします

基本方針

mnistは、モデルに画像を入力して分類結果を出力する動きでした
正確に言うと、ニューラルネットワークが出力するのは
分類項目についての確率密度です
f:id:changlikesdesktop:20190304055226p:plain:w200

傷位置検出も同じ考え方で構築します
つまり、画像の座標情報に対する確率密度を出力します
傷がある場所(ある確率の高い場所)の数値を大きくするという事ですね
f:id:changlikesdesktop:20190304055200p:plain:w400

損失関数

mnistのような画像分類と傷位置検出の大きな違いの一つが、
損失関数の立て方だと思います
画像分類で一般的に使われてるcross entropyは次の数式で表されます

 \displaystyle
  E = - \sum_{k}^{} t_k log y_k

正解ラベルを表す t_kは、正解ならば1、不正解ならば0となるone_hot表現です
色々と試行錯誤をする中で、このone_hot表現が傷位置検出には
合わないことが解ってきました
one_hot表現では、正解の場所(1が掛けられる)以外の出力値には
0が掛けられて無視されます
つまり、↓の2つの出力は同じ意味になるのです

f:id:changlikesdesktop:20190305050620p:plain:w200

画像分類であればこれで構わないのですが、
傷位置検出では傷が無い場所が等しく低い確率になることも重要です
傷がない場所は、等しく無いわけですから

この問題に対する答えが論文*3の中にありました

f:id:changlikesdesktop:20190304062824p:plain

読み解くのが大変なんですが、
要は傷がある場所には0.8の、傷が無い場所には0.2の重みを掛けろ、
と言っていると思います
傷の検出を重視しつつ、傷が無い場所は一様に、
という意味ですね

コーディング

以前の投稿*4と同じやり方で、画像をテンソル形式に並べます

import math
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from PIL import Image

TRAIN_DATA_SIZE = 1000
TRAIN_DATA_SIZE_WITH_DEFECT = 100
VALID_DATA_SIZE = 50
TEST_DATA_SIZE = 50
IMG_SIZE = 64

if __name__ == "__main__":

    init = tf.global_variables_initializer()
    sess = tf.Session()
    with sess.as_default():

        # remove old file
        if(os.path.exists('./data/trainImage64.txt')):
            os.remove('./data/trainImage64.txt')
        if(os.path.exists('./data/validationImage64.txt')):
            os.remove('./data/validationImage64.txt')
        if(os.path.exists('./data/testImage64.txt')):
            os.remove('./data/testImage64.txt')
        if(os.path.exists('./data/trainLABEL.txt')):
            os.remove('./data/trainLABEL.txt')
        if(os.path.exists('./data/trainWEIGHT.txt')):
            os.remove('./data/trainWEIGHT.txt')
        if(os.path.exists('./data/validationLABEL.txt')):
            os.remove('./data/validationLABEL.txt')
        if(os.path.exists('./data/validationWEIGHT.txt')):
            os.remove('./data/validationWEIGHT.txt')
        if(os.path.exists('./data/testLABEL.txt')):
            os.remove('./data/testLABEL.txt')
        if(os.path.exists('./data/testWEIGHT.txt')):
            os.remove('./data/testWEIGHT.txt')
        
        # without detection
        for k in range(TRAIN_DATA_SIZE):
            filename = './data/Class1/' + str(k + 1) + '.png'
            print(filename)
            imgtf = tf.read_file(filename)
            img = tf.image.decode_png(imgtf, channels=1)
            resized = tf.image.resize_images(img, [IMG_SIZE, IMG_SIZE], method=tf.image.ResizeMethod.BILINEAR)
            array = resized.eval()
            line = str(k)
            for i in range(IMG_SIZE):
                for j in range(IMG_SIZE):
                    line = line + ',' + str(array[i, j, 0])
            line = line + '\n'
            file = open('./data/trainImage64.txt', 'a')
            file.write(line)
            file.close()

        # # detection data
        for k in range(TRAIN_DATA_SIZE_WITH_DEFECT + VALID_DATA_SIZE):
            filename = './data/Class1_def/' + str(k + 1) + '.png'
            print(filename)
            imgtf = tf.read_file(filename)
            img = tf.image.decode_png(imgtf, channels=1)
            resized = tf.image.resize_images(img, [IMG_SIZE, IMG_SIZE], method=tf.image.ResizeMethod.BILINEAR)
            array = resized.eval()
            line = str(k + TRAIN_DATA_SIZE)
            for i in range(IMG_SIZE):
                for j in range(IMG_SIZE):
                    line = line + ',' + str(array[i, j, 0])
            line = line + '\n'
            if(k < TRAIN_DATA_SIZE_WITH_DEFECT):
                file = open('./data/trainImage64.txt', 'a')
                file.write(line)
                file.close()
            else:
                file = open('./data/validationImage64.txt', 'a')
                file.write(line)
                file.close()
                file = open('./data/testImage64.txt', 'a')
                file.write(line)
                file.close()

        # label #
        trnLABEL = []
        trnWEIGHT = []
        valLABEL = []
        valWEIGHT = []
        tstLABEL = []
        tstWEIGHT = []
        # no defection data
        for k in range(1000):
            label = np.zeros([16*16 + 1])
            label[0] = k
            trnLABEL.append(label)
            weight = np.zeros([16*16 + 1])
            weight[0] = k
            weight[1:16*16 + 1] = 0.2
            trnWEIGHT.append(weight)
    
        # defection data
        x = np.linspace(15.5, 495.5, 16)
        y = np.linspace(15.5, 495.5, 16)
        print('reading Class1_def')
        label1 = open('./data/Class1_def/labels.txt', 'r')
        for k in range(150):
            line = label1.readline()
            val = line.split('\t')
            num = int(val[0]) - 1
            mjr = float(val[1])
            mnr = float(val[2])
            rot = float(val[3])
            cnx = float(val[4])
            cny = float(val[5]) 

            # inverse rotate pixels
            label = np.zeros([16*16 + 1])
            weight = np.zeros([16*16 + 1])
            label[0] = num + 1000 # index
            weight[0] = num + 1000 # index
            for i in range(16):
                for j in range(16):
                    dist = math.sqrt((x[i] - cnx)**2 + (y[j] - cny)**2)
                    xTmp = (x[i] - cnx) * math.cos(-rot) - (y[j] - cny) * math.sin(-rot)
                    yTmp = (x[i] - cnx) * math.sin(-rot) + (y[j] - cny) * math.cos(-rot)
                    ang = math.atan(yTmp/xTmp)
                    distToEllipse = math.sqrt((mjr * math.cos(ang))**2 + (mnr * math.sin(ang))**2)
                    if(dist < distToEllipse):
                        label[i*16 + j + 1] = 1 # defection
                        weight[i*16 + j + 1] = 0.9
                    else:
                        label[i*16 + j + 1] = 0
                        weight[i*16 + j + 1] = 0.1
            
            # plot test
            if(k == 1):
                for i in range(16):
                    for j in range(16):
                        if(label[i*16 + j + 1] == 1):
                            plt.plot(i*32, j*32, '.', color='white')
                        else:
                            plt.plot(i*32, j*32, '.', color='black')
                plt.xlim(0, 512)
                plt.ylim(512, 0)
                plt.show()
            
            if(k < 100):
                trnLABEL.append(label)
                trnWEIGHT.append(weight)
            else:
                valLABEL.append(label)
                valWEIGHT.append(weight)
                tstLABEL.append(label)
                tstWEIGHT.append(weight)

        # normalize
        w_array = np.array(trnWEIGHT)
        for k in range(1100):
            s = sum(w_array[k, 1:16*16 + 1])
            w_array[k, 1:16*16 + 1] = w_array[k, 1:16*16 + 1]/s
        trnWEIGHT = w_array.tolist()
        
        w_val_array = np.array(valWEIGHT)
        w_tst_array = np.array(tstWEIGHT)
        for k in range(50):
            s = sum(w_val_array[k, 1:16*16 + 1])
            w_val_array[k, 1:16*16 + 1] = w_val_array[k, 1:16*16 + 1]/s
            s = sum(w_tst_array[k, 1:16*16 + 1])
            w_tst_array[k, 1:16*16 + 1] = w_tst_array[k, 1:16*16 + 1]/s
        
        valWEIGHT = w_val_array.tolist()
        tstWEIGHT = w_tst_array.tolist()
        
        np.savetxt('./data/trainLABEL.txt', trnLABEL, fmt='%d')
        np.savetxt('./data/trainWEIGHT.txt', trnWEIGHT, fmt='%.5f')
        np.savetxt('./data/validationLABEL.txt', valLABEL, fmt='%d')
        np.savetxt('./data/validationWEIGHT.txt', valWEIGHT, fmt='%.5f')
        np.savetxt('./data/testLABEL.txt', tstLABEL, fmt='%d')
        np.savetxt('./data/testWEIGHT.txt', tstWEIGHT, fmt='%.5f')

        sess.close()

重み付きのラベルを、trnWEIGHT、valWEIGHT、tstWEIGHTに格納し、
ファイルに出力しています
傷がある場所を0.9、傷が無い場所を0.1としているのがポイントです

ソース内で# normalizeと書いているのも重要です
tensorflowの損失関数に適用するラベルは、
確率密度関数である必要がある*5ためです

動かしてみる

損失関数のソースです

def loss(output, y, weight):

    xentropy = tf.nn.softmax_cross_entropy_with_logits(logits=output, labels=weight)
    loss = tf.reduce_mean(xentropy)
    return loss

効果を確かめるために、one_hot表現のラベルと、
重み付きのラベルの両方を損失関数に渡せるようにしました
labels=の後をyにすればone_hotに、weightにすれば重み付きになります

f:id:changlikesdesktop:20190305052603p:plain:w400

この結果だけを見るとone_hot表現の方が効率的に学習が進んで行くように見えます
しかし、100回程度計算を繰り返すと、one_hot表現の損失関数は発散します
重み付きの損失関数に、一定の効果があったと考えています
(確率密度関数にならない(=1が複数ある)one_hotラベルは、
そもそも間違っているのですが、、、)

ただ、学習速度は恐ろしく遅いです
一体何回計算を回せば良いことやら、、、(汗)

学習が進まない原因として、 ↓↓↓を予想しています
1. グラボの容量の関係で画像をresizeしている
2. ニューラルネットワークがmnistと同じで単純すぎる
3. そもそも何かが根本的に間違っている

亀の歩みで、また少しずつ調べて行きます