6.8301-Project/decoder.py

129 lines
4.9 KiB
Python

import argparse
import time
import cv2
import numpy as np
from creedsolo import RSCodec
from raptorq import Decoder
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
parser.add_argument("-p", "--psize", help="packet size", type=int)
args = parser.parse_args()
cheight = cwidth = max(args.height // 15, args.width // 15)
frame_size = args.height * args.width - 4 * cheight * cwidth
frame_bytes = frame_size * 3 // 8
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
rsc = RSCodec(int(args.level * 255))
decoder = Decoder.with_defaults(args.size, rs_bytes)
def find_corner(A, f):
cx, cy = A.shape[:2]
# Resize so smaller dim is 8
scale = min(cx // 8, cy // 8)
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
guess = np.array(np.unravel_index(np.argmax(f(B.astype(np.float64))), B.shape[:2])) * scale + scale // 2
mask = cv2.floodFill(
A,
np.empty(0),
tuple(np.flip(guess)),
0,
(100, 100, 100),
(100, 100, 100),
cv2.FLOODFILL_MASK_ONLY + cv2.FLOODFILL_FIXED_RANGE,
)[2][1:-1, 1:-1].astype(bool)
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
if isinstance(args.input, str) and args.input.isdecimal():
args.input = int(args.input)
cap = cv2.VideoCapture(args.input)
data = None
start_time = 0
status = 0
decoded = 0
while data is None:
try:
ret, raw_frame = cap.read()
if not ret:
print("End of stream")
break
if isinstance(args.input, int) and (status == 1 or (status == 0 and np.random.rand() < 0.5)):
status = 2
print("Skipped")
continue
# raw_frame is a uint8 BE CAREFUL
cv2.imshow("", raw_frame)
cv2.waitKey(1)
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
# Find positions and colors of corners
X, Y = raw_frame.shape[:2]
cx, cy = X // 4, Y // 4
widx, wcol = find_corner(raw_frame[:cx, :cy], lambda B: np.sum(B, axis=2) - 2 * np.std(B, axis=2))
ridx, rcol = find_corner(raw_frame[:cx, Y - cy :], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
ridx[1] += Y - cy
gidx, gcol = find_corner(raw_frame[X - cx :, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
gidx[0] += X - cx
bidx, bcol = find_corner(raw_frame[X - cx :, Y - cy :], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
bidx[0] += X - cx
bidx[1] += Y - cy
# Find basis of color space
origin = (rcol + gcol + bcol - wcol) / 2
rcol -= origin
gcol -= origin
bcol -= origin
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
cch = cheight / 2 - 1
ccw = cwidth / 2 - 1
M = cv2.getPerspectiveTransform(
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
np.float32(
[
[ccw, cch],
[args.width - ccw - 1, cch],
[ccw, args.height - cch - 1],
[args.width - ccw - 1, args.height - cch - 1],
]
),
)
frame = cv2.warpPerspective(raw_frame, M, (args.width, args.height))
# Convert to new color space
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 160).astype(np.uint8)
frame = np.packbits(
np.concatenate(
(
frame[:cheight, cwidth : args.width - cwidth].flatten(),
frame[cheight : args.height - cheight].flatten(),
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
)
)
)[:frame_bytes]
reshape_len = frame_bytes // 255 * 255
frame[:reshape_len] = np.ravel(frame[:reshape_len].reshape(255, reshape_len // 255), "F")
data = decoder.decode(bytes(rsc.decode(bytearray(frame ^ frame_xor))[0][: args.psize]))
decoded += 1
status = 1
if start_time == 0:
start_time = time.time()
print("Decoded frame")
except KeyboardInterrupt:
break
except Exception as e:
status = 0
print(e)
cap.release()
print(decoded)
with open(args.output, "wb") as f:
f.write(data)
print(8 * len(data) / (time.time() - start_time))