update
This commit is contained in:
parent
445aa48219
commit
299019b44f
10 changed files with 370 additions and 20 deletions
102
Cargo.lock
generated
102
Cargo.lock
generated
|
@ -297,6 +297,15 @@ version = "2.5.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
|
||||
|
||||
[[package]]
|
||||
name = "blas-src"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b95e83dc868db96e69795c0213143095f03de9dd3252f205d4ac716e4076a7e0"
|
||||
dependencies = [
|
||||
"openblas-src",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "block"
|
||||
version = "0.1.6"
|
||||
|
@ -420,7 +429,7 @@ source = "git+https://github.com/tracel-ai/burn#79cd3d5d21cb6217dc72d4a5e60f9e0f
|
|||
dependencies = [
|
||||
"csv",
|
||||
"derive-new",
|
||||
"dirs",
|
||||
"dirs 5.0.1",
|
||||
"gix-tempfile",
|
||||
"image",
|
||||
"r2d2",
|
||||
|
@ -489,6 +498,7 @@ name = "burn-ndarray"
|
|||
version = "0.14.0"
|
||||
source = "git+https://github.com/tracel-ai/burn#79cd3d5d21cb6217dc72d4a5e60f9e0ff885ded3"
|
||||
dependencies = [
|
||||
"blas-src",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
"burn-tensor",
|
||||
|
@ -497,6 +507,7 @@ dependencies = [
|
|||
"matrixmultiply",
|
||||
"ndarray 0.16.1",
|
||||
"num-traits",
|
||||
"openblas-src",
|
||||
"portable-atomic-util",
|
||||
"rand",
|
||||
"spin",
|
||||
|
@ -664,6 +675,15 @@ dependencies = [
|
|||
"rustversion",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cblas-sys"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.1.14"
|
||||
|
@ -1093,7 +1113,7 @@ dependencies = [
|
|||
"cfg_aliases 0.2.1",
|
||||
"cubecl-common",
|
||||
"derive-new",
|
||||
"dirs",
|
||||
"dirs 5.0.1",
|
||||
"hashbrown 0.14.5",
|
||||
"log",
|
||||
"md5",
|
||||
|
@ -1239,13 +1259,33 @@ dependencies = [
|
|||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "3.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309"
|
||||
dependencies = [
|
||||
"dirs-sys 0.3.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs"
|
||||
version = "5.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225"
|
||||
dependencies = [
|
||||
"dirs-sys",
|
||||
"dirs-sys 0.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dirs-sys"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"redox_users",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -2609,6 +2649,8 @@ version = "0.16.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841"
|
||||
dependencies = [
|
||||
"cblas-sys",
|
||||
"libc",
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
|
@ -2823,6 +2865,32 @@ version = "1.19.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "openblas-build"
|
||||
version = "0.10.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4b6b44095098cafc71915cfac3427135b6dd2ea85820a7d94a5871cb0d1e169"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"flate2",
|
||||
"native-tls",
|
||||
"tar",
|
||||
"thiserror",
|
||||
"ureq",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openblas-src"
|
||||
version = "0.10.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa4958649f766a1013db4254a852cdf2836764869b6654fa117316905f537363"
|
||||
dependencies = [
|
||||
"dirs 3.0.2",
|
||||
"openblas-build",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl"
|
||||
version = "0.10.66"
|
||||
|
@ -3492,6 +3560,19 @@ dependencies = [
|
|||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-native-certs"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa"
|
||||
dependencies = [
|
||||
"openssl-probe",
|
||||
"rustls-pemfile",
|
||||
"rustls-pki-types",
|
||||
"schannel",
|
||||
"security-framework",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "2.1.3"
|
||||
|
@ -3603,6 +3684,7 @@ dependencies = [
|
|||
"clap",
|
||||
"env_logger",
|
||||
"image",
|
||||
"itertools 0.13.0",
|
||||
"log",
|
||||
"rand",
|
||||
"scra-mirach-dataset",
|
||||
|
@ -4068,6 +4150,7 @@ checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909"
|
|||
dependencies = [
|
||||
"filetime",
|
||||
"libc",
|
||||
"xattr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -4497,8 +4580,10 @@ dependencies = [
|
|||
"base64 0.22.1",
|
||||
"flate2",
|
||||
"log",
|
||||
"native-tls",
|
||||
"once_cell",
|
||||
"rustls",
|
||||
"rustls-native-certs",
|
||||
"rustls-pki-types",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
@ -5076,6 +5161,17 @@ dependencies = [
|
|||
"syn 2.0.76",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xattr"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"rustix",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xml-rs"
|
||||
version = "0.8.21"
|
||||
|
|
|
@ -9,7 +9,7 @@ repository.workspace = true
|
|||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui"] }
|
||||
burn = { workspace = true, features = ["metrics", "train", "wgpu", "tui", "ndarray"] }
|
||||
scra-mirach-model = { path = "../model", default-features = false, features = ["image", "train"] }
|
||||
scra-mirach-dataset = { path = "../dataset" }
|
||||
anyhow = { workspace = true }
|
||||
|
@ -21,8 +21,10 @@ serde_json = "1.0.127"
|
|||
rand = "0.8.5"
|
||||
image = { workspace = true, features = ["png", "jpeg"] }
|
||||
torch-sys = { version = "^0", optional = true }
|
||||
itertools = "0.13.0"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["openblas"]
|
||||
openblas = ["burn/openblas-system"]
|
||||
libtorch = ["burn/tch"]
|
||||
download-libtorch = ["libtorch", "dep:torch-sys", "torch-sys/download-libtorch"]
|
||||
|
|
127
cli/src/batch_test.rs
Normal file
127
cli/src/batch_test.rs
Normal file
|
@ -0,0 +1,127 @@
|
|||
use std::{path::PathBuf, time::Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
#[cfg(feature = "libtorch")]
|
||||
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
|
||||
use burn::{
|
||||
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
|
||||
data::dataloader::Dataset,
|
||||
module::Module,
|
||||
prelude::Backend,
|
||||
record::DefaultRecorder,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use log::info;
|
||||
use scra_mirach_model::inference::InferenceBuilder;
|
||||
|
||||
use crate::{Args, BackendOpt};
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about = "Test a set of parameter with builtin dataset")]
|
||||
pub struct BatchTestArgs {
|
||||
#[arg(from_global)]
|
||||
pub backend: BackendOpt,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(from_global)]
|
||||
pub cuda_index: Option<usize>,
|
||||
#[arg(from_global)]
|
||||
pub model_config: Option<PathBuf>,
|
||||
|
||||
#[arg(short, long, help = "Don't print detailed result")]
|
||||
pub quiet: bool,
|
||||
#[arg(short, long, help = "Inference batch size", default_value_t = 1000)]
|
||||
pub batch_size: usize,
|
||||
|
||||
#[arg(help = "Parameters file")]
|
||||
pub params: PathBuf,
|
||||
#[arg(help = "Dataset")]
|
||||
pub dataset: DatasetKind,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy)]
|
||||
pub enum DatasetKind {
|
||||
Validate,
|
||||
Train,
|
||||
}
|
||||
|
||||
pub fn run(args: BatchTestArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
BackendOpt::Wgpu => batch_test::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
BackendOpt::NdArray => batch_test::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCpu => batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchMps => batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchVulkan => {
|
||||
batch_test::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan)
|
||||
}
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCuda => {
|
||||
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
|
||||
batch_test::<Autodiff<LibTorch>>(args, device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn batch_test<B: Backend>(args: BatchTestArgs, device: B::Device) -> Result<()>
|
||||
where
|
||||
B::IntElem:,
|
||||
{
|
||||
B::seed(rand::random());
|
||||
|
||||
let model_config = Args::model_config(&args.model_config)?;
|
||||
let model = model_config.init::<B>(&device);
|
||||
let model = model.load_file(&args.params, &DefaultRecorder::new(), &device)?;
|
||||
info!("Model loaded, {} parameters", model.num_params());
|
||||
|
||||
let dataset = match args.dataset {
|
||||
DatasetKind::Validate => scra_mirach_dataset::load_validate()?,
|
||||
DatasetKind::Train => scra_mirach_dataset::load_train()?,
|
||||
};
|
||||
let (code, images): (Vec<_>, Vec<_>) = dataset
|
||||
.iter()
|
||||
.map(|image| (image.code, image.image))
|
||||
.unzip();
|
||||
let batch_size = code.len();
|
||||
|
||||
let time = Instant::now();
|
||||
let inference = images
|
||||
.into_iter()
|
||||
.chunks(args.batch_size)
|
||||
.into_iter()
|
||||
.flat_map(|images| {
|
||||
InferenceBuilder::new(images.collect_vec())
|
||||
.device(device.clone())
|
||||
.model(&model)
|
||||
.inference()
|
||||
})
|
||||
.collect_vec();
|
||||
let time = time.elapsed();
|
||||
|
||||
let corrects = code
|
||||
.into_iter()
|
||||
.zip(inference.into_iter())
|
||||
.enumerate()
|
||||
.inspect(|(index, (label, inference))| {
|
||||
if !args.quiet {
|
||||
if label == inference {
|
||||
println!("{index}: {label}");
|
||||
} else {
|
||||
println!("{index}: {label}, got {inference}");
|
||||
}
|
||||
}
|
||||
})
|
||||
.filter(|(_, (label, inference))| label == inference)
|
||||
.count();
|
||||
|
||||
let correct_rate = (corrects as f64) / (batch_size as f64) * 100.;
|
||||
|
||||
println!("==============================");
|
||||
println!("Number of parameters: {}", model.num_params());
|
||||
println!("Correct rate: {correct_rate:.2} ({corrects} / {batch_size})");
|
||||
println!("Elapsed time: {time:?}");
|
||||
println!("==============================");
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
|||
#[cfg(feature = "libtorch")]
|
||||
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
|
||||
use burn::{
|
||||
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
|
||||
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
|
||||
module::Module,
|
||||
prelude::Backend,
|
||||
record::DefaultRecorder,
|
||||
|
@ -26,15 +26,16 @@ pub struct InferenceArgs {
|
|||
#[arg(from_global)]
|
||||
pub model_config: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'p', long, help = "Parameters file")]
|
||||
#[arg(help = "Parameters file")]
|
||||
pub params: PathBuf,
|
||||
#[arg(short = 'i', long, help = "Image")]
|
||||
#[arg(help = "Image")]
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
pub fn run(args: InferenceArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
BackendOpt::Wgpu => inference::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
BackendOpt::NdArray => inference::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCpu => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
|
@ -66,6 +67,7 @@ where
|
|||
let start_time = Instant::now();
|
||||
let result = model.forward(image);
|
||||
let time = start_time.elapsed();
|
||||
// we want the pure forward time, so not using InferenceBuilder here
|
||||
info!("Inference completed, took {:?}", time);
|
||||
|
||||
let result = result.squeeze::<2>(0);
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use std::{env, fs, path::PathBuf};
|
||||
|
||||
use anyhow::Result;
|
||||
use batch_test::BatchTestArgs;
|
||||
use clap::Parser;
|
||||
use inference::InferenceArgs;
|
||||
use log::info;
|
||||
|
@ -8,6 +9,7 @@ use scra_mirach_model::MirachModelConfig;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use train::TrainArgs;
|
||||
|
||||
pub mod batch_test;
|
||||
pub mod inference;
|
||||
pub mod train;
|
||||
|
||||
|
@ -36,11 +38,14 @@ struct Args {
|
|||
enum Action {
|
||||
Train(TrainArgs),
|
||||
Inference(InferenceArgs),
|
||||
BatchTest(BatchTestArgs),
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum BackendOpt {
|
||||
Wgpu,
|
||||
#[value(alias = "ndarray")]
|
||||
NdArray,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchCpu,
|
||||
#[cfg(feature = "libtorch")]
|
||||
|
@ -75,5 +80,6 @@ fn main() -> Result<()> {
|
|||
match args.action {
|
||||
Action::Train(args) => train::run(args),
|
||||
Action::Inference(args) => inference::run(args),
|
||||
Action::BatchTest(args) => batch_test::run(args),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use anyhow::Result;
|
|||
#[cfg(feature = "libtorch")]
|
||||
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
|
||||
use burn::{
|
||||
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
|
||||
backend::{ndarray::NdArrayDevice, wgpu::WgpuDevice, Autodiff, NdArray, Wgpu},
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
lr_scheduler::constant::ConstantLr,
|
||||
module::Module,
|
||||
|
@ -56,6 +56,7 @@ pub struct TrainArgs {
|
|||
pub fn run(args: TrainArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
BackendOpt::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
BackendOpt::NdArray => train::<Autodiff<NdArray>>(args, NdArrayDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
|
@ -121,8 +122,6 @@ pub fn train<B: AutodiffBackend>(args: TrainArgs, device: B::Device) -> Result<(
|
|||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(args.epochs)
|
||||
// .renderer(MirachMetricsRenderer)
|
||||
// .with_application_logger(None)
|
||||
.summary()
|
||||
.build(model, optimizer, lr_sched);
|
||||
|
||||
|
|
|
@ -17,3 +17,8 @@ serde = { workspace = true }
|
|||
default = ["image"]
|
||||
image = ["dep:image"]
|
||||
train = ["burn/train", "burn/metrics"]
|
||||
|
||||
wgpu = ["burn/wgpu"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-openblas = ["burn/openblas"]
|
||||
libtorch = ["burn/tch"]
|
||||
|
|
|
@ -1,15 +1,13 @@
|
|||
use std::mem::MaybeUninit;
|
||||
use std::{fmt::{Debug, Display}, mem::MaybeUninit};
|
||||
|
||||
use burn::{
|
||||
prelude::Backend,
|
||||
tensor::{Device, Tensor, TensorData},
|
||||
};
|
||||
#[cfg(feature = "image")]
|
||||
use image::RgbImage;
|
||||
|
||||
use crate::{IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
#[repr(transparent)]
|
||||
pub struct CaptchaCode(u16);
|
||||
|
||||
|
@ -51,6 +49,18 @@ impl CaptchaCode {
|
|||
}
|
||||
}
|
||||
|
||||
impl Debug for CaptchaCode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("CaptchaCode({:04})", self.0))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for CaptchaCode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!("{:04}", self.0))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct MirachImageData(Vec<u8>);
|
||||
|
||||
|
@ -65,8 +75,8 @@ impl MirachImageData {
|
|||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl From<RgbImage> for MirachImageData {
|
||||
fn from(image: RgbImage) -> Self {
|
||||
impl From<image::RgbImage> for MirachImageData {
|
||||
fn from(image: image::RgbImage) -> Self {
|
||||
// SAFETY: The array will be filled are allocation immediately,
|
||||
// and no read will be performed before all data are fully written.
|
||||
// Therefore no data can be leaked.
|
||||
|
@ -87,6 +97,13 @@ impl From<RgbImage> for MirachImageData {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "image")]
|
||||
impl From<image::DynamicImage> for MirachImageData {
|
||||
fn from(image: image::DynamicImage) -> Self {
|
||||
image.into_rgb8().into()
|
||||
}
|
||||
}
|
||||
|
||||
impl MirachImageData {
|
||||
pub fn into_tensor<B: Backend>(self, device: &Device<B>) -> Tensor<B, 4> {
|
||||
let image = TensorData::new::<u8, _>(
|
||||
|
|
95
model/src/inference.rs
Normal file
95
model/src/inference.rs
Normal file
|
@ -0,0 +1,95 @@
|
|||
use burn::{prelude::Backend, tensor::Tensor};
|
||||
|
||||
use crate::{
|
||||
data::{CaptchaCode, MirachImageData},
|
||||
MirachModel, MirachModelConfig,
|
||||
};
|
||||
|
||||
/// Inference builder
|
||||
pub struct InferenceBuilder<'model, B: Backend> {
|
||||
images: Vec<MirachImageData>,
|
||||
device: Option<B::Device>,
|
||||
model: Option<&'model MirachModel<B>>,
|
||||
}
|
||||
|
||||
impl<'model, B: Backend> InferenceBuilder<'model, B> {
|
||||
/// Creates a inference batch for given images
|
||||
pub fn new(images: Vec<MirachImageData>) -> Self {
|
||||
Self {
|
||||
images,
|
||||
device: None,
|
||||
model: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a inference batch for one image
|
||||
#[inline]
|
||||
pub fn new_one(image: MirachImageData) -> Self {
|
||||
Self::new(vec![image])
|
||||
}
|
||||
|
||||
/// Sets the tensor backend device to use
|
||||
pub fn device(mut self, device: B::Device) -> Self {
|
||||
self.device = Some(device);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the model to use.
|
||||
///
|
||||
/// The model should be initialized on the same tensor backend device to improve performance.
|
||||
pub fn model(mut self, model: &'model MirachModel<B>) -> Self {
|
||||
self.model = Some(model);
|
||||
self
|
||||
}
|
||||
|
||||
/// Do the inference, returning recognized CAPTCHA codes.
|
||||
pub fn inference(self) -> Vec<CaptchaCode> {
|
||||
let device = self.device.unwrap_or_default();
|
||||
|
||||
let batch_size = self.images.len();
|
||||
|
||||
let images = self
|
||||
.images
|
||||
.into_iter()
|
||||
.map(|image| image.into_tensor::<B>(&device))
|
||||
.collect();
|
||||
let images = Tensor::cat(images, 0);
|
||||
|
||||
let result = if let Some(model) = self.model {
|
||||
model.forward(images)
|
||||
} else {
|
||||
MirachModelConfig::new().init(&device).forward(images)
|
||||
};
|
||||
|
||||
let result = result.argmax(2).squeeze::<2>(2);
|
||||
let result = (result.clone().slice([0..batch_size, 0..1]) * 1000)
|
||||
+ (result.clone().slice([0..batch_size, 1..2]) * 100)
|
||||
+ (result.clone().slice([0..batch_size, 2..3]) * 10)
|
||||
+ (result.slice([0..batch_size, 3..4]) * 1);
|
||||
let result = result.squeeze::<1>(1);
|
||||
|
||||
let result = result.into_data().convert::<u32>();
|
||||
let result = result
|
||||
.to_vec::<u32>()
|
||||
.expect("Model output tensor could not be converted to u32 vector by the backend");
|
||||
|
||||
let result = result
|
||||
.into_iter()
|
||||
.map(|code| code as u16)
|
||||
.map(CaptchaCode::new_unchecked)
|
||||
.collect();
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Inferences only one image.
|
||||
///
|
||||
/// ## Panics
|
||||
/// Panic when there are no input image.
|
||||
pub fn inference_one(self) -> CaptchaCode {
|
||||
*self
|
||||
.inference()
|
||||
.first()
|
||||
.expect("inference_one must be called with at least one image loaded")
|
||||
}
|
||||
}
|
|
@ -15,6 +15,7 @@ use burn::{
|
|||
use burn::{nn::loss::CrossEntropyLossConfig, prelude::Int};
|
||||
|
||||
pub mod data;
|
||||
pub mod inference;
|
||||
#[cfg(feature = "train")]
|
||||
pub mod train;
|
||||
|
||||
|
@ -53,11 +54,11 @@ pub struct MirachModelConfig {
|
|||
num_heads: usize,
|
||||
#[config(default = "64")]
|
||||
hidden_size1: usize,
|
||||
#[config(default = "32")]
|
||||
#[config(default = "64")]
|
||||
hidden_size2: usize,
|
||||
#[config(default = "0.05")]
|
||||
#[config(default = "0.015")]
|
||||
att_dropout: f64,
|
||||
#[config(default = "0.03")]
|
||||
#[config(default = "0.015")]
|
||||
gf_dropout: f64,
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue