6.8301-Project/corner_training/fine_training.py
2024-05-05 10:33:51 -04:00

128 lines
4.5 KiB
Python

import argparse
import json
import math
import os
import random
import re
import matplotlib.pyplot as plt
import numpy as np
import PIL
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as transforms_f
# import torchvision.transforms.v2 as transforms
# import torchvision.transforms.v2.functional as transforms_f
from tqdm.auto import tqdm
from models import QuantizedV3, QuantizedV2
from utils import FlyingFramesDataset, NoamLR, get_gtruth_wrapper
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("data_root_path")
parser.add_argument("output_dir")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
os.makedirs(args.output_dir, exist_ok=False)
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)
lr = 3e-2 # TODO: argparse
log_steps = 10
train_batch_size = 512
num_train_epochs = 10
# leave 1 shard for eval
shard_paths = sorted([path for path in os.listdir(args.data_root_path)
if re.fullmatch(r"shard_\d+", path) is not None])
get_gtruth = get_gtruth_wrapper(4)
train_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[:-1], eval=True) # note: not flipping
eval_dataset = FlyingFramesDataset(args.data_root_path, get_gtruth, shard_paths[-1:], eval=True)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=train_batch_size, num_workers=2, pin_memory=True
)
# torch.backends.quantized.engine = "qnnpack"
model = QuantizedV2()
# model.fuse_modules(is_qat=True) # Added for QAT
print(f"{sum(p.numel() for p in model.parameters() if p.requires_grad)=}")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = NoamLR(optimizer, warmup_steps=log_steps * 3)
# model.qconfig = torch.ao.quantization.default_qconfig # Added for QAT
# torch.ao.quantization.prepare_qat(model, inplace=True) # Added for QAT
def loss_func(output, labels):
weights = (labels != 0) * 49.9 + 0.1
# weights = (labels != 0) * 99.9 + 0.1
return (weights * (output - labels) ** 2).mean()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
for epoch in tqdm(range(num_train_epochs), position=0, leave=True):
cum_loss = 0
model.train()
for i, batch in enumerate(tqdm(train_dataloader, position=1, leave=False)):
output = model(batch["imgs"].to(device))
loss = loss_func(output, batch["labels"].to(device))
loss.backward()
optimizer.step()
optimizer.zero_grad()
cum_loss += loss.item()
if (i + 1) % log_steps == 0:
print(f"{cum_loss / log_steps=}")
cum_loss = 0
if i % log_steps != 0:
print(f"{cum_loss / (i % log_steps)=}")
if epoch > 3:
# Freeze quantizer parameters
model.apply(torch.ao.quantization.disable_observer)
if epoch > 2:
# Freeze batch norm mean and variance estimates
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
# model.eval()
quantized_model = torch.ao.quantization.convert(model.cpu().eval(), inplace=False)
model.to(device)
quantized_model.eval() # not sure if this is needed
eval_device = "cpu"
with torch.no_grad():
num_eval_exs = 5
fig, axs = plt.subplots(num_eval_exs, 3, figsize=(12, 4 * num_eval_exs))
eval_loss = 0
for i in range(num_eval_exs):
eval_ex = eval_dataset[i] # slicing doesn't work yet...
axs[i, 0].imshow(transforms_f.to_pil_image(eval_ex["imgs"]))
preds = quantized_model(eval_ex["imgs"].to(eval_device).unsqueeze(0))
eval_loss += loss_func(preds, eval_ex["labels"].to(eval_device)).item()
axs[i, 1].imshow(eval_ex["labels"])
axs[i, 1].set_title("Ground Truth")
axs[i, 2].imshow(preds.detach().cpu().numpy().squeeze(0))
axs[i, 2].set_title("Prediction")
for ax in axs[i]:
ax.axis("off")
print(f"{eval_loss / num_eval_exs=}")
plt.savefig(os.path.join(args.output_dir, f"validation_results{epoch}.png"), bbox_inches="tight")
torch.save(quantized_model.state_dict(),
os.path.join(args.output_dir, f"QuantizedV3_Stage2_128_{epoch}.pt"))