55 lines
1.7 KiB
Rust
55 lines
1.7 KiB
Rust
// CNN model. This should rearch 99.1% accuracy.
|
|
|
|
use anyhow::Result;
|
|
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
|
|
|
|
#[derive(Debug)]
|
|
struct Net {
|
|
conv1: nn::Conv2D,
|
|
conv2: nn::Conv2D,
|
|
fc1: nn::Linear,
|
|
fc2: nn::Linear,
|
|
}
|
|
|
|
impl Net {
|
|
fn new(vs: &nn::Path) -> Net {
|
|
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
|
|
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
|
|
let fc1 = nn::linear(vs, 1024, 1024, Default::default());
|
|
let fc2 = nn::linear(vs, 1024, 10, Default::default());
|
|
Net { conv1, conv2, fc1, fc2 }
|
|
}
|
|
}
|
|
|
|
impl nn::ModuleT for Net {
|
|
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
|
|
xs.view([-1, 1, 28, 28])
|
|
.apply(&self.conv1)
|
|
.max_pool2d_default(2)
|
|
.apply(&self.conv2)
|
|
.max_pool2d_default(2)
|
|
.view([-1, 1024])
|
|
.apply(&self.fc1)
|
|
.relu()
|
|
.dropout_(0.5, train)
|
|
.apply(&self.fc2)
|
|
}
|
|
}
|
|
|
|
pub fn run() -> Result<()> {
|
|
let m = tch::vision::mnist::load_dir("data")?;
|
|
let vs = nn::VarStore::new(Device::cuda_if_available());
|
|
let net = Net::new(&vs.root());
|
|
let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
|
|
for epoch in 1..100 {
|
|
for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) {
|
|
let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels);
|
|
opt.backward_step(&loss);
|
|
}
|
|
let test_accuracy =
|
|
net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024);
|
|
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
|
|
}
|
|
Ok(())
|
|
}
|