import itertools import math import time import cv2 import numpy as np import torch import torchvision import torchvision.transforms.v2 as transforms import torchvision.transforms.v2.functional as transforms_f from matplotlib import pyplot as plt from corner_training.models import QuantizedV2, QuantizedV5 from corner_training.utils import get_gaussian_filter, get_bounded_slices torch.backends.quantized.engine = 'qnnpack' def localize_corners_wrapper(args, debug=False): stage1_model_checkpt_path = "checkpts/QuantizedV2_Stage1_128_9.pt" stage2_model_checkpt_path = "checkpts/QuantizedV5_Stage2_128_9.pt" stage1_model = QuantizedV2() stage1_model.eval() stage1_model.fuse_modules(is_qat=False) stage1_model.qconfig = torch.ao.quantization.default_qconfig torch.ao.quantization.prepare(stage1_model, inplace=True) torch.ao.quantization.convert(stage1_model, inplace=True) stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu'))) stage2_model = QuantizedV5() stage2_model.eval() stage2_model.fuse_modules(is_qat=False) stage2_model.qconfig = torch.ao.quantization.default_qconfig torch.ao.quantization.prepare(stage2_model, inplace=True) torch.ao.quantization.convert(stage2_model, inplace=True) stage2_model.load_state_dict(torch.load(stage2_model_checkpt_path, map_location=torch.device('cpu'))) stage1_model.eval() stage2_model.eval() stage1_size = 128 assert stage1_size & 1 == 0, "Assuming even size when dividing into quadrants" np_to_fp32_tensor = transforms.Compose([ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), ]) preprocess_img_stage1 = transforms.Compose([ transforms.Lambda(lambda img: resize_keep_aspect(img, stage1_size)), np_to_fp32_tensor, ]) gaussian_filter = get_gaussian_filter(4, 4) # for stage1 NMS heuristic # Transform cropped corners until they all look like top left corners, as that's what the model is trained on transforms_by_corner = [ lambda img: img, # identity transforms_f.hflip, transforms_f.vflip, lambda img: transforms_f.vflip(transforms_f.hflip(img)) ] inv_transforms_by_corner = transforms_by_corner # flipping is a self-inverse def localize_corners(cropped_frame: np.ndarray): """ Args: cropped_frame: Square numpy array """ orig_h, orig_w, _ = cropped_frame.shape stage2_size = max(orig_h, orig_w) // 16 upscale_factor = min(orig_w, orig_h) / stage1_size # for stage 2 start_time = time.time() stage1_img = preprocess_img_stage1(cropped_frame) if debug: print(54, time.time() - start_time) with torch.no_grad(): stage1_pred = stage1_model(stage1_img.unsqueeze(0)).squeeze(0) # plt.imshow(stage1_pred.detach().cpu()) # plt.show() if debug: print(57, time.time() - start_time) quad_h = stage1_img.size(1) // 2 # might miss 1 pixel on edge if odd quad_w = stage1_img.size(2) // 2 corners_by_quad = dict() for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses for left_half in (0, 1): quad_i_start = quad_h * (1 - top_half) quad_j_start = quad_w * (1 - left_half) curr_quad_preds = stage1_pred[ quad_i_start: quad_i_start + quad_h, quad_j_start: quad_j_start + quad_w, ].clone() max_locs = [] for i in range(6): # expect 4 points, but get top 6 to be safe max_ind = torch.argmax(curr_quad_preds).item() # TODO: more efficient like segtree, maybe account for neighbors too max_loc = (max_ind // quad_w, max_ind % quad_w) max_locs.append(max_loc) # TODO: improve, maybe scale Gaussian peak to val of max_loc, probably better to not subtract from a location multiple times preds_slice, gaussian_slice = get_bounded_slices((quad_h, quad_w), gaussian_filter.size(), *max_loc) curr_quad_preds[preds_slice] -= gaussian_filter[gaussian_slice] if debug: print(f"{max_locs=}") min_cost = 1e9 min_square = None for potential_combo in itertools.combinations(max_locs, 4): curr_pts, curr_cost = score_combo(potential_combo) if curr_cost < min_cost: min_cost = curr_cost min_square = curr_pts if min_square is None: print("all collinear") return None corners_by_quad[(1 - top_half) * 2 + (1 - left_half)] = [(i + quad_i_start, j + quad_j_start) for (i, j) in min_square] if debug: print(92, time.time() - start_time) print(corners_by_quad) outer_corners = [] # corner_colors = [] # by center, currently rounding to the pixel in the original image origin = (quad_h, quad_w) for quad in range(4): # TODO: consistent (x, y) or (i, j) outer_corners.append(max((l2_dist(corner, origin), corner) for corner in corners_by_quad[quad])[1]) # corner_colors.append(cropped_frame[int((sum(corner[0] for corner in corners_by_quad[quad]) / 4 * upscale_factor)), # int((sum(corner[1] for corner in corners_by_quad[quad]) / 4 * upscale_factor))] # .astype(np.float64)) stage2_imgs = [] for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses for left_half in (0, 1): corner_ind = top_half * 2 + left_half y, x = outer_corners[corner_ind] upscaled_y, upscaled_x = round(y * upscale_factor), round(x * upscale_factor) top = max(0, upscaled_y - stage2_size // 2) bottom = min(orig_h, upscaled_y + stage2_size // 2) left = max(0, upscaled_x - stage2_size // 2) right = min(orig_w, upscaled_x + stage2_size // 2) # Need padding if detected corner is within `stage2_size // 2` of border corner_padding = [0] * 4 # pad the side that does not affect extracted coordinates corner_padding[(1 - top_half) * 2 + 1] = stage2_size - (bottom - top) corner_padding[(1 - left_half) * 2] = stage2_size - (right - left) cropped_corner_img = transforms_f.pad( # TODO: don't pad since that should speed up inference np_to_fp32_tensor(cropped_frame[top:bottom, left:right]), corner_padding ) stage2_imgs.append(cropped_corner_img) transformed_corner_imgs = torch.stack([transforms_by_corner[corner_ind](stage2_img) for corner_ind, stage2_img in enumerate(stage2_imgs)]) if debug: print(121, time.time() - start_time) with torch.no_grad(): transformed_preds = stage2_model(transformed_corner_imgs) if debug: print(125, time.time() - start_time) transformed_pred_pts = [ torchvision.tv_tensors.BoundingBoxes( [(max_ind := pred.argmax()) % stage2_size, max_ind // stage2_size, 0, 0], format="XYWH", canvas_size=(stage2_size, stage2_size) ) for pred in transformed_preds ] stage2_pred_pts = [inv_transforms_by_corner[corner_ind](transformed_pred_pt)[0, :2].tolist() for corner_ind, transformed_pred_pt in enumerate(transformed_pred_pts)] if debug: print(137, time.time() - start_time) orig_pred_pts = [(round(orig_x * upscale_factor) + stage2_pred_x - stage2_size // 2, round(orig_y * upscale_factor) + stage2_pred_y - stage2_size // 2) for (orig_y, orig_x), (stage2_pred_x, stage2_pred_y) in zip(outer_corners, stage2_pred_pts)] if debug: print(142, time.time() - start_time) # plt.imshow(cropped_frame) # plt.scatter(np.array(orig_pred_pts).T[0], np.array(orig_pred_pts).T[1]) # plt.show() corner_size = int(max(args.height, args.width) * 0.16) qtr_corner_size = corner_size // 4 grid_coords = np.float32([ [qtr_corner_size, qtr_corner_size], [args.width - qtr_corner_size, qtr_corner_size], [qtr_corner_size, args.height - qtr_corner_size], [args.width - qtr_corner_size, args.height - qtr_corner_size], ]) grid_coords -= 1/2 M = cv2.getPerspectiveTransform( np.float32(orig_pred_pts), grid_coords, ) cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height)) # cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height), flags=cv2.INTER_NEAREST) padding = math.ceil(max(args.height, args.width) / 80) # arbitrary # guessing wildly on +/- 1s white_sq = cropped_frame[qtr_corner_size + padding: corner_size - qtr_corner_size - padding, qtr_corner_size + padding: corner_size - qtr_corner_size - padding] red_sq = cropped_frame[qtr_corner_size + padding: corner_size - qtr_corner_size - padding, args.width - corner_size + qtr_corner_size + padding: args.width - qtr_corner_size - padding] green_sq = cropped_frame[args.height - corner_size + qtr_corner_size + padding: args.height - qtr_corner_size - padding, qtr_corner_size + padding: corner_size - qtr_corner_size - padding] blue_sq = cropped_frame[args.height - corner_size + qtr_corner_size + padding: args.height - qtr_corner_size - padding, args.width - corner_size + qtr_corner_size + padding: args.width - qtr_corner_size - padding] corner_colors = [white_sq.mean(axis=(0, 1)), red_sq.mean(axis=(0, 1)), green_sq.mean(axis=(0, 1)), blue_sq.mean(axis=(0, 1))] # fig, axs = plt.subplots(2, 3) # # axs[0, 2].imshow(cropped_frame) # axs[0, 0].imshow(white_sq) # axs[0, 1].imshow(red_sq) # axs[1, 0].imshow(green_sq) # axs[1, 1].imshow(blue_sq) # plt.show() return cropped_frame, corner_colors return localize_corners def l2_dist(loc, origin): """ No sqrt """ return (loc[0] - origin[0]) ** 2 + (loc[1] - origin[1]) ** 2 def score_combo(combo): """ Plan: 1. Check if pts are convex. If no, very bad quadrilateral. 2. Check if diagonal lengths are within a factor of 1.5. If no, somewhat bad since far from right angles. 3. If the above are satisfied, then simply return how close the side lengths are to being equal. """ hull = convex_hull([Point(x, y) for x, y in combo]) # TODO: check how collinear case is handled hull = [(pt.x, pt.y) for pt in hull] # convert back to tuple if len(hull) != 4: return None, 1e9 squared_diag0 = l2_dist(hull[0], hull[2]) squared_diag1 = l2_dist(hull[1], hull[3]) if squared_diag0 < squared_diag1: # swap so that diag0 is larger squared_diag0, squared_diag1 = squared_diag1, squared_diag0 if squared_diag0 / squared_diag1 > 1.5**2: return hull, 1e8 cyclic_pts = hull + [hull[0]] side_lens = [l2_dist(cyclic_pts[i], cyclic_pts[i + 1]) for i in range(4)] return hull, (max(side_lens) - min(side_lens)) / min(side_lens) def resize_keep_aspect(img: np.ndarray, min_len: int) -> np.ndarray: h, w, _ = img.shape if h < w: output_size = (round(min_len * w / h), min_len) else: output_size = (min_len, round(min_len * h / w)) return cv2.resize(img, output_size, interpolation=cv2.INTER_NEAREST) # Gift wrapping code, adapted from GeeksForGeeks. # "This code is contributed by Akarsh Somani, IIIT Kalyani" class Point: def __init__(self, x, y): self.x = x self.y = y def left_index(points): """ Finding the left most point """ minn = 0 for i in range(1,len(points)): if points[i].x < points[minn].x: minn = i elif points[i].x == points[minn].x: if points[i].y > points[minn].y: minn = i return minn def orientation(p, q, r): """ To find orientation of ordered triplet (p, q, r). The function returns following values 0 --> p, q and r are collinear 1 --> Clockwise 2 --> Counterclockwise """ val = (q.y - p.y) * (r.x - q.x) - \ (q.x - p.x) * (r.y - q.y) if val == 0: return 0 elif val > 0: return 1 else: return 2 def convex_hull(points): n = len(points) assert n >= 3, "There must be at least 3 points." # Find the leftmost point l = left_index(points) hull = [] ''' Start from leftmost point, keep moving counterclockwise until reach the start point again. This loop runs O(h) times where h is number of points in result or output. ''' p = l q = 0 while True: # Add current point to result hull.append(points[p]) ''' Search for a point 'q' such that orientation(p, q, x) is counterclockwise for all points 'x'. The idea is to keep track of last visited most counterclock- wise point in q. If any point 'i' is more counterclock- wise than q, then update q. ''' q = (p + 1) % n for i in range(n): # If i is more counterclockwise # than current q, then update q if(orientation(points[p], points[i], points[q]) == 2): q = i ''' Now q is the most counterclockwise with respect to p Set p as q for next iteration, so that q is added to result 'hull' ''' p = q # While we don't come to first point if p == l: break return hull