6.8301-Project/decoder.py

191 lines
7.2 KiB
Python

import argparse
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from creedsolo import RSCodec
from raptorq import Decoder
from decoding_utils import localize_corners_wrapper
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
parser.add_argument("-p", "--psize", help="packet size", type=int)
parser.add_argument("-v", "--version", help="0 - original; 1 - CNN", default=0, choices=[0, 1], type=int)
args = parser.parse_args()
if args.version == 0:
cheight = cwidth = max(args.height // 10, args.width // 10)
elif args.version == 1:
cheight = cwidth = int(max(args.height, args.width) * 0.16)
else:
raise NotImplementedError
frame_size = args.height * args.width - 4 * cheight * cwidth
frame_bytes = frame_size * 3 // 8
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
rsc = RSCodec(int(args.level * 255))
decoder = Decoder.with_defaults(args.size, rs_bytes)
if args.version == 0:
def find_corner(A, f):
cx, cy = A.shape[:2]
# Resize so smaller dim is 8
scale = min(cx // 8, cy // 8)
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
mask = cv2.floodFill(
A,
np.empty(0),
tuple(np.flip(guess)),
0,
(100, 100, 100),
(100, 100, 100),
cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
)[2][1:-1, 1:-1].astype(bool)
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
def localize_corners(cropped_frame):
"""
Returns (reconstructed grid, (wcol, rcol, gcol, bcol))
"""
X, Y = cropped_frame.shape[:2]
cx, cy = X // 3, Y // 3
widx, wcol = find_corner(cropped_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
ridx, rcol = find_corner(cropped_frame[:cx, Y - cy:], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
ridx[1] += Y - cy
gidx, gcol = find_corner(cropped_frame[X - cx:, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
gidx[0] += X - cx
bidx, bcol = find_corner(cropped_frame[X - cx:, Y - cy:], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
bidx[0] += X - cx
bidx[1] += Y - cy
cch = cheight / 2 - 1
ccw = cwidth / 2 - 1
M = cv2.getPerspectiveTransform(
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
np.float32(
[
[ccw, cch],
[args.width - ccw - 1, cch],
[ccw, args.height - cch - 1],
[args.width - ccw - 1, args.height - cch - 1],
]
),
)
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
return frame, (wcol, rcol, gcol, bcol)
elif args.version == 1:
localize_corners = localize_corners_wrapper(args)
# ####
# gtruth_frames = []
# cap = cv2.VideoCapture("vid_mid_v1.mkv")
# data = None
# while data is None:
# ret, raw_frame = cap.read()
# if not ret:
# print("End of stream")
# break
# gtruth_frames.append(cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB))
# ####
if args.input.isdecimal():
args.input = int(args.input)
cap = cv2.VideoCapture(args.input)
data = None
start_time = time.time()
while data is None:
try:
ret, raw_frame = cap.read()
if not ret:
print("End of stream")
break
if args.version == 0:
# raw_frame is a uint8 BE CAREFUL
if type(args.input) == int:
# Crop image to reduce camera distortion
X, Y = raw_frame.shape[:2]
raw_frame = raw_frame[X // 4: 3 * X // 4, Y // 4: 3 * Y // 4]
elif args.version == 1:
pass
# h, w, _ = raw_frame.shape
# raw_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
# (w - input_crop_size) // 2:-(w - input_crop_size) // 2]
# cv2.imshow("", raw_frame)
# cv2.waitKey(1)
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
frame, (wcol, rcol, gcol, bcol) = localize_corners(raw_frame)
# Find basis of color space
origin = (rcol + gcol + bcol - wcol) / 2
rcol -= origin
gcol -= origin
bcol -= origin
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
# Convert to new color space
# calibrated_frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
calibrated_frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
# fig, axs = plt.subplots(1, 2)
# axs[0].imshow(frame)
# axs[1].imshow(calibrated_frame * 255)
# plt.show()
#
# closest_ind = None
# closest_diff = 1
# for i, gtruth_frame in enumerate(gtruth_frames):
# diff = (gtruth_frame != calibrated_frame * 255).any(axis=2).mean()
# if diff < closest_diff:
# closest_ind = i
# closest_diff = diff
#
# gtruth = gtruth_frames[closest_ind]
# fig, axs = plt.subplots(1, 2)
# correct_mask = np.logical_not((calibrated_frame * 255 != gtruth).any(axis=2))
# calibrated_frame_copy = calibrated_frame.copy()
# gtruth_copy = gtruth.copy()
# calibrated_frame_copy[correct_mask] = [0, 0, 0]
# gtruth_copy[correct_mask] = [0, 0, 0]
# axs[0].imshow(gtruth_copy)
# axs[1].imshow(calibrated_frame_copy * 255)
# plt.show()
calibrated_frame = np.packbits(
np.concatenate(
(
calibrated_frame[:cheight, cwidth: args.width - cwidth].flatten(),
calibrated_frame[cheight: args.height - cheight].flatten(),
calibrated_frame[args.height - cheight:, cwidth: args.width - cwidth].flatten(),
)
)
)
reshape_len = frame_bytes // 255 * 255
calibrated_frame[:reshape_len] = np.ravel(calibrated_frame[:reshape_len].reshape(255, reshape_len // 255), "F")
data = decoder.decode(bytes(rsc.decode(bytearray(calibrated_frame ^ frame_xor))[0][: args.psize]))
print("Decoded frame")
except KeyboardInterrupt:
break
except Exception as e:
print(e)
cap.release()
with open(args.output, "wb") as f:
f.write(data)
print(8 * len(data) / (time.time() - start_time))