T4JMISYCRZ3BODA2PEOXKJ2RSDIYK37FH5B6DHGRKVT47ABPIJUQC use burn::{config::Config,data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},module::Module,nn::loss::CrossEntropyLossConfig,optim::AdamConfig,record::CompactRecorder,tensor::{backend::{AutodiffBackend, Backend},Int, Tensor,},train::{metric::{AccuracyMetric, LossMetric},ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,},};use crate::{data::{MNISTBatch, MNISTBatcher},model::{Model, ModelConfig},};impl<B: Backend> Model<B> {pub fn forward_classification(&self,images: Tensor<B, 3>,targets: Tensor<B, 1, Int>,) -> ClassificationOutput<B> {let output = self.forward(images);let loss = CrossEntropyLossConfig::new().init(&output.device()).forward(output.clone(), targets.clone());ClassificationOutput::new(loss, output, targets)}}impl<B: AutodiffBackend> TrainStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {fn step(&self, batch: MNISTBatch<B>) -> TrainOutput<ClassificationOutput<B>> {let item = self.forward_classification(batch.images, batch.targets);TrainOutput::new(self, item.loss.backward(), item)}}impl<B: Backend> ValidStep<MNISTBatch<B>, ClassificationOutput<B>> for Model<B> {fn step(&self, batch: MNISTBatch<B>) -> ClassificationOutput<B> {self.forward_classification(batch.images, batch.targets)}}#[derive(Config)]pub struct TrainingConfig {pub model: ModelConfig,pub optimizer: AdamConfig,#[config(default = 10)]pub num_epochs: usize,#[config(default = 64)]pub batch_size: usize,#[config(default = 4)]pub num_workers: usize,#[config(default = 42)]pub seed: u64,#[config(default = 1.0e-4)]pub learning_rate: f64,}pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {std::fs::create_dir_all(artifact_dir).ok();config.save(format!("{artifact_dir}/config.json")).expect("config not saved correctly");B::seed(config.seed);let batcher_train = MNISTBatcher::<B>::new(device.clone());let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());let dataloader_train = DataLoaderBuilder::new(batcher_train).batch_size(config.batch_size).shuffle(config.seed).num_workers(config.num_workers).build(MNISTDataset::train());let dataloader_test = DataLoaderBuilder::new(batcher_valid).batch_size(config.batch_size).shuffle(config.seed).num_workers(config.num_workers).build(MNISTDataset::test());let model = config.model.init::<B>(&device);let learner = LearnerBuilder::new(artifact_dir).metric_train_numeric(AccuracyMetric::new()).metric_valid_numeric(AccuracyMetric::new()).metric_train_numeric(LossMetric::new()).metric_valid_numeric(LossMetric::new()).with_file_checkpointer(CompactRecorder::new()).devices(vec![device]).num_epochs(config.num_epochs).build(model, config.optimizer.init(), config.learning_rate);let model_trained = learner.fit(dataloader_train, dataloader_test);model_trained.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new()).expect("Trained model should be saved successfully");}
use burn::{config::Config,module::Module,nn::{conv::{Conv2d, Conv2dConfig},pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},Dropout, DropoutConfig, Linear, LinearConfig, ReLU,},tensor::{backend::Backend, Tensor},};#[derive(Module, Debug)]pub struct Model<B: Backend> {conv1: Conv2d<B>,conv2: Conv2d<B>,pool: AdaptiveAvgPool2d,dropout: Dropout,linear1: Linear<B>,linear2: Linear<B>,activation: ReLU,}impl<B: Backend> Model<B> {/// # Shapes/// - Images [batch_size, height, width]/// - Output [batch_size, num_classes]pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {let [batch_size, height, width] = images.dims();// Create channel at second dimlet x = images.reshape([batch_size, 1, height, width]);let x = self.conv1.forward(x); // [batch_size, 8, _, _]let x = self.dropout.forward(x);let x = self.conv2.forward(x); // [batch_size, 16, _, _]let x = self.dropout.forward(x);let x = self.activation.forward(x);let x = self.pool.forward(x); // [batch_size, 16, 8, 8]let x = x.reshape([batch_size, 16 * 8 * 8]);let x = self.linear1.forward(x);let x = self.dropout.forward(x);let x = self.activation.forward(x);self.linear2.forward(x)}}#[derive(Config, Debug)]pub struct ModelConfig {num_classes: usize,hidden_size: usize,#[config(default = "0.5")]dropout: f64,}impl ModelConfig {pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {Model {conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),activation: ReLU::new(),linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),dropout: DropoutConfig::new(self.dropout).init(),}}}
println!("Hello, world!");
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;type AutodiffBackend = Autodiff<Backend>;let device = WgpuDevice::default();training::train::<AutodiffBackend>("/tmp/guide",training::TrainingConfig::new(ModelConfig::new(10, 512), AdamConfig::new()),device,);
use burn::{data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},};pub struct MNISTBatcher<B: Backend> {device: B::Device,}impl<B: Backend> MNISTBatcher<B> {pub fn new(device: B::Device) -> Self {Self { device }}}#[derive(Clone, Debug)]pub struct MNISTBatch<B: Backend> {pub images: Tensor<B, 3>,pub targets: Tensor<B, 1, Int>,}impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {let images = items.iter().map(|item| Data::<f32, 2>::from(item.image)).map(|data| Tensor::<B, 2>::from_data(data.convert(), &self.device)).map(|tensor| tensor.reshape([1, 28, 28]))// Normalize: make between [0,1] and make the mean=0 and std=1// values mean=0.1307,std=0.3081 are from the PyTorch MNIST example// https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122.map(|tensor| ((tensor / 255) - 0.1307) / 0.3081).collect();let targets = items.iter().map(|item| {Tensor::<B, 1, Int>::from_data(Data::from([(item.label as i64).elem()]),&self.device,)}).collect();let images = Tensor::cat(images, 0).to_device(&self.device);let targets = Tensor::cat(targets, 0).to_device(&self.device);MNISTBatch { images, targets }}}
name = "console"version = "0.15.8"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb"dependencies = ["encode_unicode","lazy_static","libc","unicode-width","windows-sys 0.52.0",][[package]]
[[package]]name = "encode_unicode"version = "0.3.6"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f"[[package]]name = "encoding_rs"version = "0.8.33"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1"dependencies = ["cfg-if",]
][[package]]name = "h2"version = "0.3.25"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "4fbd2820c5e49886948654ab546d0688ff24530286bdcf8fca3cefb16d4618eb"dependencies = ["bytes","fnv","futures-core","futures-sink","futures-util","http","indexmap","slab","tokio","tokio-util","tracing",
name = "http"version = "0.2.12"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1"dependencies = ["bytes","fnv","itoa",][[package]]name = "http-body"version = "0.4.6"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"dependencies = ["bytes","http","pin-project-lite",][[package]]name = "httparse"version = "1.8.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904"[[package]]name = "httpdate"version = "1.0.3"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"[[package]]name = "hyper"version = "0.14.28"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "bf96e135eb83a2a8ddf766e426a841d8ddd7449d5f00d34ea02b41d2f19eef80"dependencies = ["bytes","futures-channel","futures-core","futures-util","h2","http","http-body","httparse","httpdate","itoa","pin-project-lite","socket2","tokio","tower-service","tracing","want",][[package]]name = "hyper-tls"version = "0.5.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905"dependencies = ["bytes","hyper","native-tls","tokio","tokio-native-tls",][[package]]
name = "native-tls"version = "0.2.11"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e"dependencies = ["lazy_static","libc","log","openssl","openssl-probe","openssl-sys","schannel","security-framework","security-framework-sys","tempfile",][[package]]
[[package]]name = "openssl"version = "0.10.64"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"dependencies = ["bitflags 2.5.0","cfg-if","foreign-types 0.3.2","libc","once_cell","openssl-macros","openssl-sys",][[package]]name = "openssl-macros"version = "0.1.1"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"dependencies = ["proc-macro2","quote","syn 2.0.53",][[package]]name = "openssl-probe"version = "0.1.5"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
name = "pin-project-lite"version = "0.2.13"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"[[package]]name = "pin-utils"version = "0.1.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"[[package]]
[[package]]name = "reqwest"version = "0.11.27"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "dd67538700a17451e7cba03ac727fb961abb7607553461627b97de0b89cf4a62"dependencies = ["base64","bytes","encoding_rs","futures-core","futures-util","h2","http","http-body","hyper","hyper-tls","ipnet","js-sys","log","mime","native-tls","once_cell","percent-encoding","pin-project-lite","rustls-pemfile","serde","serde_json","serde_urlencoded","sync_wrapper","system-configuration","tokio","tokio-native-tls","tower-service","url","wasm-bindgen","wasm-bindgen-futures","web-sys","winreg",]
[[package]]name = "security-framework"version = "2.9.2"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de"dependencies = ["bitflags 1.3.2","core-foundation","core-foundation-sys","libc","security-framework-sys",]
][[package]]name = "system-configuration"version = "0.5.1"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"dependencies = ["bitflags 1.3.2","core-foundation","system-configuration-sys",][[package]]name = "system-configuration-sys"version = "0.5.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"dependencies = ["core-foundation-sys","libc",
name = "tokio"version = "1.36.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931"dependencies = ["backtrace","bytes","libc","mio","pin-project-lite","socket2","tokio-macros","windows-sys 0.48.0",][[package]]name = "tokio-macros"version = "2.2.0"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"dependencies = ["proc-macro2","quote","syn 2.0.53",][[package]]name = "tokio-native-tls"version = "0.3.1"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2"dependencies = ["native-tls","tokio",][[package]]name = "tokio-util"version = "0.7.10"source = "registry+https://github.com/rust-lang/crates.io-index"checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15"dependencies = ["bytes","futures-core","futures-sink","pin-project-lite","tokio","tracing",][[package]]