Init model
This commit is contained in:
parent
8d74bff9e6
commit
11d2882c3f
14 changed files with 1608 additions and 550 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,2 +1,3 @@
|
|||
/target
|
||||
*.png
|
||||
/output
|
||||
|
|
1380
Cargo.lock
generated
1380
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,9 +1,9 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
members = [ "cli",
|
||||
"dataset",
|
||||
"dataset-collector",
|
||||
"model"
|
||||
"model",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
@ -14,8 +14,11 @@ repository = "https://codeberg.org/xtex/scra"
|
|||
license = "Apache-2.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
burn = { version = "0.13.2" }
|
||||
burn = { version = "0.14.0", git = "https://github.com/tracel-ai/burn" }
|
||||
clap = { version = "4.5.16", features = ["derive"] }
|
||||
tokio = { version = "1.39.3" }
|
||||
anyhow = { version = "1.0.86", features = ["backtrace"] }
|
||||
thiserror = { version = "1.0.63" }
|
||||
image = { version = "0.25.2", default-features = false, features = ["png"] }
|
||||
log = { version = "0.4.22", features = ["std"] }
|
||||
serde = { version = "1.0.209", default-features = false, features = ["derive"] }
|
||||
|
|
23
cli/Cargo.toml
Normal file
23
cli/Cargo.toml
Normal file
|
@ -0,0 +1,23 @@
|
|||
[package]
|
||||
name = "scra-mirach-cli"
|
||||
version = "0.1.0"
|
||||
description = "SCRA: Mirach - Command-line utilities"
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
repository.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui"] }
|
||||
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
|
||||
scra-mirach-dataset = { path = "../dataset" }
|
||||
anyhow = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
log = { workspace = true }
|
||||
env_logger = "0.11.5"
|
||||
serde = { workspace = true }
|
||||
serde_json = "1.0.127"
|
||||
|
||||
[features]
|
||||
libtorch = ["burn/tch"]
|
73
cli/src/main.rs
Normal file
73
cli/src/main.rs
Normal file
|
@ -0,0 +1,73 @@
|
|||
use std::{env, fs, path::PathBuf};
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use scra_mirach_model::MirachModelConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use train::TrainArgs;
|
||||
|
||||
pub mod train;
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about)]
|
||||
struct Args {
|
||||
#[command(subcommand)]
|
||||
action: Action,
|
||||
|
||||
#[arg(short='b', long, value_enum,global=true, default_value_t = Backend::Wgpu, help="Tensor backend")]
|
||||
pub backend: Backend,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(
|
||||
long,
|
||||
global = true,
|
||||
required_if_eq("backend", "Backend::LibTorchCuda"),
|
||||
help = "GPU index to use for libtorch CUDA backend"
|
||||
)]
|
||||
pub cuda_index: Option<usize>,
|
||||
|
||||
#[arg(short = 'c', long, global = true, help = "Model configuration file")]
|
||||
pub model_config: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, clap::Subcommand)]
|
||||
enum Action {
|
||||
Train(TrainArgs),
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum Backend {
|
||||
Wgpu,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchCpu,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchMps,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchVulkan,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchCuda,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
pub fn model_config(path: &Option<PathBuf>) -> Result<MirachModelConfig> {
|
||||
match path {
|
||||
Some(path) => Ok(serde_json::from_str(&fs::read_to_string(path)?)?),
|
||||
None => Ok(MirachModelConfig::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
env_logger::builder()
|
||||
.format_timestamp_secs()
|
||||
.parse_env(env_logger::Env::new().default_filter_or(
|
||||
"scra_mirach_cli=info,scra_mirach_model=info,scra_mirach_dataset=info,burn_train=warn",
|
||||
))
|
||||
.target(env_logger::Target::Stdout)
|
||||
.try_init()?;
|
||||
env::set_var("RUST_BACKTRACE", "true");
|
||||
|
||||
match args.action {
|
||||
Action::Train(args) => train::run(args),
|
||||
}
|
||||
}
|
134
cli/src/train.rs
Normal file
134
cli/src/train.rs
Normal file
|
@ -0,0 +1,134 @@
|
|||
use std::{fs, path::PathBuf, time::Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
#[cfg(feature = "libtorch")]
|
||||
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
|
||||
use burn::{
|
||||
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
lr_scheduler::constant::ConstantLr,
|
||||
module::Module,
|
||||
optim::AdamWConfig,
|
||||
record::{CompactRecorder, DefaultRecorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{HammingScore, LossMetric},
|
||||
LearnerBuilder,
|
||||
},
|
||||
};
|
||||
use log::info;
|
||||
use scra_mirach_dataset::batcher::MirachTrainBatcher;
|
||||
|
||||
use crate::{Args, Backend};
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about)]
|
||||
pub struct TrainArgs {
|
||||
#[arg(from_global)]
|
||||
pub backend: Backend,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(from_global)]
|
||||
pub cuda_index: Option<usize>,
|
||||
#[arg(from_global)]
|
||||
pub model_config: Option<PathBuf>,
|
||||
|
||||
#[arg(long, default_value_t = 42)]
|
||||
pub seed: u64,
|
||||
#[arg(short = 'o', long, help = "Output directory")]
|
||||
pub output: PathBuf,
|
||||
#[arg(long = "continue", help = "Load existing parameters")]
|
||||
pub load_existing: bool,
|
||||
#[arg(short = 'e', long, default_value_t = 10)]
|
||||
pub epochs: usize,
|
||||
#[arg(short = 's', long, default_value_t = 64)]
|
||||
pub batch_size: usize,
|
||||
#[arg(
|
||||
short = 'j',
|
||||
long,
|
||||
default_value_t = 4,
|
||||
help = "Number of dataloader workers"
|
||||
)]
|
||||
pub workers: usize,
|
||||
#[arg(long, default_value_t = 1.0e-4)]
|
||||
pub learning_rate: f64,
|
||||
}
|
||||
|
||||
pub fn run(args: TrainArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
Backend::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchCuda => {
|
||||
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
|
||||
train::<Autodiff<LibTorch>>(args, device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn train<B: AutodiffBackend>(args: TrainArgs, device: B::Device) -> Result<()> {
|
||||
fs::create_dir_all(&args.output)?;
|
||||
B::seed(args.seed);
|
||||
|
||||
let model_config = Args::model_config(&args.model_config)?;
|
||||
fs::write(
|
||||
args.output.join("config.json"),
|
||||
serde_json::to_string(&model_config)?,
|
||||
)?;
|
||||
let model = model_config.init(&device);
|
||||
fs::write(args.output.join("model.txt"), format!("{:#?}", model))?;
|
||||
let params_path = args.output.join("params");
|
||||
let model = if args.load_existing {
|
||||
info!("Loading parameters from file ...");
|
||||
model.load_file(¶ms_path, &DefaultRecorder::new(), &device)?
|
||||
} else {
|
||||
model
|
||||
};
|
||||
info!("Model created, totally {} parameters", model.num_params());
|
||||
|
||||
let optimizer = AdamWConfig::new().init();
|
||||
let lr_sched = ConstantLr::new(args.learning_rate);
|
||||
|
||||
let batcher_train = MirachTrainBatcher::<B>::new(device.clone());
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(args.batch_size)
|
||||
.shuffle(args.seed)
|
||||
.num_workers(args.workers)
|
||||
.build(scra_mirach_dataset::load_train()?);
|
||||
|
||||
let batcher_valid = MirachTrainBatcher::<B::InnerBackend>::new(device.clone());
|
||||
let dataloader_valid = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(args.batch_size / 2)
|
||||
.shuffle(args.seed)
|
||||
.num_workers(args.workers)
|
||||
.build(scra_mirach_dataset::load_validate()?);
|
||||
|
||||
let learner = LearnerBuilder::new(args.output)
|
||||
.metric_train_numeric(HammingScore::new())
|
||||
.metric_valid_numeric(HammingScore::new())
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(args.epochs)
|
||||
// .renderer(MirachMetricsRenderer)
|
||||
// .with_application_logger(None)
|
||||
.summary()
|
||||
.build(model, optimizer, lr_sched);
|
||||
|
||||
let start_time = Instant::now();
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_valid);
|
||||
let end_time = Instant::now();
|
||||
|
||||
let time = end_time - start_time;
|
||||
|
||||
info!("Train completed, took {:?}", time);
|
||||
model_trained.save_file(params_path, &DefaultRecorder::new())?;
|
||||
info!("Full-precision model has been saved");
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -15,7 +15,7 @@ base64 = "0.22.1"
|
|||
clap = { workspace = true }
|
||||
env_logger = "0.11.5"
|
||||
fantoccini = "0.21.1"
|
||||
log = { version = "0.4.22", features = ["std"] }
|
||||
log = { workspace = true }
|
||||
reqwest = "0.12.7"
|
||||
serde_json = "1.0.127"
|
||||
tokio = { workspace = true, features = ["full"] }
|
||||
|
|
|
@ -9,9 +9,10 @@ repository.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = { workspace = true, features = ["vision"] }
|
||||
burn = { workspace = true, features = ["dataset"] }
|
||||
tar = { version = "0.4.41", default-features = false }
|
||||
thiserror = { workspace = true }
|
||||
zstd = { version = "0.13.2", default-features = false, features = ["arrays"] }
|
||||
scra-mirach-model = { path = "../model" }
|
||||
image = { version = "0.25.2", default-features = false, features = ["png"] }
|
||||
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
|
||||
image = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
|
68
dataset/src/batcher.rs
Normal file
68
dataset/src/batcher.rs
Normal file
|
@ -0,0 +1,68 @@
|
|||
use std::fmt::Debug;
|
||||
|
||||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
prelude::Backend,
|
||||
tensor::{Float, Int, Shape, Tensor, TensorData},
|
||||
};
|
||||
use scra_mirach_model::{
|
||||
data::{CaptchaCode, MirachImageData},
|
||||
train::MirachTrainBatch,
|
||||
IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH,
|
||||
};
|
||||
|
||||
/// Mirach dataset item
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MirachImageItem {
|
||||
/// The label of image
|
||||
pub code: CaptchaCode,
|
||||
/// The colors of image, Tensor (channels, height, width)
|
||||
pub image: MirachImageData,
|
||||
}
|
||||
|
||||
impl Debug for MirachImageItem {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("MirachImageItem")
|
||||
.field("code", &self.code)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MirachTrainBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> MirachTrainBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MirachImageItem, MirachTrainBatch<B>> for MirachTrainBatcher<B> {
|
||||
fn batch(&self, items: Vec<MirachImageItem>) -> MirachTrainBatch<B> {
|
||||
let mut images = Vec::with_capacity(items.len());
|
||||
let mut targets = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
let image = TensorData::new(
|
||||
item.image.unwrap(),
|
||||
Shape::new([1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]),
|
||||
);
|
||||
let image = Tensor::<B, 4, Float>::from(image.convert::<f32>()).to_device(&self.device);
|
||||
let image = image / 255;
|
||||
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
|
||||
images.push(image);
|
||||
|
||||
let target = Tensor::<B, 1, Int>::from(TensorData::from(item.code.as_digits()))
|
||||
.to_device(&self.device);
|
||||
let target = target.unsqueeze_dim(0);
|
||||
debug_assert_eq!(target.dims(), [1, 4]);
|
||||
targets.push(target);
|
||||
}
|
||||
|
||||
let images = Tensor::cat(images, 0);
|
||||
let targets = Tensor::cat(targets, 0);
|
||||
|
||||
MirachTrainBatch { images, targets }
|
||||
}
|
||||
}
|
|
@ -1,11 +1,18 @@
|
|||
use std::{
|
||||
io::{self, Cursor, Read},
|
||||
string::FromUtf8Error,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use batcher::MirachImageItem;
|
||||
use burn::data::dataset::Dataset;
|
||||
use image::{ImageError, ImageFormat};
|
||||
use scra_mirach_model::CaptchaCode;
|
||||
use log::info;
|
||||
use scra_mirach_model::data::MirachImageData;
|
||||
use tar::EntryType;
|
||||
|
||||
pub mod batcher;
|
||||
pub use scra_mirach_model::{data::CaptchaCode, IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
|
||||
|
||||
/// Dataset release version
|
||||
pub const DATASET_REL: u8 = 1;
|
||||
|
@ -13,8 +20,6 @@ const TRAIN_DATASET: &[u8] = include_bytes!("../assets/SCRA: Mirach Dataset REL
|
|||
const VALIDATE_DATASET: &[u8] =
|
||||
include_bytes!("../assets/SCRA: Mirach Dataset REL 1 (validate).tar.zst");
|
||||
|
||||
pub use scra_mirach_model::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
|
||||
|
||||
pub fn load_train() -> Result<MirachDataset, Error> {
|
||||
load_from(TRAIN_DATASET)
|
||||
}
|
||||
|
@ -24,23 +29,30 @@ pub fn load_validate() -> Result<MirachDataset, Error> {
|
|||
}
|
||||
|
||||
pub fn load_from(dataset: &[u8]) -> Result<MirachDataset, Error> {
|
||||
let start_time = Instant::now();
|
||||
let zstd = zstd::decode_all(dataset).map_err(Error::Zstd)?;
|
||||
let mut tar = tar::Archive::new(zstd.as_slice());
|
||||
|
||||
let mut samples = Vec::new();
|
||||
for file in tar.entries().map_err(Error::Tar)? {
|
||||
let mut file = file.map_err(Error::Tar)?;
|
||||
match file.header().entry_type() {
|
||||
tar::EntryType::Regular => {}
|
||||
tar::EntryType::Directory => continue,
|
||||
ty @ _ => return Err(Error::DisallowedTarEntryType(ty)),
|
||||
}
|
||||
let path = String::from_utf8(file.header().path_bytes().to_vec())?;
|
||||
let path_bytes = path.as_bytes();
|
||||
if path.ends_with(".png")
|
||||
&& path_bytes.len() >= 9
|
||||
&& path_bytes[4] == b'_'
|
||||
&& path[0..4].chars().all(|c| c.is_numeric())
|
||||
if path.starts_with("./")
|
||||
&& path.ends_with(".png")
|
||||
&& path_bytes.len() >= 11
|
||||
&& path_bytes[6] == b'_'
|
||||
&& path[2..6].chars().all(|c| c.is_numeric())
|
||||
{
|
||||
let code = ((path_bytes[0] ^ b'0') as u16 * 1000)
|
||||
+ ((path_bytes[1] ^ b'0') as u16 * 100)
|
||||
+ ((path_bytes[2] ^ b'0') as u16 * 10)
|
||||
+ ((path_bytes[3] ^ b'0') as u16);
|
||||
let code = ((path_bytes[2] ^ b'0') as u16 * 1000)
|
||||
+ ((path_bytes[3] ^ b'0') as u16 * 100)
|
||||
+ ((path_bytes[4] ^ b'0') as u16 * 10)
|
||||
+ ((path_bytes[5] ^ b'0') as u16);
|
||||
let code = CaptchaCode::new_unchecked(code);
|
||||
|
||||
let mut image = Vec::new();
|
||||
|
@ -53,34 +65,37 @@ pub fn load_from(dataset: &[u8]) -> Result<MirachDataset, Error> {
|
|||
image.len() as usize,
|
||||
IMAGE_WIDTH * IMAGE_HEIGHT * IMAGE_CHANNELS
|
||||
);
|
||||
let image = image.into_raw();
|
||||
|
||||
samples.push(MirachImageItem {
|
||||
code,
|
||||
image: image.try_into().expect("Invalid size of CAPTCHA image"),
|
||||
image: MirachImageData::from(image),
|
||||
});
|
||||
} else {
|
||||
return Err(Error::InvalidPath { path });
|
||||
}
|
||||
}
|
||||
info!(
|
||||
"Loaded {} samples from dataset in {:?}",
|
||||
samples.len(),
|
||||
start_time.elapsed()
|
||||
);
|
||||
Ok(MirachDataset { samples })
|
||||
}
|
||||
|
||||
/// Mirach dataset errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
/// Zstandard decompression errors
|
||||
#[error("Zstandard error")]
|
||||
#[error("Zstandard IO error: {0}")]
|
||||
Zstd(io::Error),
|
||||
/// Tar archive errors
|
||||
#[error("tar error")]
|
||||
#[error("tar IO error: {0}")]
|
||||
Tar(io::Error),
|
||||
/// Image-parsing-related errors
|
||||
#[error("image error")]
|
||||
#[error("disallowed tar entry type: {0:?}")]
|
||||
DisallowedTarEntryType(EntryType),
|
||||
#[error("image parsing error: {0}")]
|
||||
ImageError(#[from] ImageError),
|
||||
/// Tar entry file path is not valid UTF-8
|
||||
#[error("tar entry with non-UTF-8 path")]
|
||||
#[error("tar entry with non-UTF-8 path: {0}")]
|
||||
NonUtf8Path(#[from] FromUtf8Error),
|
||||
/// Invalid path found in tar archive
|
||||
#[error("invalid path")]
|
||||
#[error("invalid path: {path}")]
|
||||
InvalidPath { path: String },
|
||||
}
|
||||
|
||||
|
@ -90,15 +105,6 @@ pub struct MirachDataset {
|
|||
samples: Vec<MirachImageItem>,
|
||||
}
|
||||
|
||||
/// Mirach dataset item
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MirachImageItem {
|
||||
/// The label of image
|
||||
pub code: CaptchaCode,
|
||||
/// The colors of image, Tensor (height, width, channels)
|
||||
pub image: [u8; IMAGE_WIDTH * IMAGE_HEIGHT * IMAGE_CHANNELS],
|
||||
}
|
||||
|
||||
impl Dataset<MirachImageItem> for MirachDataset {
|
||||
fn get(&self, index: usize) -> Option<MirachImageItem> {
|
||||
self.samples.get(index).cloned()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
[package]
|
||||
name = "scra-mirach-model"
|
||||
version = "0.1.0"
|
||||
description = "SCRA: Mirach - Model definition"
|
||||
description = "SCRA: Mirach - Model"
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
homepage.workspace = true
|
||||
|
@ -9,4 +9,11 @@ repository.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = { workspace = true, features = ["tui", "wgpu", "train"] }
|
||||
burn = { workspace = true }
|
||||
image = { workspace = true, optional = true }
|
||||
serde = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = ["image"]
|
||||
image = ["dep:image"]
|
||||
train = ["burn/train", "burn/metrics"]
|
||||
|
|
74
model/src/data.rs
Normal file
74
model/src/data.rs
Normal file
|
@ -0,0 +1,74 @@
|
|||
use std::mem::MaybeUninit;
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
use image::RgbImage;
|
||||
|
||||
use crate::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct CaptchaCode(u16);
|
||||
|
||||
impl CaptchaCode {
|
||||
#[inline]
|
||||
pub fn new(code: u16) -> Self {
|
||||
assert!(code < 10000);
|
||||
Self(code)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn new_unchecked(code: u16) -> Self {
|
||||
Self(code)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn as_u16(&self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn as_digits(&self) -> [u8; 4] {
|
||||
[
|
||||
(self.0 / 1000) as u8,
|
||||
((self.0 / 100) % 10) as u8,
|
||||
((self.0 / 10) % 10) as u8,
|
||||
(self.0 % 10) as u8,
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MirachImageData(Vec<u8>);
|
||||
|
||||
impl MirachImageData {
|
||||
pub fn new(pixels: Vec<u8>) -> Self {
|
||||
Self(pixels)
|
||||
}
|
||||
|
||||
pub fn unwrap(self) -> Vec<u8> {
|
||||
self.0
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl From<RgbImage> for MirachImageData {
|
||||
fn from(image: RgbImage) -> Self {
|
||||
// SAFETY: The array will be filled are allocation immediately,
|
||||
// and no read will be performed before all data are fully written.
|
||||
// Therefore no data can be leaked.
|
||||
#[allow(invalid_value)]
|
||||
let mut data: Box<[u8; IMAGE_HEIGHT * IMAGE_WIDTH * IMAGE_CHANNELS]> =
|
||||
unsafe { Box::new(MaybeUninit::uninit().assume_init()) };
|
||||
for x in 0..IMAGE_WIDTH {
|
||||
let offset = x * IMAGE_HEIGHT;
|
||||
for y in 0..IMAGE_HEIGHT {
|
||||
let offset = offset + y;
|
||||
let pixel = image.get_pixel(x as u32, y as u32);
|
||||
data[(0 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[0];
|
||||
data[(1 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[1];
|
||||
data[(2 * (IMAGE_WIDTH * IMAGE_HEIGHT)) + offset] = pixel[2];
|
||||
}
|
||||
}
|
||||
Self(Vec::from(data as Box<[u8]>))
|
||||
}
|
||||
}
|
240
model/src/lib.rs
240
model/src/lib.rs
|
@ -1,31 +1,221 @@
|
|||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct CaptchaCode(u16);
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
conv::{Conv2d, Conv2dConfig},
|
||||
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig, MaxPool2d, MaxPool2dConfig},
|
||||
BatchNorm, BatchNormConfig, Dropout, DropoutConfig, LeakyRelu, LeakyReluConfig, Linear,
|
||||
LinearConfig, Lstm, LstmConfig, LstmState,
|
||||
},
|
||||
prelude::Backend,
|
||||
tensor::{activation::softmax, Tensor},
|
||||
};
|
||||
#[cfg(feature = "train")]
|
||||
use burn::{nn::loss::CrossEntropyLossConfig, prelude::Int};
|
||||
|
||||
impl CaptchaCode {
|
||||
#[inline]
|
||||
pub fn new(code: u16) -> Self {
|
||||
assert!(code < 10000);
|
||||
Self(code)
|
||||
}
|
||||
pub mod data;
|
||||
#[cfg(feature = "train")]
|
||||
pub mod train;
|
||||
|
||||
#[inline]
|
||||
pub fn new_unchecked(code: u16) -> Self {
|
||||
Self(code)
|
||||
}
|
||||
/// Sample image width, in pixels, currently 110
|
||||
pub const IMAGE_WIDTH: usize = 110;
|
||||
/// Sample image height, in pixels, currently 50
|
||||
pub const IMAGE_HEIGHT: usize = 50;
|
||||
/// Sample image color channels, currently 3 (RGB)
|
||||
pub const IMAGE_CHANNELS: usize = 3;
|
||||
/// Digits count of image
|
||||
pub const DIGIT_COUNT: usize = 4;
|
||||
|
||||
#[inline]
|
||||
pub fn as_u16(&self) -> u16 {
|
||||
self.0
|
||||
}
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MirachModel<B: Backend> {
|
||||
conv: Conv2d<B>,
|
||||
bn1: BatchNorm<B, 2>,
|
||||
activation: LeakyRelu,
|
||||
pool: MaxPool2d,
|
||||
lstm_gf1_linear: Linear<B>,
|
||||
gf_dropout: Dropout,
|
||||
lstm_gf2_pool: AdaptiveAvgPool2d,
|
||||
attention: MultiHeadAttention<B>,
|
||||
bn2: BatchNorm<B, 2>,
|
||||
lstm1_forward: Lstm<B>,
|
||||
lstm1_backward: Lstm<B>,
|
||||
lstm2_forward: Lstm<B>,
|
||||
lstm2_backward: Lstm<B>,
|
||||
out_linear: Linear<B>,
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn as_digits(&self) -> [u8; 4] {
|
||||
[
|
||||
(self.0 / 1000) as u8,
|
||||
((self.0 / 100) % 10) as u8,
|
||||
((self.0 / 10) % 10) as u8,
|
||||
(self.0 % 10) as u8,
|
||||
]
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MirachModelConfig {
|
||||
#[config(default = "10")]
|
||||
num_classes: usize,
|
||||
#[config(default = "4")]
|
||||
num_heads: usize,
|
||||
#[config(default = "64")]
|
||||
hidden_size1: usize,
|
||||
#[config(default = "32")]
|
||||
hidden_size2: usize,
|
||||
#[config(default = "0.05")]
|
||||
att_dropout: f64,
|
||||
#[config(default = "0.03")]
|
||||
gf_dropout: f64,
|
||||
}
|
||||
|
||||
impl MirachModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> MirachModel<B> {
|
||||
MirachModel {
|
||||
conv: Conv2dConfig::new([3, 4], [3, 3]).init(device),
|
||||
bn1: BatchNormConfig::new(4).init(device),
|
||||
activation: LeakyReluConfig::new().init(),
|
||||
pool: MaxPool2dConfig::new([2, 2]).with_strides([2, 2]).init(),
|
||||
lstm_gf1_linear: LinearConfig::new(54 * 24, self.hidden_size1).init(device),
|
||||
gf_dropout: DropoutConfig::new(self.gf_dropout).init(),
|
||||
lstm_gf2_pool: AdaptiveAvgPool2dConfig::new([1, self.hidden_size2]).init(),
|
||||
attention: MultiHeadAttentionConfig::new(24, self.num_heads)
|
||||
.with_dropout(self.att_dropout)
|
||||
.init(device),
|
||||
bn2: BatchNormConfig::new(1).init(device),
|
||||
lstm1_forward: LstmConfig::new(24, self.hidden_size1, true).init(device),
|
||||
lstm1_backward: LstmConfig::new(24, self.hidden_size1, true).init(device),
|
||||
lstm2_forward: LstmConfig::new(self.hidden_size1, self.hidden_size2, true).init(device),
|
||||
lstm2_backward: LstmConfig::new(self.hidden_size1, self.hidden_size2, true)
|
||||
.init(device),
|
||||
out_linear: LinearConfig::new(self.hidden_size2 * 2, self.num_heads * self.num_classes)
|
||||
.init(device),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> MirachModel<B> {
|
||||
/// # Shapes
|
||||
/// - Images [batch_size, channels, width, height]
|
||||
/// - Output [batch_size, length, n_classes]
|
||||
pub fn forward(&self, images: Tensor<B, 4>) -> Tensor<B, 3> {
|
||||
let [batch_size, _channels, _width, _height] = images.dims();
|
||||
let device = images.device();
|
||||
|
||||
let x = images;
|
||||
debug_assert_eq!(x.dims(), [batch_size, 3, 110, 50]);
|
||||
let x = self.conv.forward(x);
|
||||
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
|
||||
let x = self.bn1.forward(x);
|
||||
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
|
||||
let x = self.activation.forward(x);
|
||||
debug_assert_eq!(x.dims(), [batch_size, 4, 108, 48]);
|
||||
let x = self.pool.forward(x);
|
||||
debug_assert_eq!(x.dims(), [batch_size, 4, 54, 24]);
|
||||
let [x1, x2, x3, x4] = x
|
||||
.iter_dim(1)
|
||||
.map(|x| x.squeeze(1))
|
||||
.collect::<Vec<Tensor<B, 3>>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
debug_assert_eq!(x1.dims(), [batch_size, 54, 24]);
|
||||
debug_assert_eq!(x2.dims(), [batch_size, 54, 24]);
|
||||
debug_assert_eq!(x3.dims(), [batch_size, 54, 24]);
|
||||
debug_assert_eq!(x4.dims(), [batch_size, 54, 24]);
|
||||
|
||||
// GF
|
||||
let att_gf = self.gf_dropout.forward(x3);
|
||||
debug_assert_eq!(att_gf.dims(), [batch_size, 54, 24]);
|
||||
|
||||
let lstm_gf1 = x4.reshape([batch_size, 54 * 24]);
|
||||
let lstm_gf1 = self.lstm_gf1_linear.forward(lstm_gf1);
|
||||
let lstm_gf1 = self.gf_dropout.forward(lstm_gf1);
|
||||
let hidden_size1 = lstm_gf1.dims()[1];
|
||||
|
||||
let lstm_gf2 = lstm_gf1.clone().unsqueeze_dims(&[1, 2]);
|
||||
let lstm_gf2 = self.lstm_gf2_pool.forward(lstm_gf2);
|
||||
let hidden_size2 = lstm_gf2.dims()[3];
|
||||
let lstm_gf2 = lstm_gf2.reshape([batch_size, hidden_size2]);
|
||||
|
||||
// MHA
|
||||
let x = self.attention.forward(MhaInput::new(x1, att_gf, x2));
|
||||
let n_heads = x.weights.dims()[1];
|
||||
let x = x.context;
|
||||
debug_assert_eq!(x.dims(), [batch_size, 54, 24]);
|
||||
|
||||
let x = self.bn2.forward(x.unsqueeze_dim::<4>(1)).squeeze(1);
|
||||
debug_assert_eq!(x.dims(), [batch_size, 54, 24]);
|
||||
|
||||
// BiLSTM 1
|
||||
let (lstm_f, _) = self.lstm1_forward.forward(
|
||||
x.clone(),
|
||||
Some(LstmState::new(
|
||||
Tensor::zeros([batch_size, hidden_size1], &device),
|
||||
lstm_gf1.clone(),
|
||||
)),
|
||||
);
|
||||
let (lstm_b, _) = self.lstm1_backward.forward(
|
||||
x.flip([1]),
|
||||
Some(LstmState::new(
|
||||
Tensor::zeros([batch_size, hidden_size1], &device),
|
||||
lstm_gf1.clone(),
|
||||
)),
|
||||
);
|
||||
let x = lstm_f + lstm_b;
|
||||
debug_assert_eq!(x.dims(), [batch_size, 54, hidden_size1]);
|
||||
|
||||
// BiLSTM 2
|
||||
let (lstm_f, _) = self.lstm2_forward.forward(
|
||||
x.clone(),
|
||||
Some(LstmState::new(
|
||||
Tensor::zeros([batch_size, hidden_size2], &device),
|
||||
lstm_gf2.clone(),
|
||||
)),
|
||||
);
|
||||
let lstm_f = lstm_f
|
||||
.slice([0..batch_size, 53..54, 0..hidden_size2])
|
||||
.squeeze::<2>(1);
|
||||
let (lstm_b, _) = self.lstm2_backward.forward(
|
||||
x.flip([1]),
|
||||
Some(LstmState::new(
|
||||
Tensor::zeros([batch_size, hidden_size2], &device),
|
||||
lstm_gf2.clone(),
|
||||
)),
|
||||
);
|
||||
let lstm_b = lstm_b
|
||||
.slice([0..batch_size, 53..54, 0..hidden_size2])
|
||||
.squeeze::<2>(1);
|
||||
let x = Tensor::cat(vec![lstm_f, lstm_b], 1);
|
||||
debug_assert_eq!(x.dims(), [batch_size, hidden_size2 * 2]);
|
||||
|
||||
let x = self.out_linear.forward(x);
|
||||
let n_classes = x.dims()[1] / n_heads;
|
||||
debug_assert_eq!(x.dims(), [batch_size, n_heads * n_classes]);
|
||||
|
||||
let x = x.reshape([batch_size, n_heads, n_classes]);
|
||||
let x = softmax(x, 2);
|
||||
|
||||
x
|
||||
}
|
||||
|
||||
#[cfg(feature = "train")]
|
||||
pub fn forward_recognition(
|
||||
&self,
|
||||
images: Tensor<B, 4>,
|
||||
targets: Tensor<B, 2, Int>,
|
||||
) -> train::MirachOutput<B> {
|
||||
use train::MirachOutput;
|
||||
|
||||
let output = self.forward(images);
|
||||
let device = output.device();
|
||||
let batch_size = output.dims()[0];
|
||||
let loss_func = CrossEntropyLossConfig::new()
|
||||
.with_logits(true)
|
||||
.init(&device);
|
||||
|
||||
let loss = output
|
||||
.clone()
|
||||
.iter_dim(1)
|
||||
.zip(targets.clone().iter_dim(1))
|
||||
.map(|(output, target)| loss_func.forward(output.squeeze(1), target.squeeze(1)))
|
||||
.fold(Tensor::zeros([batch_size], &device), |sum, loss| sum + loss);
|
||||
|
||||
MirachOutput {
|
||||
output,
|
||||
targets,
|
||||
loss,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
66
model/src/train.rs
Normal file
66
model/src/train.rs
Normal file
|
@ -0,0 +1,66 @@
|
|||
use burn::{
|
||||
prelude::Backend,
|
||||
tensor::{backend::AutodiffBackend, Int, Tensor},
|
||||
train::{
|
||||
metric::{Adaptor, HammingScoreInput, LossInput},
|
||||
TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::MirachModel;
|
||||
|
||||
/// A batch for trainning
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MirachTrainBatch<B: Backend> {
|
||||
/// Input images, in [batch size, channels (=== 3 / RGB), width (=== 110), height (=== 50)]
|
||||
pub images: Tensor<B, 4>,
|
||||
/// Image labels as logits, in [batch size, length (=== 4)]
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<MirachTrainBatch<B>, MirachOutput<B>> for MirachModel<B> {
|
||||
fn step(&self, batch: MirachTrainBatch<B>) -> TrainOutput<MirachOutput<B>> {
|
||||
let item = self.forward_recognition(batch.images, batch.targets);
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<MirachTrainBatch<B>, MirachOutput<B>> for MirachModel<B> {
|
||||
fn step(&self, batch: MirachTrainBatch<B>) -> MirachOutput<B> {
|
||||
self.forward_recognition(batch.images, batch.targets)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct MirachOutput<B: Backend> {
|
||||
/// The output, in [batch_size, length, n_digits]
|
||||
pub output: Tensor<B, 3>,
|
||||
/// The targets, in [batch_size, length]
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
/// The cross-entropy loss, in [batch_size]
|
||||
pub loss: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MirachOutput<B> {
|
||||
fn adapt(&self) -> HammingScoreInput<B> {
|
||||
let batch_size = self.output.dims()[0];
|
||||
|
||||
let outputs = self
|
||||
.output
|
||||
.clone()
|
||||
.argmax(2)
|
||||
.squeeze(2)
|
||||
.equal(self.targets.clone())
|
||||
.float();
|
||||
|
||||
HammingScoreInput::new(
|
||||
outputs,
|
||||
Tensor::<B, 2, Int>::ones([batch_size, 4], &self.output.device()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for MirachOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue