This commit is contained in:
xtex 2024-08-28 14:34:19 +08:00
parent 445aa48219
commit 299019b44f
Signed by: xtex
GPG key ID: B918086ED8045B91
10 changed files with 370 additions and 20 deletions

102
Cargo.lock generated
View file

@ -297,6 +297,15 @@ version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]]
name = "blas-src"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b95e83dc868db96e69795c0213143095f03de9dd3252f205d4ac716e4076a7e0"
dependencies = [
"openblas-src",
]
[[package]]
name = "block"
version = "0.1.6"
@ -420,7 +429,7 @@ source = "git+https://github.com/tracel-ai/burn#79cd3d5d21cb6217dc72d4a5e60f9e0f
dependencies = [
"csv",
"derive-new",
"dirs",
"dirs 5.0.1",
"gix-tempfile",
"image",
"r2d2",
@ -489,6 +498,7 @@ name = "burn-ndarray"
version = "0.14.0"
source = "git+https://github.com/tracel-ai/burn#79cd3d5d21cb6217dc72d4a5e60f9e0ff885ded3"
dependencies = [
"blas-src",
"burn-autodiff",
"burn-common",
"burn-tensor",
@ -497,6 +507,7 @@ dependencies = [
"matrixmultiply",
"ndarray 0.16.1",
"num-traits",
"openblas-src",
"portable-atomic-util",
"rand",
"spin",
@ -664,6 +675,15 @@ dependencies = [
"rustversion",
]
[[package]]
name = "cblas-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
dependencies = [
"libc",
]
[[package]]
name = "cc"
version = "1.1.14"
@ -1093,7 +1113,7 @@ dependencies = [
"cfg_aliases 0.2.1",
"cubecl-common",
"derive-new",
"dirs",
"dirs 5.0.1",
"hashbrown 0.14.5",
"log",
"md5",
@ -1239,13 +1259,33 @@ dependencies = [
"subtle",
]
[[package]]
name = "dirs"
version = "3.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
dependencies = [
"dirs-sys 0.3.7",
]
[[package]]
name = "dirs"
version = "5.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
dependencies = [
"dirs-sys",
"dirs-sys 0.4.1",
]
[[package]]
name = "dirs-sys"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6"
dependencies = [
"libc",
"redox_users",
"winapi",
]
[[package]]
@ -2609,6 +2649,8 @@ version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
dependencies = [
"cblas-sys",
"libc",
"matrixmultiply",
"num-complex",
"num-integer",
@ -2823,6 +2865,32 @@ version = "1.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
[[package]]
name = "openblas-build"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4b6b44095098cafc71915cfac3427135b6dd2ea85820a7d94a5871cb0d1e169"
dependencies = [
"anyhow",
"flate2",
"native-tls",
"tar",
"thiserror",
"ureq",
"walkdir",
]
[[package]]
name = "openblas-src"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa4958649f766a1013db4254a852cdf2836764869b6654fa117316905f537363"
dependencies = [
"dirs 3.0.2",
"openblas-build",
"vcpkg",
]
[[package]]
name = "openssl"
version = "0.10.66"
@ -3492,6 +3560,19 @@ dependencies = [
"zeroize",
]
[[package]]
name = "rustls-native-certs"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.1.3"
@ -3603,6 +3684,7 @@ dependencies = [
"clap",
"env_logger",
"image",
"itertools 0.13.0",
"log",
"rand",
"scra-mirach-dataset",
@ -4068,6 +4150,7 @@ checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909"
dependencies = [
"filetime",
"libc",
"xattr",
]
[[package]]
@ -4497,8 +4580,10 @@ dependencies = [
"base64 0.22.1",
"flate2",
"log",
"native-tls",
"once_cell",
"rustls",
"rustls-native-certs",
"rustls-pki-types",
"serde",
"serde_json",
@ -5076,6 +5161,17 @@ dependencies = [
"syn 2.0.76",
]
[[package]]
name = "xattr"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
dependencies = [
"libc",
"linux-raw-sys",
"rustix",
]
[[package]]
name = "xml-rs"
version = "0.8.21"

View file

@ -9,7 +9,7 @@ repository.workspace = true
license.workspace = true
[dependencies]
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui"] }
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui", "ndarray"] }
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
scra-mirach-dataset = { path = "../dataset" }
anyhow = { workspace = true }
@ -21,8 +21,10 @@ serde_json = "1.0.127"
rand = "0.8.5"
image = { workspace = true, features = ["png", "jpeg"] }
torch-sys = { version = "^0", optional = true }
itertools = "0.13.0"
[features]
default = []
default = ["openblas"]
openblas = ["burn/openblas-system"]
libtorch = ["burn/tch"]
download-libtorch = ["libtorch", "dep:torch-sys", "torch-sys/download-libtorch"]

127
cli/src/batch_test.rs Normal file
View file

@ -0,0 +1,127 @@
use std::{path::PathBuf, time::Instant};
use anyhow::Result;
#[cfg(feature = "libtorch")]
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
use burn::{
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
data::dataloader::Dataset,
module::Module,
prelude::Backend,
record::DefaultRecorder,
};
use itertools::Itertools;
use log::info;
use scra_mirach_model::inference::InferenceBuilder;
use crate::{Args, BackendOpt};
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about = "Test a set of parameter with builtin dataset")]
pub struct BatchTestArgs {
#[arg(from_global)]
pub backend: BackendOpt,
#[cfg(feature = "libtorch")]
#[arg(from_global)]
pub cuda_index: Option<usize>,
#[arg(from_global)]
pub model_config: Option<PathBuf>,
#[arg(short, long, help = "Don't print detailed result")]
pub quiet: bool,
#[arg(short, long, help = "Inference batch size", default_value_t = 1000)]
pub batch_size: usize,
#[arg(help = "Parameters file")]
pub params: PathBuf,
#[arg(help = "Dataset")]
pub dataset: DatasetKind,
}
#[derive(clap::ValueEnum, Debug, Clone, Copy)]
pub enum DatasetKind {
Validate,
Train,
}
pub fn run(args: BatchTestArgs) -> Result<()> {
match args.backend {
BackendOpt::Wgpu => batch_test::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
BackendOpt::NdArray => batch_test::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCpu => batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchMps => batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchVulkan => {
batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan)
}
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCuda => {
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
batch_test::<Autodiff<LibTorch>>(args, device)
}
}
}
pub fn batch_test<B: Backend>(args: BatchTestArgs, device: B::Device) -> Result<()>
where
B::IntElem:,
{
B::seed(rand::random());
let model_config = Args::model_config(&args.model_config)?;
let model = model_config.init::<B>(&device);
let model = model.load_file(&args.params, &DefaultRecorder::new(), &device)?;
info!("Model loaded, {} parameters", model.num_params());
let dataset = match args.dataset {
DatasetKind::Validate => scra_mirach_dataset::load_validate()?,
DatasetKind::Train => scra_mirach_dataset::load_train()?,
};
let (code, images): (Vec<_>, Vec<_>) = dataset
.iter()
.map(|image| (image.code, image.image))
.unzip();
let batch_size = code.len();
let time = Instant::now();
let inference = images
.into_iter()
.chunks(args.batch_size)
.into_iter()
.flat_map(|images| {
InferenceBuilder::new(images.collect_vec())
.device(device.clone())
.model(&model)
.inference()
})
.collect_vec();
let time = time.elapsed();
let corrects = code
.into_iter()
.zip(inference.into_iter())
.enumerate()
.inspect(|(index, (label, inference))| {
if !args.quiet {
if label == inference {
println!("{index}: {label}");
} else {
println!("{index}: {label}, got {inference}");
}
}
})
.filter(|(_, (label, inference))| label == inference)
.count();
let correct_rate = (corrects as f64) / (batch_size as f64) * 100.;
println!("==============================");
println!("Number of parameters: {}", model.num_params());
println!("Correct rate: {correct_rate:.2} ({corrects} / {batch_size})");
println!("Elapsed time: {time:?}");
println!("==============================");
Ok(())
}

View file

@ -4,7 +4,7 @@ use anyhow::Result;
#[cfg(feature = "libtorch")]
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
use burn::{
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
module::Module,
prelude::Backend,
record::DefaultRecorder,
@ -26,15 +26,16 @@ pub struct InferenceArgs {
#[arg(from_global)]
pub model_config: Option<PathBuf>,
#[arg(short = 'p', long, help = "Parameters file")]
#[arg(help = "Parameters file")]
pub params: PathBuf,
#[arg(short = 'i', long, help = "Image")]
#[arg(help = "Image")]
pub image: PathBuf,
}
pub fn run(args: InferenceArgs) -> Result<()> {
match args.backend {
BackendOpt::Wgpu => inference::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
BackendOpt::NdArray => inference::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCpu => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
@ -66,6 +67,7 @@ where
let start_time = Instant::now();
let result = model.forward(image);
let time = start_time.elapsed();
// we want the pure forward time, so not using InferenceBuilder here
info!("Inference completed, took {:?}", time);
let result = result.squeeze::<2>(0);

View file

@ -1,6 +1,7 @@
use std::{env, fs, path::PathBuf};
use anyhow::Result;
use batch_test::BatchTestArgs;
use clap::Parser;
use inference::InferenceArgs;
use log::info;
@ -8,6 +9,7 @@ use scra_mirach_model::MirachModelConfig;
use serde::{Deserialize, Serialize};
use train::TrainArgs;
pub mod batch_test;
pub mod inference;
pub mod train;
@ -36,11 +38,14 @@ struct Args {
enum Action {
Train(TrainArgs),
Inference(InferenceArgs),
BatchTest(BatchTestArgs),
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum BackendOpt {
Wgpu,
#[value(alias = "ndarray")]
NdArray,
#[cfg(feature = "libtorch")]
LibTorchCpu,
#[cfg(feature = "libtorch")]
@ -75,5 +80,6 @@ fn main() -> Result<()> {
match args.action {
Action::Train(args) => train::run(args),
Action::Inference(args) => inference::run(args),
Action::BatchTest(args) => batch_test::run(args),
}
}

View file

@ -4,7 +4,7 @@ use anyhow::Result;
#[cfg(feature = "libtorch")]
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
use burn::{
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
data::dataloader::DataLoaderBuilder,
lr_scheduler::constant::ConstantLr,
module::Module,
@ -56,6 +56,7 @@ pub struct TrainArgs {
pub fn run(args: TrainArgs) -> Result<()> {
match args.backend {
BackendOpt::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
BackendOpt::NdArray => train::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
@ -121,8 +122,6 @@ pub fn train<B: AutodiffBackend>(args: TrainArgs, device: B::Device) -> Result<(
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device.clone()])
.num_epochs(args.epochs)
// .renderer(MirachMetricsRenderer)
// .with_application_logger(None)
.summary()
.build(model, optimizer, lr_sched);

View file

@ -17,3 +17,8 @@ serde = { workspace = true }
default = ["image"]
image = ["dep:image"]
train = ["burn/train", "burn/metrics"]
wgpu = ["burn/wgpu"]
ndarray = ["burn/ndarray"]
ndarray-openblas = ["burn/openblas"]
libtorch = ["burn/tch"]

View file

@ -1,15 +1,13 @@
use std::mem::MaybeUninit;
use std::{fmt::{Debug, Display}, mem::MaybeUninit};
use burn::{
prelude::Backend,
tensor::{Device, Tensor, TensorData},
};
#[cfg(feature = "image")]
use image::RgbImage;
use crate::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(transparent)]
pub struct CaptchaCode(u16);
@ -51,6 +49,18 @@ impl CaptchaCode {
}
}
impl Debug for CaptchaCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("CaptchaCode({:04})", self.0))
}
}
impl Display for CaptchaCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:04}", self.0))
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct MirachImageData(Vec<u8>);
@ -65,8 +75,8 @@ impl MirachImageData {
}
#[cfg(feature = "image")]
impl From<RgbImage> for MirachImageData {
fn from(image: RgbImage) -> Self {
impl From<image::RgbImage> for MirachImageData {
fn from(image: image::RgbImage) -> Self {
// SAFETY: The array will be filled are allocation immediately,
// and no read will be performed before all data are fully written.
// Therefore no data can be leaked.
@ -87,6 +97,13 @@ impl From<RgbImage> for MirachImageData {
}
}
#[cfg(feature = "image")]
impl From<image::DynamicImage> for MirachImageData {
fn from(image: image::DynamicImage) -> Self {
image.into_rgb8().into()
}
}
impl MirachImageData {
pub fn into_tensor<B: Backend>(self, device: &Device<B>) -> Tensor<B, 4> {
let image = TensorData::new::<u8, _>(

95
model/src/inference.rs Normal file
View file

@ -0,0 +1,95 @@
use burn::{prelude::Backend, tensor::Tensor};
use crate::{
data::{CaptchaCode, MirachImageData},
MirachModel, MirachModelConfig,
};
/// Inference builder
pub struct InferenceBuilder<'model, B: Backend> {
images: Vec<MirachImageData>,
device: Option<B::Device>,
model: Option<&'model MirachModel<B>>,
}
impl<'model, B: Backend> InferenceBuilder<'model, B> {
/// Creates a inference batch for given images
pub fn new(images: Vec<MirachImageData>) -> Self {
Self {
images,
device: None,
model: None,
}
}
/// Creates a inference batch for one image
#[inline]
pub fn new_one(image: MirachImageData) -> Self {
Self::new(vec![image])
}
/// Sets the tensor backend device to use
pub fn device(mut self, device: B::Device) -> Self {
self.device = Some(device);
self
}
/// Sets the model to use.
///
/// The model should be initialized on the same tensor backend device to improve performance.
pub fn model(mut self, model: &'model MirachModel<B>) -> Self {
self.model = Some(model);
self
}
/// Do the inference, returning recognized CAPTCHA codes.
pub fn inference(self) -> Vec<CaptchaCode> {
let device = self.device.unwrap_or_default();
let batch_size = self.images.len();
let images = self
.images
.into_iter()
.map(|image| image.into_tensor::<B>(&device))
.collect();
let images = Tensor::cat(images, 0);
let result = if let Some(model) = self.model {
model.forward(images)
} else {
MirachModelConfig::new().init(&device).forward(images)
};
let result = result.argmax(2).squeeze::<2>(2);
let result = (result.clone().slice([0..batch_size, 0..1]) * 1000)
+ (result.clone().slice([0..batch_size, 1..2]) * 100)
+ (result.clone().slice([0..batch_size, 2..3]) * 10)
+ (result.slice([0..batch_size, 3..4]) * 1);
let result = result.squeeze::<1>(1);
let result = result.into_data().convert::<u32>();
let result = result
.to_vec::<u32>()
.expect("Model output tensor could not be converted to u32 vector by the backend");
let result = result
.into_iter()
.map(|code| code as u16)
.map(CaptchaCode::new_unchecked)
.collect();
result
}
/// Inferences only one image.
///
/// ## Panics
/// Panic when there are no input image.
pub fn inference_one(self) -> CaptchaCode {
*self
.inference()
.first()
.expect("inference_one must be called with at least one image loaded")
}
}

View file

@ -15,6 +15,7 @@ use burn::{
use burn::{nn::loss::CrossEntropyLossConfig, prelude::Int};
pub mod data;
pub mod inference;
#[cfg(feature = "train")]
pub mod train;
@ -53,11 +54,11 @@ pub struct MirachModelConfig {
num_heads: usize,
#[config(default = "64")]
hidden_size1: usize,
#[config(default = "32")]
#[config(default = "64")]
hidden_size2: usize,
#[config(default = "0.05")]
#[config(default = "0.015")]
att_dropout: f64,
#[config(default = "0.03")]
#[config(default = "0.015")]
gf_dropout: f64,
}