masalibの日記

システム開発、運用と猫の写真ブログです

TensorflowにおけるDeepLabを用いた意味的画像セグメンテーションをやってみた

research.googleblog.com

shiropen.com

を見ました
画像を切り抜く作業をやっていた事があって非常に気になって実際に試してみた
環境はgoogle coloboratoryというgoogle先生機械学習が試せるサイトでやりました

coloboratoryを知らない人は下記の記事を参考にしてください
masalib.hatenablog.com

「DeepLab-V3+1」とは?

f:id:masalib:20180315015307j:plain

画像認識で写っているものを人か動物なのかを判別してくれるものです
他にもそういった画像認識はあるのですが
ピクセルレベルで認識できます

f:id:masalib:20180315011928p:plain

試し方について

基本的には下記のipythonnotebookを参考にしていますが
coloboratory特有の問題も追加しています

github.com

解説

ライブラリーインストール系
# ipywidgetsがはいっていないので
# ipywidgetsをインストールする、セッションがきれたら毎回おこなう
!pip install ipywidgets

# ipywidgets有効にする、セッションがきれたら毎回おこなう
!jupyter nbextension enable --py widgetsnbextension

import collections
import os
import StringIO
import sys
import tarfile
import tempfile
import urllib

from IPython import display
from ipywidgets import interact
from ipywidgets import interactive
from matplotlib import gridspec
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

import tensorflow as tf

# tensorflowが古いと動かない
if tf.__version__ < '1.5.0':
    raise ImportError('Please upgrade your tensorflow installation to v1.5.0 or newer!')
gdriveにアクセスためのライブラリーインストールと設定
# セッションがきれたら毎回おこなう
# google-drive-ocamlfuseのインストール
# https://github.com/astrada/google-drive-ocamlfuse
!apt-get install -y -qq software-properties-common python-software-properties module-init-tools
!add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null
!apt-get update -qq 2>&1 > /dev/null
!apt-get -y install -qq google-drive-ocamlfuse fuse

# Colab用のAuth token作成
from google.colab import auth
auth.authenticate_user()

# Drive FUSE library用のcredential生成
from oauth2client.client import GoogleCredentials
creds = GoogleCredentials.get_application_default()
import getpass
!google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL
vcode = getpass.getpass()
!echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret}

# drive/ を作り、そこにGoogle Driveをマウントする
!mkdir -p drive
!google-drive-ocamlfuse drive

qiita.com
このソースは下記のブログを丸パクリしています
本当にありがとうございます
この記事にも記載してありますが

!ls drive/

で確認するとgdriveの直下が表示されるます

「DeepLab-V3+1」のデータ・セットにアクセスためのライブラリーインストールと設定

https://github.com/tensorflow/models/tree/master/research/deeplabをローカルに落としてきて
gdriveにアップします

f:id:masalib:20180315012626j:plain

#googleDriveに deeplabのライブラリをパス追加する
sys.path.append('drive/deeplab/utils')

# googleDriveにあるファイルをインポートする
import get_dataset_colormap
「DeepLab-V3+1」を呼び出す準備
_MODEL_URLS = {
    'xception_coco_voctrainaug': 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
    'xception_coco_voctrainval': 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}

Config = collections.namedtuple('Config', 'model_url, model_dir')

def get_config(model_name, model_dir):
    return Config(_MODEL_URLS[model_name], model_dir)

config_widget = interactive(get_config, model_name=_MODEL_URLS.keys(), model_dir='')

#エラーがでますが、進めます。原因がわからず・・・諦めました
display.display(config_widget)

_MODEL_URLS = {
    'xception_coco_voctrainaug': 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz',
    'xception_coco_voctrainval': 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz',
}

Config = collections.namedtuple('Config', 'model_url, model_dir')

def get_config(model_name, model_dir):
    return Config(_MODEL_URLS[model_name], model_dir)

config_widget = interactive(get_config, model_name=_MODEL_URLS.keys(), model_dir='')

#エラーになってわからず・・・
display.display(config_widget)

# Check configuration and download the model

_TARBALL_NAME = 'deeplab_model.tar.gz'

config = config_widget.result

model_dir = config.model_dir or tempfile.mkdtemp()
tf.gfile.MakeDirs(model_dir)

download_path = os.path.join(model_dir, _TARBALL_NAME)
print 'downloading model to %s, this might take a while...' % download_path
urllib.urlretrieve(config.model_url, download_path)
print 'download completed!'

_FROZEN_GRAPH_NAME = 'frozen_inference_graph'


class DeepLabModel(object):
    """Class to load deeplab model and run inference."""
    
    INPUT_TENSOR_NAME = 'ImageTensor:0'
    OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
    #リサイズに影響される出力するファイルを変更したい場合はここを修正
    INPUT_SIZE = 513

    def __init__(self, tarball_path):
        """Creates and loads pretrained deeplab model."""
        self.graph = tf.Graph()
        
        graph_def = None
        # Extract frozen graph from tar archive.
        tar_file = tarfile.open(tarball_path)
        for tar_info in tar_file.getmembers():
            if _FROZEN_GRAPH_NAME in os.path.basename(tar_info.name):
                file_handle = tar_file.extractfile(tar_info)
                graph_def = tf.GraphDef.FromString(file_handle.read())
                break

        tar_file.close()
        
        if graph_def is None:
            raise RuntimeError('Cannot find inference graph in tar archive.')

        with self.graph.as_default():      
            tf.import_graph_def(graph_def, name='')
        
        self.sess = tf.Session(graph=self.graph)
            
    def run(self, image):
        """Runs inference on a single image.
        
        Args:
            image: A PIL.Image object, raw input image.
            
        Returns:
            resized_image: RGB image resized from original input image.
            seg_map: Segmentation map of `resized_image`.
        """
        width, height = image.size
        resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
        print 'self.INPUT_SIZE: %s ' % self.INPUT_SIZE
        print 'resize_ratio: %s ' % resize_ratio
# リサイズに影響される出力するファイルを変更したい場合はここを修正
        target_size = (int(resize_ratio * width), int(resize_ratio * height))
        resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
        batch_seg_map = self.sess.run(
            self.OUTPUT_TENSOR_NAME,
            feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
        seg_map = batch_seg_map[0]
        return resized_image, seg_map

model = DeepLabModel(download_path)

LABEL_NAMES = np.asarray([
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
    'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
    'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',
    'train', 'tv'
])

FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
FULL_COLOR_MAP = get_dataset_colormap.label_to_color_image(FULL_LABEL_MAP)


def vis_segmentation(image, seg_map):
    plt.figure(figsize=(15, 5))
    grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])

    plt.subplot(grid_spec[0])
    plt.imshow(image)
    plt.axis('off')
    plt.title('input image')
    
    plt.subplot(grid_spec[1])
    seg_image = get_dataset_colormap.label_to_color_image(
        seg_map, get_dataset_colormap.get_pascal_name()).astype(np.uint8)

    im = np.array(seg_image)
    # 
    pil_img = Image.fromarray(im)
    pil_img.save('drive/deeplab/g3doc/img/image1_4.jpg')
    pil_img.show()

    
    plt.imshow(seg_image)
    plt.axis('off')
    plt.title('segmentation map')

    plt.subplot(grid_spec[2])
    plt.imshow(image)
    plt.imshow(seg_image, alpha=0.7)
    plt.axis('off')
    plt.title('segmentation overlay')
    
    unique_labels = np.unique(seg_map)
    ax = plt.subplot(grid_spec[3])
    plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
    ax.yaxis.tick_right()
    plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
    plt.xticks([], [])
    ax.tick_params(width=0)


vis_segmentationの関数の
pil_img.save('drive/deeplab/g3doc/img/image1_4.jpg')は結果ファイルを保存しています

元画像をリサイズして実行しています
もと画像のままやりたい場合はリサイズの設定変更が必要みたいです

ローカルファイルで実行する

GDRIVEですが、ローカルにあるファイルをもとに
作成する部分です

IMAGE_DIR = 'drive/deeplab/g3doc/img'

def run_demo_image(image_name):
    try:
        image_path = os.path.join(IMAGE_DIR, image_name)
        orignal_im = Image.open(image_path)
    except IOError:
        print 'Failed to read image from %s.' % image_path 
        return 
    print 'running deeplab on image %s...' % image_name
    resized_im, seg_map = model.run(orignal_im)
    
    vis_segmentation(resized_im, seg_map)

_ = interact(run_demo_image, image_name=['image1.jpg'])

実行結果は以下のとおりです
f:id:masalib:20180315012600j:plain

urlで指定した画像で実行する

自分のブログで使った画像をテストしました

url = 'https://cdn-ak.f.st-hatena.com/images/fotolife/m/masalib/20150818/20150818000442.jpg'
def get_an_internet_image(url):
    print 'test url: ' + url
    if not url:
        return

    try:
        # Prefix with 'file://' for local file.
        if os.path.exists(url):
            url = 'file://' + url
        f = urllib.urlopen(url)
        jpeg_str = f.read()
    except IOError:
        print 'invalid url: ' + url
        return

    orignal_im = Image.open(StringIO.StringIO(jpeg_str))
    print 'running deeplab on image %s...' % url
    resized_im, seg_map = model.run(orignal_im)
    print 'run end '

    vis_segmentation(resized_im, seg_map)
    
##_ = interact(get_an_internet_image, url)

get_an_internet_image(url)


結果
f:id:masalib:20180315012517j:plain

当たり前ですが、きれいに分類されました

実際に試したい人は・・・

下記のURLより試す事ができます
https://colab.research.google.com/drive/14ElD58dMoBtsUyr0FghVAeCISWe-ePr-

今後について

vpsとかのサーバーとかで動かせるようにしたい