Pytorch cheatsheet
Cheatsheet:
Using PyTorch | Using Candle | |
---|---|---|
Creation | torch.Tensor([[1, 2], [3, 4]]) | Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)? |
Creation | torch.zeros((2, 2)) | Tensor::zeros((2, 2), DType::F32, &Device::Cpu)? |
Indexing | tensor[:, :4] | tensor.i((.., ..4))? |
Operations | tensor.view((2, 2)) | tensor.reshape((2, 2))? |
Operations | a.matmul(b) | a.matmul(&b)? |
Arithmetic | a + b | &a + &b |
Device | tensor.to(device="cuda") | tensor.to_device(&Device::new_cuda(0)?)? |
Dtype | tensor.to(dtype=torch.float16) | tensor.to_dtype(&DType::F16)? |
Saving | torch.save({"A": A}, "model.bin") | candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")? |
Loading | weights = torch.load("model.bin") | candle::safetensors::load("model.safetensors", &device) |