Inference

This commit is contained in:
xtex 2024-08-27 21:29:55 +08:00
parent 11d2882c3f
commit 1316ff5de4
Signed by: xtex
GPG key ID: B918086ED8045B91
7 changed files with 133 additions and 20 deletions

2
Cargo.lock generated
View file

@ -3602,7 +3602,9 @@ dependencies = [
"burn",
"clap",
"env_logger",
"image",
"log",
"rand",
"scra-mirach-dataset",
"scra-mirach-model",
"serde",

View file

@ -18,6 +18,9 @@ log = { workspace = true }
env_logger = "0.11.5"
serde = { workspace = true }
serde_json = "1.0.127"
rand = "0.8.5"
image = { workspace = true, features = ["png", "jpeg"] }
[features]
default = ["libtorch"]
libtorch = ["burn/tch"]

85
cli/src/inference.rs Normal file
View file

@ -0,0 +1,85 @@
use std::{path::PathBuf, time::Instant};
use anyhow::Result;
#[cfg(feature = "libtorch")]
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
use burn::{
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
module::Module,
prelude::Backend,
record::DefaultRecorder,
};
use log::info;
use scra_mirach_dataset::CaptchaCode;
use scra_mirach_model::data::MirachImageData;
use crate::{Args, BackendOpt};
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about)]
pub struct InferenceArgs {
#[arg(from_global)]
pub backend: BackendOpt,
#[cfg(feature = "libtorch")]
#[arg(from_global)]
pub cuda_index: Option<usize>,
#[arg(from_global)]
pub model_config: Option<PathBuf>,
#[arg(short = 'p', long, help = "Parameters file")]
pub params: PathBuf,
#[arg(short = 'i', long, help = "Image")]
pub image: PathBuf,
}
pub fn run(args: InferenceArgs) -> Result<()> {
match args.backend {
BackendOpt::Wgpu => inference::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCpu => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchMps => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchVulkan => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
#[cfg(feature = "libtorch")]
BackendOpt::LibTorchCuda => {
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
inference::<Autodiff<LibTorch>>(args, device)
}
}
}
pub fn inference<B: Backend>(args: InferenceArgs, device: B::Device) -> Result<()>
where
B::IntElem:,
{
B::seed(rand::random());
let model_config = Args::model_config(&args.model_config)?;
let model = model_config.init::<B>(&device);
let model = model.load_file(&args.params, &DefaultRecorder::new(), &device)?;
info!("Model loaded, {} parameters", model.num_params());
let image = MirachImageData::from(image::open(args.image)?.into_rgb8()).into_tensor(&device);
info!("Image loaded");
let start_time = Instant::now();
let result = model.forward(image);
let time = start_time.elapsed();
info!("Inference completed, took {:?}", time);
let result = result.squeeze::<2>(0);
let result = result.argmax(1);
let result = result.squeeze::<1>(1);
let result = result.into_data().convert::<u8>();
let result = result
.to_vec::<u8>()
.expect("Model output tensor could not be converted to u32 vector");
let result = result
.try_into()
.expect("Model output dimension must be [1, 4]");
let result = CaptchaCode::from_digits(result);
info!("Code: {}", result.as_u16());
Ok(())
}

View file

@ -2,11 +2,13 @@ use std::{env, fs, path::PathBuf};
use anyhow::Result;
use clap::Parser;
use inference::InferenceArgs;
use scra_mirach_model::MirachModelConfig;
use serde::{Deserialize, Serialize};
use train::TrainArgs;
pub mod train;
pub mod inference;
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about)]
@ -14,8 +16,8 @@ struct Args {
#[command(subcommand)]
action: Action,
#[arg(short='b', long, value_enum,global=true, default_value_t = Backend::Wgpu, help="Tensor backend")]
pub backend: Backend,
#[arg(short='b', long, value_enum,global=true, default_value_t = BackendOpt::Wgpu, help="Tensor backend")]
pub backend: BackendOpt,
#[cfg(feature = "libtorch")]
#[arg(
long,
@ -32,10 +34,11 @@ struct Args {
#[derive(Debug, Clone, clap::Subcommand)]
enum Action {
Train(TrainArgs),
Inference(InferenceArgs),
}
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Backend {
pub enum BackendOpt {
Wgpu,
#[cfg(feature = "libtorch")]
LibTorchCpu,
@ -69,5 +72,6 @@ fn main() -> Result<()> {
match args.action {
Action::Train(args) => train::run(args),
Action::Inference(args) => inference::run(args),
}
}

View file

@ -19,13 +19,13 @@ use burn::{
use log::info;
use scra_mirach_dataset::batcher::MirachTrainBatcher;
use crate::{Args, Backend};
use crate::{Args, BackendOpt};
#[derive(clap::Parser, Debug, Clone)]
#[command(version, about)]
pub struct TrainArgs {
#[arg(from_global)]
pub backend: Backend,
pub backend: BackendOpt,
#[cfg(feature = "libtorch")]
#[arg(from_global)]
pub cuda_index: Option<usize>,
@ -55,15 +55,15 @@ pub struct TrainArgs {
pub fn run(args: TrainArgs) -> Result<()> {
match args.backend {
Backend::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
BackendOpt::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
#[cfg(feature = "libtorch")]
Backend::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
BackendOpt::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
#[cfg(feature = "libtorch")]
Backend::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
BackendOpt::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
#[cfg(feature = "libtorch")]
Backend::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
BackendOpt::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
#[cfg(feature = "libtorch")]
Backend::LibTorchCuda => {
BackendOpt::LibTorchCuda => {
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
train::<Autodiff<LibTorch>>(args, device)
}

View file

@ -3,12 +3,11 @@ use std::fmt::Debug;
use burn::{
data::dataloader::batcher::Batcher,
prelude::Backend,
tensor::{Float, Int, Shape, Tensor, TensorData},
tensor::{Int, Tensor, TensorData},
};
use scra_mirach_model::{
data::{CaptchaCode, MirachImageData},
train::MirachTrainBatch,
IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH,
};
/// Mirach dataset item
@ -44,14 +43,7 @@ impl<B: Backend> Batcher<MirachImageItem, MirachTrainBatch<B>> for MirachTrainBa
let mut images = Vec::with_capacity(items.len());
let mut targets = Vec::with_capacity(items.len());
for item in items {
let image = TensorData::new(
item.image.unwrap(),
Shape::new([1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]),
);
let image = Tensor::<B, 4, Float>::from(image.convert::<f32>()).to_device(&self.device);
let image = image / 255;
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
images.push(image);
images.push(item.image.into_tensor(&self.device));
let target = Tensor::<B, 1, Int>::from(TensorData::from(item.code.as_digits()))
.to_device(&self.device);

View file

@ -1,5 +1,9 @@
use std::mem::MaybeUninit;
use burn::{
prelude::Backend,
tensor::{Device, Tensor, TensorData},
};
#[cfg(feature = "image")]
use image::RgbImage;
@ -35,6 +39,16 @@ impl CaptchaCode {
(self.0 % 10) as u8,
]
}
#[inline]
pub fn from_digits(digits: [u8; 4]) -> Self {
Self::new_unchecked(
(digits[0] as u16 * 1000)
+ (digits[1] as u16 * 100)
+ (digits[2] as u16 * 10)
+ digits[3] as u16,
)
}
}
#[derive(Clone, PartialEq, Eq)]
@ -72,3 +86,16 @@ impl From<RgbImage> for MirachImageData {
Self(Vec::from(data as Box<[u8]>))
}
}
impl MirachImageData {
pub fn into_tensor<B: Backend>(self, device: &Device<B>) -> Tensor<B, 4> {
let image = TensorData::new::<u8, _>(
self.unwrap(),
[1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT],
);
let image = Tensor::<B, 4>::from(image.convert::<B::FloatElem>()).to_device(device);
let image = image / 255;
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
image
}
}