を見ました
画像を切り抜く作業をやっていた事があって非常に気になって実際に試してみた
環境はgoogle coloboratoryというgoogle先生の機械学習が試せるサイトでやりました
coloboratoryを知らない人は下記の記事を参考にしてください
masalib.hatenablog.com
解説
ライブラリーインストール系
# ipywidgetsがはいっていないので # ipywidgetsをインストールする、セッションがきれたら毎回おこなう !pip install ipywidgets # ipywidgets有効にする、セッションがきれたら毎回おこなう !jupyter nbextension enable --py widgetsnbextension import collections import os import StringIO import sys import tarfile import tempfile import urllib from IPython import display from ipywidgets import interact from ipywidgets import interactive from matplotlib import gridspec from matplotlib import pyplot as plt import numpy as np from PIL import Image import tensorflow as tf # tensorflowが古いと動かない if tf.__version__ < '1.5.0': raise ImportError('Please upgrade your tensorflow installation to v1.5.0 or newer!')
gdriveにアクセスためのライブラリーインストールと設定
# セッションがきれたら毎回おこなう # google-drive-ocamlfuseのインストール # https://github.com/astrada/google-drive-ocamlfuse !apt-get install -y -qq software-properties-common python-software-properties module-init-tools !add-apt-repository -y ppa:alessandro-strada/ppa 2>&1 > /dev/null !apt-get update -qq 2>&1 > /dev/null !apt-get -y install -qq google-drive-ocamlfuse fuse # Colab用のAuth token作成 from google.colab import auth auth.authenticate_user() # Drive FUSE library用のcredential生成 from oauth2client.client import GoogleCredentials creds = GoogleCredentials.get_application_default() import getpass !google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} < /dev/null 2>&1 | grep URL vcode = getpass.getpass() !echo {vcode} | google-drive-ocamlfuse -headless -id={creds.client_id} -secret={creds.client_secret} # drive/ を作り、そこにGoogle Driveをマウントする !mkdir -p drive !google-drive-ocamlfuse drive
qiita.com
このソースは下記のブログを丸パクリしています
本当にありがとうございます
この記事にも記載してありますが
!ls drive/
で確認するとgdriveの直下が表示されるます
「DeepLab-V3+1」のデータ・セットにアクセスためのライブラリーインストールと設定
https://github.com/tensorflow/models/tree/master/research/deeplabをローカルに落としてきて
gdriveにアップします
#googleDriveに deeplabのライブラリをパス追加する sys.path.append('drive/deeplab/utils') # googleDriveにあるファイルをインポートする import get_dataset_colormap
「DeepLab-V3+1」を呼び出す準備
_MODEL_URLS = { 'xception_coco_voctrainaug': 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz', 'xception_coco_voctrainval': 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz', } Config = collections.namedtuple('Config', 'model_url, model_dir') def get_config(model_name, model_dir): return Config(_MODEL_URLS[model_name], model_dir) config_widget = interactive(get_config, model_name=_MODEL_URLS.keys(), model_dir='') #エラーがでますが、進めます。原因がわからず・・・諦めました display.display(config_widget) _MODEL_URLS = { 'xception_coco_voctrainaug': 'http://download.tensorflow.org/models/deeplabv3_pascal_train_aug_2018_01_04.tar.gz', 'xception_coco_voctrainval': 'http://download.tensorflow.org/models/deeplabv3_pascal_trainval_2018_01_04.tar.gz', } Config = collections.namedtuple('Config', 'model_url, model_dir') def get_config(model_name, model_dir): return Config(_MODEL_URLS[model_name], model_dir) config_widget = interactive(get_config, model_name=_MODEL_URLS.keys(), model_dir='') #エラーになってわからず・・・ display.display(config_widget) # Check configuration and download the model _TARBALL_NAME = 'deeplab_model.tar.gz' config = config_widget.result model_dir = config.model_dir or tempfile.mkdtemp() tf.gfile.MakeDirs(model_dir) download_path = os.path.join(model_dir, _TARBALL_NAME) print 'downloading model to %s, this might take a while...' % download_path urllib.urlretrieve(config.model_url, download_path) print 'download completed!' _FROZEN_GRAPH_NAME = 'frozen_inference_graph' class DeepLabModel(object): """Class to load deeplab model and run inference.""" INPUT_TENSOR_NAME = 'ImageTensor:0' OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' #リサイズに影響される出力するファイルを変更したい場合はここを修正 INPUT_SIZE = 513 def __init__(self, tarball_path): """Creates and loads pretrained deeplab model.""" self.graph = tf.Graph() graph_def = None # Extract frozen graph from tar archive. tar_file = tarfile.open(tarball_path) for tar_info in tar_file.getmembers(): if _FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): file_handle = tar_file.extractfile(tar_info) graph_def = tf.GraphDef.FromString(file_handle.read()) break tar_file.close() if graph_def is None: raise RuntimeError('Cannot find inference graph in tar archive.') with self.graph.as_default(): tf.import_graph_def(graph_def, name='') self.sess = tf.Session(graph=self.graph) def run(self, image): """Runs inference on a single image. Args: image: A PIL.Image object, raw input image. Returns: resized_image: RGB image resized from original input image. seg_map: Segmentation map of `resized_image`. """ width, height = image.size resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) print 'self.INPUT_SIZE: %s ' % self.INPUT_SIZE print 'resize_ratio: %s ' % resize_ratio # リサイズに影響される出力するファイルを変更したい場合はここを修正 target_size = (int(resize_ratio * width), int(resize_ratio * height)) resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) batch_seg_map = self.sess.run( self.OUTPUT_TENSOR_NAME, feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) seg_map = batch_seg_map[0] return resized_image, seg_map model = DeepLabModel(download_path) LABEL_NAMES = np.asarray([ 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tv' ]) FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1) FULL_COLOR_MAP = get_dataset_colormap.label_to_color_image(FULL_LABEL_MAP) def vis_segmentation(image, seg_map): plt.figure(figsize=(15, 5)) grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1]) plt.subplot(grid_spec[0]) plt.imshow(image) plt.axis('off') plt.title('input image') plt.subplot(grid_spec[1]) seg_image = get_dataset_colormap.label_to_color_image( seg_map, get_dataset_colormap.get_pascal_name()).astype(np.uint8) im = np.array(seg_image) # pil_img = Image.fromarray(im) pil_img.save('drive/deeplab/g3doc/img/image1_4.jpg') pil_img.show() plt.imshow(seg_image) plt.axis('off') plt.title('segmentation map') plt.subplot(grid_spec[2]) plt.imshow(image) plt.imshow(seg_image, alpha=0.7) plt.axis('off') plt.title('segmentation overlay') unique_labels = np.unique(seg_map) ax = plt.subplot(grid_spec[3]) plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest') ax.yaxis.tick_right() plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels]) plt.xticks([], []) ax.tick_params(width=0)
vis_segmentationの関数の
pil_img.save('drive/deeplab/g3doc/img/image1_4.jpg')は結果ファイルを保存しています
元画像をリサイズして実行しています
もと画像のままやりたい場合はリサイズの設定変更が必要みたいです
ローカルファイルで実行する
GDRIVEですが、ローカルにあるファイルをもとに
作成する部分です
IMAGE_DIR = 'drive/deeplab/g3doc/img' def run_demo_image(image_name): try: image_path = os.path.join(IMAGE_DIR, image_name) orignal_im = Image.open(image_path) except IOError: print 'Failed to read image from %s.' % image_path return print 'running deeplab on image %s...' % image_name resized_im, seg_map = model.run(orignal_im) vis_segmentation(resized_im, seg_map) _ = interact(run_demo_image, image_name=['image1.jpg'])
実行結果は以下のとおりです
urlで指定した画像で実行する
自分のブログで使った画像をテストしました
url = 'https://cdn-ak.f.st-hatena.com/images/fotolife/m/masalib/20150818/20150818000442.jpg' def get_an_internet_image(url): print 'test url: ' + url if not url: return try: # Prefix with 'file://' for local file. if os.path.exists(url): url = 'file://' + url f = urllib.urlopen(url) jpeg_str = f.read() except IOError: print 'invalid url: ' + url return orignal_im = Image.open(StringIO.StringIO(jpeg_str)) print 'running deeplab on image %s...' % url resized_im, seg_map = model.run(orignal_im) print 'run end ' vis_segmentation(resized_im, seg_map) ##_ = interact(get_an_internet_image, url) get_an_internet_image(url)
結果
当たり前ですが、きれいに分類されました
実際に試したい人は・・・
下記のURLより試す事ができます
https://colab.research.google.com/drive/14ElD58dMoBtsUyr0FghVAeCISWe-ePr-
今後について
vpsとかのサーバーとかで動かせるようにしたい