Using the hub
Install the hf-hub
crate:
cargo add hf-hub
Then let's start by downloading the model file.
#![allow(unused)] fn main() { extern crate candle_core; extern crate candle_hf_hub; use candle_hf_hub::api::sync::Api; use candle_core::Device; let api = Api::new().unwrap(); let repo = api.model("bert-base-uncased".to_string()); let weights = repo.get("model.safetensors").unwrap(); let weights = candle_core::safetensors::load(weights, &Device::Cpu); }
We now have access to all the tensors within the file.
You can check all the names of the tensors here
Using async
hf-hub
comes with an async API.
cargo add hf-hub --features tokio
This is tested directly in examples crate because it needs external dependencies unfortunately:
See [this](https://github.com/rust-lang/mdBook/issues/706)
use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
Using in a real model.
Now that we have our weights, we can use them in our bert architecture:
#![allow(unused)] fn main() { extern crate candle_core; extern crate candle_nn; extern crate candle_hf_hub; use candle_hf_hub::api::sync::Api; let api = Api::new().unwrap(); let repo = api.model("bert-base-uncased".to_string()); let weights = repo.get("model.safetensors").unwrap(); use candle_core::{Device, Tensor, DType}; use candle_nn::{Linear, Module}; let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap(); let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap(); let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap(); let linear = Linear::new(weight.clone(), Some(bias.clone())); let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap(); let output = linear.forward(&input_ids).unwrap(); }
For a full reference, you can check out the full bert example.
Memory mapping
For more efficient loading, instead of reading the file, you could use memmap2
Note: Be careful about memory mapping it seems to cause issues on Windows, WSL and will definitely be slower on network mounted disk, because it will issue more read calls.
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
Note: This operation is unsafe. See the safety notice. In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.
Tensor Parallel Sharding
When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.
For that you need to use safetensors
directly.
cargo add safetensors
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();