3D Deep Learning について学ぶ - pytorch3d

はじめに

3D オブジェクトを扱うDeep Learning技術について知りたいと
思ったので、pytorch3d1チュートリアルを行った。

概要

1次元的なデータを取り扱う自然言語処理や音声信号処理、
2次元的なデータを取り扱う画像処理だけでなく、
3Dオブジェクトを取り扱う Deep Learning アルゴリズム
近年多く提案されている。

今回は3Dオブジェクトを取り扱う処理がまとめられている、
pytorch3dライブラリのチュートリアルを行うことで3D Deep Learningに触れてみる。

pytorch3d 2

pytorch3dはpytorchをベースに3D Deep Learningタスクにおいて、
必要な処理が実装、最適化されているライブラリである。

  • メッシュ・テクスチャの入出力、汎用処理
  • 微分可能なrenderer
  • 損失関数

などを提供している。

カメラ位置最適化 (チュートリアル)

Camera position optimizationチュートリアルをやってみる。

このチュートリアルでは、ティーポットオブジェクトに対して、
カメラの撮影場所をrendererによって出力される画像から
最適化するというものである。

流れとしては

  1. 目標となるカメラ位置からティーポット画像を撮影する(目標画像)。
  2. 適当な位置にカメラを置く。
  3. 適当な位置からティーポット画像を撮影して目標画像との誤差を計算。
  4. 誤差を減らすようにbackpropagationでカメラ位置を調整

となる。

pytorch3dは微分可能なrendererが実装されているので
このようなこともできるよ、というチュートリアルだと思う。
(実用的には画像の美しさを定量化して、 それを最大化する
カメラ位置を求める、とかに利用できなくもない?)

環境設定

Google Colab上で検証した。

環境設定・モジュールimport

!pip install torch torchvision
!pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
:
import os
import torch
import numpy as np
from tqdm.notebook import tqdm
import imageio
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from skimage import img_as_ubyte

# io utils
from pytorch3d.io import load_obj

# datastructures
from pytorch3d.structures import Meshes, Textures

# 3D transformations functions
from pytorch3d.transforms import Rotate, Translate

# rendering components
from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation, 
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights
)

データ読み込み

!mkdir -p data
!wget -P data https://dl.fbaipublicfiles.com/pytorch3d/data/teapot/teapot.obj

# Set the cuda device 
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Load the obj and ignore the textures and materials.
verts, faces_idx, _ = load_obj("./data/teapot.obj")
faces = faces_idx.verts_idx

# Initialize each vertex to be white in color.
verts_rgb = torch.ones_like(verts)[None]  # (1, V, 3)
textures = Textures(verts_rgb=verts_rgb.to(device))

# Create a Meshes object for the teapot. Here we have only one mesh in the batch.
teapot_mesh = Meshes(
    verts=[verts.to(device)],   
    faces=[faces.to(device)], 
    textures=textures
)

obj形式のファイルを読み込み、頂点情報(verts)と面情報(faces_idx)を得る。
その後頂点に対応するテクスチャをそれぞれ1で初期化し、
3Dオブジェクトに対応するMeshesオブジェクトを構成している。
pytorch3dではこのMeshesオブジェクトを起点に様々な処理を行う。

rendererの設定

rendererの設定

# Initialize an OpenGL perspective camera.
cameras = OpenGLPerspectiveCameras(device=device)

# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of 
# edges. Refer to blending.py for more details. 
blend_params = BlendParams(sigma=1e-4, gamma=1e-4)

# Define the settings for rasterization and shading. Here we set the output image to be of size
# 256x256. To form the blended image we use 100 faces for each pixel. We also set bin_size and max_faces_per_bin to None which ensure that 
# the faster coarse-to-fine rasterization method is used. Refer to rasterize_meshes.py for 
# explanations of these parameters. Refer to docs/notes/renderer.md for an explanation of 
# the difference between naive and coarse-to-fine rasterization. 
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=np.log(1. / 1e-4 - 1.) * blend_params.sigma, 
    faces_per_pixel=100, 
)

# Create a silhouette mesh renderer by composing a rasterizer and a shader. 
silhouette_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=SoftSilhouetteShader(blend_params=blend_params)
)


# We will also create a phong renderer. This is simpler and only needs to render one face per pixel.
raster_settings = RasterizationSettings(
    image_size=256, 
    blur_radius=0.0, 
    faces_per_pixel=1, 
)
# We can add a point light in front of the object. 
lights = PointLights(device=device, location=((2.0, 2.0, -2.0),))
phong_renderer = MeshRenderer(
    rasterizer=MeshRasterizer(
        cameras=cameras, 
        raster_settings=raster_settings
    ),
    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)
)

カメラ、ライト、シェーダーなどからrendererを設定している。
ティーポットのシルエットのみを表示する silhouette_renderer と
陰影情報も表示するphong_rendererの2つのrendererを用意している。

描画

# Select the viewpoint using spherical angles  
distance = 3   # distance from camera to the object
elevation = 50.0   # angle of elevation in degrees
azimuth = 0.0  # No rotation so the camera is positioned on the +Z axis. 

# Get the position of the camera based on the spherical angles
R, T = look_at_view_transform(distance, elevation, azimuth, device=device)

# Render the teapot providing the values of R and T. 
silhouete = silhouette_renderer(meshes_world=teapot_mesh, R=R, T=T)
image_ref = phong_renderer(meshes_world=teapot_mesh, R=R, T=T)

silhouete = silhouete.cpu().numpy()
image_ref = image_ref.cpu().numpy()

plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(silhouete.squeeze()[..., 3])  # only plot the alpha channel of the RGBA image
plt.grid(False)
plt.subplot(1, 2, 2)
plt.imshow(image_ref.squeeze())
plt.grid(False)

f:id:nakamrnk:20200721113842p:plain

左: silhouette_renderer(matplotlibにより着色) , 右: phong_rendererの描画結果

上の画像を目標としてカメラ位置を動かすモデルを構築する。

モデルと損失関数

class Model(nn.Module):
    def __init__(self, meshes, renderer, image_ref):
        super().__init__()
        self.meshes = meshes
        self.device = meshes.device
        self.renderer = renderer
        
        # Get the silhouette of the reference RGB image by finding all the non zero values. 
        image_ref = torch.from_numpy((image_ref[..., :3].max(-1) != 0).astype(np.float32))
        self.register_buffer('image_ref', image_ref)
        
        # Create an optimizable parameter for the x, y, z position of the camera. 
        self.camera_position = nn.Parameter(
            torch.from_numpy(np.array([3.0,  6.9, +2.5], dtype=np.float32)).to(meshes.device))

    def forward(self):
        
        # Render the image using the updated camera position. Based on the new position of the 
        # camer we calculate the rotation and translation matrices
        R = look_at_rotation(self.camera_position[None, :], device=self.device)  # (1, 3, 3)
        T = -torch.bmm(R.transpose(1, 2), self.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        
        image = self.renderer(meshes_world=self.meshes.clone(), R=R, T=T)
        
        # Calculate the silhouette loss
        loss = torch.sum((image[..., 3] - self.image_ref) ** 2)
        return loss, image

# We will save images periodically and compose them into a GIF.
filename_output = "./teapot_optimization_demo.gif"
writer = imageio.get_writer(filename_output, mode='I', duration=0.3)

# Initialize a model using the renderer, mesh and reference image
model = Model(meshes=teapot_mesh, renderer=silhouette_renderer, image_ref=image_ref).to(device)

# Create an optimizer. Here we are using Adam and we pass in the parameters of the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

モデルは最適化するパラメータとしてカメラ位置(camera_position)を持っている。
損失関数は現在のカメラ位置でrenderingした画像と参考画像との二乗誤差であり、
これの最小化を目指す。
位置さえ分かればいいので、出力のうちアルファチャンネルのみを利用している。

学習

loop = tqdm(range(200))
for i in loop:
    optimizer.zero_grad()
    loss, _ = model()
    loss.backward()
    optimizer.step()
    
    loop.set_description('Optimizing (loss %.4f)' % loss.data)
    
    if loss.item() < 200:
        break
    
    # Save outputs to create a GIF. 
    if i % 10 == 0:
        R = look_at_rotation(model.camera_position[None, :], device=model.device)
        T = -torch.bmm(R.transpose(1, 2), model.camera_position[None, :, None])[:, :, 0]   # (1, 3)
        image = phong_renderer(meshes_world=model.meshes.clone(), R=R, T=T)
        image = image[0, ..., :3].detach().squeeze().cpu().numpy()
        image = img_as_ubyte(image)
        writer.append_data(image)
        
        plt.figure()
        plt.imshow(image[..., :3])
        plt.title("iter: %d, loss: %0.2f" % (i, loss.data))
        plt.grid("off")
        plt.axis("off")
    
writer.close()

結果

学習経過ムービー

f:id:nakamrnk:20200721123504g:plain

学習初期に遠くにあったティーポットに徐々にカメラが近づいている。
最終的には元の参考画像に近い位置に落ち着いている。

まとめ

今回はpytorch3dのチュートリアルに軽く触れてみた。
基本的な学習プロセス(前処理、モデルの構築、学習)は
pytorchと変わらないので、比較的触りやすかった。
rendererがbackpropagation可能なところはおもしろく、
目的関数を工夫すればいろいろできそうだと思った。

参考文献