PyTorch/mnist/src/mnist_linear.rs
2022-01-11 12:24:42 -06:00

43 lines
1.4 KiB
Rust

// This should rearch 91.5% accuracy.
use anyhow::Result;
use tch::{kind, no_grad, vision, Kind, Tensor};
const IMAGE_DIM: i64 = 784;
const LABELS: i64 = 10;
pub fn run() -> Result<()> {
let m = vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.size());
println!("train-labels: {:?}", m.train_labels.size());
println!("test-images: {:?}", m.test_images.size());
println!("test-labels: {:?}", m.test_labels.size());
let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
for epoch in 1..200 {
let logits = m.train_images.mm(&ws) + &bs;
let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
ws.zero_grad();
bs.zero_grad();
loss.backward();
no_grad(|| {
ws += ws.grad() * (-1);
bs += bs.grad() * (-1);
});
let test_logits = m.test_images.mm(&ws) + &bs;
let test_accuracy = test_logits
.argmax(Some(-1), false)
.eq_tensor(&m.test_labels)
.to_kind(Kind::Float)
.mean(Kind::Float)
.double_value(&[]);
println!(
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
epoch,
loss.double_value(&[]),
100. * test_accuracy
);
}
Ok(())
}