Introduction
Features
- Simple syntax, looks and feels like PyTorch.
- Model training.
- Embed user-defined ops/kernels, such as flash-attention v2.
- Backends.
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- WASM support, run your models in a browser.
- Included models.
- Language Models.
- LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.
- Falcon.
- StarCoder, StarCoder2.
- Phi 1, 1.5, 2, and 3.
- Mamba, Minimal Mamba
- Gemma v1 2b and 7b+, v2 2b and 9b.
- Mistral 7b v0.1.
- Mixtral 8x7b v0.1.
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
- Replit-code-v1.5-3B.
- Bert.
- Yi-6B and Yi-34B.
- Qwen1.5, Qwen1.5 MoE.
- RWKV v5 and v6.
- Quantized LLMs.
- Llama 7b, 13b, 70b, as well as the chat and code variants.
- Mistral 7b, and 7b instruct.
- Mixtral 8x7b.
- Zephyr 7b a and b (Mistral-7b based).
- OpenChat 3.5 (Mistral-7b based).
- Text to text.
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
- Marian MT (Machine Translation).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- TrOCR.
- Audio.
- Whisper, multi-lingual speech-to-text.
- EnCodec, audio compression model.
- MetaVoice-1B, text-to-speech model.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT, ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
- Language Models.
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.
- Quantization support using the llama.cpp quantized types.
This book will introduce step by step how to use candle
.
Installation
With Cuda support:
- First, make sure that Cuda is correctly installed.
nvcc --version
should print information about your Cuda compiler driver.nvidia-smi --query-gpu=compute_cap --format=csv
should print your GPUs compute capability, e.g. something like:
compute_cap
8.9
You can also compile the Cuda kernels for a specific compute cap using the
CUDA_COMPUTE_CAP=<compute cap>
environment variable.
If any of the above commands errors out, please make sure to update your Cuda version.
- Create a new app and add
candle-core
with Cuda support.
Start by creating a new cargo:
cargo new myapp
cd myapp
Make sure to add the candle-core
crate with the cuda feature:
cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
Run cargo build
to make sure everything can be correctly built.
cargo build
Without Cuda support:
Create a new app and add candle-core
as follows:
cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
Finally, run cargo build
to make sure everything can be correctly built.
cargo build
With mkl support
You can also see the mkl
feature which could be interesting to get faster inference on CPU. Using mkl
Hello world!
We will now create the hello world of the ML world, building a model capable of solving MNIST dataset.
Open src/main.rs
and fill in this content:
extern crate candle_core; use candle_core::{Device, Result, Tensor}; struct Model { first: Tensor, second: Tensor, } impl Model { fn forward(&self, image: &Tensor) -> Result<Tensor> { let x = image.matmul(&self.first)?; let x = x.relu()?; x.matmul(&self.second) } } fn main() -> Result<()> { // Use Device::new_cuda(0)?; to use the GPU. let device = Device::Cpu; let first = Tensor::randn(0f32, 1.0, (784, 100), &device)?; let second = Tensor::randn(0f32, 1.0, (100, 10), &device)?; let model = Model { first, second }; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let digit = model.forward(&dummy_image)?; println!("Digit {digit:?} digit"); Ok(()) }
Everything should now run with:
cargo run --release
Using a Linear
layer.
Now that we have this, we might want to complexify things a bit, for instance by adding bias
and creating
the classical Linear
layer. We can do as such
#![allow(unused)] fn main() { extern crate candle_core; use candle_core::{Device, Result, Tensor}; struct Linear{ weight: Tensor, bias: Tensor, } impl Linear{ fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.matmul(&self.weight)?; x.broadcast_add(&self.bias) } } struct Model { first: Linear, second: Linear, } impl Model { fn forward(&self, image: &Tensor) -> Result<Tensor> { let x = self.first.forward(image)?; let x = x.relu()?; self.second.forward(&x) } } }
This will change the model running code into a new function
extern crate candle_core; use candle_core::{Device, Result, Tensor}; struct Linear{ weight: Tensor, bias: Tensor, } impl Linear{ fn forward(&self, x: &Tensor) -> Result<Tensor> { let x = x.matmul(&self.weight)?; x.broadcast_add(&self.bias) } } struct Model { first: Linear, second: Linear, } impl Model { fn forward(&self, image: &Tensor) -> Result<Tensor> { let x = self.first.forward(image)?; let x = x.relu()?; self.second.forward(&x) } } fn main() -> Result<()> { // Use Device::new_cuda(0)?; to use the GPU. // Use Device::Cpu; to use the CPU. let device = Device::cuda_if_available(0)?; // Creating a dummy model let weight = Tensor::randn(0f32, 1.0, (784, 100), &device)?; let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let first = Linear{weight, bias}; let weight = Tensor::randn(0f32, 1.0, (100, 10), &device)?; let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let second = Linear{weight, bias}; let model = Model { first, second }; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; // Inference on the model let digit = model.forward(&dummy_image)?; println!("Digit {digit:?} digit"); Ok(()) }
Now it works, it is a great way to create your own layers. But most of the classical layers are already implemented in candle-nn.
Using candle_nn
.
For instance Linear is already there. This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.
So instead we can simplify our example:
cargo add --git https://github.com/huggingface/candle.git candle-nn
And rewrite our examples using it
extern crate candle_core; extern crate candle_nn; use candle_core::{Device, Result, Tensor}; use candle_nn::{Linear, Module}; struct Model { first: Linear, second: Linear, } impl Model { fn forward(&self, image: &Tensor) -> Result<Tensor> { let x = self.first.forward(image)?; let x = x.relu()?; self.second.forward(&x) } } fn main() -> Result<()> { // Use Device::new_cuda(0)?; to use the GPU. let device = Device::Cpu; // This has changed (784, 100) -> (100, 784) ! let weight = Tensor::randn(0f32, 1.0, (100, 784), &device)?; let bias = Tensor::randn(0f32, 1.0, (100, ), &device)?; let first = Linear::new(weight, Some(bias)); let weight = Tensor::randn(0f32, 1.0, (10, 100), &device)?; let bias = Tensor::randn(0f32, 1.0, (10, ), &device)?; let second = Linear::new(weight, Some(bias)); let model = Model { first, second }; let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &device)?; let digit = model.forward(&dummy_image)?; println!("Digit {digit:?} digit"); Ok(()) }
Feel free to modify this example to use Conv2d
to create a classical convnet instead.
Now that we have the running dummy code we can get to more advanced topics:
Pytorch cheatsheet
Cheatsheet:
Using PyTorch | Using Candle | |
---|---|---|
Creation | torch.Tensor([[1, 2], [3, 4]]) | Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)? |
Creation | torch.zeros((2, 2)) | Tensor::zeros((2, 2), DType::F32, &Device::Cpu)? |
Indexing | tensor[:, :4] | tensor.i((.., ..4))? |
Operations | tensor.view((2, 2)) | tensor.reshape((2, 2))? |
Operations | a.matmul(b) | a.matmul(&b)? |
Arithmetic | a + b | &a + &b |
Device | tensor.to(device="cuda") | tensor.to_device(&Device::new_cuda(0)?)? |
Dtype | tensor.to(dtype=torch.float16) | tensor.to_dtype(&DType::F16)? |
Saving | torch.save({"A": A}, "model.bin") | candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")? |
Loading | weights = torch.load("model.bin") | candle::safetensors::load("model.safetensors", &device) |
Running a model
In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in safetensors
format.
Let's get started by running an old model : bert-base-uncased
.
Using the hub
Install the hf-hub
crate:
cargo add hf-hub
Then let's start by downloading the model file.
#![allow(unused)] fn main() { extern crate candle_core; extern crate hf_hub; use hf_hub::api::sync::Api; use candle_core::Device; let api = Api::new().unwrap(); let repo = api.model("bert-base-uncased".to_string()); let weights = repo.get("model.safetensors").unwrap(); let weights = candle_core::safetensors::load(weights, &Device::Cpu); }
We now have access to all the tensors within the file.
You can check all the names of the tensors here
Using async
hf-hub
comes with an async API.
cargo add hf-hub --features tokio
This is tested directly in examples crate because it needs external dependencies unfortunately:
See [this](https://github.com/rust-lang/mdBook/issues/706)
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
Using in a real model.
Now that we have our weights, we can use them in our bert architecture:
#![allow(unused)] fn main() { extern crate candle_core; extern crate candle_nn; extern crate hf_hub; use hf_hub::api::sync::Api; let api = Api::new().unwrap(); let repo = api.model("bert-base-uncased".to_string()); let weights = repo.get("model.safetensors").unwrap(); use candle_core::{Device, Tensor, DType}; use candle_nn::{Linear, Module}; let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap(); let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap(); let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap(); let linear = Linear::new(weight.clone(), Some(bias.clone())); let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap(); let output = linear.forward(&input_ids).unwrap(); }
For a full reference, you can check out the full bert example.
Memory mapping
For more efficient loading, instead of reading the file, you could use memmap2
Note: Be careful about memory mapping it seems to cause issues on Windows, WSL and will definitely be slower on network mounted disk, because it will issue more read calls.
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
Note: This operation is unsafe. See the safety notice. In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
Tensor Parallel Sharding
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
For that you need to use safetensors
directly.
cargo add safetensors
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
Error management
You might have seen in the code base a lot of .unwrap()
or ?
.
If you're unfamiliar with Rust check out the Rust book
for more information.
What's important to know though, is that if you want to know where a particular operation failed
You can simply use RUST_BACKTRACE=1
to get the location of where the model actually failed.
Let's see on failing code:
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
let z = x.matmul(&y)?;
Will print at runtime:
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
After adding RUST_BACKTRACE=1
:
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
Not super pretty at the moment, but we can see error occurred on { fn: "myapp::main", file: "./src/main.rs", line: 29 }
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using anyhow
for that.
The library is still young, please report any issues detecting where an error is coming from.
Cuda error management
When running a model on Cuda, you might get a stacktrace not really representing the error. The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.
One way to avoid this is to use CUDA_LAUNCH_BLOCKING=1
as an environment variable. This will force every kernel to be launched sequentially.
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the CudaSlice
only.
If this occurs, you can use compute-sanitizer
This tool is like valgrind
but for cuda. It will help locate the errors in the kernels.
Training
Training starts with data. We're going to use the huggingface hub and start with the Hello world dataset of machine learning, MNIST.
Let's start with downloading MNIST
from huggingface.
This requires hf-hub
.
cargo add hf-hub
This is going to be very hands-on for now.
This uses the standardized parquet
files from the refs/convert/parquet
branch on every dataset.
Our handles are now [parquet::file::serialized_reader::SerializedFileReader
].
We can inspect the content of the files with:
You should see something like:
Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
So each row contains 2 columns (image, label) with image being saved as bytes. Let's put them into a useful struct.
Simplified
How its works
This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.
Basic moments:
- A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.
- The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.
- The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.
- For training, samples with real data on the results of the first and second stages of different elections are used.
- The model is trained by backpropagation using gradient descent and the cross-entropy loss function.
- Model parameters (weights of neurons) are initialized randomly, then optimized during training.
- After training, the model is tested on a deferred sample to evaluate the accuracy.
- If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.
Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.
const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -> Result<Self> {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&xs)?;
let xs = xs.relu()?;
self.ln3.forward(&xs)
}
}
fn train(m: Dataset, dev: &Device) -> anyhow::Result<MultiLevelPerceptron> {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&train_votes)?;
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
let loss = loss::nll(&log_sm, &train_results)?;
sgd.backward_step(&loss)?;
let test_logits = model.forward(&test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::<f32>()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::<f32>()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy < 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}
#[tokio::test]
async fn simplified() -> anyhow::Result<()> {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec<u32> = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec<u32> = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &dev)?;
let test_votes_vec: Vec<u32> = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec<u32> = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &dev) {
Ok(model) => {
trained_model = model;
break;
},
Err(e) => {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec<u32> = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::<f32>())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}
Example output
Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
MNIST
So we now have downloaded the MNIST parquet files, let's put them in a simple struct.
let test_samples = 10_000;
let mut test_buffer_images: Vec<u8> = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec<u8> = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec<u8> = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec<u8> = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory. It is quite rudimentary, but simple enough for a small dataset like MNIST.