スポンサーリンク

【最短】PyTorchでjpeg/png画像を読みこむread_image、保存するwrite_jpeg/write_png【Python】

Image

PyTorchで画像を読み込むにはtorchvision.io.read_imageを使う。また、保存はtorchvision.io.write_jpeg/torchvision.io.write_pngを使う。

今回用いるサンプル画像をWEBからダウロード、保存します。

import requests

url = "https://github.com/opencv/opencv/blob/master/samples/data/butterfly.jpg?raw=true"
file_name = "butterfly.jpg"

response = requests.get(url)
image = response.content

with open(file_name, "wb") as f:
    f.write(image)

画像を読みこむread_image

read_imageの第一引数pathに読み込む画像のパスを指定する。

from torchvision.io import read_image
import matplotlib.pyplot as plt

image = read_image(path=file_name)

plt.imshow(image.permute(1,2,0)); # 画像の表示

グレースケール画像で読み込みたい場合は、第二引数のmodeImageReadMode.GRAYを指定する。元々はカラー画像ですが、グレー画像として読み込まれるため、チャンネル数が1であることが確認できます。

from torchvision.io import ImageReadMode

image_gray = read_image(path=file_name, mode=ImageReadMode.GRAY)

print(image_gray.shape)
# torch.Size([1, 356, 493])

plt.imshow(image_gray.squeeze(), cmap='gray'); # 画像の表示

画像を保存するwrite_jpeg/write_png

画像をjpegで保存する場合はwrite_jpeg、pngで保存する場合はwrite_pngを用いる。第一引数inputにTensor配列[チャンネル, 高さ, 幅]、第二引数filenameに出力画像のファイル名を指定する。

from torchvision.io import write_jpeg, write_png

output_name = 'out'

write_jpeg(input=image, filename=output_name + '.jpg')

write_png(input=image, filename=output_name + '.png')

関連記事、参考記事

コメント