Compare commits
2 commits
49e4564566
...
9205312ad5
Author | SHA1 | Date | |
---|---|---|---|
|
9205312ad5 | ||
|
3d4862e725 |
7 changed files with 1001 additions and 10 deletions
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
|
@ -0,0 +1,160 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# Custom
|
||||
.DS_Store
|
||||
.idea/
|
||||
Report_Stuff/
|
||||
data/
|
||||
*.mkv
|
126
corner_training/coarse_training.py
Normal file
126
corner_training/coarse_training.py
Normal file
|
@ -0,0 +1,126 @@
|
|||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as transforms_f
|
||||
# import torchvision.transforms.v2 as transforms
|
||||
# import torchvision.transforms.v2.functional as transforms_f
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from models import QuantizedV2
|
||||
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("data_root_path")
|
||||
parser.add_argument("output_dir")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=False)
|
||||
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
lr = 3e-2 # TODO: argparse
|
||||
log_steps = 10
|
||||
train_batch_size = 512
|
||||
num_train_epochs = 10
|
||||
|
||||
# leave 1 shard for eval
|
||||
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
|
||||
if re.fullmatch(r"shard_\d+", path) is not None])
|
||||
|
||||
get_gtruth = get_gtruth_wrapper(4)
|
||||
|
||||
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1])
|
||||
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
# train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True,
|
||||
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=False
|
||||
)
|
||||
|
||||
model = QuantizedV2()
|
||||
|
||||
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
|
||||
|
||||
# loss_func = nn.MSELoss()
|
||||
def loss_func(output, labels):
|
||||
# weights = (labels != 0) * 49.9 + 0.1 # arbitrary nums (good DeepStage1v0 64x64)
|
||||
weights = (labels != 0) * 99.9 + 0.1 # 4 all pts
|
||||
# weights = (labels != 0) * 199.9 + 0.1 # 1 pt 1 layer
|
||||
return (weights * (output - labels) ** 2).mean()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
|
||||
cum_loss = 0
|
||||
model.train()
|
||||
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
|
||||
output = model(batch["imgs"].to(device))
|
||||
loss = loss_func(output, batch["labels"].to(device))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
cum_loss += loss.item()
|
||||
if (i + 1) % log_steps == 0:
|
||||
print(f"{cum_loss / log_steps=}")
|
||||
cum_loss = 0
|
||||
|
||||
if i % log_steps != 0:
|
||||
print(f"{cum_loss / (i % log_steps)=}")
|
||||
|
||||
if epoch > 3:
|
||||
# Freeze quantizer parameters
|
||||
model.apply(torch.ao.quantization.disable_observer)
|
||||
if epoch > 2:
|
||||
# Freeze batch norm mean and variance estimates
|
||||
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
||||
|
||||
# model.eval()
|
||||
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
|
||||
model.to(device)
|
||||
quantized_model.eval() # not sure if this is needed
|
||||
eval_device = "cpu"
|
||||
with torch.no_grad():
|
||||
num_eval_exs = 5
|
||||
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
|
||||
|
||||
eval_loss = 0
|
||||
for i in range(num_eval_exs):
|
||||
eval_ex = eval_dataset[i] # slicing doesn't work yet...
|
||||
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
|
||||
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
|
||||
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
|
||||
axs[i, 1].imshow(eval_ex["labels"])
|
||||
axs[i, 1].set_title("Ground Truth")
|
||||
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
|
||||
axs[i, 2].set_title("Prediction")
|
||||
for ax in axs[i]:
|
||||
ax.axis("off")
|
||||
|
||||
print(f"{eval_loss / num_eval_exs=}")
|
||||
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
|
||||
torch.save(quantized_model.state_dict(),
|
||||
os.path.join(args.output_dir, f"QuantizedV3_Stage1_128_{epoch}.pt"))
|
127
corner_training/fine_training.py
Normal file
127
corner_training/fine_training.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as transforms_f
|
||||
# import torchvision.transforms.v2 as transforms
|
||||
# import torchvision.transforms.v2.functional as transforms_f
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from models import QuantizedV3, QuantizedV2
|
||||
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("data_root_path")
|
||||
parser.add_argument("output_dir")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
os.makedirs(args.output_dir, exist_ok=False)
|
||||
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
random.seed(42)
|
||||
|
||||
lr = 3e-2 # TODO: argparse
|
||||
log_steps = 10
|
||||
train_batch_size = 512
|
||||
num_train_epochs = 10
|
||||
|
||||
# leave 1 shard for eval
|
||||
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
|
||||
if re.fullmatch(r"shard_\d+", path) is not None])
|
||||
|
||||
get_gtruth = get_gtruth_wrapper(4)
|
||||
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping
|
||||
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True
|
||||
)
|
||||
|
||||
# torch.backends.quantized.engine = "qnnpack"
|
||||
model = QuantizedV2()
|
||||
# model.fuse_modules(is_qat=True) # Added for QAT
|
||||
|
||||
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
||||
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
|
||||
|
||||
# model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT
|
||||
# torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT
|
||||
|
||||
def loss_func(output, labels):
|
||||
weights = (labels != 0) * 49.9 + 0.1
|
||||
# weights = (labels != 0) * 99.9 + 0.1
|
||||
return (weights * (output - labels) ** 2).mean()
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
model.to(device)
|
||||
|
||||
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
|
||||
cum_loss = 0
|
||||
model.train()
|
||||
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
|
||||
output = model(batch["imgs"].to(device))
|
||||
loss = loss_func(output, batch["labels"].to(device))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
cum_loss += loss.item()
|
||||
if (i + 1) % log_steps == 0:
|
||||
print(f"{cum_loss / log_steps=}")
|
||||
cum_loss = 0
|
||||
|
||||
if i % log_steps != 0:
|
||||
print(f"{cum_loss / (i % log_steps)=}")
|
||||
|
||||
if epoch > 3:
|
||||
# Freeze quantizer parameters
|
||||
model.apply(torch.ao.quantization.disable_observer)
|
||||
if epoch > 2:
|
||||
# Freeze batch norm mean and variance estimates
|
||||
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
||||
|
||||
# model.eval()
|
||||
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
|
||||
model.to(device)
|
||||
quantized_model.eval() # not sure if this is needed
|
||||
eval_device = "cpu"
|
||||
with torch.no_grad():
|
||||
num_eval_exs = 5
|
||||
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
|
||||
|
||||
eval_loss = 0
|
||||
for i in range(num_eval_exs):
|
||||
eval_ex = eval_dataset[i] # slicing doesn't work yet...
|
||||
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
|
||||
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
|
||||
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
|
||||
axs[i, 1].imshow(eval_ex["labels"])
|
||||
axs[i, 1].set_title("Ground Truth")
|
||||
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
|
||||
axs[i, 2].set_title("Prediction")
|
||||
for ax in axs[i]:
|
||||
ax.axis("off")
|
||||
|
||||
print(f"{eval_loss / num_eval_exs=}")
|
||||
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
|
||||
torch.save(quantized_model.state_dict(),
|
||||
os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))
|
181
corner_training/models.py
Normal file
181
corner_training/models.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class UNetStage1v1(nn.Module):
|
||||
def __init__(self, in_channels=3):
|
||||
super().__init__()
|
||||
|
||||
self.block0 = nn.Sequential(
|
||||
nn.BatchNorm2d(in_channels),
|
||||
nn.Conv2d(in_channels, 32, kernel_size=3, padding="same"),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding="same")
|
||||
)
|
||||
|
||||
self.pool0 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.block1 = nn.Sequential(
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 64, kernel_size=3, padding="same"),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding="same"),
|
||||
)
|
||||
|
||||
self.pool1 = nn.MaxPool2d(kernel_size=2)
|
||||
|
||||
self.block2 = nn.Sequential(
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 128, kernel_size=3, padding="same"),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 128, kernel_size=3, padding="same"),
|
||||
)
|
||||
|
||||
self.upsample0 = nn.Sequential(
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
self.block3 = nn.Sequential(
|
||||
nn.BatchNorm2d(128),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(128, 64, kernel_size=3, padding="same"),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 64, kernel_size=3, padding="same"),
|
||||
)
|
||||
|
||||
self.upsample1 = nn.Sequential(
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2),
|
||||
)
|
||||
|
||||
self.block4 = nn.Sequential(
|
||||
nn.BatchNorm2d(64),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(64, 32, kernel_size=3, padding="same"),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, padding="same"),
|
||||
)
|
||||
|
||||
def forward(self, imgs):
|
||||
assert imgs.dim() == 4, imgs.size()
|
||||
x = self.block0(imgs)
|
||||
downsampled_x0 = self.block1(self.pool0(x))
|
||||
downsampled_x1 = self.block2(self.pool1(downsampled_x0))
|
||||
upsampled_x0 = self.block3(torch.cat([self.upsample0(downsampled_x1), downsampled_x0], dim=1))
|
||||
x = self.block4(torch.cat([self.upsample1(upsampled_x0), x], dim=1))
|
||||
x = torch.max(x, dim=-3)[0] # max over channel dim
|
||||
return torch.sigmoid(x)
|
||||
|
||||
|
||||
class QuantizedV2(nn.Module):
|
||||
""" Normal convolutions with quantization """
|
||||
def __init__(self, in_channels=3):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
torch.quantization.QuantStub(),
|
||||
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 8, kernel_size=3, padding=1),
|
||||
torch.quantization.DeQuantStub(),
|
||||
)
|
||||
|
||||
def forward(self, imgs):
|
||||
x = self.layers(imgs)
|
||||
x = torch.max(x, dim=-3)[0] # max over channel dim
|
||||
return torch.sigmoid(x)
|
||||
|
||||
def fuse_modules(self, is_qat=False):
|
||||
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
|
||||
fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)
|
||||
|
||||
|
||||
class QuantizedV3(nn.Module):
|
||||
""" Depthwise convs except input layer; no inverted bottleneck """
|
||||
|
||||
def __init__(self, in_channels=3):
|
||||
super().__init__()
|
||||
|
||||
hidden_size = 32
|
||||
self.quant = torch.quantization.QuantStub()
|
||||
self.dequant = torch.quantization.DeQuantStub()
|
||||
self.input_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(hidden_size),
|
||||
)
|
||||
|
||||
self.block1 = nn.Sequential(
|
||||
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=2, groups=hidden_size, padding=2),
|
||||
nn.BatchNorm2d(hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(hidden_size, hidden_size, kernel_size=1, padding=0),
|
||||
nn.BatchNorm2d(hidden_size),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.block2 = nn.Sequential(
|
||||
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, dilation=1, groups=hidden_size, padding=1),
|
||||
nn.BatchNorm2d(hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(hidden_size, 8, kernel_size=1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, imgs):
|
||||
x = self.quant(imgs)
|
||||
x = self.input_conv(x)
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.dequant(x)
|
||||
x = torch.max(x, dim=-3)[0] # max over channel dim
|
||||
return torch.sigmoid(x)
|
||||
|
||||
def fuse_modules(self, is_qat=False):
|
||||
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
|
||||
fuse_modules(self, [
|
||||
["input_conv.0", "input_conv.1"],
|
||||
["block1.0", "block1.1", "block1.2"],
|
||||
["block1.3", "block1.4", "block1.5"],
|
||||
["block2.0", "block2.1", "block2.2"],
|
||||
], inplace=True)
|
||||
|
||||
|
||||
class QuantizedV5(nn.Module):
|
||||
"""normal convs, biasless"""
|
||||
def __init__(self, in_channels=3):
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.Sequential(
|
||||
torch.quantization.QuantStub(),
|
||||
# nn.BatchNorm2d(in_channels),
|
||||
nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 32, kernel_size=3, dilation=2, padding=2, bias=False),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(32, 8, kernel_size=3, padding=1),
|
||||
torch.quantization.DeQuantStub(),
|
||||
)
|
||||
|
||||
def forward(self, imgs):
|
||||
x = self.layers(imgs)
|
||||
x = torch.max(x, dim=-3)[0] # max over channel dim
|
||||
return torch.sigmoid(x)
|
||||
|
||||
def fuse_modules(self, is_qat=False):
|
||||
fuse_modules = torch.ao.quantization.fuse_modules_qat if is_qat else torch.ao.quantization.fuse_modules
|
||||
fuse_modules(self, [[f"layers.{i}", f"layers.{i+1}", f"layers.{i+2}"] for i in range(1, 6, 3)], inplace=True)
|
224
corner_training/utils.py
Normal file
224
corner_training/utils.py
Normal file
|
@ -0,0 +1,224 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as transforms_f
|
||||
# import torchvision.transforms.v2 as transforms
|
||||
# import torchvision.transforms.v2.functional as transforms_f
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def get_gaussian_filter(sigma, half_window_size) -> torch.Tensor:
|
||||
full_window_size = half_window_size * 2 + 1
|
||||
x_sq_offsets = torch.square(torch.arange(full_window_size) - half_window_size) \
|
||||
.expand((full_window_size, full_window_size))
|
||||
y_sq_offsets = x_sq_offsets.T
|
||||
gaussian_filter = torch.exp(-(x_sq_offsets + y_sq_offsets) / sigma) # "normalized" so that the peak is 1
|
||||
return gaussian_filter
|
||||
|
||||
|
||||
def get_bounded_slices(size_a: torch.Size, size_b: torch.Size, i: int, j: int) \
|
||||
-> tuple[tuple[slice, slice], tuple[slice, slice]]:
|
||||
"""
|
||||
Args:
|
||||
size_a: size of the first 2D tensor (`a_tensor`)
|
||||
size_b: size of the second 2D tensor (`tensor_b`),
|
||||
should be at most `size_a`, and have odd height and width (so that center is clearly defined)
|
||||
i: the row of `tensor_a` to center `tensor_b` on
|
||||
j: the column of `tensor_a` to center `tensor_b` on
|
||||
|
||||
Returns:
|
||||
((row_slice_a, col_slice_a), (row_slice_b, col_slice_b))
|
||||
which are as large as possible (at most the size of `size_b`) while remaining in bounds.
|
||||
"""
|
||||
assert len(size_a) == 2 and len(size_b) == 2 \
|
||||
and size_b[0] % 2 == size_b[1] % 2 == 1
|
||||
half_mask_width, half_mask_height = size_b[1] // 2, size_b[0] // 2
|
||||
left_offset = max(0, j - half_mask_width)
|
||||
right_offset = min(size_a[1], j + half_mask_width + 1) # exclusive
|
||||
top_offset = max(0, i - half_mask_height)
|
||||
bottom_offset = min(size_a[0], i + half_mask_height + 1)
|
||||
|
||||
return (slice(top_offset, bottom_offset), slice(left_offset, right_offset)), \
|
||||
(slice(half_mask_height + top_offset - i, half_mask_height + bottom_offset - i),
|
||||
slice(half_mask_width + left_offset - j, half_mask_width + right_offset - j))
|
||||
|
||||
|
||||
def get_gtruth_wrapper(sigma: int, display_pts_inds: list[int] = None):
|
||||
"""
|
||||
Args:
|
||||
sigma: variance for the Gaussian smoothing.
|
||||
display_pts_inds: If specified, will only display the points corresponding to these indices in `transformed_pts`
|
||||
Otherwise, will display all points by default.
|
||||
|
||||
One main reason for using a closure is to reuse the Gaussian mask for efficiency.
|
||||
"""
|
||||
|
||||
half_window_size = sigma
|
||||
# half_window_size = 2 * sigma
|
||||
# half_window_size = 3 * sigma
|
||||
# half_window_size = int(3 * sigma / math.sqrt(2))
|
||||
|
||||
gaussian_filter = get_gaussian_filter(sigma, half_window_size)
|
||||
|
||||
# def get_gtruth(transformed_pts: torchvision.tv_tensors.BoundingBoxes, img_size: tuple) -> torch.Tensor:
|
||||
def get_gtruth(transformed_pts: torch.Tensor, img_size: tuple) -> torch.Tensor:
|
||||
"""
|
||||
Converts coordinates of points to a heatmap with Gaussians centered at those coordinates.
|
||||
|
||||
Args: # TODO doc
|
||||
transformed_pts:
|
||||
img_size:
|
||||
"""
|
||||
smoothed_gtruth = torch.zeros(img_size)
|
||||
|
||||
for ind in range(len(transformed_pts)) if display_pts_inds is None else display_pts_inds:
|
||||
# x, y, _, _ = transformed_pts[ind]
|
||||
x, y = transformed_pts[ind]
|
||||
|
||||
gtruth_slice, gaussian_slice = get_bounded_slices(smoothed_gtruth.size(), gaussian_filter.size(), y, x)
|
||||
smoothed_gtruth[gtruth_slice] = torch.maximum( # should modify smoothed_gtruth in-place
|
||||
smoothed_gtruth[gtruth_slice],
|
||||
gaussian_filter[gaussian_slice]
|
||||
)
|
||||
|
||||
return smoothed_gtruth
|
||||
|
||||
return get_gtruth
|
||||
|
||||
|
||||
class FlyingFramesDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
Different from the `FlyingFramesDataset` in Simulate_Data, which
|
||||
randomly generates each frame. Here we read the saved examples.
|
||||
"""
|
||||
def __init__(self, data_root_path: str, get_gtruth: Callable, shard_paths: list[str] = None, eval: bool = False):
|
||||
self.data_root_path, self.get_gtruth = data_root_path, get_gtruth
|
||||
with open(os.path.join(self.data_root_path, "configs.json"), "r") as f:
|
||||
self.configs = json.load(f)
|
||||
|
||||
assert self.configs["dataset_size"] % self.configs["shard_size"] == 0 # simplifying assumption
|
||||
self.shard_size = self.configs["shard_size"]
|
||||
|
||||
if shard_paths is None: # use all shards
|
||||
self.shard_paths = sorted([path for path in os.listdir(data_root_path)
|
||||
if re.fullmatch(r"shard_\d+", path) is not None])
|
||||
else:
|
||||
self.shard_paths = shard_paths
|
||||
|
||||
assert len(self.shard_paths) * self.shard_size <= self.configs["dataset_size"], (len(self.shard_paths), self.shard_size, self.configs["dataset_size"])
|
||||
|
||||
identity_transform = transforms.Lambda(lambda x: x) # for eval=True
|
||||
# Some last-mile stuff
|
||||
# self.img_transforms = transforms.Compose([
|
||||
# transforms.PILToTensor(),
|
||||
# # transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
|
||||
# # if not eval else identity_transform,
|
||||
# transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
|
||||
# # transforms.Grayscale(),
|
||||
# # normalize?
|
||||
# transforms.ToDtype(torch.float32, scale=True),
|
||||
# # transforms.ToPILImage(), # for visualization
|
||||
# ])
|
||||
|
||||
# self.unified_transforms = transforms.Compose([
|
||||
# # transforms.Resize(self.img_size),
|
||||
# transforms.RandomHorizontalFlip(),
|
||||
# transforms.RandomVerticalFlip(),
|
||||
# transforms.RandomApply([
|
||||
# transforms.GaussianBlur(kernel_size=(3, 3))
|
||||
# ], p=0.33), # TODO: tune
|
||||
# ]) if not eval else identity_transform
|
||||
|
||||
self.img_transforms = transforms.Compose([
|
||||
transforms.PILToTensor(),
|
||||
# transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05))
|
||||
# if not eval else identity_transform,
|
||||
transforms.ColorJitter((0.875, 1.125), (0.5, 1.5), (0.5, 1.5), (-0.05, 0.05)), # color jitter even in eval
|
||||
# transforms.Grayscale(),
|
||||
# normalize?
|
||||
transforms.ConvertImageDtype(torch.float32),
|
||||
# transforms.ToPILImage(), # for visualization
|
||||
])
|
||||
|
||||
self.unified_transforms = transforms.Compose([
|
||||
# transforms.Resize(self.img_size),
|
||||
unified_hflip,
|
||||
unified_vflip,
|
||||
lambda img_and_pts: (transforms.RandomApply([
|
||||
transforms.GaussianBlur(kernel_size=(3, 3))
|
||||
], p=0.33)(img_and_pts[0]), img_and_pts[1]), # TODO: tune
|
||||
]) if not eval else identity_transform
|
||||
|
||||
def __getitem__(self, idx):
|
||||
shard_path = self.shard_paths[idx // self.shard_size]
|
||||
shard_ind = int(re.fullmatch(r"shard_(?P<shard_ind>\d+)", shard_path).group("shard_ind"))
|
||||
img_ind = shard_ind * self.shard_size + idx % self.shard_size
|
||||
img_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}.jpg")
|
||||
bbox_path = os.path.join(self.data_root_path, shard_path, f"{img_ind}_tensor.pt")
|
||||
|
||||
assert os.path.isfile(img_path) and os.path.isfile(bbox_path), (img_path, bbox_path)
|
||||
img = self.img_transforms(PIL.Image.open(img_path))
|
||||
bbox = torch.load(bbox_path)
|
||||
|
||||
img, bbox = self.unified_transforms((img, bbox))
|
||||
|
||||
return {
|
||||
"imgs": img,
|
||||
"labels": self.get_gtruth(bbox, self.configs["img_size"]),
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.shard_paths) * self.shard_size
|
||||
|
||||
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
"""
|
||||
Taken from https://github.com/tugstugi/pytorch-saltnet/blob/master/utils/lr_scheduler.py
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_steps): # steps, not ratio
|
||||
self.warmup_steps = warmup_steps
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
last_epoch = max(1, self.last_epoch)
|
||||
scale = self.warmup_steps ** 0.5 * min(last_epoch ** (-0.5), last_epoch * self.warmup_steps ** (-1.5))
|
||||
return [base_lr * scale for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
def unified_hflip(img_and_pts):
|
||||
img, corner_pts = img_and_pts
|
||||
if random.random() < 0.5:
|
||||
return img, corner_pts
|
||||
|
||||
_, h, w = img.size()
|
||||
img = transforms_f.hflip(img)
|
||||
|
||||
# (x, y) -> (w - x, y)
|
||||
corner_pts = torch.stack([w - corner_pts[:, 0], corner_pts[:, 1]], dim=1)
|
||||
return img, corner_pts
|
||||
|
||||
|
||||
def unified_vflip(img_and_pts):
|
||||
img, corner_pts = img_and_pts
|
||||
if random.random() < 0.5:
|
||||
return img, corner_pts
|
||||
|
||||
_, h, w = img.size()
|
||||
img = transforms_f.vflip(img)
|
||||
|
||||
# (x, y) -> (x, h - y)
|
||||
corner_pts = torch.stack([corner_pts[:, 0], h - corner_pts[:, 1]], dim=1)
|
||||
return img, corner_pts
|
151
decoder_cnn.py
Normal file
151
decoder_cnn.py
Normal file
|
@ -0,0 +1,151 @@
|
|||
import argparse
|
||||
import traceback
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from creedsolo import RSCodec
|
||||
from matplotlib import pyplot as plt
|
||||
from raptorq import Decoder
|
||||
|
||||
from corner_training.models import QuantizedV2, QuantizedV5
|
||||
from decoding_utils import localize_corners_wrapper
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("-i", "--input", help="camera device index or input video file", default=0)
|
||||
parser.add_argument("-o", "--output", help="output file for decoded data", default="out")
|
||||
parser.add_argument("-x", "--height", help="grid height", default=100, type=int)
|
||||
parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
|
||||
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
|
||||
parser.add_argument("-s", "--size", help="number of bytes to decode", type=int)
|
||||
parser.add_argument("-p", "--psize", help="packet size", type=int)
|
||||
parser.add_argument("-v", "--version",
|
||||
help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
|
||||
default=0, choices=[0, 1], type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.version == 1
|
||||
# cell borders are 0.0375% of width/height
|
||||
assert args.height * 3 % 80 == args.width * 3 % 80 == 0
|
||||
cheight = int(args.height * 0.15)
|
||||
cwidth = int(args.width * 0.15)
|
||||
|
||||
frame_size = args.height * args.width - 4 * cheight * cwidth
|
||||
frame_bytes = frame_size * 3 // 8
|
||||
frame_xor = np.arange(frame_bytes, dtype=np.uint8)
|
||||
rs_bytes = frame_bytes - (frame_bytes + 254) // 255 * int(args.level * 255) - 4
|
||||
|
||||
rsc = RSCodec(int(args.level * 255))
|
||||
decoder = Decoder.with_defaults(args.size, rs_bytes)
|
||||
|
||||
|
||||
stage1_model_checkpt_path = "/Users/kevinzhao/Downloads/QuantizedV0_Stage1_128_9.pt"
|
||||
|
||||
stage1_model = QuantizedV2()
|
||||
stage1_model.eval()
|
||||
stage1_model.fuse_modules(is_qat=False)
|
||||
stage1_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage1_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage1_model, inplace=True)
|
||||
stage1_model.load_state_dict(torch.load(stage1_model_checkpt_path, map_location=torch.device('cpu')))
|
||||
|
||||
stage2_model = QuantizedV5()
|
||||
stage2_model.eval()
|
||||
stage2_model.fuse_modules(is_qat=False)
|
||||
stage2_model.qconfig = torch.ao.quantization.default_qconfig
|
||||
torch.ao.quantization.prepare(stage2_model, inplace=True)
|
||||
torch.ao.quantization.convert(stage2_model, inplace=True)
|
||||
stage2_model.load_state_dict(torch.load("/Users/kevinzhao/Downloads/QuantizedV5_Stage2_128_9.pt", map_location=torch.device('cpu')))
|
||||
|
||||
# stage1_size = 128
|
||||
# stage2_size = 128
|
||||
|
||||
stage1_size = 128
|
||||
stage2_size = 64
|
||||
|
||||
input_crop_size = 1024
|
||||
|
||||
localize_corners = localize_corners_wrapper(stage1_model, stage2_model, stage1_size, stage2_size)
|
||||
|
||||
if args.input.isdecimal():
|
||||
args.input = int(args.input)
|
||||
cap = cv2.VideoCapture(args.input)
|
||||
data = None
|
||||
while data is None:
|
||||
try:
|
||||
ret, raw_frame = cap.read()
|
||||
if not ret:
|
||||
print("End of stream")
|
||||
break
|
||||
# # raw_frame is a uint8 BE CAREFUL
|
||||
# if type(args.input) == int:
|
||||
# # Crop image to reduce camera distortion
|
||||
# X, Y = raw_frame.shape[:2]
|
||||
# raw_frame = raw_frame[X // 4 : 3 * X // 4, Y // 4 : 3 * Y // 4]
|
||||
# cv2.imshow("", raw_frame)
|
||||
# cv2.waitKey(1)
|
||||
# raw_frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
h, w, _ = raw_frame.shape
|
||||
cropped_frame = raw_frame[(h - input_crop_size) // 2:-(h - input_crop_size) // 2,
|
||||
(w - input_crop_size) // 2:-(w - input_crop_size) // 2]
|
||||
cropped_frame = cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB)
|
||||
(widx, ridx, gidx, bidx), (wcol, rcol, gcol, bcol) = localize_corners(cropped_frame)
|
||||
|
||||
widx = widx[::-1]
|
||||
ridx = ridx[::-1]
|
||||
gidx = gidx[::-1]
|
||||
bidx = bidx[::-1]
|
||||
# plt.imshow(cropped_frame)
|
||||
# plt.scatter([widx[1]], [widx[0]], color="r")
|
||||
# plt.scatter([ridx[1]], [ridx[0]], color="g")
|
||||
# plt.scatter([gidx[1]], [gidx[0]], color="b")
|
||||
# plt.scatter([bidx[1]], [bidx[0]], color="w")
|
||||
# plt.show()
|
||||
|
||||
# Find basis of color space
|
||||
origin = (rcol + gcol + bcol - wcol) / 2
|
||||
rcol -= origin
|
||||
gcol -= origin
|
||||
bcol -= origin
|
||||
F = 255 * np.linalg.inv(np.stack((rcol, gcol, bcol)).T)
|
||||
|
||||
# cch = cheight / 2 - 1
|
||||
# ccw = cwidth / 2 - 1
|
||||
cch = cheight / 4 - 1
|
||||
ccw = cwidth / 4 - 1
|
||||
|
||||
M = cv2.getPerspectiveTransform(
|
||||
np.float32([np.flip(widx), np.flip(ridx), np.flip(gidx), np.flip(bidx)]),
|
||||
np.float32(
|
||||
[
|
||||
[ccw, cch],
|
||||
[args.width - ccw - 1, cch],
|
||||
[ccw, args.height - cch - 1],
|
||||
[args.width - ccw - 1, args.height - cch - 1],
|
||||
]
|
||||
),
|
||||
)
|
||||
frame = cv2.warpPerspective(cropped_frame, M, (args.width, args.height))
|
||||
# # Convert to new color space
|
||||
# frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 128).astype(np.uint8)
|
||||
frame = (np.squeeze(F @ (frame - origin)[..., np.newaxis]) >= 192).astype(np.uint8)
|
||||
# import matplotlib.pyplot as plt
|
||||
# # plt.imshow(frame * 255)
|
||||
# plt.imshow((1 - frame) * 255)
|
||||
# plt.show()
|
||||
frame = np.concatenate(
|
||||
(
|
||||
frame[:cheight, cwidth : args.width - cwidth].flatten(),
|
||||
frame[cheight : args.height - cheight].flatten(),
|
||||
frame[args.height - cheight :, cwidth : args.width - cwidth].flatten(),
|
||||
)
|
||||
)
|
||||
data = decoder.decode(bytes(rsc.decode(bytearray(np.packbits(frame) ^ frame_xor))[0][: args.psize]))
|
||||
print("Decoded frame")
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except:
|
||||
traceback.print_exc()
|
||||
with open(args.output, "wb") as f:
|
||||
f.write(data)
|
||||
cap.release()
|
42
encoder.py
42
encoder.py
|
@ -12,10 +12,21 @@ parser.add_argument("-y", "--width", help="grid width", default=100, type=int)
|
|||
parser.add_argument("-l", "--level", help="error correction level", default=0.1, type=float)
|
||||
parser.add_argument("-f", "--fps", help="frame rate", default=30, type=int)
|
||||
parser.add_argument("-m", "--mix", help="mix frames with original video", action="store_true")
|
||||
parser.add_argument("-v", "--version",
|
||||
help="0: 10% corners w/ two-sided one-cell padding; 1: 15% corners w/ four-sided 25% padding.",
|
||||
default=0, choices=[0, 1], type=int)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.version == 0:
|
||||
cheight = cwidth = max(args.height // 10, args.width // 10)
|
||||
elif args.version == 1:
|
||||
# cell borders are 0.0375% of width/height
|
||||
assert args.height * 3 % 80 == args.width * 3 % 80 == 0 # TODO: less strict better ratio
|
||||
cheight = int(args.height * 0.15)
|
||||
cwidth = int(args.width * 0.15)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
cheight = cwidth = max(args.height // 10, args.width // 10)
|
||||
midwidth = args.width - 2 * cwidth
|
||||
frame_size = args.height * args.width - 4 * cheight * cwidth
|
||||
# Divide by 8 / 3 for 3-bit color
|
||||
|
@ -32,15 +43,26 @@ encoder = Encoder.with_defaults(data, rs_bytes)
|
|||
packets = encoder.get_encoded_packets(int(len(data) / rs_bytes * (1 / (1 - args.level) - 1)))
|
||||
|
||||
# Make corners
|
||||
ones = np.ones((cheight - 1, cwidth - 1))
|
||||
zeros = np.zeros((cheight - 1, cwidth - 1))
|
||||
wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
|
||||
rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
|
||||
gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
|
||||
bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
|
||||
if args.version == 0:
|
||||
ones = np.ones((cheight - 1, cwidth - 1))
|
||||
zeros = np.zeros((cheight - 1, cwidth - 1))
|
||||
wcorner = np.pad(np.dstack((ones, ones, ones)), ((0, 1), (0, 1), (0, 0)))
|
||||
rcorner = np.pad(np.dstack((ones, zeros, zeros)), ((0, 1), (1, 0), (0, 0)))
|
||||
gcorner = np.pad(np.dstack((zeros, ones, zeros)), ((1, 0), (0, 1), (0, 0)))
|
||||
bcorner = np.pad(np.dstack((zeros, zeros, ones)), ((1, 0), (1, 0), (0, 0)))
|
||||
elif args.version == 1:
|
||||
zeros = np.zeros((cheight, cwidth, 3))
|
||||
wcorner = zeros.copy()
|
||||
rcorner = zeros.copy()
|
||||
gcorner = zeros.copy()
|
||||
bcorner = zeros.copy()
|
||||
black_border_h, black_border_w = cheight // 4, cwidth // 4
|
||||
for corner_arr, ones_channel_ind in [(wcorner, 0), (wcorner, 1), (wcorner, 2),
|
||||
(rcorner, 0), (gcorner, 1), (bcorner, 2)]:
|
||||
corner_arr[black_border_h:-black_border_h, black_border_w:-black_border_w, ones_channel_ind] = np.ones((cheight // 2, cwidth // 2))
|
||||
|
||||
# Output flags for decoder
|
||||
print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])}", end="")
|
||||
print(f"-x {args.height} -y {args.width} -l {args.level} -s {len(data)} -p {len(packets[0])} -v {args.version}", end="")
|
||||
|
||||
|
||||
def mkframe(packet):
|
||||
|
@ -55,11 +77,11 @@ def mkframe(packet):
|
|||
(wcorner, frame[: cheight * midwidth].reshape((cheight, midwidth, 3)), rcorner),
|
||||
axis=1,
|
||||
),
|
||||
frame[cheight * midwidth : frame_size - cheight * midwidth].reshape(
|
||||
frame[cheight * midwidth: frame_size - cheight * midwidth].reshape(
|
||||
(args.height - 2 * cheight, args.width, 3)
|
||||
),
|
||||
np.concatenate(
|
||||
(gcorner, frame[frame_size - cheight * midwidth :].reshape((cheight, midwidth, 3)), bcorner),
|
||||
(gcorner, frame[frame_size - cheight * midwidth:].reshape((cheight, midwidth, 3)), bcorner),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue