Compare commits

...

2 commits

Author SHA1 Message Date
Kevin Zhao
9205312ad5 Preliminary CNN decoding 2024-05-07 14:02:26 -04:00
Kevin Zhao
3d4862e725 Add CNN arch and training code 2024-05-05 10:33:51 -04:00
7 changed files with 1001 additions and 10 deletions

160
.gitignore vendored Normal file
View file

@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# Custom
.DS_Store
.idea/
Report_Stuff/
data/
*.mkv

View file

@ -0,0 +1,126 @@
import argparse
import json
import math
import os
import random
import re
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
# import torchvision.transforms.v2 as transforms
# import torchvision.transforms.v2.functional as transforms_f
from tqdm.auto import tqdm
from models import QuantizedV2
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("data_root_path")
parser.add_argument("output_dir")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=False)
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)
lr = 3e-2 # TODO: argparse
log_steps = 10
train_batch_size = 512
num_train_epochs = 10
# leave 1 shard for eval
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
if re.fullmatch(r"shard_\d+", path) is not None])
get_gtruth = get_gtruth_wrapper(4)
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1])
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
train_dataloader = torch.utils.data.DataLoader(
# train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True,
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=False
)
model = QuantizedV2()
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
# loss_func = nn.MSELoss()
def loss_func(output, labels):
# weights = (labels != 0) * 49.9 + 0.1 # arbitrary nums (good DeepStage1v0 64x64)
weights = (labels != 0) * 99.9 + 0.1 # 4 all pts
# weights = (labels != 0) * 199.9 + 0.1 # 1 pt 1 layer
return (weights * (output - labels) ** 2).mean()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
cum_loss = 0
model.train()
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
output = model(batch["imgs"].to(device))
loss = loss_func(output, batch["labels"].to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
cum_loss += loss.item()
if (i + 1) % log_steps == 0:
print(f"{cum_loss / log_steps=}")
cum_loss = 0
if i % log_steps != 0:
print(f"{cum_loss / (i % log_steps)=}")
if epoch > 3:
# Freeze quantizer parameters
model.apply(torch.ao.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
# model.eval()
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
model.to(device)
quantized_model.eval() # not sure if this is needed
eval_device = "cpu"
with torch.no_grad():
num_eval_exs = 5
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
eval_loss = 0
for i in range(num_eval_exs):
eval_ex = eval_dataset[i] # slicing doesn't work yet...
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
axs[i, 1].imshow(eval_ex["labels"])
axs[i, 1].set_title("Ground Truth")
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
axs[i, 2].set_title("Prediction")
for ax in axs[i]:
ax.axis("off")
print(f"{eval_loss / num_eval_exs=}")
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
torch.save(quantized_model.state_dict(),
os.path.join(args.output_dir, f"QuantizedV3_Stage1_128_{epoch}.pt"))

View file

@ -0,0 +1,127 @@
import argparse
import json
import math
import os
import random
import re
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
# import torchvision.transforms.v2 as transforms
# import torchvision.transforms.v2.functional as transforms_f
from tqdm.auto import tqdm
from models import QuantizedV3, QuantizedV2
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("data_root_path")
parser.add_argument("output_dir")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=False)
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)
lr = 3e-2 # TODO: argparse
log_steps = 10
train_batch_size = 512
num_train_epochs = 10
# leave 1 shard for eval
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
if re.fullmatch(r"shard_\d+", path) is not None])
get_gtruth = get_gtruth_wrapper(4)
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True
)
# torch.backends.quantized.engine = "qnnpack"
model = QuantizedV2()
# model.fuse_modules(is_qat=True) # Added for QAT
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
# model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT
# torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT
def loss_func(output, labels):
weights = (labels != 0) * 49.9 + 0.1
# weights = (labels != 0) * 99.9 + 0.1
return (weights * (output - labels) ** 2).mean()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
cum_loss = 0
model.train()
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
output = model(batch["imgs"].to(device))
loss = loss_func(output, batch["labels"].to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
cum_loss += loss.item()
if (i + 1) % log_steps == 0:
print(f"{cum_loss / log_steps=}")
cum_loss = 0
if i % log_steps != 0:
print(f"{cum_loss / (i % log_steps)=}")
if epoch > 3:
# Freeze quantizer parameters
model.apply(torch.ao.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
# model.eval()
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
model.to(device)
quantized_model.eval() # not sure if this is needed
eval_device = "cpu"
with torch.no_grad():
num_eval_exs = 5
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
eval_loss = 0
for i in range(num_eval_exs):
eval_ex = eval_dataset[i] # slicing doesn't work yet...
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
axs[i, 1].imshow(eval_ex["labels"])
axs[i, 1].set_title("Ground Truth")
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
axs[i, 2].set_title("Prediction")
for ax in axs[i]:
ax.axis("off")
print(f"{eval_loss / num_eval_exs=}")
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
torch.save(quantized_model.state_dict(),
os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))

181
corner_training/models.py Normal file
View file

@ -0,0 +1,181 @@
import torch
import torch.nn as nn
class UNetStage1v1(nn.Module):
def __init__(self, in_channels=3):
super().__init__()
self.block0 = nn.Sequential(
nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, 32, kernel_size=3, padding="same"),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding="same")
)
self.pool0 = nn.MaxPool2d(kernel_size=2)
self.block1 = nn.Sequential(
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=3, padding="same"),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding="same"),
)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.block2 = nn.Sequential(
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, padding="same"),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 128, kernel_size=3, padding="same"),
)
self.upsample0 = nn.Sequential(
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
)
self.block3 = nn.Sequential(
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=3, padding="same"),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, padding="same"),
)
self.upsample1 = nn.Sequential(
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
)
self.block4 = nn.Sequential(
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 32, kernel_size=3, padding="same"),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, padding="same"),
)
def forward(self, imgs):
assert imgs.dim() == 4, imgs.size()
x = self.block0(imgs)
downsampled_x0 = self.block1(self.pool0(x))
downsampled_x1 = self.block2(self.pool1(downsampled_x0))
upsampled_x0 = self.block3(torch.cat([self.upsample0(downsampled_x1), downsampled_x0], dim=1))
x = self.block4(torch.cat([self.upsample1(upsampled_x0), x], dim=1))
x = torch.max(x, dim=-3)[0] # max over channel dim
return torch.sigmoid(x)
class QuantizedV2(nn.Module):
""" Normal convolutions with quantization """
def __init__(self, in_channels=3):
super().__init__()
self.layers = nn.Sequential(
torch.quantization.QuantStub(),
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 8, kernel_size=3, padding=1),
torch.quantization.DeQuantStub(),
)
def forward(self, imgs):
x = self.layers(imgs)
x = torch.max(x, dim=-3)[0] # max over channel dim
return torch.sigmoid(x)
def fuse_modules(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)
class QuantizedV3(nn.Module):
""" Depthwise convs except input layer; no inverted bottleneck """
def __init__(self, in_channels=3):
super().__init__()
hidden_size = 32
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.input_conv = nn.Sequential(
nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1),
nn.BatchNorm2d(hidden_size),
)
self.block1 = nn.Sequential(
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=2, groups=hidden_size, padding=2),
nn.BatchNorm2d(hidden_size),
nn.ReLU(),
nn.Conv2d(hidden_size, hidden_size, kernel_size=1, padding=0),
nn.BatchNorm2d(hidden_size),
nn.ReLU(),
)
self.block2 = nn.Sequential(
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=1, groups=hidden_size, padding=1),
nn.BatchNorm2d(hidden_size),
nn.ReLU(),
nn.Conv2d(hidden_size, 8, kernel_size=1, padding=0),
)
def forward(self, imgs):
x = self.quant(imgs)
x = self.input_conv(x)
x = self.block1(x)
x = self.block2(x)
x = self.dequant(x)
x = torch.max(x, dim=-3)[0] # max over channel dim
return torch.sigmoid(x)
def fuse_modules(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
fuse_modules(self, [
["input_conv.0", "input_conv.1"],
["block1.0", "block1.1", "block1.2"],
["block1.3", "block1.4", "block1.5"],
["block2.0", "block2.1", "block2.2"],
], inplace=True)
class QuantizedV5(nn.Module):
"""normal convs, biasless"""
def __init__(self, in_channels=3):
super().__init__()
self.layers = nn.Sequential(
torch.quantization.QuantStub(),
# nn.BatchNorm2d(in_channels),
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 8, kernel_size=3, padding=1),
torch.quantization.DeQuantStub(),
)
def forward(self, imgs):
x = self.layers(imgs)
x = torch.max(x, dim=-3)[0] # max over channel dim
return torch.sigmoid(x)
def fuse_modules(self, is_qat=False):
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)

224
corner_training/utils.py Normal file
View file

@ -0,0 +1,224 @@
import json
import math
import os
import random
import re
from typing import Callable
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
# import torchvision.transforms.v2 as transforms
# import torchvision.transforms.v2.functional as transforms_f
from tqdm.auto import tqdm
def get_gaussian_filter(sigma, half_window_size) -> torch.Tensor:
full_window_size = half_window_size * 2 + 1
x_sq_offsets = torch.square(torch.arange(full_window_size) - half_window_size) \
.expand((full_window_size, full_window_size))
y_sq_offsets = x_sq_offsets.T
gaussian_filter = torch.exp(-(x_sq_offsets + y_sq_offsets) / sigma) # "normalized" so that the peak is 1
return gaussian_filter
def get_bounded_slices(size_a: torch.Size, size_b: torch.Size, i: int, j: int) \
-> tuple[tuple[slice, slice], tuple[slice, slice]]:
"""
Args:
size_a: size of the first 2D tensor (`a_tensor`)
size_b: size of the second 2D tensor (`tensor_b`),
should be at most `size_a`, and have odd height and width (so that center is clearly defined)
i: the row of `tensor_a` to center `tensor_b` on
j: the column of `tensor_a` to center `tensor_b` on
Returns:
((row_slice_a, col_slice_a), (row_slice_b, col_slice_b))
which are as large as possible (at most the size of `size_b`) while remaining in bounds.
"""
assert len(size_a) == 2 and len(size_b) == 2 \
and size_b[0] % 2 == size_b[1] % 2 == 1
half_mask_width, half_mask_height = size_b[1] // 2, size_b[0] // 2
left_offset = max(0, j - half_mask_width)
right_offset = min(size_a[1], j + half_mask_width + 1) # exclusive
top_offset = max(0, i - half_mask_height)
bottom_offset = min(size_a[0], i + half_mask_height + 1)
return (slice(top_offset, bottom_offset), slice(left_offset, right_offset)), \
(slice(half_mask_height + top_offset - i, half_mask_height + bottom_offset - i),
slice(half_mask_width + left_offset - j, half_mask_width + right_offset - j))
def get_gtruth_wrapper(sigma: int, display_pts_inds: list[int] = None):
"""
Args:
sigma: variance for the Gaussian smoothing.
display_pts_inds: If specified, will only display the points corresponding to these indices in `transformed_pts`
Otherwise, will display all points by default.
One main reason for using a closure is to reuse the Gaussian mask for efficiency.
"""
half_window_size = sigma
# half_window_size = 2 * sigma
# half_window_size = 3 * sigma
# half_window_size = int(3 * sigma / math.sqrt(2))
gaussian_filter = get_gaussian_filter(sigma, half_window_size)
# def get_gtruth(transformed_pts: torchvision.tv_tensors.BoundingBoxes, img_size: tuple) -> torch.Tensor:
def get_gtruth(transformed_pts: torch.Tensor, img_size: tuple) -> torch.Tensor:
"""
Converts coordinates of points to a heatmap with Gaussians centered at those coordinates.
Args: # TODO doc
transformed_pts:
img_size:
"""
smoothed_gtruth = torch.zeros(img_size)
for ind in range(len(transformed_pts)) if display_pts_inds is None else display_pts_inds:
# x, y, _, _ = transformed_pts[ind]
x, y = transformed_pts[ind]
gtruth_slice, gaussian_slice = get_bounded_slices(smoothed_gtruth.size(), gaussian_filter.size(), y, x)
smoothed_gtruth[gtruth_slice] = torch.maximum( # should modify smoothed_gtruth in-place
smoothed_gtruth[gtruth_slice],
gaussian_filter[gaussian_slice]
)
return smoothed_gtruth
return get_gtruth
class FlyingFramesDataset(torch.utils.data.Dataset):
"""
Different from the `FlyingFramesDataset` in Simulate_Data, which
randomly generates each frame. Here we read the saved examples.
"""
def __init__(self, data_root_path: str, get_gtruth: Callable, shard_paths: list[str] = None, eval: bool = False):
self.data_root_path, self.get_gtruth = data_root_path, get_gtruth
with open(os.path.join(self.data_root_path, "configs.json"), "r") as f:
self.configs = json.load(f)
assert self.configs["dataset_size"] % self.configs["shard_size"] == 0 # simplifying assumption
self.shard_size = self.configs["shard_size"]
if shard_paths is None: # use all shards
self.shard_paths = sorted([path for path in os.listdir(data_root_path)
if re.fullmatch(r"shard_\d+", path) is not None])
else:
self.shard_paths = shard_paths
assert len(self.shard_paths) * self.shard_size <= self.configs["dataset_size"], (len(self.shard_paths), self.shard_size, self.configs["dataset_size"])
identity_transform = transforms.Lambda(lambda x: x) # for eval=True
# Some last-mile stuff
# self.img_transforms = transforms.Compose([
# transforms.PILToTensor(),
# # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
# # if not eval else identity_transform,
# transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
# # transforms.Grayscale(),
# # normalize?
# transforms.ToDtype(torch.float32, scale=True),
# # transforms.ToPILImage(), # for visualization
# ])
# self.unified_transforms = transforms.Compose([
# # transforms.Resize(self.img_size),
# transforms.RandomHorizontalFlip(),
# transforms.RandomVerticalFlip(),
# transforms.RandomApply([
# transforms.GaussianBlur(kernel_size=(3, 3))
# ], p=0.33), # TODO: tune
# ]) if not eval else identity_transform
self.img_transforms = transforms.Compose([
transforms.PILToTensor(),
# transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
# if not eval else identity_transform,
transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
# transforms.Grayscale(),
# normalize?
transforms.ConvertImageDtype(torch.float32),
# transforms.ToPILImage(), # for visualization
])
self.unified_transforms = transforms.Compose([
# transforms.Resize(self.img_size),
unified_hflip,
unified_vflip,
lambda img_and_pts: (transforms.RandomApply([
transforms.GaussianBlur(kernel_size=(3, 3))
], p=0.33)(img_and_pts[0]), img_and_pts[1]), # TODO: tune
]) if not eval else identity_transform
def __getitem__(self, idx):
shard_path = self.shard_paths[idx // self.shard_size]
shard_ind = int(re.fullmatch(r"shard_(?P<shard_ind>\d+)", shard_path).group("shard_ind"))
img_ind = shard_ind * self.shard_size + idx % self.shard_size
img_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}.jpg")
bbox_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}_tensor.pt")
assert os.path.isfile(img_path) and os.path.isfile(bbox_path), (img_path, bbox_path)
img = self.img_transforms(PIL.Image.open(img_path))
bbox = torch.load(bbox_path)
img, bbox = self.unified_transforms((img, bbox))
return {
"imgs": img,
"labels": self.get_gtruth(bbox, self.configs["img_size"]),
}
def __len__(self):
return len(self.shard_paths) * self.shard_size
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
"""
Taken from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
"""
def __init__(self, optimizer, warmup_steps): # steps, not ratio
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
last_epoch = max(1, self.last_epoch)
scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
return [base_lr * scale for base_lr in self.base_lrs]
def unified_hflip(img_and_pts):
img, corner_pts = img_and_pts
if random.random() < 0.5:
return img, corner_pts
_, h, w = img.size()
img = transforms_f.hflip(img)
# (x, y) -> (w - x, y)
corner_pts = torch.stack([w - corner_pts[:, 0], corner_pts[:, 1]], dim=1)
return img, corner_pts
def unified_vflip(img_and_pts):
img, corner_pts = img_and_pts
if random.random() < 0.5:
return img, corner_pts
_, h, w = img.size()
img = transforms_f.vflip(img)
# (x, y) -> (x, h - y)
corner_pts = torch.stack([corner_pts[:, 0], h - corner_pts[:, 1]], dim=1)
return img, corner_pts

151
decoder_cnn.py Normal file
View file

@ -0,0 +1,151 @@
import argparse
import traceback
import cv2
import numpy as np
import torch
from creedsolo import RSCodec
from matplotlib import pyplot as plt
from raptorq import Decoder
from corner_training.models import QuantizedV2, QuantizedV5
from decoding_utils import localize_corners_wrapper
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
parser.add_argument("-p", "--psize", help="packet size", type=int)
parser.add_argument("-v", "--version",
help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
default=0, choices=[0, 1], type=int)
args = parser.parse_args()
assert args.version == 1
# cell borders are 0.0375% of width/height
assert args.height * 3 % 80 == args.width * 3 % 80 == 0
cheight = int(args.height * 0.15)
cwidth = int(args.width * 0.15)
frame_size = args.height * args.width - 4 * cheight * cwidth
frame_bytes = frame_size * 3 // 8
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
rsc = RSCodec(int(args.level * 255))
decoder = Decoder.with_defaults(args.size, rs_bytes)
stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt"
stage1_model = QuantizedV2()
stage1_model.eval()
stage1_model.fuse_modules(is_qat=False)
stage1_model.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.prepare(stage1_model, inplace=True)
torch.ao.quantization.convert(stage1_model, inplace=True)
stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
stage2_model = QuantizedV5()
stage2_model.eval()
stage2_model.fuse_modules(is_qat=False)
stage2_model.qconfig = torch.ao.quantization.default_qconfig
torch.ao.quantization.prepare(stage2_model, inplace=True)
torch.ao.quantization.convert(stage2_model, inplace=True)
stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu')))
# stage1_size = 128
# stage2_size = 128
stage1_size = 128
stage2_size = 64
input_crop_size = 1024
localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size)
if args.input.isdecimal():
args.input = int(args.input)
cap = cv2.VideoCapture(args.input)
data = None
while data is None:
try:
ret, raw_frame = cap.read()
if not ret:
print("End of stream")
break
# # raw_frame is a uint8 BE CAREFUL
# if type(args.input) == int:
# # Crop image to reduce camera distortion
# X, Y = raw_frame.shape[:2]
# raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
# cv2.imshow("", raw_frame)
# cv2.waitKey(1)
# raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
h, w, _ = raw_frame.shape
cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
(w - input_crop_size) // 2:-(w - input_crop_size) // 2]
cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
(widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame)
widx = widx[::-1]
ridx = ridx[::-1]
gidx = gidx[::-1]
bidx = bidx[::-1]
# plt.imshow(cropped_frame)
# plt.scatter([widx[1]], [widx[0]], color="r")
# plt.scatter([ridx[1]], [ridx[0]], color="g")
# plt.scatter([gidx[1]], [gidx[0]], color="b")
# plt.scatter([bidx[1]], [bidx[0]], color="w")
# plt.show()
# Find basis of color space
origin = (rcol + gcol + bcol - wcol) / 2
rcol -= origin
gcol -= origin
bcol -= origin
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
# cch = cheight / 2 - 1
# ccw = cwidth / 2 - 1
cch = cheight / 4 - 1
ccw = cwidth / 4 - 1
M = cv2.getPerspectiveTransform(
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
np.float32(
[
[ccw, cch],
[args.width - ccw - 1, cch],
[ccw, args.height - cch - 1],
[args.width - ccw - 1, args.height - cch - 1],
]
),
)
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
# # Convert to new color space
# frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
# import matplotlib.pyplot as plt
# # plt.imshow(frame * 255)
# plt.imshow((1 - frame) * 255)
# plt.show()
frame = np.concatenate(
(
frame[:cheight, cwidth : args.width - cwidth].flatten(),
frame[cheight : args.height - cheight].flatten(),
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
)
)
data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
print("Decoded frame")
except KeyboardInterrupt:
break
except:
traceback.print_exc()
with open(args.output, "wb") as f:
f.write(data)
cap.release()

View file

@ -12,10 +12,21 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
parser.add_argument("-f", "--fps", help="frame rate", default=30, type=int)
parser.add_argument("-m", "--mix", help="mix frames with original video", action="store_true")
parser.add_argument("-v", "--version",
help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
default=0, choices=[0, 1], type=int)
args = parser.parse_args()
if args.version == 0:
cheight = cwidth = max(args.height // 10, args.width // 10)
elif args.version == 1:
# cell borders are 0.0375% of width/height
assert args.height * 3 % 80 == args.width * 3 % 80 == 0 # TODO: less strict better ratio
cheight = int(args.height * 0.15)
cwidth = int(args.width * 0.15)
else:
raise NotImplementedError
cheight = cwidth = max(args.height // 10, args.width // 10)
midwidth = args.width - 2 * cwidth
frame_size = args.height * args.width - 4 * cheight * cwidth
# Divide by 8 / 3 for 3-bit color
@ -32,15 +43,26 @@ encoder = Encoder.with_defaults(data, rs_bytes)
packets = encoder.get_encoded_packets(int(len(data) / rs_bytes * (1 / (1 - args.level) - 1)))
# Make corners
ones = np.ones((cheight - 1, cwidth - 1))
zeros = np.zeros((cheight - 1, cwidth - 1))
wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
if args.version == 0:
ones = np.ones((cheight - 1, cwidth - 1))
zeros = np.zeros((cheight - 1, cwidth - 1))
wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
elif args.version == 1:
zeros = np.zeros((cheight, cwidth, 3))
wcorner = zeros.copy()
rcorner = zeros.copy()
gcorner = zeros.copy()
bcorner = zeros.copy()
black_border_h, black_border_w = cheight // 4, cwidth // 4
for corner_arr, ones_channel_ind in [(wcorner, 0), (wcorner, 1), (wcorner, 2),
(rcorner, 0), (gcorner, 1), (bcorner, 2)]:
corner_arr[black_border_h:-black_border_h, black_border_w:-black_border_w, ones_channel_ind] = np.ones((cheight // 2, cwidth // 2))
# Output flags for decoder
print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])}", end="")
print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])} -v {args.version}", end="")
def mkframe(packet):
@ -55,11 +77,11 @@ def mkframe(packet):
(wcorner, frame[: cheight * midwidth].reshape((cheight, midwidth, 3)), rcorner),
axis=1,
),
frame[cheight * midwidth : frame_size - cheight * midwidth].reshape(
frame[cheight * midwidth: frame_size - cheight * midwidth].reshape(
(args.height - 2 * cheight, args.width, 3)
),
np.concatenate(
(gcorner, frame[frame_size - cheight * midwidth :].reshape((cheight, midwidth, 3)), bcorner),
(gcorner, frame[frame_size - cheight * midwidth:].reshape((cheight, midwidth, 3)), bcorner),
axis=1,
),
)