PyTorchで画像を読み込むにはtorchvision.io.read_image
を使う。また、保存はtorchvision.io.write_jpeg
/torchvision.io.write_png
を使う。
今回用いるサンプル画像をWEBからダウロード、保存します。
import requests
url = "https://github.com/opencv/opencv/blob/master/samples/data/butterfly.jpg?raw=true"
file_name = "butterfly.jpg"
response = requests.get(url)
image = response.content
with open(file_name, "wb") as f:
f.write(image)
画像を読みこむread_image
read_imageの第一引数pathに読み込む画像のパスを指定する。
- torchvision.io.read_image – PyTorch torchvision v0.13 Docs
- torchvision.io.ImageReadMode – PyTorch torchvision v0.13 Docs
- 関連記事 – 【PyTorch】多次元配列の次元(軸)を任意の順番に入れ替えるtorch.permute
from torchvision.io import read_image
import matplotlib.pyplot as plt
image = read_image(path=file_name)
plt.imshow(image.permute(1,2,0)); # 画像の表示
グレースケール画像で読み込みたい場合は、第二引数のmodeにImageReadMode.GRAYを指定する。元々はカラー画像ですが、グレー画像として読み込まれるため、チャンネル数が1であることが確認できます。
from torchvision.io import ImageReadMode
image_gray = read_image(path=file_name, mode=ImageReadMode.GRAY)
print(image_gray.shape)
# torch.Size([1, 356, 493])
plt.imshow(image_gray.squeeze(), cmap='gray'); # 画像の表示
画像を保存するwrite_jpeg/write_png
画像をjpegで保存する場合はwrite_jpeg、pngで保存する場合はwrite_pngを用いる。第一引数inputにTensor配列[チャンネル, 高さ, 幅]、第二引数filenameに出力画像のファイル名を指定する。
from torchvision.io import write_jpeg, write_png
output_name = 'out'
write_jpeg(input=image, filename=output_name + '.jpg')
write_png(input=image, filename=output_name + '.png')
コメント