Compare commits
3 commits
9205312ad5
...
3a40a8fb03
Author | SHA1 | Date | |
---|---|---|---|
|
3a40a8fb03 | ||
|
9d9c891d74 | ||
7a2c685c15 |
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -158,3 +158,4 @@ cython_debug/
|
|||
Report_Stuff/
|
||||
data/
|
||||
*.mkv
|
||||
*.slurm
|
||||
|
|
BIN
checkpts/QuantizedV2_Stage1_128_9.pt
Normal file
BIN
checkpts/QuantizedV2_Stage1_128_9.pt
Normal file
Binary file not shown.
BIN
checkpts/QuantizedV5_Stage2_128_9.pt
Normal file
BIN
checkpts/QuantizedV5_Stage2_128_9.pt
Normal file
Binary file not shown.
171
decoder.py
171
decoder.py
|
@ -1,10 +1,14 @@
|
|||
import argparse
|
||||
import traceback
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from creedsolo import RSCodec
|
||||
from raptorq import Decoder
|
||||
|
||||
from corner_training.models import QuantizedV2, QuantizedV5
|
||||
from decoding_utils import localize_corners_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
|
||||
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
|
||||
|
@ -13,9 +17,19 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
|
|||
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
|
||||
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
|
||||
parser.add_argument("-p", "--psize", help="packet size", type=int)
|
||||
parser.add_argument("-v", "--version", help="0 - original; 1 - CNN", default=0, choices=[0, 1], type=int)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
cheight = cwidth = max(args.height // 10, args.width // 10)
|
||||
if args.version == 0:
|
||||
cheight = cwidth = max(args.height // 10, args.width // 10)
|
||||
elif args.version == 1:
|
||||
assert args.height * 3 % 80 == args.width * 3 % 80 == 0
|
||||
cheight = int(args.height * 0.15)
|
||||
cwidth = int(args.width * 0.15)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
frame_size = args.height * args.width - 4 * cheight * cwidth
|
||||
frame_bytes = frame_size * 3 // 8
|
||||
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
|
||||
|
@ -24,63 +38,41 @@ rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
|
|||
rsc = RSCodec(int(args.level * 255))
|
||||
decoder = Decoder.with_defaults(args.size, rs_bytes)
|
||||
|
||||
input_crop_size = 1024
|
||||
|
||||
def find_corner(A, f):
|
||||
cx, cy = A.shape[:2]
|
||||
# Resize so smaller dim is 8
|
||||
scale = min(cx // 8, cy // 8)
|
||||
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
|
||||
guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
|
||||
mask = cv2.floodFill(
|
||||
A,
|
||||
np.empty(0),
|
||||
tuple(np.flip(guess)),
|
||||
0,
|
||||
(100, 100, 100),
|
||||
(100, 100, 100),
|
||||
cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
|
||||
)[2][1:-1, 1:-1].astype(bool)
|
||||
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
|
||||
if args.version == 0:
|
||||
def find_corner(A, f):
|
||||
cx, cy = A.shape[:2]
|
||||
# Resize so smaller dim is 8
|
||||
scale = min(cx // 8, cy // 8)
|
||||
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
|
||||
guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
|
||||
mask = cv2.floodFill(
|
||||
A,
|
||||
np.empty(0),
|
||||
tuple(np.flip(guess)),
|
||||
0,
|
||||
(100, 100, 100),
|
||||
(100, 100, 100),
|
||||
cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
|
||||
)[2][1:-1, 1:-1].astype(bool)
|
||||
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
|
||||
|
||||
|
||||
if args.input.isdecimal():
|
||||
args.input = int(args.input)
|
||||
cap = cv2.VideoCapture(args.input)
|
||||
data = None
|
||||
while data is None:
|
||||
try:
|
||||
ret, raw_frame = cap.read()
|
||||
if not ret:
|
||||
print("End of stream")
|
||||
break
|
||||
# raw_frame is a uint8 BE CAREFUL
|
||||
if type(args.input) == int:
|
||||
# Crop image to reduce camera distortion
|
||||
X, Y = raw_frame.shape[:2]
|
||||
raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
|
||||
cv2.imshow("", raw_frame)
|
||||
cv2.waitKey(1)
|
||||
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Find positions and colors of corners
|
||||
X, Y = raw_frame.shape[:2]
|
||||
def localize_corners(cropped_frame):
|
||||
"""
|
||||
Returns (reconstructed grid, (wcol, rcol, gcol, bcol))
|
||||
"""
|
||||
X, Y = cropped_frame.shape[:2]
|
||||
cx, cy = X // 3, Y // 3
|
||||
widx, wcol = find_corner(raw_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - np.std(B, axis=2))
|
||||
ridx, rcol = find_corner(raw_frame[:cx, Y - cy :], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
|
||||
widx, wcol = find_corner(cropped_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
|
||||
ridx, rcol = find_corner(cropped_frame[:cx, Y - cy:], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
|
||||
ridx[1] += Y - cy
|
||||
gidx, gcol = find_corner(raw_frame[X - cx :, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
|
||||
gidx, gcol = find_corner(cropped_frame[X - cx:, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
|
||||
gidx[0] += X - cx
|
||||
bidx, bcol = find_corner(raw_frame[X - cx :, Y - cy :], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
|
||||
bidx, bcol = find_corner(cropped_frame[X - cx:, Y - cy:], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
|
||||
bidx[0] += X - cx
|
||||
bidx[1] += Y - cy
|
||||
|
||||
# Find basis of color space
|
||||
origin = (rcol + gcol + bcol - wcol) / 2
|
||||
rcol -= origin
|
||||
gcol -= origin
|
||||
bcol -= origin
|
||||
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
|
||||
|
||||
cch = cheight / 2 - 1
|
||||
ccw = cwidth / 2 - 1
|
||||
M = cv2.getPerspectiveTransform(
|
||||
|
@ -89,30 +81,77 @@ while data is None:
|
|||
[
|
||||
[ccw, cch],
|
||||
[args.width - ccw - 1, cch],
|
||||
[ccw, args.height - ccw - 1],
|
||||
[ccw, args.height - cch - 1],
|
||||
[args.width - ccw - 1, args.height - cch - 1],
|
||||
]
|
||||
),
|
||||
)
|
||||
frame = cv2.warpPerspective(raw_frame, M, (args.width, args.height))
|
||||
|
||||
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
|
||||
return frame, (wcol, rcol, gcol, bcol)
|
||||
|
||||
elif args.version == 1:
|
||||
localize_corners = localize_corners_wrapper(args, input_crop_size)
|
||||
|
||||
if args.input.isdecimal():
|
||||
args.input = int(args.input)
|
||||
cap = cv2.VideoCapture(args.input)
|
||||
data = None
|
||||
start_time = time.time()
|
||||
while data is None:
|
||||
try:
|
||||
ret, raw_frame = cap.read()
|
||||
if not ret:
|
||||
print("End of stream")
|
||||
break
|
||||
|
||||
if args.version == 0:
|
||||
# raw_frame is a uint8 BE CAREFUL
|
||||
if type(args.input) == int:
|
||||
# Crop image to reduce camera distortion
|
||||
X, Y = raw_frame.shape[:2]
|
||||
raw_frame = raw_frame[X // 4: 3 * X // 4, Y // 4: 3 * Y // 4]
|
||||
elif args.version == 1:
|
||||
h, w, _ = raw_frame.shape
|
||||
raw_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
|
||||
(w - input_crop_size) // 2:-(w - input_crop_size) // 2]
|
||||
|
||||
cv2.imshow("", raw_frame)
|
||||
cv2.waitKey(1)
|
||||
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
frame, (wcol, rcol, gcol, bcol) = localize_corners(raw_frame)
|
||||
|
||||
# Find basis of color space
|
||||
origin = (rcol + gcol + bcol - wcol) / 2
|
||||
rcol -= origin
|
||||
gcol -= origin
|
||||
bcol -= origin
|
||||
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
|
||||
|
||||
# Convert to new color space
|
||||
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(frame * 255)
|
||||
# plt.show()
|
||||
frame = np.concatenate(
|
||||
(
|
||||
frame[:cheight, cwidth : args.width - cwidth].flatten(),
|
||||
frame[cheight : args.height - cheight].flatten(),
|
||||
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
|
||||
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
|
||||
import matplotlib.pyplot as plt
|
||||
plt.imshow(frame * 255)
|
||||
plt.show()
|
||||
frame = np.packbits(
|
||||
np.concatenate(
|
||||
(
|
||||
frame[:cheight, cwidth: args.width - cwidth].flatten(),
|
||||
frame[cheight: args.height - cheight].flatten(),
|
||||
frame[args.height - cheight:, cwidth: args.width - cwidth].flatten(),
|
||||
)
|
||||
)
|
||||
)
|
||||
data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
|
||||
reshape_len = frame_bytes // 255 * 255
|
||||
frame[:reshape_len] = np.ravel(frame[:reshape_len].reshape(255, reshape_len // 255), "F")
|
||||
data = decoder.decode(bytes(rsc.decode(bytearray(frame ^ frame_xor))[0][: args.psize]))
|
||||
print("Decoded frame")
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except:
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
cap.release()
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(data)
|
||||
cap.release()
|
||||
print(8 * len(data) / (time.time() - start_time))
|
||||
|
|
151
decoder_cnn.py
151
decoder_cnn.py
|
@ -1,151 +0,0 @@
|
|||
import argparse
|
||||
import traceback
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from creedsolo import RSCodec
|
||||
from matplotlib import pyplot as plt
|
||||
from raptorq import Decoder
|
||||
|
||||
from corner_training.models import QuantizedV2, QuantizedV5
|
||||
from decoding_utils import localize_corners_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
|
||||
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
|
||||
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
|
||||
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
|
||||
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
|
||||
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
|
||||
parser.add_argument("-p", "--psize", help="packet size", type=int)
|
||||
parser.add_argument("-v", "--version",
|
||||
help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
|
||||
default=0, choices=[0, 1], type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.version == 1
|
||||
# cell borders are 0.0375% of width/height
|
||||
assert args.height * 3 % 80 == args.width * 3 % 80 == 0
|
||||
cheight = int(args.height * 0.15)
|
||||
cwidth = int(args.width * 0.15)
|
||||
|
||||
frame_size = args.height * args.width - 4 * cheight * cwidth
|
||||
frame_bytes = frame_size * 3 // 8
|
||||
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
|
||||
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
|
||||
|
||||
rsc = RSCodec(int(args.level * 255))
|
||||
decoder = Decoder.with_defaults(args.size, rs_bytes)
|
||||
|
||||
|
||||
stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt"
|
||||
|
||||
stage1_model = QuantizedV2()
|
||||
stage1_model.eval()
|
||||
stage1_model.fuse_modules(is_qat=False)
|
||||
stage1_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage1_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage1_model, inplace=True)
|
||||
stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
|
||||
|
||||
stage2_model = QuantizedV5()
|
||||
stage2_model.eval()
|
||||
stage2_model.fuse_modules(is_qat=False)
|
||||
stage2_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage2_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage2_model, inplace=True)
|
||||
stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu')))
|
||||
|
||||
# stage1_size = 128
|
||||
# stage2_size = 128
|
||||
|
||||
stage1_size = 128
|
||||
stage2_size = 64
|
||||
|
||||
input_crop_size = 1024
|
||||
|
||||
localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size)
|
||||
|
||||
if args.input.isdecimal():
|
||||
args.input = int(args.input)
|
||||
cap = cv2.VideoCapture(args.input)
|
||||
data = None
|
||||
while data is None:
|
||||
try:
|
||||
ret, raw_frame = cap.read()
|
||||
if not ret:
|
||||
print("End of stream")
|
||||
break
|
||||
# # raw_frame is a uint8 BE CAREFUL
|
||||
# if type(args.input) == int:
|
||||
# # Crop image to reduce camera distortion
|
||||
# X, Y = raw_frame.shape[:2]
|
||||
# raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
|
||||
# cv2.imshow("", raw_frame)
|
||||
# cv2.waitKey(1)
|
||||
# raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = raw_frame.shape
|
||||
cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
|
||||
(w - input_crop_size) // 2:-(w - input_crop_size) // 2]
|
||||
cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
|
||||
(widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame)
|
||||
|
||||
widx = widx[::-1]
|
||||
ridx = ridx[::-1]
|
||||
gidx = gidx[::-1]
|
||||
bidx = bidx[::-1]
|
||||
# plt.imshow(cropped_frame)
|
||||
# plt.scatter([widx[1]], [widx[0]], color="r")
|
||||
# plt.scatter([ridx[1]], [ridx[0]], color="g")
|
||||
# plt.scatter([gidx[1]], [gidx[0]], color="b")
|
||||
# plt.scatter([bidx[1]], [bidx[0]], color="w")
|
||||
# plt.show()
|
||||
|
||||
# Find basis of color space
|
||||
origin = (rcol + gcol + bcol - wcol) / 2
|
||||
rcol -= origin
|
||||
gcol -= origin
|
||||
bcol -= origin
|
||||
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
|
||||
|
||||
# cch = cheight / 2 - 1
|
||||
# ccw = cwidth / 2 - 1
|
||||
cch = cheight / 4 - 1
|
||||
ccw = cwidth / 4 - 1
|
||||
|
||||
M = cv2.getPerspectiveTransform(
|
||||
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
|
||||
np.float32(
|
||||
[
|
||||
[ccw, cch],
|
||||
[args.width - ccw - 1, cch],
|
||||
[ccw, args.height - cch - 1],
|
||||
[args.width - ccw - 1, args.height - cch - 1],
|
||||
]
|
||||
),
|
||||
)
|
||||
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
|
||||
# # Convert to new color space
|
||||
# frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
|
||||
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
|
||||
# import matplotlib.pyplot as plt
|
||||
# # plt.imshow(frame * 255)
|
||||
# plt.imshow((1 - frame) * 255)
|
||||
# plt.show()
|
||||
frame = np.concatenate(
|
||||
(
|
||||
frame[:cheight, cwidth : args.width - cwidth].flatten(),
|
||||
frame[cheight : args.height - cheight].flatten(),
|
||||
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
|
||||
)
|
||||
)
|
||||
data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
|
||||
print("Decoded frame")
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except:
|
||||
traceback.print_exc()
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(data)
|
||||
cap.release()
|
341
decoding_utils.py
Normal file
341
decoding_utils.py
Normal file
|
@ -0,0 +1,341 @@
|
|||
import itertools
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms.v2 as transforms
|
||||
import torchvision.transforms.v2.functional as transforms_f
|
||||
|
||||
from corner_training.models import QuantizedV2, QuantizedV5
|
||||
from corner_training.utils import get_gaussian_filter, get_bounded_slices
|
||||
|
||||
|
||||
torch.backends.quantized.engine = 'qnnpack'
|
||||
|
||||
|
||||
def localize_corners_wrapper(args, input_crop_size, debug=False):
|
||||
stage1_model_checkpt_path = "checkpts/QuantizedV2_Stage1_128_9.pt"
|
||||
stage2_model_checkpt_path = "checkpts/QuantizedV5_Stage2_128_9.pt"
|
||||
|
||||
stage1_model = QuantizedV2()
|
||||
stage1_model.eval()
|
||||
stage1_model.fuse_modules(is_qat=False)
|
||||
stage1_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage1_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage1_model, inplace=True)
|
||||
stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
|
||||
|
||||
stage2_model = QuantizedV5()
|
||||
stage2_model.eval()
|
||||
stage2_model.fuse_modules(is_qat=False)
|
||||
stage2_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage2_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage2_model, inplace=True)
|
||||
stage2_model.load_state_dict(torch.load(stage2_model_checkpt_path, map_location=torch.device('cpu')))
|
||||
|
||||
stage1_size = 128
|
||||
stage2_size = input_crop_size // 16
|
||||
|
||||
assert stage1_size & 1 == 0, "Assuming even size when dividing into quadrants"
|
||||
assert stage2_size & 1 == 0, "Assuming even size when center cropping"
|
||||
stage1_model.eval()
|
||||
stage2_model.eval()
|
||||
|
||||
preprocess_img_stage1 = transforms.Compose([
|
||||
transforms.Lambda(lambda img: cv2.resize(img, (stage1_size, stage1_size), interpolation=cv2.INTER_NEAREST)),
|
||||
transforms.ToImage(),
|
||||
transforms.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
gaussian_filter = get_gaussian_filter(4, 4) # for stage1 NMS heuristic
|
||||
|
||||
preprocess_img_stage2 = transforms.Compose([
|
||||
transforms.ToImage(),
|
||||
transforms.ToDtype(torch.float32, scale=True),
|
||||
])
|
||||
|
||||
# Transform cropped corners until they all look like top left corners, as that's what the model is trained on
|
||||
transforms_by_corner = [
|
||||
lambda img: img, # identity
|
||||
transforms_f.hflip,
|
||||
transforms_f.vflip,
|
||||
lambda img: transforms_f.vflip(transforms_f.hflip(img))
|
||||
]
|
||||
|
||||
inv_transforms_by_corner = transforms_by_corner # flipping is a self-inverse
|
||||
|
||||
def localize_corners(cropped_frame: np.ndarray):
|
||||
"""
|
||||
Args:
|
||||
cropped_frame: Square numpy array
|
||||
"""
|
||||
orig_h, orig_w, _ = cropped_frame.shape
|
||||
assert orig_w == orig_h, "Assuming square img"
|
||||
assert orig_w % stage1_size == 0
|
||||
upscale_factor = orig_w // stage1_size # for stage 2
|
||||
|
||||
start_time = time.time()
|
||||
stage1_img = preprocess_img_stage1(cropped_frame)
|
||||
if debug:
|
||||
print(54, time.time() - start_time)
|
||||
|
||||
with torch.no_grad():
|
||||
stage1_pred = stage1_model(stage1_img.unsqueeze(0)).squeeze(0)
|
||||
|
||||
if debug:
|
||||
print(57, time.time() - start_time)
|
||||
|
||||
quad_size = stage1_size // 2
|
||||
|
||||
corners_by_quad = dict()
|
||||
|
||||
for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
|
||||
for left_half in (0, 1):
|
||||
quad_i_start = quad_size * (1 - top_half)
|
||||
quad_j_start = quad_size * (1 - left_half)
|
||||
curr_quad_preds = stage1_pred[
|
||||
quad_i_start: quad_i_start + quad_size,
|
||||
quad_j_start: quad_j_start + quad_size,
|
||||
].clone()
|
||||
|
||||
max_locs = []
|
||||
for i in range(6): # expect 4 points, but get top 6 to be safe
|
||||
max_ind = torch.argmax(curr_quad_preds).item() # TODO: more efficient like segtree, maybe account for neighbors too
|
||||
max_loc = (max_ind // quad_size, max_ind % quad_size)
|
||||
max_locs.append(max_loc)
|
||||
|
||||
# TODO: improve, maybe scale Gaussian peak to val of max_loc, probably better to not subtract from a location multiple times
|
||||
preds_slice, gaussian_slice = get_bounded_slices((quad_size, quad_size), gaussian_filter.size(),
|
||||
*max_loc)
|
||||
curr_quad_preds[preds_slice] -= gaussian_filter[gaussian_slice]
|
||||
|
||||
if debug:
|
||||
print(f"{max_locs=}")
|
||||
|
||||
min_cost = 1e9
|
||||
min_square = None
|
||||
for potential_combo in itertools.combinations(max_locs, 4): # TODO: don't repeat symmetrical squares
|
||||
curr_pts, curr_cost = score_combo(potential_combo)
|
||||
if curr_cost < min_cost:
|
||||
min_cost = curr_cost
|
||||
min_square = curr_pts
|
||||
|
||||
if min_square is None:
|
||||
print("all collinear")
|
||||
return None
|
||||
corners_by_quad[(1 - top_half) * 2 + (1 - left_half)] = [(i + quad_i_start, j + quad_j_start) for (i, j)
|
||||
in min_square]
|
||||
if debug:
|
||||
print(92, time.time() - start_time)
|
||||
print(corners_by_quad)
|
||||
|
||||
outer_corners = []
|
||||
corner_colors = [] # by center, currently rounding to the pixel in the original image
|
||||
origin = (quad_size, quad_size)
|
||||
for quad in range(4): # TODO: consistent (x, y) or (i, j)
|
||||
outer_corners.append(max((l2_dist(corner, origin), corner) for corner in corners_by_quad[quad])[1])
|
||||
corner_colors.append(cropped_frame[int((sum(corner[0] for corner in corners_by_quad[quad]) / 4 * upscale_factor)),
|
||||
int((sum(corner[1] for corner in corners_by_quad[quad]) / 4 * upscale_factor))]
|
||||
.astype(np.float64))
|
||||
|
||||
stage2_imgs = []
|
||||
|
||||
for top_half in (0, 1): # TODO: bot/right to remove all 1 minuses
|
||||
for left_half in (0, 1):
|
||||
corner_ind = top_half * 2 + left_half
|
||||
y, x = outer_corners[corner_ind]
|
||||
upscaled_y, upscaled_x = y * upscale_factor, x * upscale_factor
|
||||
|
||||
top = max(0, upscaled_y - stage2_size // 2)
|
||||
bottom = min(orig_h, upscaled_y + stage2_size // 2)
|
||||
left = max(0, upscaled_x - stage2_size // 2)
|
||||
right = min(orig_w, upscaled_x + stage2_size // 2)
|
||||
|
||||
# Need padding if detected corner is within `stage2_size // 2` of border
|
||||
corner_padding = [0] * 4 # pad the side that does not affect extracted coordinates
|
||||
corner_padding[(1 - top_half) * 2 + 1] = stage2_size - (bottom - top)
|
||||
corner_padding[(1 - left_half) * 2] = stage2_size - (right - left)
|
||||
cropped_corner_img = transforms_f.pad( # TODO: don't pad since that should speed up inference
|
||||
preprocess_img_stage2(cropped_frame[top:bottom, left:right]),
|
||||
corner_padding
|
||||
)
|
||||
stage2_imgs.append(cropped_corner_img)
|
||||
|
||||
transformed_corner_imgs = torch.stack([transforms_by_corner[corner_ind](stage2_img)
|
||||
for corner_ind, stage2_img in enumerate(stage2_imgs)])
|
||||
|
||||
if debug:
|
||||
print(121, time.time() - start_time)
|
||||
|
||||
with torch.no_grad():
|
||||
transformed_preds = stage2_model(transformed_corner_imgs)
|
||||
|
||||
if debug:
|
||||
print(125, time.time() - start_time)
|
||||
|
||||
transformed_pred_pts = [
|
||||
torchvision.tv_tensors.BoundingBoxes(
|
||||
[(max_ind := pred.argmax()) % stage2_size, max_ind // stage2_size, 0, 0],
|
||||
format="XYWH", canvas_size=(stage2_size, stage2_size)
|
||||
)
|
||||
for pred in transformed_preds
|
||||
]
|
||||
|
||||
stage2_pred_pts = [inv_transforms_by_corner[corner_ind](transformed_pred_pt)[0, :2].tolist()
|
||||
for corner_ind, transformed_pred_pt in enumerate(transformed_pred_pts)]
|
||||
|
||||
if debug:
|
||||
print(137, time.time() - start_time)
|
||||
|
||||
orig_pred_pts = [(orig_x * upscale_factor + stage2_pred_x - stage2_size // 2,
|
||||
orig_y * upscale_factor + stage2_pred_y - stage2_size // 2)
|
||||
for (orig_y, orig_x), (stage2_pred_x, stage2_pred_y) in zip(outer_corners, stage2_pred_pts)]
|
||||
|
||||
if debug:
|
||||
print(142, time.time() - start_time)
|
||||
|
||||
cch = int(args.height * 0.15) / 4 - 1
|
||||
ccw = int(args.width * 0.15) / 4 - 1
|
||||
|
||||
M = cv2.getPerspectiveTransform(
|
||||
np.float32(orig_pred_pts),
|
||||
np.float32(
|
||||
[
|
||||
[ccw, cch],
|
||||
[args.width - ccw - 1, cch],
|
||||
[ccw, args.height - cch - 1],
|
||||
[args.width - ccw - 1, args.height - cch - 1],
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
cropped_frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
|
||||
|
||||
return cropped_frame, corner_colors
|
||||
|
||||
return localize_corners
|
||||
|
||||
|
||||
def l2_dist(loc, origin):
|
||||
""" No sqrt """
|
||||
return (loc[0] - origin[0]) ** 2 + (loc[1] - origin[1]) ** 2
|
||||
|
||||
|
||||
def score_combo(combo):
|
||||
"""
|
||||
Plan:
|
||||
1. Check if pts are convex. If no, very bad quadrilateral.
|
||||
2. Check if diagonal lengths are within a factor of 1.5. If no, somewhat bad since far from right angles.
|
||||
3. If the above are satisfied, then simply return how close the side lengths are to being equal.
|
||||
"""
|
||||
hull = convex_hull([Point(x, y) for x, y in combo]) # TODO: check how collinear case is handled
|
||||
hull = [(pt.x, pt.y) for pt in hull] # convert back to tuple
|
||||
if len(hull) != 4:
|
||||
return None, 1e9
|
||||
|
||||
squared_diag0 = l2_dist(hull[0], hull[2])
|
||||
squared_diag1 = l2_dist(hull[1], hull[3])
|
||||
if squared_diag0 < squared_diag1: # swap so that diag0 is larger
|
||||
squared_diag0, squared_diag1 = squared_diag1, squared_diag0
|
||||
|
||||
if squared_diag0 / squared_diag1 > 1.5**2:
|
||||
return hull, 1e8
|
||||
|
||||
cyclic_pts = hull + [hull[0]]
|
||||
side_lens = [l2_dist(cyclic_pts[i], cyclic_pts[i + 1]) for i in range(4)]
|
||||
|
||||
return hull, (max(side_lens) - min(side_lens)) / min(side_lens)
|
||||
|
||||
|
||||
# Gift wrapping code, adapted from GeeksForGeeks.
|
||||
# "This code is contributed by Akarsh Somani, IIIT Kalyani"
|
||||
class Point:
|
||||
def __init__(self, x, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
|
||||
|
||||
def left_index(points):
|
||||
"""
|
||||
Finding the left most point
|
||||
"""
|
||||
minn = 0
|
||||
for i in range(1,len(points)):
|
||||
if points[i].x < points[minn].x:
|
||||
minn = i
|
||||
elif points[i].x == points[minn].x:
|
||||
if points[i].y > points[minn].y:
|
||||
minn = i
|
||||
return minn
|
||||
|
||||
|
||||
def orientation(p, q, r):
|
||||
"""
|
||||
To find orientation of ordered triplet (p, q, r).
|
||||
The function returns following values
|
||||
0 --> p, q and r are collinear
|
||||
1 --> Clockwise
|
||||
2 --> Counterclockwise
|
||||
"""
|
||||
val = (q.y - p.y) * (r.x - q.x) - \
|
||||
(q.x - p.x) * (r.y - q.y)
|
||||
|
||||
if val == 0:
|
||||
return 0
|
||||
elif val > 0:
|
||||
return 1
|
||||
else:
|
||||
return 2
|
||||
|
||||
|
||||
def convex_hull(points):
|
||||
n = len(points)
|
||||
assert n >= 3, "There must be at least 3 points."
|
||||
|
||||
# Find the leftmost point
|
||||
l = left_index(points)
|
||||
|
||||
hull = []
|
||||
|
||||
'''
|
||||
Start from leftmost point, keep moving counterclockwise
|
||||
until reach the start point again. This loop runs O(h)
|
||||
times where h is number of points in result or output.
|
||||
'''
|
||||
p = l
|
||||
q = 0
|
||||
while True:
|
||||
# Add current point to result
|
||||
hull.append(points[p])
|
||||
|
||||
'''
|
||||
Search for a point 'q' such that orientation(p, q,
|
||||
x) is counterclockwise for all points 'x'. The idea
|
||||
is to keep track of last visited most counterclock-
|
||||
wise point in q. If any point 'i' is more counterclock-
|
||||
wise than q, then update q.
|
||||
'''
|
||||
q = (p + 1) % n
|
||||
|
||||
for i in range(n):
|
||||
# If i is more counterclockwise
|
||||
# than current q, then update q
|
||||
if(orientation(points[p],
|
||||
points[i], points[q]) == 2):
|
||||
q = i
|
||||
|
||||
'''
|
||||
Now q is the most counterclockwise with respect to p
|
||||
Set p as q for next iteration, so that q is added to
|
||||
result 'hull'
|
||||
'''
|
||||
p = q
|
||||
|
||||
# While we don't come to first point
|
||||
if p == l:
|
||||
break
|
||||
|
||||
return hull
|
|
@ -67,7 +67,11 @@ print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(
|
|||
|
||||
def mkframe(packet):
|
||||
frame = np.array(rsc.encode(bytearray(packet)))
|
||||
frame = np.unpackbits(np.pad(frame, (0, frame_bytes - len(frame))) ^ frame_xor)
|
||||
frame = np.pad(frame, (0, frame_bytes - len(frame))) ^ frame_xor
|
||||
reshape_len = frame_bytes // 255 * 255
|
||||
# Space out elements in each size 255 chunk
|
||||
frame[:reshape_len] = np.ravel(frame[:reshape_len].reshape(reshape_len // 255, 255), "F")
|
||||
frame = np.unpackbits(frame)
|
||||
# Pad to be multiple of 3 so we can reshape into RGB channels
|
||||
frame = np.pad(frame, (0, (3 - len(frame)) % 3))
|
||||
frame = np.reshape(frame, (frame_size, 3))
|
||||
|
|
Loading…
Reference in a new issue