Inference
This commit is contained in:
parent
11d2882c3f
commit
1316ff5de4
7 changed files with 133 additions and 20 deletions
2
Cargo.lock
generated
2
Cargo.lock
generated
|
@ -3602,7 +3602,9 @@ dependencies = [
|
|||
"burn",
|
||||
"clap",
|
||||
"env_logger",
|
||||
"image",
|
||||
"log",
|
||||
"rand",
|
||||
"scra-mirach-dataset",
|
||||
"scra-mirach-model",
|
||||
"serde",
|
||||
|
|
|
@ -18,6 +18,9 @@ log = { workspace = true }
|
|||
env_logger = "0.11.5"
|
||||
serde = { workspace = true }
|
||||
serde_json = "1.0.127"
|
||||
rand = "0.8.5"
|
||||
image = { workspace = true, features = ["png", "jpeg"] }
|
||||
|
||||
[features]
|
||||
default = ["libtorch"]
|
||||
libtorch = ["burn/tch"]
|
||||
|
|
85
cli/src/inference.rs
Normal file
85
cli/src/inference.rs
Normal file
|
@ -0,0 +1,85 @@
|
|||
use std::{path::PathBuf, time::Instant};
|
||||
|
||||
use anyhow::Result;
|
||||
#[cfg(feature = "libtorch")]
|
||||
use burn::backend::{libtorch::LibTorchDevice, LibTorch};
|
||||
use burn::{
|
||||
backend::{wgpu::WgpuDevice, Autodiff, Wgpu},
|
||||
module::Module,
|
||||
prelude::Backend,
|
||||
record::DefaultRecorder,
|
||||
};
|
||||
use log::info;
|
||||
use scra_mirach_dataset::CaptchaCode;
|
||||
use scra_mirach_model::data::MirachImageData;
|
||||
|
||||
use crate::{Args, BackendOpt};
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about)]
|
||||
pub struct InferenceArgs {
|
||||
#[arg(from_global)]
|
||||
pub backend: BackendOpt,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(from_global)]
|
||||
pub cuda_index: Option<usize>,
|
||||
#[arg(from_global)]
|
||||
pub model_config: Option<PathBuf>,
|
||||
|
||||
#[arg(short = 'p', long, help = "Parameters file")]
|
||||
pub params: PathBuf,
|
||||
#[arg(short = 'i', long, help = "Image")]
|
||||
pub image: PathBuf,
|
||||
}
|
||||
|
||||
pub fn run(args: InferenceArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
BackendOpt::Wgpu => inference::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCpu => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchMps => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchVulkan => inference::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
|
||||
#[cfg(feature = "libtorch")]
|
||||
BackendOpt::LibTorchCuda => {
|
||||
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
|
||||
inference::<Autodiff<LibTorch>>(args, device)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inference<B: Backend>(args: InferenceArgs, device: B::Device) -> Result<()>
|
||||
where
|
||||
B::IntElem:,
|
||||
{
|
||||
B::seed(rand::random());
|
||||
|
||||
let model_config = Args::model_config(&args.model_config)?;
|
||||
let model = model_config.init::<B>(&device);
|
||||
let model = model.load_file(&args.params, &DefaultRecorder::new(), &device)?;
|
||||
info!("Model loaded, {} parameters", model.num_params());
|
||||
|
||||
let image = MirachImageData::from(image::open(args.image)?.into_rgb8()).into_tensor(&device);
|
||||
info!("Image loaded");
|
||||
|
||||
let start_time = Instant::now();
|
||||
let result = model.forward(image);
|
||||
let time = start_time.elapsed();
|
||||
info!("Inference completed, took {:?}", time);
|
||||
|
||||
let result = result.squeeze::<2>(0);
|
||||
let result = result.argmax(1);
|
||||
let result = result.squeeze::<1>(1);
|
||||
let result = result.into_data().convert::<u8>();
|
||||
let result = result
|
||||
.to_vec::<u8>()
|
||||
.expect("Model output tensor could not be converted to u32 vector");
|
||||
let result = result
|
||||
.try_into()
|
||||
.expect("Model output dimension must be [1, 4]");
|
||||
let result = CaptchaCode::from_digits(result);
|
||||
info!("Code: {}", result.as_u16());
|
||||
|
||||
Ok(())
|
||||
}
|
|
@ -2,11 +2,13 @@ use std::{env, fs, path::PathBuf};
|
|||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use inference::InferenceArgs;
|
||||
use scra_mirach_model::MirachModelConfig;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use train::TrainArgs;
|
||||
|
||||
pub mod train;
|
||||
pub mod inference;
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about)]
|
||||
|
@ -14,8 +16,8 @@ struct Args {
|
|||
#[command(subcommand)]
|
||||
action: Action,
|
||||
|
||||
#[arg(short='b', long, value_enum,global=true, default_value_t = Backend::Wgpu, help="Tensor backend")]
|
||||
pub backend: Backend,
|
||||
#[arg(short='b', long, value_enum,global=true, default_value_t = BackendOpt::Wgpu, help="Tensor backend")]
|
||||
pub backend: BackendOpt,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(
|
||||
long,
|
||||
|
@ -32,10 +34,11 @@ struct Args {
|
|||
#[derive(Debug, Clone, clap::Subcommand)]
|
||||
enum Action {
|
||||
Train(TrainArgs),
|
||||
Inference(InferenceArgs),
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum Backend {
|
||||
pub enum BackendOpt {
|
||||
Wgpu,
|
||||
#[cfg(feature = "libtorch")]
|
||||
LibTorchCpu,
|
||||
|
@ -69,5 +72,6 @@ fn main() -> Result<()> {
|
|||
|
||||
match args.action {
|
||||
Action::Train(args) => train::run(args),
|
||||
Action::Inference(args) => inference::run(args),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,13 +19,13 @@ use burn::{
|
|||
use log::info;
|
||||
use scra_mirach_dataset::batcher::MirachTrainBatcher;
|
||||
|
||||
use crate::{Args, Backend};
|
||||
use crate::{Args, BackendOpt};
|
||||
|
||||
#[derive(clap::Parser, Debug, Clone)]
|
||||
#[command(version, about)]
|
||||
pub struct TrainArgs {
|
||||
#[arg(from_global)]
|
||||
pub backend: Backend,
|
||||
pub backend: BackendOpt,
|
||||
#[cfg(feature = "libtorch")]
|
||||
#[arg(from_global)]
|
||||
pub cuda_index: Option<usize>,
|
||||
|
@ -55,15 +55,15 @@ pub struct TrainArgs {
|
|||
|
||||
pub fn run(args: TrainArgs) -> Result<()> {
|
||||
match args.backend {
|
||||
Backend::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
BackendOpt::Wgpu => train::<Autodiff<Wgpu>>(args, WgpuDevice::BestAvailable),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
BackendOpt::LibTorchCpu => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Cpu),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
|
||||
BackendOpt::LibTorchMps => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Mps),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
|
||||
BackendOpt::LibTorchVulkan => train::<Autodiff<LibTorch>>(args, LibTorchDevice::Vulkan),
|
||||
#[cfg(feature = "libtorch")]
|
||||
Backend::LibTorchCuda => {
|
||||
BackendOpt::LibTorchCuda => {
|
||||
let device = LibTorchDevice::Cuda(args.cuda_index.unwrap());
|
||||
train::<Autodiff<LibTorch>>(args, device)
|
||||
}
|
||||
|
|
|
@ -3,12 +3,11 @@ use std::fmt::Debug;
|
|||
use burn::{
|
||||
data::dataloader::batcher::Batcher,
|
||||
prelude::Backend,
|
||||
tensor::{Float, Int, Shape, Tensor, TensorData},
|
||||
tensor::{Int, Tensor, TensorData},
|
||||
};
|
||||
use scra_mirach_model::{
|
||||
data::{CaptchaCode, MirachImageData},
|
||||
train::MirachTrainBatch,
|
||||
IMAGE_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH,
|
||||
};
|
||||
|
||||
/// Mirach dataset item
|
||||
|
@ -44,14 +43,7 @@ impl<B: Backend> Batcher<MirachImageItem, MirachTrainBatch<B>> for MirachTrainBa
|
|||
let mut images = Vec::with_capacity(items.len());
|
||||
let mut targets = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
let image = TensorData::new(
|
||||
item.image.unwrap(),
|
||||
Shape::new([1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]),
|
||||
);
|
||||
let image = Tensor::<B, 4, Float>::from(image.convert::<f32>()).to_device(&self.device);
|
||||
let image = image / 255;
|
||||
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
|
||||
images.push(image);
|
||||
images.push(item.image.into_tensor(&self.device));
|
||||
|
||||
let target = Tensor::<B, 1, Int>::from(TensorData::from(item.code.as_digits()))
|
||||
.to_device(&self.device);
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
use std::mem::MaybeUninit;
|
||||
|
||||
use burn::{
|
||||
prelude::Backend,
|
||||
tensor::{Device, Tensor, TensorData},
|
||||
};
|
||||
#[cfg(feature = "image")]
|
||||
use image::RgbImage;
|
||||
|
||||
|
@ -35,6 +39,16 @@ impl CaptchaCode {
|
|||
(self.0 % 10) as u8,
|
||||
]
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn from_digits(digits: [u8; 4]) -> Self {
|
||||
Self::new_unchecked(
|
||||
(digits[0] as u16 * 1000)
|
||||
+ (digits[1] as u16 * 100)
|
||||
+ (digits[2] as u16 * 10)
|
||||
+ digits[3] as u16,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
|
@ -72,3 +86,16 @@ impl From<RgbImage> for MirachImageData {
|
|||
Self(Vec::from(data as Box<[u8]>))
|
||||
}
|
||||
}
|
||||
|
||||
impl MirachImageData {
|
||||
pub fn into_tensor<B: Backend>(self, device: &Device<B>) -> Tensor<B, 4> {
|
||||
let image = TensorData::new::<u8, _>(
|
||||
self.unwrap(),
|
||||
[1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT],
|
||||
);
|
||||
let image = Tensor::<B, 4>::from(image.convert::<B::FloatElem>()).to_device(device);
|
||||
let image = image / 255;
|
||||
debug_assert_eq!(image.dims(), [1, IMAGE_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT]);
|
||||
image
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue