6.8301-Project/decoding_utils.py

383 lines
14 KiB
Python

import itertools
import math
import time
import cv2
import numpy as np
import torch
import torchvision
import torchvision.transforms.v2 as transforms
import torchvision.transforms.v2.functional as transforms_f
from matplotlib import pyplot as plt
from corner_training.models import QuantizedV2, QuantizedV5
from corner_training.utils import get_gaussian_filter, get_bounded_slices
torch.backends.quantized.engine = 'qnnpack'
def localize_corners_wrapper(args, debug=False):
stage1_model_checkpt_path = "checkpts/QuantizedV2_Stage1_128_9.pt"
stage2_model_checkpt_path = "checkpts/QuantizedV5_Stage2_128_9.pt"
stage1_model = QuantizedV2()
stage1_model.eval()
stage1_model.fuse_modules(is_qat=False)
stage1_model.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.prepare(stage1_model, inplace=True)
torch.ao.quantization.convert(stage1_model, inplace=True)
stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
stage2_model = QuantizedV5()
stage2_model.eval()
stage2_model.fuse_modules(is_qat=False)
stage2_model.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.prepare(stage2_model, inplace=True)
torch.ao.quantization.convert(stage2_model, inplace=True)
stage2_model.load_state_dict(torch.load(stage2_model_checkpt_path, map_location=torch.device('cpu')))
stage1_model.eval()
stage2_model.eval()
stage1_size = 128
assert stage1_size & 1 == 0, "Assuming even size when dividing into quadrants"
np_to_fp32_tensor = transforms.Compose([
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
])
preprocess_img_stage1 = transforms.Compose([
transforms.Lambda(lambda img: resize_keep_aspect(img, stage1_size)),
np_to_fp32_tensor,
])
gaussian_filter = get_gaussian_filter(4, 4) # for stage1 NMS heuristic
# Transform cropped corners until they all look like top left corners, as that's what the model is trained on
transforms_by_corner = [
lambda img: img, # identity
transforms_f.hflip,
transforms_f.vflip,
lambda img: transforms_f.vflip(transforms_f.hflip(img))
]
inv_transforms_by_corner = transforms_by_corner # flipping is a self-inverse
def localize_corners(cropped_frame: np.ndarray):
"""
Args:
cropped_frame: Square numpy array
"""
orig_h, orig_w, _ = cropped_frame.shape
stage2_size = max(orig_h, orig_w) // 16
upscale_factor = min(orig_w, orig_h) / stage1_size # for stage 2
start_time = time.time()
stage1_img = preprocess_img_stage1(cropped_frame)
if debug:
print(54, time.time() - start_time)
with torch.no_grad():
stage1_pred = stage1_model(stage1_img.unsqueeze(0)).squeeze(0)
# plt.imshow(stage1_pred.detach().cpu())
# plt.show()
if debug:
print(57, time.time() - start_time)
quad_h = stage1_img.size(1) // 2 # might miss 1 pixel on edge if odd
quad_w = stage1_img.size(2) // 2
corners_by_quad = dict()
for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
for left_half in (0, 1):
quad_i_start = quad_h * (1 - top_half)
quad_j_start = quad_w * (1 - left_half)
curr_quad_preds = stage1_pred[
quad_i_start: quad_i_start + quad_h,
quad_j_start: quad_j_start + quad_w,
].clone()
max_locs = []
for i in range(6): # expect 4 points, but get top 6 to be safe
max_ind = torch.argmax(curr_quad_preds).item() # TODO: more efficient like segtree, maybe account for neighbors too
max_loc = (max_ind // quad_w, max_ind % quad_w)
max_locs.append(max_loc)
# TODO: improve, maybe scale Gaussian peak to val of max_loc, probably better to not subtract from a location multiple times
preds_slice, gaussian_slice = get_bounded_slices((quad_h, quad_w), gaussian_filter.size(),
*max_loc)
curr_quad_preds[preds_slice] -= gaussian_filter[gaussian_slice]
if debug:
print(f"{max_locs=}")
min_cost = 1e9
min_square = None
for potential_combo in itertools.combinations(max_locs, 4):
curr_pts, curr_cost = score_combo(potential_combo)
if curr_cost < min_cost:
min_cost = curr_cost
min_square = curr_pts
if min_square is None:
print("all collinear")
return None
corners_by_quad[(1 - top_half) * 2 + (1 - left_half)] = [(i + quad_i_start, j + quad_j_start) for (i, j)
in min_square]
if debug:
print(92, time.time() - start_time)
print(corners_by_quad)
outer_corners = []
# corner_colors = [] # by center, currently rounding to the pixel in the original image
origin = (quad_h, quad_w)
for quad in range(4): # TODO: consistent (x, y) or (i, j)
outer_corners.append(max((l2_dist(corner, origin), corner) for corner in corners_by_quad[quad])[1])
# corner_colors.append(cropped_frame[int((sum(corner[0] for corner in corners_by_quad[quad]) / 4 * upscale_factor)),
# int((sum(corner[1] for corner in corners_by_quad[quad]) / 4 * upscale_factor))]
# .astype(np.float64))
stage2_imgs = []
for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
for left_half in (0, 1):
corner_ind = top_half * 2 + left_half
y, x = outer_corners[corner_ind]
upscaled_y, upscaled_x = round(y * upscale_factor), round(x * upscale_factor)
top = max(0, upscaled_y - stage2_size // 2)
bottom = min(orig_h, upscaled_y + stage2_size // 2)
left = max(0, upscaled_x - stage2_size // 2)
right = min(orig_w, upscaled_x + stage2_size // 2)
# Need padding if detected corner is within `stage2_size // 2` of border
corner_padding = [0] * 4 # pad the side that does not affect extracted coordinates
corner_padding[(1 - top_half) * 2 + 1] = stage2_size - (bottom - top)
corner_padding[(1 - left_half) * 2] = stage2_size - (right - left)
cropped_corner_img = transforms_f.pad( # TODO: don't pad since that should speed up inference
np_to_fp32_tensor(cropped_frame[top:bottom, left:right]),
corner_padding
)
stage2_imgs.append(cropped_corner_img)
transformed_corner_imgs = torch.stack([transforms_by_corner[corner_ind](stage2_img)
for corner_ind, stage2_img in enumerate(stage2_imgs)])
if debug:
print(121, time.time() - start_time)
with torch.no_grad():
transformed_preds = stage2_model(transformed_corner_imgs)
if debug:
print(125, time.time() - start_time)
transformed_pred_pts = [
torchvision.tv_tensors.BoundingBoxes(
[(max_ind := pred.argmax()) % stage2_size, max_ind // stage2_size, 0, 0],
format="XYWH", canvas_size=(stage2_size, stage2_size)
)
for pred in transformed_preds
]
stage2_pred_pts = [inv_transforms_by_corner[corner_ind](transformed_pred_pt)[0, :2].tolist()
for corner_ind, transformed_pred_pt in enumerate(transformed_pred_pts)]
if debug:
print(137, time.time() - start_time)
orig_pred_pts = [(round(orig_x * upscale_factor) + stage2_pred_x - stage2_size // 2,
round(orig_y * upscale_factor) + stage2_pred_y - stage2_size // 2)
for (orig_y, orig_x), (stage2_pred_x, stage2_pred_y) in zip(outer_corners, stage2_pred_pts)]
if debug:
print(142, time.time() - start_time)
# plt.imshow(cropped_frame)
# plt.scatter(np.array(orig_pred_pts).T[0], np.array(orig_pred_pts).T[1])
# plt.show()
corner_size = int(max(args.height, args.width) * 0.16)
qtr_corner_size = corner_size // 4
grid_coords = np.float32([
[qtr_corner_size, qtr_corner_size],
[args.width - qtr_corner_size, qtr_corner_size],
[qtr_corner_size, args.height - qtr_corner_size],
[args.width - qtr_corner_size, args.height - qtr_corner_size],
])
grid_coords -= 1/2
M = cv2.getPerspectiveTransform(
np.float32(orig_pred_pts),
grid_coords,
)
cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
# cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height), flags=cv2.INTER_NEAREST)
padding = math.ceil(max(args.height, args.width) / 80) # arbitrary
# guessing wildly on +/- 1s
white_sq = cropped_frame[qtr_corner_size + padding: corner_size - qtr_corner_size - padding,
qtr_corner_size + padding: corner_size - qtr_corner_size - padding]
red_sq = cropped_frame[qtr_corner_size + padding: corner_size - qtr_corner_size - padding,
args.width - corner_size + qtr_corner_size + padding: args.width - qtr_corner_size - padding]
green_sq = cropped_frame[args.height - corner_size + qtr_corner_size + padding: args.height - qtr_corner_size - padding,
qtr_corner_size + padding: corner_size - qtr_corner_size - padding]
blue_sq = cropped_frame[args.height - corner_size + qtr_corner_size + padding: args.height - qtr_corner_size - padding,
args.width - corner_size + qtr_corner_size + padding: args.width - qtr_corner_size - padding]
corner_colors = [white_sq.mean(axis=(0, 1)), red_sq.mean(axis=(0, 1)),
green_sq.mean(axis=(0, 1)), blue_sq.mean(axis=(0, 1))]
# fig, axs = plt.subplots(2, 3)
#
# axs[0, 2].imshow(cropped_frame)
# axs[0, 0].imshow(white_sq)
# axs[0, 1].imshow(red_sq)
# axs[1, 0].imshow(green_sq)
# axs[1, 1].imshow(blue_sq)
# plt.show()
return cropped_frame, corner_colors
return localize_corners
def l2_dist(loc, origin):
""" No sqrt """
return (loc[0] - origin[0]) ** 2 + (loc[1] - origin[1]) ** 2
def score_combo(combo):
"""
Plan:
1. Check if pts are convex. If no, very bad quadrilateral.
2. Check if diagonal lengths are within a factor of 1.5. If no, somewhat bad since far from right angles.
3. If the above are satisfied, then simply return how close the side lengths are to being equal.
"""
hull = convex_hull([Point(x, y) for x, y in combo]) # TODO: check how collinear case is handled
hull = [(pt.x, pt.y) for pt in hull] # convert back to tuple
if len(hull) != 4:
return None, 1e9
squared_diag0 = l2_dist(hull[0], hull[2])
squared_diag1 = l2_dist(hull[1], hull[3])
if squared_diag0 < squared_diag1: # swap so that diag0 is larger
squared_diag0, squared_diag1 = squared_diag1, squared_diag0
if squared_diag0 / squared_diag1 > 1.5**2:
return hull, 1e8
cyclic_pts = hull + [hull[0]]
side_lens = [l2_dist(cyclic_pts[i], cyclic_pts[i + 1]) for i in range(4)]
return hull, (max(side_lens) - min(side_lens)) / min(side_lens)
def resize_keep_aspect(img: np.ndarray, min_len: int) -> np.ndarray:
h, w, _ = img.shape
if h < w:
output_size = (round(min_len * w / h), min_len)
else:
output_size = (min_len, round(min_len * h / w))
return cv2.resize(img, output_size, interpolation=cv2.INTER_NEAREST)
# Gift wrapping code, adapted from GeeksForGeeks.
# "This code is contributed by Akarsh Somani, IIIT Kalyani"
class Point:
def __init__(self, x, y):
self.x = x
self.y = y
def left_index(points):
"""
Finding the left most point
"""
minn = 0
for i in range(1,len(points)):
if points[i].x < points[minn].x:
minn = i
elif points[i].x == points[minn].x:
if points[i].y > points[minn].y:
minn = i
return minn
def orientation(p, q, r):
"""
To find orientation of ordered triplet (p, q, r).
The function returns following values
0 --> p, q and r are collinear
1 --> Clockwise
2 --> Counterclockwise
"""
val = (q.y - p.y) * (r.x - q.x) - \
(q.x - p.x) * (r.y - q.y)
if val == 0:
return 0
elif val > 0:
return 1
else:
return 2
def convex_hull(points):
n = len(points)
assert n >= 3, "There must be at least 3 points."
# Find the leftmost point
l = left_index(points)
hull = []
'''
Start from leftmost point, keep moving counterclockwise
until reach the start point again. This loop runs O(h)
times where h is number of points in result or output.
'''
p = l
q = 0
while True:
# Add current point to result
hull.append(points[p])
'''
Search for a point 'q' such that orientation(p, q,
x) is counterclockwise for all points 'x'. The idea
is to keep track of last visited most counterclock-
wise point in q. If any point 'i' is more counterclock-
wise than q, then update q.
'''
q = (p + 1) % n
for i in range(n):
# If i is more counterclockwise
# than current q, then update q
if(orientation(points[p],
points[i], points[q]) == 2):
q = i
'''
Now q is the most counterclockwise with respect to p
Set p as q for next iteration, so that q is added to
result 'hull'
'''
p = q
# While we don't come to first point
if p == l:
break
return hull