import argparse import json import math import os import random import re import matplotlib.pyplot as plt import numpy as np import PIL import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms import torchvision.transforms.functional as transforms_f # import torchvision.transforms.v2 as transforms # import torchvision.transforms.v2.functional as transforms_f from tqdm.auto import tqdm from models import QuantizedV3, QuantizedV2 from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("data_root_path") parser.add_argument("output_dir") return parser.parse_args() if __name__ == "__main__": args = parse_args() os.makedirs(args.output_dir, exist_ok=False) np.random.seed(42) torch.manual_seed(42) random.seed(42) lr = 3e-2 # TODO: argparse log_steps = 10 train_batch_size = 512 num_train_epochs = 10 # leave 1 shard for eval shard_paths = sorted([path for path in os.listdir(args.data_root_path) if re.fullmatch(r"shard_\d+", path) is not None]) get_gtruth = get_gtruth_wrapper(4) train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True) train_dataloader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True ) # torch.backends.quantized.engine = "qnnpack" model = QuantizedV2() # model.fuse_modules(is_qat=True) # Added for QAT print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}") optimizer = torch.optim.AdamW(model.parameters(), lr=lr) scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3) # model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT # torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT def loss_func(output, labels): weights = (labels != 0) * 49.9 + 0.1 # weights = (labels != 0) * 99.9 + 0.1 return (weights * (output - labels) ** 2).mean() device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) for epoch in tqdm(range(num_train_epochs), position=0, leave=True): cum_loss = 0 model.train() for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)): output = model(batch["imgs"].to(device)) loss = loss_func(output, batch["labels"].to(device)) loss.backward() optimizer.step() optimizer.zero_grad() cum_loss += loss.item() if (i + 1) % log_steps == 0: print(f"{cum_loss / log_steps=}") cum_loss = 0 if i % log_steps != 0: print(f"{cum_loss / (i % log_steps)=}") if epoch > 3: # Freeze quantizer parameters model.apply(torch.ao.quantization.disable_observer) if epoch > 2: # Freeze batch norm mean and variance estimates model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) # model.eval() quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False) model.to(device) quantized_model.eval() # not sure if this is needed eval_device = "cpu" with torch.no_grad(): num_eval_exs = 5 fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs)) eval_loss = 0 for i in range(num_eval_exs): eval_ex = eval_dataset[i] # slicing doesn't work yet... axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"])) preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0)) eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item() axs[i, 1].imshow(eval_ex["labels"]) axs[i, 1].set_title("Ground Truth") axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0)) axs[i, 2].set_title("Prediction") for ax in axs[i]: ax.axis("off") print(f"{eval_loss / num_eval_exs=}") plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight") torch.save(quantized_model.state_dict(), os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))