128 lines
4.5 KiB
Python
128 lines
4.5 KiB
Python
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
import re
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import torchvision.transforms.functional as transforms_f
|
|
# import torchvision.transforms.v2 as transforms
|
|
# import torchvision.transforms.v2.functional as transforms_f
|
|
from tqdm.auto import tqdm
|
|
|
|
from models import QuantizedV3, QuantizedV2
|
|
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("data_root_path")
|
|
parser.add_argument("output_dir")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
os.makedirs(args.output_dir, exist_ok=False)
|
|
|
|
np.random.seed(42)
|
|
torch.manual_seed(42)
|
|
random.seed(42)
|
|
|
|
lr = 3e-2 # TODO: argparse
|
|
log_steps = 10
|
|
train_batch_size = 512
|
|
num_train_epochs = 10
|
|
|
|
# leave 1 shard for eval
|
|
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
|
|
if re.fullmatch(r"shard_\d+", path) is not None])
|
|
|
|
get_gtruth = get_gtruth_wrapper(4)
|
|
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping
|
|
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True
|
|
)
|
|
|
|
# torch.backends.quantized.engine = "qnnpack"
|
|
model = QuantizedV2()
|
|
# model.fuse_modules(is_qat=True) # Added for QAT
|
|
|
|
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
|
|
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
|
|
|
|
# model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT
|
|
# torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT
|
|
|
|
def loss_func(output, labels):
|
|
weights = (labels != 0) * 49.9 + 0.1
|
|
# weights = (labels != 0) * 99.9 + 0.1
|
|
return (weights * (output - labels) ** 2).mean()
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
model.to(device)
|
|
|
|
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
|
|
cum_loss = 0
|
|
model.train()
|
|
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
|
|
output = model(batch["imgs"].to(device))
|
|
loss = loss_func(output, batch["labels"].to(device))
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
cum_loss += loss.item()
|
|
if (i + 1) % log_steps == 0:
|
|
print(f"{cum_loss / log_steps=}")
|
|
cum_loss = 0
|
|
|
|
if i % log_steps != 0:
|
|
print(f"{cum_loss / (i % log_steps)=}")
|
|
|
|
if epoch > 3:
|
|
# Freeze quantizer parameters
|
|
model.apply(torch.ao.quantization.disable_observer)
|
|
if epoch > 2:
|
|
# Freeze batch norm mean and variance estimates
|
|
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
|
|
|
|
# model.eval()
|
|
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
|
|
model.to(device)
|
|
quantized_model.eval() # not sure if this is needed
|
|
eval_device = "cpu"
|
|
with torch.no_grad():
|
|
num_eval_exs = 5
|
|
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
|
|
|
|
eval_loss = 0
|
|
for i in range(num_eval_exs):
|
|
eval_ex = eval_dataset[i] # slicing doesn't work yet...
|
|
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
|
|
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
|
|
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
|
|
axs[i, 1].imshow(eval_ex["labels"])
|
|
axs[i, 1].set_title("Ground Truth")
|
|
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
|
|
axs[i, 2].set_title("Prediction")
|
|
for ax in axs[i]:
|
|
ax.axis("off")
|
|
|
|
print(f"{eval_loss / num_eval_exs=}")
|
|
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
|
|
torch.save(quantized_model.state_dict(),
|
|
os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))
|