Add Rust MNIST code
This commit is contained in:
parent
6be866b0b8
commit
3c47bb63da
459
mnist/Cargo.lock
generated
Normal file
459
mnist/Cargo.lock
generated
Normal file
|
@ -0,0 +1,459 @@
|
|||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "84450d0b4a8bd1ba4144ce8ce718fbc5d071358b1e5384bace6536b3d1f2d5b3"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
|
||||
|
||||
[[package]]
|
||||
name = "byteorder"
|
||||
version = "1.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610"
|
||||
|
||||
[[package]]
|
||||
name = "bzip2"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6afcd980b5f3a45017c57e57a2fcccbb351cc43a356ce117ef760ef8052b89b0"
|
||||
dependencies = [
|
||||
"bzip2-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bzip2-sys"
|
||||
version = "0.1.11+1.0.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "22a9137b95ea06864e018375b72adfb7db6e6f68cfc8df5a04d00288050485ee"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "738c290dfaea84fc1ca15ad9c168d083b05a714e1efddd8edaab678dc28d2836"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curl"
|
||||
version = "0.4.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7de97b894edd5b5bcceef8b78d7da9b75b1d2f2f9a910569d0bde3dd31d84939"
|
||||
dependencies = [
|
||||
"curl-sys",
|
||||
"libc",
|
||||
"openssl-probe",
|
||||
"openssl-sys",
|
||||
"schannel",
|
||||
"socket2",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "curl-sys"
|
||||
version = "0.4.52+curl-7.81.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "14b8c2d1023ea5fded5b7b892e4b8e95f70038a421126a056761a84246a28971"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"libz-sys",
|
||||
"openssl-sys",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e6988e897c1c9c485f43b47a529cef42fde0547f9d8d41a7062518f1d8fc53f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crc32fast",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.112"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b03d17f364a3a042d5e5d46b053bbbf82c92c9430c592dd4c064dc6ee997125"
|
||||
|
||||
[[package]]
|
||||
name = "libz-sys"
|
||||
version = "1.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de5435b8549c16d423ed0c03dbaafe57cf6c3344744f1242520d59c9d8ecec66"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matrixmultiply"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "add85d4dd35074e6fedc608f8c8f513a3548619a9024b751949ef0e8e45a4d84"
|
||||
dependencies = [
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a92518e98c078586bc6c934028adcca4c92a53d6a958196de835170a01d84e4b"
|
||||
dependencies = [
|
||||
"adler",
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mnist"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"tch",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ndarray"
|
||||
version = "0.15.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dec23e6762830658d2b3d385a75aa212af2f67a4586d4442907144f3bb6a1ca8"
|
||||
dependencies = [
|
||||
"matrixmultiply",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"rawpointer",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26873667bbbb7c5182d4a37c1add32cdf09f841af72da53318fdb81543c15085"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.44"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2cc698a63b549a70bc047073d2949cce27cd1c7b0a4a862d08a8031bc2801db"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "openssl-probe"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28988d872ab76095a6e6ac88d99b54fd267702734fd7ffe610ca27f533ddb95a"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-sys"
|
||||
version = "0.9.72"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e46109c383602735fa0a2e48dd2b7c892b048e1bf69e5c3b1d804b7d9c203cb"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cc",
|
||||
"libc",
|
||||
"pkg-config",
|
||||
"vcpkg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "58893f751c9b0412871a09abd62ecd2a00298c6c83befa223ef98c52aef40cbe"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c7342d5883fbccae1cc37a2353b09c87c9b0f3afd73f5fb9bba687a1f733b029"
|
||||
dependencies = [
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "47aa80447ce4daf1717500037052af176af5d38cc3e571d9ec1c7353fc10c87d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e7573632e6454cf6b99d7aac4ccca54be06da05aca2ef7423d22d27d4d4bcd8"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rand_hc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d34f1408f55294453790c48b2f1ebbb1c5b4b7563eb1f418bcfcfdbb06ebb4e7"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_hc"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d51e9f596de227fda2ea6c84607f5558e196eeaf43c986b724ba4fb8fdf497e7"
|
||||
dependencies = [
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rawpointer"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "schannel"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f05ba609c234e60bee0d547fe94a4c7e9da733d1c962cf6e59efa4cd9c8bc75"
|
||||
dependencies = [
|
||||
"lazy_static",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5dc90fe6c7be1a323296982db1836d1ea9e47b6839496dde9a541bc496df3516"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.85"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a684ac3dcd8913827e18cd09a68384ee66c1de24157e3c556c9ab16d85695fb7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-xid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tch"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2b73f876b186599e22b01fa6ebfeea2dee2f11e8083463ab3572933d8201436b"
|
||||
dependencies = [
|
||||
"half",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"ndarray",
|
||||
"rand",
|
||||
"thiserror",
|
||||
"torch-sys",
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "854babe52e4df1653706b98fcfc05843010039b406875930a70e4d9644e5c417"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa32fd3f627f367fe16f893e2597ae3c05020f8bba2666a4e6ea73d377e5714b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ca8a50ef2360fbd1eeb0ecd46795a87a19024eb4b53c5dc916ca1fd95fe62438"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "torch-sys"
|
||||
version = "0.6.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34cc0f21b1aad5d71d529e9fe4dbbbdbf53918d7b4bde946f523839aa32cffae"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cc",
|
||||
"curl",
|
||||
"libc",
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ccb82d61f80a663efe1f787a51b16b5a51e3314d6ac365b08639f52387b33f3"
|
||||
|
||||
[[package]]
|
||||
name = "vcpkg"
|
||||
version = "0.2.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.10.2+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fd6fbd9a79829dd1ad0cc20627bf1ed606756a7f77edff7b66b7064f9cb327c6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.5.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93ab48844d61251bb3835145c521d88aa4031d7139e8485990f60ca911fa0815"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bzip2",
|
||||
"crc32fast",
|
||||
"flate2",
|
||||
"thiserror",
|
||||
"time",
|
||||
]
|
10
mnist/Cargo.toml
Normal file
10
mnist/Cargo.toml
Normal file
|
@ -0,0 +1,10 @@
|
|||
[package]
|
||||
name = "mnist"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0"
|
||||
tch = "0.6.1"
|
25
mnist/src/main.rs
Normal file
25
mnist/src/main.rs
Normal file
|
@ -0,0 +1,25 @@
|
|||
/* Some very simple models trained on the MNIST dataset.
|
||||
The 4 following dataset files can be downloaded from http://yann.lecun.com/exdb/mnist/
|
||||
These files should be extracted in the 'data' directory.
|
||||
train-images-idx3-ubyte.gz
|
||||
train-labels-idx1-ubyte.gz
|
||||
t10k-images-idx3-ubyte.gz
|
||||
t10k-labels-idx1-ubyte.gz
|
||||
*/
|
||||
|
||||
use anyhow::Result;
|
||||
|
||||
mod mnist_conv;
|
||||
mod mnist_linear;
|
||||
mod mnist_nn;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let model = if args.len() < 2 { None } else { Some(args[1].as_str()) };
|
||||
match model {
|
||||
None => mnist_nn::run(),
|
||||
Some("linear") => mnist_linear::run(),
|
||||
Some("conv") => mnist_conv::run(),
|
||||
Some(_) => mnist_nn::run(),
|
||||
}
|
||||
}
|
54
mnist/src/mnist_conv.rs
Normal file
54
mnist/src/mnist_conv.rs
Normal file
|
@ -0,0 +1,54 @@
|
|||
// CNN model. This should rearch 99.1% accuracy.
|
||||
|
||||
use anyhow::Result;
|
||||
use tch::{nn, nn::ModuleT, nn::OptimizerConfig, Device, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Net {
|
||||
conv1: nn::Conv2D,
|
||||
conv2: nn::Conv2D,
|
||||
fc1: nn::Linear,
|
||||
fc2: nn::Linear,
|
||||
}
|
||||
|
||||
impl Net {
|
||||
fn new(vs: &nn::Path) -> Net {
|
||||
let conv1 = nn::conv2d(vs, 1, 32, 5, Default::default());
|
||||
let conv2 = nn::conv2d(vs, 32, 64, 5, Default::default());
|
||||
let fc1 = nn::linear(vs, 1024, 1024, Default::default());
|
||||
let fc2 = nn::linear(vs, 1024, 10, Default::default());
|
||||
Net { conv1, conv2, fc1, fc2 }
|
||||
}
|
||||
}
|
||||
|
||||
impl nn::ModuleT for Net {
|
||||
fn forward_t(&self, xs: &Tensor, train: bool) -> Tensor {
|
||||
xs.view([-1, 1, 28, 28])
|
||||
.apply(&self.conv1)
|
||||
.max_pool2d_default(2)
|
||||
.apply(&self.conv2)
|
||||
.max_pool2d_default(2)
|
||||
.view([-1, 1024])
|
||||
.apply(&self.fc1)
|
||||
.relu()
|
||||
.dropout_(0.5, train)
|
||||
.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run() -> Result<()> {
|
||||
let m = tch::vision::mnist::load_dir("data")?;
|
||||
let vs = nn::VarStore::new(Device::cuda_if_available());
|
||||
let net = Net::new(&vs.root());
|
||||
let mut opt = nn::Adam::default().build(&vs, 1e-4)?;
|
||||
for epoch in 1..100 {
|
||||
for (bimages, blabels) in m.train_iter(256).shuffle().to_device(vs.device()) {
|
||||
let loss = net.forward_t(&bimages, true).cross_entropy_for_logits(&blabels);
|
||||
opt.backward_step(&loss);
|
||||
}
|
||||
let test_accuracy =
|
||||
net.batch_accuracy_for_logits(&m.test_images, &m.test_labels, vs.device(), 1024);
|
||||
println!("epoch: {:4} test acc: {:5.2}%", epoch, 100. * test_accuracy,);
|
||||
}
|
||||
Ok(())
|
||||
}
|
42
mnist/src/mnist_linear.rs
Normal file
42
mnist/src/mnist_linear.rs
Normal file
|
@ -0,0 +1,42 @@
|
|||
// This should rearch 91.5% accuracy.
|
||||
|
||||
use anyhow::Result;
|
||||
use tch::{kind, no_grad, vision, Kind, Tensor};
|
||||
|
||||
const IMAGE_DIM: i64 = 784;
|
||||
const LABELS: i64 = 10;
|
||||
|
||||
pub fn run() -> Result<()> {
|
||||
let m = vision::mnist::load_dir("data")?;
|
||||
println!("train-images: {:?}", m.train_images.size());
|
||||
println!("train-labels: {:?}", m.train_labels.size());
|
||||
println!("test-images: {:?}", m.test_images.size());
|
||||
println!("test-labels: {:?}", m.test_labels.size());
|
||||
let mut ws = Tensor::zeros(&[IMAGE_DIM, LABELS], kind::FLOAT_CPU).set_requires_grad(true);
|
||||
let mut bs = Tensor::zeros(&[LABELS], kind::FLOAT_CPU).set_requires_grad(true);
|
||||
for epoch in 1..200 {
|
||||
let logits = m.train_images.mm(&ws) + &bs;
|
||||
let loss = logits.log_softmax(-1, Kind::Float).nll_loss(&m.train_labels);
|
||||
ws.zero_grad();
|
||||
bs.zero_grad();
|
||||
loss.backward();
|
||||
no_grad(|| {
|
||||
ws += ws.grad() * (-1);
|
||||
bs += bs.grad() * (-1);
|
||||
});
|
||||
let test_logits = m.test_images.mm(&ws) + &bs;
|
||||
let test_accuracy = test_logits
|
||||
.argmax(Some(-1), false)
|
||||
.eq_tensor(&m.test_labels)
|
||||
.to_kind(Kind::Float)
|
||||
.mean(Kind::Float)
|
||||
.double_value(&[]);
|
||||
println!(
|
||||
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
epoch,
|
||||
loss.double_value(&[]),
|
||||
100. * test_accuracy
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
34
mnist/src/mnist_nn.rs
Normal file
34
mnist/src/mnist_nn.rs
Normal file
|
@ -0,0 +1,34 @@
|
|||
// This should rearch 97% accuracy.
|
||||
|
||||
use anyhow::Result;
|
||||
use tch::{nn, nn::Module, nn::OptimizerConfig, Device};
|
||||
|
||||
const IMAGE_DIM: i64 = 784;
|
||||
const HIDDEN_NODES: i64 = 128;
|
||||
const LABELS: i64 = 10;
|
||||
|
||||
fn net(vs: &nn::Path) -> impl Module {
|
||||
nn::seq()
|
||||
.add(nn::linear(vs / "layer1", IMAGE_DIM, HIDDEN_NODES, Default::default()))
|
||||
.add_fn(|xs| xs.relu())
|
||||
.add(nn::linear(vs, HIDDEN_NODES, LABELS, Default::default()))
|
||||
}
|
||||
|
||||
pub fn run() -> Result<()> {
|
||||
let m = tch::vision::mnist::load_dir("data")?;
|
||||
let vs = nn::VarStore::new(Device::Cpu);
|
||||
let net = net(&vs.root());
|
||||
let mut opt = nn::Adam::default().build(&vs, 1e-3)?;
|
||||
for epoch in 1..1000 {
|
||||
let loss = net.forward(&m.train_images).cross_entropy_for_logits(&m.train_labels);
|
||||
opt.backward_step(&loss);
|
||||
let test_accuracy = net.forward(&m.test_images).accuracy_for_logits(&m.test_labels);
|
||||
println!(
|
||||
"epoch: {:4} train loss: {:8.5} test acc: {:5.2}%",
|
||||
epoch,
|
||||
f64::from(&loss),
|
||||
100. * f64::from(&test_accuracy),
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
Loading…
Reference in a new issue