6.8301-Project/decoder.py

132 lines
5.2 KiB
Python

import argparse
import sys
import traceback
import cv2
import matplotlib.pyplot as plt
import numpy as np
from creedsolo import RSCodec
from raptorq import Decoder
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
parser.add_argument("-o", "--output", help="output file for decoded data")
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-f", "--fps", help="frame rate", default=30, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-s", "--size", help="number of bytes to decode", default=2**16, type=int)
parser.add_argument("-e", "--erasure", help="detect erasures", action="store_true")
args = parser.parse_args()
cheight = cwidth = max(args.height // 10, args.width // 10)
frame_size = args.height * args.width - 4 * cheight * cwidth
frame_xor = np.arange(frame_size // 2, dtype=np.uint8)
rs_size = frame_size // 2 - (frame_size // 2 + 254) // 255 * int(args.level * 255) - 4
rsc = RSCodec(int(args.level * 255))
decoder = Decoder.with_defaults(args.size, rs_size)
def find_corner(A, f):
cx, cy = A.shape[:2]
# Resize so smaller dim is 5
scale = min(cx // 5, cy // 5)
B = cv2.resize(A, (cy // scale, cx // scale), interpolation=cv2.INTER_AREA)
guess = np.array(np.unravel_index(np.argmax(f(B)), B.shape[:2])) * scale + scale // 2
mask = cv2.floodFill(A, np.empty(0), tuple(reversed(guess)), 1, 10, 10, cv2.FLOODFILL_MASK_ONLY)[2][
1:-1, 1:-1
].astype(bool)
return np.average(np.where(mask), axis=1), np.average(A[mask], axis=0).astype(np.float64)
if args.input.isdecimal():
args.input = int(args.input)
cap = cv2.VideoCapture(args.input)
data = None
while data is None:
try:
ret, raw_frame = cap.read()
if not ret:
print("End of stream")
break
cv2.imshow("", raw_frame)
cv2.waitKey(1)
# raw_frame is a uint8 BE CAREFUL
raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
X, Y = raw_frame.shape[:2]
cx, cy = X // 4, Y // 4
widx, wcol = find_corner(raw_frame[:cx, :cy], lambda B: (np.std(B, axis=2) < 35) * np.sum(B, axis=2))
ridx, rcol = find_corner(raw_frame[:cx, Y - cy :], lambda B: B[:, :, 0] - B[:, :, 1] - B[:, :, 2])
ridx[1] += Y - cy
gidx, gcol = find_corner(raw_frame[X - cx :, :cy], lambda B: B[:, :, 1] - B[:, :, 2] - B[:, :, 0])
gidx[0] += X - cx
bidx, bcol = find_corner(raw_frame[X - cx :, Y - cy :], lambda B: B[:, :, 2] - B[:, :, 0] - B[:, :, 1])
bidx[0] += X - cx
bidx[1] += Y - cy
# Find basis of color space
origin = (rcol + gcol + bcol - wcol) / 2
rcol -= origin
gcol -= origin
bcol -= origin
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
# Dumb perspective transform
xv = np.linspace(
-(cheight / 2 - 1) / (args.height - cheight + 1),
1 + (cheight / 2 - 1) / (args.height - cheight + 1),
args.height,
)
yv = np.linspace(
-(cwidth / 2 - 1) / (args.width - cwidth + 1),
1 + (cwidth / 2 - 1) / (args.width - cwidth + 1),
args.width,
)
xp = (
np.outer(1 - xv, 1 - yv) * widx[0]
+ np.outer(1 - xv, yv) * ridx[0]
+ np.outer(xv, 1 - yv) * gidx[0]
+ np.outer(xv, yv) * bidx[0]
)
yp = (
np.outer(1 - xv, 1 - yv) * widx[1]
+ np.outer(1 - xv, yv) * ridx[1]
+ np.outer(xv, 1 - yv) * gidx[1]
+ np.outer(xv, yv) * bidx[1]
)
# plt.scatter(widx[1], widx[0])
# plt.scatter(ridx[1], ridx[0])
# plt.scatter(gidx[1], gidx[0])
# plt.scatter(bidx[1], bidx[0])
# plt.scatter(yp, xp)
# plt.imshow(raw_frame.astype(np.uint8))
# plt.show()
frame = raw_frame[
np.clip(np.round(xp).astype(np.int64), 0, X - 1), np.clip(np.round(yp).astype(np.int64), 0, Y - 1), :
]
frame = np.clip(np.squeeze(F @ (frame - origin)[..., np.newaxis]), 0, 255).astype(np.uint8)
frame = (frame[:, :, 0] >> 7) + (frame[:, :, 1] >> 5 & 0b0110) + (frame[:, :, 2] >> 4 & 0b1000)
frame = np.concatenate(
(
frame[:cheight, cwidth : args.width - cwidth].flatten(),
frame[cheight : args.height - cheight].flatten(),
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
)
)
frame = (frame[::2] << 4) + frame[1::2]
frame = np.pad(frame, (0, (len(frame) + 254) // 255 * 255 - len(frame)))
frame = np.ravel(frame.reshape(255, len(frame) // 255), "F")[: frame_size // 2]
erase_pos = list(np.where(frame == 0)[0]) if args.erasure else []
data = decoder.decode(bytes(rsc.decode(frame ^ frame_xor, erase_pos=erase_pos)[0]))
print("Decoded frame")
except KeyboardInterrupt:
sys.exit()
except:
traceback.print_exc()
with open(args.output, "wb") as f:
f.write(data)
cap.release()