T4JMISYCRZ3BODA2PEOXKJ2RSDIYK37FH5B6DHGRKVT47ABPIJUQC
use burn::{
config::Config,
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
module::Module,
nn::loss::CrossEntropyLossConfig,
optim::AdamConfig,
record::CompactRecorder,
tensor::{
backend::{AutodiffBackend, Backend},
Int, Tensor,
},
train::{
metric::{AccuracyMetric, LossMetric},
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
},
};
use crate::{
data::{MNISTBatch, MNISTBatcher},
model::{Model, ModelConfig},
};
impl<B: Backend> Model<B> {
pub fn forward_classification(
&self,
images: Tensor<B, 3>,
targets: Tensor<B, 1, Int>,
) -> ClassificationOutput<B> {
let output = self.forward(images);
let loss = CrossEntropyLossConfig::new()
.init(&output.device())
.forward(output.clone(), targets.clone());
ClassificationOutput::new(loss, output, targets)
}
}
impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {
let item = self.forward_classification(batch.images, batch.targets);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {
fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {
self.forward_classification(batch.images, batch.targets)
}
}
#[derive(Config)]
pub struct TrainingConfig {
pub model: ModelConfig,
pub optimizer: AdamConfig,
#[config(default = 10)]
pub num_epochs: usize,
#[config(default = 64)]
pub batch_size: usize,
#[config(default = 4)]
pub num_workers: usize,
#[config(default = 42)]
pub seed: u64,
#[config(default = 1.0e-4)]
pub learning_rate: f64,
}
pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {
std::fs::create_dir_all(artifact_dir).ok();
config
.save(format!("{artifact_dir}/config.json"))
.expect("config not saved correctly");
B::seed(config.seed);
let batcher_train = MNISTBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MNISTDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(MNISTDataset::test());
let model = config.model.init::<B>(&device);
let learner = LearnerBuilder::new(artifact_dir)
.metric_train_numeric(AccuracyMetric::new())
.metric_valid_numeric(AccuracyMetric::new())
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.devices(vec![device])
.num_epochs(config.num_epochs)
.build(model, config.optimizer.init(), config.learning_rate);
let model_trained = learner.fit(dataloader_train, dataloader_test);
model_trained
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
.expect("Trained model should be saved successfully");
}
use burn::{
config::Config,
module::Module,
nn::{
conv::{Conv2d, Conv2dConfig},
pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},
Dropout, DropoutConfig, Linear, LinearConfig, ReLU,
},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Conv2d<B>,
conv2: Conv2d<B>,
pool: AdaptiveAvgPool2d,
dropout: Dropout,
linear1: Linear<B>,
linear2: Linear<B>,
activation: ReLU,
}
impl<B: Backend> Model<B> {
/// # Shapes
/// - Images [batch_size, height, width]
/// - Output [batch_size, num_classes]
pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {
let [batch_size, height, width] = images.dims();
// Create channel at second dim
let x = images.reshape([batch_size, 1, height, width]);
let x = self.conv1.forward(x); // [batch_size, 8, _, _]
let x = self.dropout.forward(x);
let x = self.conv2.forward(x); // [batch_size, 16, _, _]
let x = self.dropout.forward(x);
let x = self.activation.forward(x);
let x = self.pool.forward(x); // [batch_size, 16, 8, 8]
let x = x.reshape([batch_size, 16 * 8 * 8]);
let x = self.linear1.forward(x);
let x = self.dropout.forward(x);
let x = self.activation.forward(x);
self.linear2.forward(x)
}
}
#[derive(Config, Debug)]
pub struct ModelConfig {
num_classes: usize,
hidden_size: usize,
#[config(default = "0.5")]
dropout: f64,
}
impl ModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {
Model {
conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),
conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),
pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),
activation: ReLU::new(),
linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),
linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}
println!("Hello, world!");
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
type AutodiffBackend = Autodiff<Backend>;
let device = WgpuDevice::default();
training::train::<AutodiffBackend>(
"/tmp/guide",
training::TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),
device,
);
use burn::{
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
};
pub struct MNISTBatcher<B: Backend> {
device: B::Device,
}
impl<B: Backend> MNISTBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}
#[derive(Clone, Debug)]
pub struct MNISTBatch<B: Backend> {
pub images: Tensor<B, 3>,
pub targets: Tensor<B, 1, Int>,
}
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
let images = items
.iter()
.map(|item| Data::<f32, 2>::from(item.image))
.map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device))
.map(|tensor| tensor.reshape([1, 28, 28]))
// Normalize: make between [0,1] and make the mean=0 and std=1
// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example
// https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122
.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)
.collect();
let targets = items
.iter()
.map(|item| {
Tensor::<B, 1, Int>::from_data(
Data::from([(item.label as i64).elem()]),
&self.device,
)
})
.collect();
let images = Tensor::cat(images, 0).to_device(&self.device);
let targets = Tensor::cat(targets, 0).to_device(&self.device);
MNISTBatch { images, targets }
}
}
name = "console"
version = "0.15.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"
dependencies = [
"encode_unicode",
"lazy_static",
"libc",
"unicode-width",
"windows-sys 0.52.0",
]
[[package]]
[[package]]
name = "encode_unicode"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"
[[package]]
name = "encoding_rs"
version = "0.8.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1"
dependencies = [
"cfg-if",
]
]
[[package]]
name = "h2"
version = "0.3.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
name = "http"
version = "0.2.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http",
"pin-project-lite",
]
[[package]]
name = "httparse"
version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"
[[package]]
name = "httpdate"
version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hyper"
version = "0.14.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80"
dependencies = [
"bytes",
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
"want",
]
[[package]]
name = "hyper-tls"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"
dependencies = [
"bytes",
"hyper",
"native-tls",
"tokio",
"tokio-native-tls",
]
[[package]]
name = "native-tls"
version = "0.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"
dependencies = [
"lazy_static",
"libc",
"log",
"openssl",
"openssl-probe",
"openssl-sys",
"schannel",
"security-framework",
"security-framework-sys",
"tempfile",
]
[[package]]
[[package]]
name = "openssl"
version = "0.10.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [
"bitflags 2.5.0",
"cfg-if",
"foreign-types 0.3.2",
"libc",
"once_cell",
"openssl-macros",
"openssl-sys",
]
[[package]]
name = "openssl-macros"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
]
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
[[package]]
name = "reqwest"
version = "0.11.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"
dependencies = [
"base64",
"bytes",
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"hyper-tls",
"ipnet",
"js-sys",
"log",
"mime",
"native-tls",
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-native-tls",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"winreg",
]
[[package]]
name = "security-framework"
version = "2.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
]
[[package]]
name = "system-configuration"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [
"bitflags 1.3.2",
"core-foundation",
"system-configuration-sys",
]
[[package]]
name = "system-configuration-sys"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
dependencies = [
"core-foundation-sys",
"libc",
name = "tokio"
version = "1.36.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"
dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"pin-project-lite",
"socket2",
"tokio-macros",
"windows-sys 0.48.0",
]
[[package]]
name = "tokio-macros"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.53",
]
[[package]]
name = "tokio-native-tls"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"
dependencies = [
"native-tls",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
"tracing",
]
[[package]]