191 lines
7.2 KiB
Python
191 lines
7.2 KiB
Python
import argparse
|
|
import time
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
from creedsolo import RSCodec
|
|
from raptorq import Decoder
|
|
|
|
from decoding_utils import localize_corners_wrapper
|
|
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
|
|
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
|
|
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
|
|
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
|
|
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
|
|
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
|
|
parser.add_argument("-p", "--psize", help="packet size", type=int)
|
|
parser.add_argument("-v", "--version", help="0 - original; 1 - CNN", default=0, choices=[0, 1], type=int)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.version == 0:
|
|
cheight = cwidth = max(args.height // 10, args.width // 10)
|
|
elif args.version == 1:
|
|
cheight = cwidth = int(max(args.height, args.width) * 0.16)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
frame_size = args.height * args.width - 4 * cheight * cwidth
|
|
frame_bytes = frame_size * 3 // 8
|
|
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
|
|
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
|
|
|
|
rsc = RSCodec(int(args.level * 255))
|
|
decoder = Decoder.with_defaults(args.size, rs_bytes)
|
|
|
|
if args.version == 0:
|
|
def find_corner(A, f):
|
|
cx, cy = A.shape[:2]
|
|
# Resize so smaller dim is 8
|
|
scale = min(cx // 8, cy // 8)
|
|
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
|
|
guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
|
|
mask = cv2.floodFill(
|
|
A,
|
|
np.empty(0),
|
|
tuple(np.flip(guess)),
|
|
0,
|
|
(100, 100, 100),
|
|
(100, 100, 100),
|
|
cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
|
|
)[2][1:-1, 1:-1].astype(bool)
|
|
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
|
|
|
|
def localize_corners(cropped_frame):
|
|
"""
|
|
Returns (reconstructed grid, (wcol, rcol, gcol, bcol))
|
|
"""
|
|
X, Y = cropped_frame.shape[:2]
|
|
cx, cy = X // 3, Y // 3
|
|
widx, wcol = find_corner(cropped_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
|
|
ridx, rcol = find_corner(cropped_frame[:cx, Y - cy:], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
|
|
ridx[1] += Y - cy
|
|
gidx, gcol = find_corner(cropped_frame[X - cx:, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
|
|
gidx[0] += X - cx
|
|
bidx, bcol = find_corner(cropped_frame[X - cx:, Y - cy:], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
|
|
bidx[0] += X - cx
|
|
bidx[1] += Y - cy
|
|
|
|
cch = cheight / 2 - 1
|
|
ccw = cwidth / 2 - 1
|
|
M = cv2.getPerspectiveTransform(
|
|
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
|
|
np.float32(
|
|
[
|
|
[ccw, cch],
|
|
[args.width - ccw - 1, cch],
|
|
[ccw, args.height - cch - 1],
|
|
[args.width - ccw - 1, args.height - cch - 1],
|
|
]
|
|
),
|
|
)
|
|
|
|
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
|
|
return frame, (wcol, rcol, gcol, bcol)
|
|
|
|
elif args.version == 1:
|
|
localize_corners = localize_corners_wrapper(args)
|
|
|
|
# ####
|
|
# gtruth_frames = []
|
|
# cap = cv2.VideoCapture("vid_mid_v1.mkv")
|
|
# data = None
|
|
# while data is None:
|
|
# ret, raw_frame = cap.read()
|
|
# if not ret:
|
|
# print("End of stream")
|
|
# break
|
|
# gtruth_frames.append(cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB))
|
|
# ####
|
|
|
|
if args.input.isdecimal():
|
|
args.input = int(args.input)
|
|
cap = cv2.VideoCapture(args.input)
|
|
data = None
|
|
start_time = time.time()
|
|
while data is None:
|
|
try:
|
|
ret, raw_frame = cap.read()
|
|
if not ret:
|
|
print("End of stream")
|
|
break
|
|
|
|
if args.version == 0:
|
|
# raw_frame is a uint8 BE CAREFUL
|
|
if type(args.input) == int:
|
|
# Crop image to reduce camera distortion
|
|
X, Y = raw_frame.shape[:2]
|
|
raw_frame = raw_frame[X // 4: 3 * X // 4, Y // 4: 3 * Y // 4]
|
|
elif args.version == 1:
|
|
pass
|
|
# h, w, _ = raw_frame.shape
|
|
# raw_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
|
|
# (w - input_crop_size) // 2:-(w - input_crop_size) // 2]
|
|
|
|
# cv2.imshow("", raw_frame)
|
|
# cv2.waitKey(1)
|
|
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
|
|
|
|
frame, (wcol, rcol, gcol, bcol) = localize_corners(raw_frame)
|
|
|
|
# Find basis of color space
|
|
origin = (rcol + gcol + bcol - wcol) / 2
|
|
rcol -= origin
|
|
gcol -= origin
|
|
bcol -= origin
|
|
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
|
|
|
|
|
|
# Convert to new color space
|
|
# calibrated_frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
|
|
calibrated_frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
|
|
|
|
# fig, axs = plt.subplots(1, 2)
|
|
# axs[0].imshow(frame)
|
|
# axs[1].imshow(calibrated_frame * 255)
|
|
# plt.show()
|
|
#
|
|
# closest_ind = None
|
|
# closest_diff = 1
|
|
# for i, gtruth_frame in enumerate(gtruth_frames):
|
|
# diff = (gtruth_frame != calibrated_frame * 255).any(axis=2).mean()
|
|
# if diff < closest_diff:
|
|
# closest_ind = i
|
|
# closest_diff = diff
|
|
#
|
|
# gtruth = gtruth_frames[closest_ind]
|
|
# fig, axs = plt.subplots(1, 2)
|
|
# correct_mask = np.logical_not((calibrated_frame * 255 != gtruth).any(axis=2))
|
|
# calibrated_frame_copy = calibrated_frame.copy()
|
|
# gtruth_copy = gtruth.copy()
|
|
# calibrated_frame_copy[correct_mask] = [0, 0, 0]
|
|
# gtruth_copy[correct_mask] = [0, 0, 0]
|
|
# axs[0].imshow(gtruth_copy)
|
|
# axs[1].imshow(calibrated_frame_copy * 255)
|
|
# plt.show()
|
|
|
|
calibrated_frame = np.packbits(
|
|
np.concatenate(
|
|
(
|
|
calibrated_frame[:cheight, cwidth: args.width - cwidth].flatten(),
|
|
calibrated_frame[cheight: args.height - cheight].flatten(),
|
|
calibrated_frame[args.height - cheight:, cwidth: args.width - cwidth].flatten(),
|
|
)
|
|
)
|
|
)
|
|
reshape_len = frame_bytes // 255 * 255
|
|
calibrated_frame[:reshape_len] = np.ravel(calibrated_frame[:reshape_len].reshape(255, reshape_len // 255), "F")
|
|
data = decoder.decode(bytes(rsc.decode(bytearray(calibrated_frame ^ frame_xor))[0][: args.psize]))
|
|
print("Decoded frame")
|
|
except KeyboardInterrupt:
|
|
break
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
cap.release()
|
|
with open(args.output, "wb") as f:
|
|
f.write(data)
|
|
print(8 * len(data) / (time.time() - start_time))
|