영상처리/3D computer vision
RAFT: Recurrent All-Pairs Field Transforms for Optical Flow 돌려보기
난개발자
2023. 10. 30. 20:56
728x90
RAFT를 이용해 dense optical flow를 계산해보고자 한다.
아래와 같이, torchvision.models.optical_flow로부터 이미 train된 raft 모델을 가져와, image와 optical flow를 overlay 해보 았다.
import numpy as np
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as F
import torchvision.transforms as T
import cv2
from torchvision.models.optical_flow import raft_small
from torchvision.models.optical_flow import raft_large
def preprocess(batch):
transforms = T.Compose(
[
T.ConvertImageDtype(torch.float32),
T.Normalize(mean=0.5, std=0.5), # map [0, 1] into [-1, 1]
T.Resize(size=(512, 512)),
]
)
batch = transforms(batch)
return batch
def add_arrow_line(img,flow):
for y in range(16,flow.shape[1]-16,32):
for x in range(16,flow.shape[2]-16,32):
cv2.arrowedLine(img, (x,y), (x+int(flow[0][y][x]),y+int(flow[1][y][x])), (255,0,0))
return img
device = "cuda" if torch.cuda.is_available() else "cpu"
img1= cv2.imread('20231014_233154.jpg')
img2= cv2.imread('20231014_233150.jpg')
# model = raft_small(progress=False).to(device)
model = raft_large(pretrained=True, progress=False).to(device)
model = model.eval()
img1_stack=torch.stack([torch.permute(torch.tensor(img1),(2,0,1))]) # tensor shape (N, C, H, W)
img2_stack=torch.stack([torch.permute(torch.tensor(img2),(2,0,1))]) # tensor shape (N, C, H, W)
flow = model(preprocess(img1_stack).to(device), preprocess(img2_stack).to(device))
# plot
cv2.imshow('overlay', add_arrow_line( cv2.resize(cv2.addWeighted(img1,0.5,img2,0.5,0),(512,512)),flow[0][0]))
cv2.waitKey()
728x90