Init model

This commit is contained in:
xtex 2024-08-27 19:51:14 +08:00
parent 8d74bff9e6
commit 11d2882c3f
Signed by: xtex
GPG key ID: B918086ED8045B91
14 changed files with 1608 additions and 550 deletions

1
.gitignore vendored
View file

@ -1,2 +1,3 @@
/target
*.png
/output

1380
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,9 +1,9 @@
[workspace]
resolver = "2"
members = [
members = [ "cli",
"dataset",
"dataset-collector",
"model"
"model",
]
[workspace.package]
@ -14,8 +14,11 @@ repository = "https://codeberg.org/xtex/scra"
license = "Apache-2.0"
[workspace.dependencies]
burn = { version = "0.13.2" }
burn = { version = "0.14.0", git = "https://github.com/tracel-ai/burn" }
clap = { version = "4.5.16", features = ["derive"] }
tokio = { version = "1.39.3" }
anyhow = { version = "1.0.86", features = ["backtrace"] }
thiserror = { version = "1.0.63" }
image = { version = "0.25.2", default-features = false, features = ["png"] }
log = { version = "0.4.22", features = ["std"] }
serde = { version = "1.0.209", default-features = false, features = ["derive"] }

23
cli/Cargo.toml Normal file
View file

@ -0,0 +1,23 @@
[package]
name = "scra-mirach-cli"
version = "0.1.0"
description = "SCRA: Mirach - Command-line utilities"
edition.workspace = true
authors.workspace = true
homepage.workspace = true
repository.workspace = true
license.workspace = true
[dependencies]
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui"] }
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
scra-mirach-dataset = { path = "../dataset" }
anyhow = { workspace = true }
clap = { workspace = true }
log = { workspace = true }
env_logger = "0.11.5"
serde = { workspace = true }
serde_json = "1.0.127"
[features]
libtorch = ["burn/tch"]

73
cli/src/main.rs Normal file
View file

@ -0,0 +1,73 @@
use std::{env, fs, path::PathBuf};
use anyhow::Result;
use clap::Parser;
use scra_mirach_model::MirachModelConfig;
use serde::{Deserialize, Serialize};
use train::TrainArgs;
pub mod train;
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about)]
struct Args {
#[command(subcommand)]
action: Action,
#[arg(short='b', long, value_enum,global=true, default_value_t = Backend::Wgpu, help="Tensor backend")]
pub backend: Backend,
#[cfg(feature = "libtorch")]
#[arg(
long,
global = true,
required_if_eq("backend", "Backend::LibTorchCuda"),
help = "GPU index to use for libtorch CUDA backend"
)]
pub cuda_index: Option<usize>,
#[arg(short = 'c', long, global = true, help = "Model configuration file")]
pub model_config: Option<PathBuf>,
}
#[derive(Debug, Clone, clap::Subcommand)]
enum Action {
Train(TrainArgs),
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Backend {
Wgpu,
#[cfg(feature = "libtorch")]
LibTorchCpu,
#[cfg(feature = "libtorch")]
LibTorchMps,
#[cfg(feature = "libtorch")]
LibTorchVulkan,
#[cfg(feature = "libtorch")]
LibTorchCuda,
}
impl Args {
pub fn model_config(path: &Option<PathBuf>) -> Result<MirachModelConfig> {
match path {
Some(path) => Ok(serde_json::from_str(&fs::read_to_string(path)?)?),
None => Ok(MirachModelConfig::new()),
}
}
}
fn main() -> Result<()> {
let args = Args::parse();
env_logger::builder()
.format_timestamp_secs()
.parse_env(env_logger::Env::new().default_filter_or(
"scra_mirach_cli=info,scra_mirach_model=info,scra_mirach_dataset=info,burn_train=warn",
))
.target(env_logger::Target::Stdout)
.try_init()?;
env::set_var("RUST_BACKTRACE", "true");
match args.action {
Action::Train(args) => train::run(args),
}
}

134
cli/src/train.rs Normal file
View file

@ -0,0 +1,134 @@
use std::{fs, path::PathBuf, time::Instant};
use anyhow::Result;
#[cfg(feature = "libtorch")]
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
use burn::{
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
data::dataloader::DataLoaderBuilder,
lr_scheduler::constant::ConstantLr,
module::Module,
optim::AdamWConfig,
record::{CompactRecorder, DefaultRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::{HammingScore, LossMetric},
LearnerBuilder,
},
};
use log::info;
use scra_mirach_dataset::batcher::MirachTrainBatcher;
use crate::{Args, Backend};
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about)]
pub struct TrainArgs {
#[arg(from_global)]
pub backend: Backend,
#[cfg(feature = "libtorch")]
#[arg(from_global)]
pub cuda_index: Option<usize>,
#[arg(from_global)]
pub model_config: Option<PathBuf>,
#[arg(long, default_value_t = 42)]
pub seed: u64,
#[arg(short = 'o', long, help = "Output directory")]
pub output: PathBuf,
#[arg(long = "continue", help = "Load existing parameters")]
pub load_existing: bool,
#[arg(short = 'e', long, default_value_t = 10)]
pub epochs: usize,
#[arg(short = 's', long, default_value_t = 64)]
pub batch_size: usize,
#[arg(
short = 'j',
long,
default_value_t = 4,
help = "Number of dataloader workers"
)]
pub workers: usize,
#[arg(long, default_value_t = 1.0e-4)]
pub learning_rate: f64,
}
pub fn run(args: TrainArgs) -> Result<()> {
match args.backend {
Backend::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
#[cfg(feature = "libtorch")]
Backend::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
Backend::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
#[cfg(feature = "libtorch")]
Backend::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
#[cfg(feature = "libtorch")]
Backend::LibTorchCuda => {
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
train::<Autodiff<LibTorch>>(args, device)
}
}
}
pub fn train<B: AutodiffBackend>(args: TrainArgs, device: B::Device) -> Result<()> {
fs::create_dir_all(&args.output)?;
B::seed(args.seed);
let model_config = Args::model_config(&args.model_config)?;
fs::write(
args.output.join("config.json"),
serde_json::to_string(&model_config)?,
)?;
let model = model_config.init(&device);
fs::write(args.output.join("model.txt"), format!("{:#?}", model))?;
let params_path = args.output.join("params");
let model = if args.load_existing {
info!("Loading parameters from file ...");
model.load_file(&params_path, &DefaultRecorder::new(), &device)?
} else {
model
};
info!("Model created, totally {} parameters", model.num_params());
let optimizer = AdamWConfig::new().init();
let lr_sched = ConstantLr::new(args.learning_rate);
let batcher_train = MirachTrainBatcher::<B>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(args.batch_size)
.shuffle(args.seed)
.num_workers(args.workers)
.build(scra_mirach_dataset::load_train()?);
let batcher_valid = MirachTrainBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
.batch_size(args.batch_size / 2)
.shuffle(args.seed)
.num_workers(args.workers)
.build(scra_mirach_dataset::load_validate()?);
let learner = LearnerBuilder::new(args.output)
.metric_train_numeric(HammingScore::new())
.metric_valid_numeric(HammingScore::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(args.epochs)
// .renderer(MirachMetricsRenderer)
// .with_application_logger(None)
.summary()
.build(model, optimizer, lr_sched);
let start_time = Instant::now();
let model_trained = learner.fit(dataloader_train, dataloader_valid);
let end_time = Instant::now();
let time = end_time - start_time;
info!("Train completed, took {:?}", time);
model_trained.save_file(params_path, &DefaultRecorder::new())?;
info!("Full-precision model has been saved");
Ok(())
}

View file

@ -15,7 +15,7 @@ base64 = "0.22.1"
clap = { workspace = true }
env_logger = "0.11.5"
fantoccini = "0.21.1"
log = { version = "0.4.22", features = ["std"] }
log = { workspace = true }
reqwest = "0.12.7"
serde_json = "1.0.127"
tokio = { workspace = true, features = ["full"] }

View file

@ -9,9 +9,10 @@ repository.workspace = true
license.workspace = true
[dependencies]
burn = { workspace = true, features = ["vision"] }
burn = { workspace = true, features = ["dataset"] }
tar = { version = "0.4.41", default-features = false }
thiserror = { workspace = true }
zstd = { version = "0.13.2", default-features = false, features = ["arrays"] }
scra-mirach-model = { path = "../model" }
image = { version = "0.25.2", default-features = false, features = ["png"] }
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
image = { workspace = true }
log = { workspace = true }

68
dataset/src/batcher.rs Normal file
View file

@ -0,0 +1,68 @@
use std::fmt::Debug;
use burn::{
data::dataloader::batcher::Batcher,
prelude::Backend,
tensor::{Float, Int, Shape, Tensor, TensorData},
};
use scra_mirach_model::{
data::{CaptchaCode, MirachImageData},
train::MirachTrainBatch,
IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH,
};
/// Mirach dataset item
#[derive(Clone, PartialEq, Eq)]
pub struct MirachImageItem {
/// The label of image
pub code: CaptchaCode,
/// The colors of image, Tensor (channels, height, width)
pub image: MirachImageData,
}
impl Debug for MirachImageItem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MirachImageItem")
.field("code", &self.code)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct MirachTrainBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> MirachTrainBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}
impl<B: Backend> Batcher<MirachImageItem, MirachTrainBatch<B>> for MirachTrainBatcher<B> {
fn batch(&self, items: Vec<MirachImageItem>) -> MirachTrainBatch<B> {
let mut images = Vec::with_capacity(items.len());
let mut targets = Vec::with_capacity(items.len());
for item in items {
let image = TensorData::new(
item.image.unwrap(),
Shape::new([1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]),
);
let image = Tensor::<B, 4, Float>::from(image.convert::<f32>()).to_device(&self.device);
let image = image / 255;
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
images.push(image);
let target = Tensor::<B, 1, Int>::from(TensorData::from(item.code.as_digits()))
.to_device(&self.device);
let target = target.unsqueeze_dim(0);
debug_assert_eq!(target.dims(), [1, 4]);
targets.push(target);
}
let images = Tensor::cat(images, 0);
let targets = Tensor::cat(targets, 0);
MirachTrainBatch { images, targets }
}
}

View file

@ -1,11 +1,18 @@
use std::{
io::{self, Cursor, Read},
string::FromUtf8Error,
time::Instant,
};
use batcher::MirachImageItem;
use burn::data::dataset::Dataset;
use image::{ImageError, ImageFormat};
use scra_mirach_model::CaptchaCode;
use log::info;
use scra_mirach_model::data::MirachImageData;
use tar::EntryType;
pub mod batcher;
pub use scra_mirach_model::{data::CaptchaCode, IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
/// Dataset release version
pub const DATASET_REL: u8 = 1;
@ -13,8 +20,6 @@ const TRAIN_DATASET: &[u8] = include_bytes!("../assets/SCRA: Mirach Dataset REL
const VALIDATE_DATASET: &[u8] =
include_bytes!("../assets/SCRA: Mirach Dataset REL 1 (validate).tar.zst");
pub use scra_mirach_model::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
pub fn load_train() -> Result<MirachDataset, Error> {
load_from(TRAIN_DATASET)
}
@ -24,23 +29,30 @@ pub fn load_validate() -> Result<MirachDataset, Error> {
}
pub fn load_from(dataset: &[u8]) -> Result<MirachDataset, Error> {
let start_time = Instant::now();
let zstd = zstd::decode_all(dataset).map_err(Error::Zstd)?;
let mut tar = tar::Archive::new(zstd.as_slice());
let mut samples = Vec::new();
for file in tar.entries().map_err(Error::Tar)? {
let mut file = file.map_err(Error::Tar)?;
match file.header().entry_type() {
tar::EntryType::Regular => {}
tar::EntryType::Directory => continue,
ty @ _ => return Err(Error::DisallowedTarEntryType(ty)),
}
let path = String::from_utf8(file.header().path_bytes().to_vec())?;
let path_bytes = path.as_bytes();
if path.ends_with(".png")
&& path_bytes.len() >= 9
&& path_bytes[4] == b'_'
&& path[0..4].chars().all(|c| c.is_numeric())
if path.starts_with("./")
&& path.ends_with(".png")
&& path_bytes.len() >= 11
&& path_bytes[6] == b'_'
&& path[2..6].chars().all(|c| c.is_numeric())
{
let code = ((path_bytes[0] ^ b'0') as u16 * 1000)
+ ((path_bytes[1] ^ b'0') as u16 * 100)
+ ((path_bytes[2] ^ b'0') as u16 * 10)
+ ((path_bytes[3] ^ b'0') as u16);
let code = ((path_bytes[2] ^ b'0') as u16 * 1000)
+ ((path_bytes[3] ^ b'0') as u16 * 100)
+ ((path_bytes[4] ^ b'0') as u16 * 10)
+ ((path_bytes[5] ^ b'0') as u16);
let code = CaptchaCode::new_unchecked(code);
let mut image = Vec::new();
@ -53,34 +65,37 @@ pub fn load_from(dataset: &[u8]) -> Result<MirachDataset, Error> {
image.len() as usize,
IMAGE_WIDTH * IMAGE_HEIGHT * IMAGE_CHANNELS
);
let image = image.into_raw();
samples.push(MirachImageItem {
code,
image: image.try_into().expect("Invalid size of CAPTCHA image"),
image: MirachImageData::from(image),
});
} else {
return Err(Error::InvalidPath { path });
}
}
info!(
"Loaded {} samples from dataset in {:?}",
samples.len(),
start_time.elapsed()
);
Ok(MirachDataset { samples })
}
/// Mirach dataset errors
#[derive(thiserror::Error, Debug)]
pub enum Error {
/// Zstandard decompression errors
#[error("Zstandard error")]
#[error("Zstandard IO error: {0}")]
Zstd(io::Error),
/// Tar archive errors
#[error("tar error")]
#[error("tar IO error: {0}")]
Tar(io::Error),
/// Image-parsing-related errors
#[error("image error")]
#[error("disallowed tar entry type: {0:?}")]
DisallowedTarEntryType(EntryType),
#[error("image parsing error: {0}")]
ImageError(#[from] ImageError),
/// Tar entry file path is not valid UTF-8
#[error("tar entry with non-UTF-8 path")]
#[error("tar entry with non-UTF-8 path: {0}")]
NonUtf8Path(#[from] FromUtf8Error),
/// Invalid path found in tar archive
#[error("invalid path")]
#[error("invalid path: {path}")]
InvalidPath { path: String },
}
@ -90,15 +105,6 @@ pub struct MirachDataset {
samples: Vec<MirachImageItem>,
}
/// Mirach dataset item
#[derive(Clone, PartialEq, Eq)]
pub struct MirachImageItem {
/// The label of image
pub code: CaptchaCode,
/// The colors of image, Tensor (height, width, channels)
pub image: [u8; IMAGE_WIDTH * IMAGE_HEIGHT * IMAGE_CHANNELS],
}
impl Dataset<MirachImageItem> for MirachDataset {
fn get(&self, index: usize) -> Option<MirachImageItem> {
self.samples.get(index).cloned()

View file

@ -1,7 +1,7 @@
[package]
name = "scra-mirach-model"
version = "0.1.0"
description = "SCRA: Mirach - Model definition"
description = "SCRA: Mirach - Model"
edition.workspace = true
authors.workspace = true
homepage.workspace = true
@ -9,4 +9,11 @@ repository.workspace = true
license.workspace = true
[dependencies]
burn = { workspace = true, features = ["tui", "wgpu", "train"] }
burn = { workspace = true }
image = { workspace = true, optional = true }
serde = { workspace = true }
[features]
default = ["image"]
image = ["dep:image"]
train = ["burn/train", "burn/metrics"]

74
model/src/data.rs Normal file
View file

@ -0,0 +1,74 @@
use std::mem::MaybeUninit;
#[cfg(feature = "image")]
use image::RgbImage;
use crate::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct CaptchaCode(u16);
impl CaptchaCode {
#[inline]
pub fn new(code: u16) -> Self {
assert!(code < 10000);
Self(code)
}
#[inline]
pub fn new_unchecked(code: u16) -> Self {
Self(code)
}
#[inline]
pub fn as_u16(&self) -> u16 {
self.0
}
#[inline]
pub fn as_digits(&self) -> [u8; 4] {
[
(self.0 / 1000) as u8,
((self.0 / 100) % 10) as u8,
((self.0 / 10) % 10) as u8,
(self.0 % 10) as u8,
]
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct MirachImageData(Vec<u8>);
impl MirachImageData {
pub fn new(pixels: Vec<u8>) -> Self {
Self(pixels)
}
pub fn unwrap(self) -> Vec<u8> {
self.0
}
}
#[cfg(feature = "image")]
impl From<RgbImage> for MirachImageData {
fn from(image: RgbImage) -> Self {
// SAFETY: The array will be filled are allocation immediately,
// and no read will be performed before all data are fully written.
// Therefore no data can be leaked.
#[allow(invalid_value)]
let mut data: Box<[u8; IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_CHANNELS]> =
unsafe { Box::new(MaybeUninit::uninit().assume_init()) };
for x in 0..IMAGE_WIDTH {
let offset = x * IMAGE_HEIGHT;
for y in 0..IMAGE_HEIGHT {
let offset = offset + y;
let pixel = image.get_pixel(x as u32, y as u32);
data[(0 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[0];
data[(1 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[1];
data[(2 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[2];
}
}
Self(Vec::from(data as Box<[u8]>))
}
}

View file

@ -1,31 +1,221 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct CaptchaCode(u16);
use burn::{
config::Config,
module::Module,
nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, MaxPool2d, MaxPool2dConfig},
BatchNorm, BatchNormConfig, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear,
LinearConfig, Lstm, LstmConfig, LstmState,
},
prelude::Backend,
tensor::{activation::softmax, Tensor},
};
#[cfg(feature = "train")]
use burn::{nn::loss::CrossEntropyLossConfig, prelude::Int};
impl CaptchaCode {
#[inline]
pub fn new(code: u16) -> Self {
assert!(code < 10000);
Self(code)
}
pub mod data;
#[cfg(feature = "train")]
pub mod train;
#[inline]
pub fn new_unchecked(code: u16) -> Self {
Self(code)
}
/// Sample image width, in pixels, currently 110
pub const IMAGE_WIDTH: usize = 110;
/// Sample image height, in pixels, currently 50
pub const IMAGE_HEIGHT: usize = 50;
/// Sample image color channels, currently 3 (RGB)
pub const IMAGE_CHANNELS: usize = 3;
/// Digits count of image
pub const DIGIT_COUNT: usize = 4;
#[inline]
pub fn as_u16(&self) -> u16 {
self.0
}
#[derive(Module, Debug)]
pub struct MirachModel<B: Backend> {
conv: Conv2d<B>,
bn1: BatchNorm<B, 2>,
activation: LeakyRelu,
pool: MaxPool2d,
lstm_gf1_linear: Linear<B>,
gf_dropout: Dropout,
lstm_gf2_pool: AdaptiveAvgPool2d,
attention: MultiHeadAttention<B>,
bn2: BatchNorm<B, 2>,
lstm1_forward: Lstm<B>,
lstm1_backward: Lstm<B>,
lstm2_forward: Lstm<B>,
lstm2_backward: Lstm<B>,
out_linear: Linear<B>,
}
#[inline]
pub fn as_digits(&self) -> [u8; 4] {
[
(self.0 / 1000) as u8,
((self.0 / 100) % 10) as u8,
((self.0 / 10) % 10) as u8,
(self.0 % 10) as u8,
]
#[derive(Config, Debug)]
pub struct MirachModelConfig {
#[config(default = "10")]
num_classes: usize,
#[config(default = "4")]
num_heads: usize,
#[config(default = "64")]
hidden_size1: usize,
#[config(default = "32")]
hidden_size2: usize,
#[config(default = "0.05")]
att_dropout: f64,
#[config(default = "0.03")]
gf_dropout: f64,
}
impl MirachModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MirachModel<B> {
MirachModel {
conv: Conv2dConfig::new([3, 4], [3, 3]).init(device),
bn1: BatchNormConfig::new(4).init(device),
activation: LeakyReluConfig::new().init(),
pool: MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init(),
lstm_gf1_linear: LinearConfig::new(54 * 24, self.hidden_size1).init(device),
gf_dropout: DropoutConfig::new(self.gf_dropout).init(),
lstm_gf2_pool: AdaptiveAvgPool2dConfig::new([1, self.hidden_size2]).init(),
attention: MultiHeadAttentionConfig::new(24, self.num_heads)
.with_dropout(self.att_dropout)
.init(device),
bn2: BatchNormConfig::new(1).init(device),
lstm1_forward: LstmConfig::new(24, self.hidden_size1, true).init(device),
lstm1_backward: LstmConfig::new(24, self.hidden_size1, true).init(device),
lstm2_forward: LstmConfig::new(self.hidden_size1, self.hidden_size2, true).init(device),
lstm2_backward: LstmConfig::new(self.hidden_size1, self.hidden_size2, true)
.init(device),
out_linear: LinearConfig::new(self.hidden_size2 * 2, self.num_heads * self.num_classes)
.init(device),
}
}
}
impl<B: Backend> MirachModel<B> {
/// # Shapes
/// - Images [batch_size, channels, width, height]
/// - Output [batch_size, length, n_classes]
pub fn forward(&self, images: Tensor<B, 4>) -> Tensor<B, 3> {
let [batch_size, _channels, _width, _height] = images.dims();
let device = images.device();
let x = images;
debug_assert_eq!(x.dims(), [batch_size, 3, 110, 50]);
let x = self.conv.forward(x);
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
let x = self.bn1.forward(x);
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
let x = self.activation.forward(x);
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
let x = self.pool.forward(x);
debug_assert_eq!(x.dims(), [batch_size, 4, 54, 24]);
let [x1, x2, x3, x4] = x
.iter_dim(1)
.map(|x| x.squeeze(1))
.collect::<Vec<Tensor<B, 3>>>()
.try_into()
.unwrap();
debug_assert_eq!(x1.dims(), [batch_size, 54, 24]);
debug_assert_eq!(x2.dims(), [batch_size, 54, 24]);
debug_assert_eq!(x3.dims(), [batch_size, 54, 24]);
debug_assert_eq!(x4.dims(), [batch_size, 54, 24]);
// GF
let att_gf = self.gf_dropout.forward(x3);
debug_assert_eq!(att_gf.dims(), [batch_size, 54, 24]);
let lstm_gf1 = x4.reshape([batch_size, 54 * 24]);
let lstm_gf1 = self.lstm_gf1_linear.forward(lstm_gf1);
let lstm_gf1 = self.gf_dropout.forward(lstm_gf1);
let hidden_size1 = lstm_gf1.dims()[1];
let lstm_gf2 = lstm_gf1.clone().unsqueeze_dims(&[1, 2]);
let lstm_gf2 = self.lstm_gf2_pool.forward(lstm_gf2);
let hidden_size2 = lstm_gf2.dims()[3];
let lstm_gf2 = lstm_gf2.reshape([batch_size, hidden_size2]);
// MHA
let x = self.attention.forward(MhaInput::new(x1, att_gf, x2));
let n_heads = x.weights.dims()[1];
let x = x.context;
debug_assert_eq!(x.dims(), [batch_size, 54, 24]);
let x = self.bn2.forward(x.unsqueeze_dim::<4>(1)).squeeze(1);
debug_assert_eq!(x.dims(), [batch_size, 54, 24]);
// BiLSTM 1
let (lstm_f, _) = self.lstm1_forward.forward(
x.clone(),
Some(LstmState::new(
Tensor::zeros([batch_size, hidden_size1], &device),
lstm_gf1.clone(),
)),
);
let (lstm_b, _) = self.lstm1_backward.forward(
x.flip([1]),
Some(LstmState::new(
Tensor::zeros([batch_size, hidden_size1], &device),
lstm_gf1.clone(),
)),
);
let x = lstm_f + lstm_b;
debug_assert_eq!(x.dims(), [batch_size, 54, hidden_size1]);
// BiLSTM 2
let (lstm_f, _) = self.lstm2_forward.forward(
x.clone(),
Some(LstmState::new(
Tensor::zeros([batch_size, hidden_size2], &device),
lstm_gf2.clone(),
)),
);
let lstm_f = lstm_f
.slice([0..batch_size, 53..54, 0..hidden_size2])
.squeeze::<2>(1);
let (lstm_b, _) = self.lstm2_backward.forward(
x.flip([1]),
Some(LstmState::new(
Tensor::zeros([batch_size, hidden_size2], &device),
lstm_gf2.clone(),
)),
);
let lstm_b = lstm_b
.slice([0..batch_size, 53..54, 0..hidden_size2])
.squeeze::<2>(1);
let x = Tensor::cat(vec![lstm_f, lstm_b], 1);
debug_assert_eq!(x.dims(), [batch_size, hidden_size2 * 2]);
let x = self.out_linear.forward(x);
let n_classes = x.dims()[1] / n_heads;
debug_assert_eq!(x.dims(), [batch_size, n_heads * n_classes]);
let x = x.reshape([batch_size, n_heads, n_classes]);
let x = softmax(x, 2);
x
}
#[cfg(feature = "train")]
pub fn forward_recognition(
&self,
images: Tensor<B, 4>,
targets: Tensor<B, 2, Int>,
) -> train::MirachOutput<B> {
use train::MirachOutput;
let output = self.forward(images);
let device = output.device();
let batch_size = output.dims()[0];
let loss_func = CrossEntropyLossConfig::new()
.with_logits(true)
.init(&device);
let loss = output
.clone()
.iter_dim(1)
.zip(targets.clone().iter_dim(1))
.map(|(output, target)| loss_func.forward(output.squeeze(1), target.squeeze(1)))
.fold(Tensor::zeros([batch_size], &device), |sum, loss| sum + loss);
MirachOutput {
output,
targets,
loss,
}
}
}

66
model/src/train.rs Normal file
View file

@ -0,0 +1,66 @@
use burn::{
prelude::Backend,
tensor::{backend::AutodiffBackend, Int, Tensor},
train::{
metric::{Adaptor, HammingScoreInput, LossInput},
TrainOutput, TrainStep, ValidStep,
},
};
use crate::MirachModel;
/// A batch for trainning
#[derive(Clone, Debug)]
pub struct MirachTrainBatch<B: Backend> {
/// Input images, in [batch size, channels (=== 3 / RGB), width (=== 110), height (=== 50)]
pub images: Tensor<B, 4>,
/// Image labels as logits, in [batch size, length (=== 4)]
pub targets: Tensor<B, 2, Int>,
}
impl<B: AutodiffBackend> TrainStep<MirachTrainBatch<B>, MirachOutput<B>> for MirachModel<B> {
fn step(&self, batch: MirachTrainBatch<B>) -> TrainOutput<MirachOutput<B>> {
let item = self.forward_recognition(batch.images, batch.targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<MirachTrainBatch<B>, MirachOutput<B>> for MirachModel<B> {
fn step(&self, batch: MirachTrainBatch<B>) -> MirachOutput<B> {
self.forward_recognition(batch.images, batch.targets)
}
}
pub struct MirachOutput<B: Backend> {
/// The output, in [batch_size, length, n_digits]
pub output: Tensor<B, 3>,
/// The targets, in [batch_size, length]
pub targets: Tensor<B, 2, Int>,
/// The cross-entropy loss, in [batch_size]
pub loss: Tensor<B, 1>,
}
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MirachOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
let batch_size = self.output.dims()[0];
let outputs = self
.output
.clone()
.argmax(2)
.squeeze(2)
.equal(self.targets.clone())
.float();
HammingScoreInput::new(
outputs,
Tensor::<B, 2, Int>::ones([batch_size, 4], &self.output.device()),
)
}
}
impl<B: Backend> Adaptor<LossInput<B>> for MirachOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}