Files
candle/print.html
2025-04-13 14:47:49 +00:00

938 lines
45 KiB
HTML

<!DOCTYPE HTML>
<html lang="en" class="light sidebar-visible" dir="ltr">
<head>
<!-- Book generated using mdBook -->
<meta charset="UTF-8">
<title>Candle Documentation</title>
<meta name="robots" content="noindex">
<!-- Custom HTML head -->
<meta name="description" content="">
<meta name="viewport" content="width=device-width, initial-scale=1">
<meta name="theme-color" content="#ffffff">
<link rel="icon" href="favicon.svg">
<link rel="shortcut icon" href="favicon.png">
<link rel="stylesheet" href="css/variables.css">
<link rel="stylesheet" href="css/general.css">
<link rel="stylesheet" href="css/chrome.css">
<link rel="stylesheet" href="css/print.css" media="print">
<!-- Fonts -->
<link rel="stylesheet" href="FontAwesome/css/font-awesome.css">
<link rel="stylesheet" href="fonts/fonts.css">
<!-- Highlight.js Stylesheets -->
<link rel="stylesheet" id="highlight-css" href="highlight.css">
<link rel="stylesheet" id="tomorrow-night-css" href="tomorrow-night.css">
<link rel="stylesheet" id="ayu-highlight-css" href="ayu-highlight.css">
<!-- Custom theme stylesheets -->
<!-- Provide site root and default themes to javascript -->
<script>
const path_to_root = "";
const default_light_theme = "light";
const default_dark_theme = "navy";
</script>
<!-- Start loading toc.js asap -->
<script src="toc.js"></script>
</head>
<body>
<div id="body-container">
<!-- Work around some values being stored in localStorage wrapped in quotes -->
<script>
try {
let theme = localStorage.getItem('mdbook-theme');
let sidebar = localStorage.getItem('mdbook-sidebar');
if (theme.startsWith('"') && theme.endsWith('"')) {
localStorage.setItem('mdbook-theme', theme.slice(1, theme.length - 1));
}
if (sidebar.startsWith('"') && sidebar.endsWith('"')) {
localStorage.setItem('mdbook-sidebar', sidebar.slice(1, sidebar.length - 1));
}
} catch (e) { }
</script>
<!-- Set the theme before any content is loaded, prevents flash -->
<script>
const default_theme = window.matchMedia("(prefers-color-scheme: dark)").matches ? default_dark_theme : default_light_theme;
let theme;
try { theme = localStorage.getItem('mdbook-theme'); } catch(e) { }
if (theme === null || theme === undefined) { theme = default_theme; }
const html = document.documentElement;
html.classList.remove('light')
html.classList.add(theme);
html.classList.add("js");
</script>
<input type="checkbox" id="sidebar-toggle-anchor" class="hidden">
<!-- Hide / unhide sidebar before it is displayed -->
<script>
let sidebar = null;
const sidebar_toggle = document.getElementById("sidebar-toggle-anchor");
if (document.body.clientWidth >= 1080) {
try { sidebar = localStorage.getItem('mdbook-sidebar'); } catch(e) { }
sidebar = sidebar || 'visible';
} else {
sidebar = 'hidden';
}
sidebar_toggle.checked = sidebar === 'visible';
html.classList.remove('sidebar-visible');
html.classList.add("sidebar-" + sidebar);
</script>
<nav id="sidebar" class="sidebar" aria-label="Table of contents">
<!-- populated by js -->
<mdbook-sidebar-scrollbox class="sidebar-scrollbox"></mdbook-sidebar-scrollbox>
<noscript>
<iframe class="sidebar-iframe-outer" src="toc.html"></iframe>
</noscript>
<div id="sidebar-resize-handle" class="sidebar-resize-handle">
<div class="sidebar-resize-indicator"></div>
</div>
</nav>
<div id="page-wrapper" class="page-wrapper">
<div class="page">
<div id="menu-bar-hover-placeholder"></div>
<div id="menu-bar" class="menu-bar sticky">
<div class="left-buttons">
<label id="sidebar-toggle" class="icon-button" for="sidebar-toggle-anchor" title="Toggle Table of Contents" aria-label="Toggle Table of Contents" aria-controls="sidebar">
<i class="fa fa-bars"></i>
</label>
<button id="theme-toggle" class="icon-button" type="button" title="Change theme" aria-label="Change theme" aria-haspopup="true" aria-expanded="false" aria-controls="theme-list">
<i class="fa fa-paint-brush"></i>
</button>
<ul id="theme-list" class="theme-popup" aria-label="Themes" role="menu">
<li role="none"><button role="menuitem" class="theme" id="default_theme">Auto</button></li>
<li role="none"><button role="menuitem" class="theme" id="light">Light</button></li>
<li role="none"><button role="menuitem" class="theme" id="rust">Rust</button></li>
<li role="none"><button role="menuitem" class="theme" id="coal">Coal</button></li>
<li role="none"><button role="menuitem" class="theme" id="navy">Navy</button></li>
<li role="none"><button role="menuitem" class="theme" id="ayu">Ayu</button></li>
</ul>
<button id="search-toggle" class="icon-button" type="button" title="Search. (Shortkey: s)" aria-label="Toggle Searchbar" aria-expanded="false" aria-keyshortcuts="S" aria-controls="searchbar">
<i class="fa fa-search"></i>
</button>
</div>
<h1 class="menu-title">Candle Documentation</h1>
<div class="right-buttons">
<a href="print.html" title="Print this book" aria-label="Print this book">
<i id="print-button" class="fa fa-print"></i>
</a>
</div>
</div>
<div id="search-wrapper" class="hidden">
<form id="searchbar-outer" class="searchbar-outer">
<input type="search" id="searchbar" name="searchbar" placeholder="Search this book ..." aria-controls="searchresults-outer" aria-describedby="searchresults-header">
</form>
<div id="searchresults-outer" class="searchresults-outer hidden">
<div id="searchresults-header" class="searchresults-header"></div>
<ul id="searchresults">
</ul>
</div>
</div>
<!-- Apply ARIA attributes after the sidebar and the sidebar toggle button are added to the DOM -->
<script>
document.getElementById('sidebar-toggle').setAttribute('aria-expanded', sidebar === 'visible');
document.getElementById('sidebar').setAttribute('aria-hidden', sidebar !== 'visible');
Array.from(document.querySelectorAll('#sidebar a')).forEach(function(link) {
link.setAttribute('tabIndex', sidebar === 'visible' ? 0 : -1);
});
</script>
<div id="content" class="content">
<main>
<h1 id="introduction"><a class="header" href="#introduction">Introduction</a></h1>
<h2 id="features"><a class="header" href="#features">Features</a></h2>
<ul>
<li>Simple syntax, looks and feels like PyTorch.
<ul>
<li>Model training.</li>
<li>Embed user-defined ops/kernels, such as <a href="https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152">flash-attention v2</a>.</li>
</ul>
</li>
<li>Backends.
<ul>
<li>Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.</li>
<li>CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.</li>
<li>WASM support, run your models in a browser.</li>
</ul>
</li>
<li>Included models.
<ul>
<li>Language Models.
<ul>
<li>LLaMA v1, v2, and v3 with variants such as SOLAR-10.7B.</li>
<li>Falcon.</li>
<li>StarCoder, StarCoder2.</li>
<li>Phi 1, 1.5, 2, and 3.</li>
<li>Mamba, Minimal Mamba</li>
<li>Gemma v1 2b and 7b+, v2 2b and 9b.</li>
<li>Mistral 7b v0.1.</li>
<li>Mixtral 8x7b v0.1.</li>
<li>StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.</li>
<li>Replit-code-v1.5-3B.</li>
<li>Bert.</li>
<li>Yi-6B and Yi-34B.</li>
<li>Qwen1.5, Qwen1.5 MoE.</li>
<li>RWKV v5 and v6.</li>
</ul>
</li>
<li>Quantized LLMs.
<ul>
<li>Llama 7b, 13b, 70b, as well as the chat and code variants.</li>
<li>Mistral 7b, and 7b instruct.</li>
<li>Mixtral 8x7b.</li>
<li>Zephyr 7b a and b (Mistral-7b based).</li>
<li>OpenChat 3.5 (Mistral-7b based).</li>
</ul>
</li>
<li>Text to text.
<ul>
<li>T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).</li>
<li>Marian MT (Machine Translation).</li>
</ul>
</li>
<li>Text to image.
<ul>
<li>Stable Diffusion v1.5, v2.1, XL v1.0.</li>
<li>Wurstchen v2.</li>
</ul>
</li>
<li>Image to text.
<ul>
<li>BLIP.</li>
<li>TrOCR.</li>
</ul>
</li>
<li>Audio.
<ul>
<li>Whisper, multi-lingual speech-to-text.</li>
<li>EnCodec, audio compression model.</li>
<li>MetaVoice-1B, text-to-speech model.</li>
<li>Parler-TTS, text-to-speech model.</li>
</ul>
</li>
<li>Computer Vision Models.
<ul>
<li>DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.</li>
<li>yolo-v3, yolo-v8.</li>
<li>Segment-Anything Model (SAM).</li>
<li>SegFormer.</li>
</ul>
</li>
</ul>
</li>
<li>File formats: load models from safetensors, npz, ggml, or PyTorch files.</li>
<li>Serverless (on CPU), small and fast deployments.</li>
<li>Quantization support using the llama.cpp quantized types.</li>
</ul>
<p>This book will introduce step by step how to use <code>candle</code>.</p>
<div style="break-before: page; page-break-before: always;"></div><h1 id="installation"><a class="header" href="#installation">Installation</a></h1>
<p><strong>With Cuda support</strong>:</p>
<ol>
<li>First, make sure that Cuda is correctly installed.</li>
</ol>
<ul>
<li><code>nvcc --version</code> should print information about your Cuda compiler driver.</li>
<li><code>nvidia-smi --query-gpu=compute_cap --format=csv</code> should print your GPUs compute capability, e.g. something
like:</li>
</ul>
<pre><code class="language-bash">compute_cap
8.9
</code></pre>
<p>You can also compile the Cuda kernels for a specific compute cap using the
<code>CUDA_COMPUTE_CAP=&lt;compute cap&gt;</code> environment variable.</p>
<p>If any of the above commands errors out, please make sure to update your Cuda version.</p>
<ol start="2">
<li>Create a new app and add <a href="https://github.com/huggingface/candle/tree/main/candle-core"><code>candle-core</code></a> with Cuda support.</li>
</ol>
<p>Start by creating a new cargo:</p>
<pre><code class="language-bash">cargo new myapp
cd myapp
</code></pre>
<p>Make sure to add the <code>candle-core</code> crate with the cuda feature:</p>
<pre><code class="language-bash">cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"
</code></pre>
<p>Run <code>cargo build</code> to make sure everything can be correctly built.</p>
<pre><code class="language-bash">cargo build
</code></pre>
<p><strong>Without Cuda support</strong>:</p>
<p>Create a new app and add <a href="https://github.com/huggingface/candle/tree/main/candle-core"><code>candle-core</code></a> as follows:</p>
<pre><code class="language-bash">cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core
</code></pre>
<p>Finally, run <code>cargo build</code> to make sure everything can be correctly built.</p>
<pre><code class="language-bash">cargo build
</code></pre>
<p><strong>With mkl support</strong></p>
<p>You can also see the <code>mkl</code> feature which could be interesting to get faster inference on CPU. <a href="guide/./advanced/mkl.html">Using mkl</a></p>
<div style="break-before: page; page-break-before: always;"></div><h1 id="hello-world"><a class="header" href="#hello-world">Hello world!</a></h1>
<p>We will now create the hello world of the ML world, building a model capable of solving MNIST dataset.</p>
<p>Open <code>src/main.rs</code> and fill in this content:</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">extern crate candle_core;
</span>use candle_core::{Device, Result, Tensor};
struct Model {
first: Tensor,
second: Tensor,
}
impl Model {
fn forward(&amp;self, image: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
let x = image.matmul(&amp;self.first)?;
let x = x.relu()?;
x.matmul(&amp;self.second)
}
}
fn main() -&gt; Result&lt;()&gt; {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
let first = Tensor::randn(0f32, 1.0, (784, 100), &amp;device)?;
let second = Tensor::randn(0f32, 1.0, (100, 10), &amp;device)?;
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &amp;device)?;
let digit = model.forward(&amp;dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}</code></pre></pre>
<p>Everything should now run with:</p>
<pre><code class="language-bash">cargo run --release
</code></pre>
<h2 id="using-a-linear-layer"><a class="header" href="#using-a-linear-layer">Using a <code>Linear</code> layer.</a></h2>
<p>Now that we have this, we might want to complexify things a bit, for instance by adding <code>bias</code> and creating
the classical <code>Linear</code> layer. We can do as such</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">#![allow(unused)]
</span><span class="boring">fn main() {
</span><span class="boring">extern crate candle_core;
</span><span class="boring">use candle_core::{Device, Result, Tensor};
</span>struct Linear{
weight: Tensor,
bias: Tensor,
}
impl Linear{
fn forward(&amp;self, x: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
let x = x.matmul(&amp;self.weight)?;
x.broadcast_add(&amp;self.bias)
}
}
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&amp;self, image: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&amp;x)
}
}
<span class="boring">}</span></code></pre></pre>
<p>This will change the model running code into a new function</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">extern crate candle_core;
</span><span class="boring">use candle_core::{Device, Result, Tensor};
</span><span class="boring">struct Linear{
</span><span class="boring"> weight: Tensor,
</span><span class="boring"> bias: Tensor,
</span><span class="boring">}
</span><span class="boring">impl Linear{
</span><span class="boring"> fn forward(&amp;self, x: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
</span><span class="boring"> let x = x.matmul(&amp;self.weight)?;
</span><span class="boring"> x.broadcast_add(&amp;self.bias)
</span><span class="boring"> }
</span><span class="boring">}
</span><span class="boring">
</span><span class="boring">struct Model {
</span><span class="boring"> first: Linear,
</span><span class="boring"> second: Linear,
</span><span class="boring">}
</span><span class="boring">
</span><span class="boring">impl Model {
</span><span class="boring"> fn forward(&amp;self, image: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
</span><span class="boring"> let x = self.first.forward(image)?;
</span><span class="boring"> let x = x.relu()?;
</span><span class="boring"> self.second.forward(&amp;x)
</span><span class="boring"> }
</span><span class="boring">}
</span>fn main() -&gt; Result&lt;()&gt; {
// Use Device::new_cuda(0)?; to use the GPU.
// Use Device::Cpu; to use the CPU.
let device = Device::cuda_if_available(0)?;
// Creating a dummy model
let weight = Tensor::randn(0f32, 1.0, (784, 100), &amp;device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &amp;device)?;
let first = Linear{weight, bias};
let weight = Tensor::randn(0f32, 1.0, (100, 10), &amp;device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &amp;device)?;
let second = Linear{weight, bias};
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &amp;device)?;
// Inference on the model
let digit = model.forward(&amp;dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}</code></pre></pre>
<p>Now it works, it is a great way to create your own layers.
But most of the classical layers are already implemented in <a href="https://github.com/huggingface/candle/tree/main/candle-nn">candle-nn</a>.</p>
<h2 id="using-candle_nn"><a class="header" href="#using-candle_nn">Using <code>candle_nn</code>.</a></h2>
<p>For instance <a href="https://github.com/huggingface/candle/blob/main/candle-nn/src/linear.rs">Linear</a> is already there.
This Linear is coded with PyTorch layout in mind, to reuse better existing models out there, so it uses the transpose of the weights and not the weights directly.</p>
<p>So instead we can simplify our example:</p>
<pre><code class="language-bash">cargo add --git https://github.com/huggingface/candle.git candle-nn
</code></pre>
<p>And rewrite our examples using it</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">extern crate candle_core;
</span><span class="boring">extern crate candle_nn;
</span>use candle_core::{Device, Result, Tensor};
use candle_nn::{Linear, Module};
struct Model {
first: Linear,
second: Linear,
}
impl Model {
fn forward(&amp;self, image: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
let x = self.first.forward(image)?;
let x = x.relu()?;
self.second.forward(&amp;x)
}
}
fn main() -&gt; Result&lt;()&gt; {
// Use Device::new_cuda(0)?; to use the GPU.
let device = Device::Cpu;
// This has changed (784, 100) -&gt; (100, 784) !
let weight = Tensor::randn(0f32, 1.0, (100, 784), &amp;device)?;
let bias = Tensor::randn(0f32, 1.0, (100, ), &amp;device)?;
let first = Linear::new(weight, Some(bias));
let weight = Tensor::randn(0f32, 1.0, (10, 100), &amp;device)?;
let bias = Tensor::randn(0f32, 1.0, (10, ), &amp;device)?;
let second = Linear::new(weight, Some(bias));
let model = Model { first, second };
let dummy_image = Tensor::randn(0f32, 1.0, (1, 784), &amp;device)?;
let digit = model.forward(&amp;dummy_image)?;
println!("Digit {digit:?} digit");
Ok(())
}</code></pre></pre>
<p>Feel free to modify this example to use <code>Conv2d</code> to create a classical convnet instead.</p>
<p>Now that we have the running dummy code we can get to more advanced topics:</p>
<ul>
<li><a href="guide/../guide/cheatsheet.html">For PyTorch users</a></li>
<li><a href="guide/../inference/inference.html">Running existing models</a></li>
<li><a href="guide/../training/training.html">Training models</a></li>
</ul>
<div style="break-before: page; page-break-before: always;"></div><h1 id="pytorch-cheatsheet"><a class="header" href="#pytorch-cheatsheet">Pytorch cheatsheet</a></h1>
<p>Cheatsheet:</p>
<div class="table-wrapper"><table><thead><tr><th></th><th>Using PyTorch</th><th>Using Candle</th></tr></thead><tbody>
<tr><td>Creation</td><td><code>torch.Tensor([[1, 2], [3, 4]])</code></td><td><code>Tensor::new(&amp;[[1f32, 2.], [3., 4.]], &amp;Device::Cpu)?</code></td></tr>
<tr><td>Creation</td><td><code>torch.zeros((2, 2))</code></td><td><code>Tensor::zeros((2, 2), DType::F32, &amp;Device::Cpu)?</code></td></tr>
<tr><td>Indexing</td><td><code>tensor[:, :4]</code></td><td><code>tensor.i((.., ..4))?</code></td></tr>
<tr><td>Operations</td><td><code>tensor.view((2, 2))</code></td><td><code>tensor.reshape((2, 2))?</code></td></tr>
<tr><td>Operations</td><td><code>a.matmul(b)</code></td><td><code>a.matmul(&amp;b)?</code></td></tr>
<tr><td>Arithmetic</td><td><code>a + b</code></td><td><code>&amp;a + &amp;b</code></td></tr>
<tr><td>Device</td><td><code>tensor.to(device="cuda")</code></td><td><code>tensor.to_device(&amp;Device::new_cuda(0)?)?</code></td></tr>
<tr><td>Dtype</td><td><code>tensor.to(dtype=torch.float16)</code></td><td><code>tensor.to_dtype(&amp;DType::F16)?</code></td></tr>
<tr><td>Saving</td><td><code>torch.save({"A": A}, "model.bin")</code></td><td><code>candle::safetensors::save(&amp;HashMap::from([("A", A)]), "model.safetensors")?</code></td></tr>
<tr><td>Loading</td><td><code>weights = torch.load("model.bin")</code></td><td><code>candle::safetensors::load("model.safetensors", &amp;device)</code></td></tr>
</tbody></table>
</div><div style="break-before: page; page-break-before: always;"></div><h1 id="running-a-model"><a class="header" href="#running-a-model">Running a model</a></h1>
<p>In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in <a href="https://github.com/huggingface/safetensors"><code>safetensors</code></a> format.</p>
<p>Let's get started by running an old model : <code>bert-base-uncased</code>.</p>
<div style="break-before: page; page-break-before: always;"></div><h1 id="using-the-hub"><a class="header" href="#using-the-hub">Using the hub</a></h1>
<p>Install the <a href="https://github.com/huggingface/hf-hub"><code>hf-hub</code></a> crate:</p>
<pre><code class="language-bash">cargo add hf-hub
</code></pre>
<p>Then let's start by downloading the <a href="https://huggingface.co/bert-base-uncased/tree/main">model file</a>.</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">#![allow(unused)]
</span><span class="boring">fn main() {
</span><span class="boring">extern crate candle_core;
</span><span class="boring">extern crate hf_hub;
</span>use hf_hub::api::sync::Api;
use candle_core::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").unwrap();
let weights = candle_core::safetensors::load(weights, &amp;Device::Cpu);
<span class="boring">}</span></code></pre></pre>
<p>We now have access to all the <a href="https://huggingface.co/bert-base-uncased?show_tensors=true">tensors</a> within the file.</p>
<p>You can check all the names of the tensors <a href="https://huggingface.co/bert-base-uncased?show_tensors=true">here</a></p>
<h2 id="using-async"><a class="header" href="#using-async">Using async</a></h2>
<p><code>hf-hub</code> comes with an async API.</p>
<pre><code class="language-bash">cargo add hf-hub --features tokio
</code></pre>
<pre><code class="language-rust ignore"><span class="boring">This is tested directly in examples crate because it needs external dependencies unfortunately:
</span><span class="boring">See [this](https://github.com/rust-lang/mdBook/issues/706)
</span>use candle::Device;
use hf_hub::api::tokio::Api;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights_filename, &amp;Device::Cpu).unwrap();</code></pre>
<h2 id="using-in-a-real-model"><a class="header" href="#using-in-a-real-model">Using in a real model.</a></h2>
<p>Now that we have our weights, we can use them in our bert architecture:</p>
<pre><pre class="playground"><code class="language-rust"><span class="boring">#![allow(unused)]
</span><span class="boring">fn main() {
</span><span class="boring">extern crate candle_core;
</span><span class="boring">extern crate candle_nn;
</span><span class="boring">extern crate hf_hub;
</span><span class="boring">use hf_hub::api::sync::Api;
</span><span class="boring">
</span><span class="boring">let api = Api::new().unwrap();
</span><span class="boring">let repo = api.model("bert-base-uncased".to_string());
</span><span class="boring">
</span><span class="boring">let weights = repo.get("model.safetensors").unwrap();
</span>use candle_core::{Device, Tensor, DType};
use candle_nn::{Linear, Module};
let weights = candle_core::safetensors::load(weights, &amp;Device::Cpu).unwrap();
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
let linear = Linear::new(weight.clone(), Some(bias.clone()));
let input_ids = Tensor::zeros((3, 768), DType::F32, &amp;Device::Cpu).unwrap();
let output = linear.forward(&amp;input_ids).unwrap();
<span class="boring">}</span></code></pre></pre>
<p>For a full reference, you can check out the full <a href="https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert">bert</a> example.</p>
<h2 id="memory-mapping"><a class="header" href="#memory-mapping">Memory mapping</a></h2>
<p>For more efficient loading, instead of reading the file, you could use <a href="https://docs.rs/memmap2/latest/memmap2/"><code>memmap2</code></a></p>
<p><strong>Note</strong>: Be careful about memory mapping it seems to cause issues on <a href="https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893">Windows, WSL</a>
and will definitely be slower on network mounted disk, because it will issue more read calls.</p>
<pre><code class="language-rust ignore">use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&amp;file).unwrap() };
let weights = candle::safetensors::load_buffer(&amp;mmap[..], &amp;Device::Cpu).unwrap();</code></pre>
<p><strong>Note</strong>: This operation is <strong>unsafe</strong>. <a href="https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety">See the safety notice</a>.
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.</p>
<h2 id="tensor-parallel-sharding"><a class="header" href="#tensor-parallel-sharding">Tensor Parallel Sharding</a></h2>
<p>When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.</p>
<p>For that you need to use <a href="https://crates.io/crates/safetensors"><code>safetensors</code></a> directly.</p>
<pre><code class="language-bash">cargo add safetensors
</code></pre>
<pre><code class="language-rust ignore">use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();
let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&amp;file).unwrap() };
// Use safetensors directly
let tensors = SafeTensors::deserialize(&amp;mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();
// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];
if size % world_size != 0 {
panic!("The dimension is not divisible by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;
// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.
let iterator = view.slice(start..stop).unwrap();
tp_shape[dim] = block_size;
// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();
// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec&lt;u8&gt; = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&amp;raw, dtype, &amp;tp_shape, &amp;Device::Cpu).unwrap();</code></pre>
<div style="break-before: page; page-break-before: always;"></div><h1 id="error-management"><a class="header" href="#error-management">Error management</a></h1>
<p>You might have seen in the code base a lot of <code>.unwrap()</code> or <code>?</code>.
If you're unfamiliar with Rust check out the <a href="https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html">Rust book</a>
for more information.</p>
<p>What's important to know though, is that if you want to know <em>where</em> a particular operation failed
You can simply use <code>RUST_BACKTRACE=1</code> to get the location of where the model actually failed.</p>
<p>Let's see on failing code:</p>
<pre><code class="language-rust ignore">let x = Tensor::zeros((1, 784), DType::F32, &amp;device)?;
let y = Tensor::zeros((1, 784), DType::F32, &amp;device)?;
let z = x.matmul(&amp;y)?;</code></pre>
<p>Will print at runtime:</p>
<pre><code class="language-bash">Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
</code></pre>
<p>After adding <code>RUST_BACKTRACE=1</code>:</p>
<pre><code class="language-bash">Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::&lt;impl core::ops::function::FnOnce&lt;A&gt; for &amp;F&gt;::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
</code></pre>
<p>Not super pretty at the moment, but we can see error occurred on <code>{ fn: "myapp::main", file: "./src/main.rs", line: 29 }</code></p>
<p>Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using <a href="https://docs.rs/anyhow/latest/anyhow/"><code>anyhow</code></a> for that.
The library is still young, please <a href="https://github.com/LaurentMazare/candle/issues">report</a> any issues detecting where an error is coming from.</p>
<h2 id="cuda-error-management"><a class="header" href="#cuda-error-management">Cuda error management</a></h2>
<p>When running a model on Cuda, you might get a stacktrace not really representing the error.
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.</p>
<p>One way to avoid this is to use <code>CUDA_LAUNCH_BLOCKING=1</code> as an environment variable. This will force every kernel to be launched sequentially.
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the <code>CudaSlice</code> only.</p>
<p>If this occurs, you can use <a href="https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html"><code>compute-sanitizer</code></a>
This tool is like <code>valgrind</code> but for cuda. It will help locate the errors in the kernels.</p>
<div style="break-before: page; page-break-before: always;"></div><h1 id="training"><a class="header" href="#training">Training</a></h1>
<p>Training starts with data. We're going to use the huggingface hub and
start with the Hello world dataset of machine learning, MNIST.</p>
<p>Let's start with downloading <code>MNIST</code> from <a href="https://huggingface.co/datasets/mnist">huggingface</a>.</p>
<p>This requires <a href="https://github.com/huggingface/hf-hub"><code>hf-hub</code></a>.</p>
<pre><code class="language-bash">cargo add hf-hub
</code></pre>
<p>This is going to be very hands-on for now.</p>
<pre><code class="language-rust ignore"></code></pre>
<p>This uses the standardized <code>parquet</code> files from the <code>refs/convert/parquet</code> branch on every dataset.
Our handles are now [<code>parquet::file::serialized_reader::SerializedFileReader</code>].</p>
<p>We can inspect the content of the files with:</p>
<pre><code class="language-rust ignore"></code></pre>
<p>You should see something like:</p>
<pre><code class="language-bash">Column id 1, name label, value 6
Column id 0, name image, value {bytes: [137, ....]
Column id 1, name label, value 8
Column id 0, name image, value {bytes: [137, ....]
</code></pre>
<p>So each row contains 2 columns (image, label) with image being saved as bytes.
Let's put them into a useful struct.</p>
<div style="break-before: page; page-break-before: always;"></div><h1 id="simplified"><a class="header" href="#simplified">Simplified</a></h1>
<h2 id="how-its-works"><a class="header" href="#how-its-works">How its works</a></h2>
<p>This program implements a neural network to predict the winner of the second round of elections based on the results of the first round.</p>
<p>Basic moments:</p>
<ol>
<li>A multilayer perceptron with two hidden layers is used. The first hidden layer has 4 neurons, the second has 2 neurons.</li>
<li>The input is a vector of 2 numbers - the percentage of votes for the first and second candidates in the first stage.</li>
<li>The output is the number 0 or 1, where 1 means that the first candidate will win in the second stage, 0 means that he will lose.</li>
<li>For training, samples with real data on the results of the first and second stages of different elections are used.</li>
<li>The model is trained by backpropagation using gradient descent and the cross-entropy loss function.</li>
<li>Model parameters (weights of neurons) are initialized randomly, then optimized during training.</li>
<li>After training, the model is tested on a deferred sample to evaluate the accuracy.</li>
<li>If the accuracy on the test set is below 100%, the model is considered underfit and the learning process is repeated.</li>
</ol>
<p>Thus, this neural network learns to find hidden relationships between the results of the first and second rounds of voting in order to make predictions for new data.</p>
<pre><code class="language-rust ignore">const VOTE_DIM: usize = 2;
const RESULTS: usize = 1;
const EPOCHS: usize = 10;
const LAYER1_OUT_SIZE: usize = 4;
const LAYER2_OUT_SIZE: usize = 2;
const LEARNING_RATE: f64 = 0.05;
#[derive(Clone)]
pub struct Dataset {
pub train_votes: Tensor,
pub train_results: Tensor,
pub test_votes: Tensor,
pub test_results: Tensor,
}
struct MultiLevelPerceptron {
ln1: Linear,
ln2: Linear,
ln3: Linear,
}
impl MultiLevelPerceptron {
fn new(vs: VarBuilder) -&gt; Result&lt;Self&gt; {
let ln1 = candle_nn::linear(VOTE_DIM, LAYER1_OUT_SIZE, vs.pp("ln1"))?;
let ln2 = candle_nn::linear(LAYER1_OUT_SIZE, LAYER2_OUT_SIZE, vs.pp("ln2"))?;
let ln3 = candle_nn::linear(LAYER2_OUT_SIZE, RESULTS + 1, vs.pp("ln3"))?;
Ok(Self { ln1, ln2, ln3 })
}
fn forward(&amp;self, xs: &amp;Tensor) -&gt; Result&lt;Tensor&gt; {
let xs = self.ln1.forward(xs)?;
let xs = xs.relu()?;
let xs = self.ln2.forward(&amp;xs)?;
let xs = xs.relu()?;
self.ln3.forward(&amp;xs)
}
}
</code></pre>
<pre><code class="language-rust ignore">fn train(m: Dataset, dev: &amp;Device) -&gt; anyhow::Result&lt;MultiLevelPerceptron&gt; {
let train_results = m.train_results.to_device(dev)?;
let train_votes = m.train_votes.to_device(dev)?;
let varmap = VarMap::new();
let vs = VarBuilder::from_varmap(&amp;varmap, DType::F32, dev);
let model = MultiLevelPerceptron::new(vs.clone())?;
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), LEARNING_RATE)?;
let test_votes = m.test_votes.to_device(dev)?;
let test_results = m.test_results.to_device(dev)?;
let mut final_accuracy: f32 = 0.0;
for epoch in 1..EPOCHS + 1 {
let logits = model.forward(&amp;train_votes)?;
let log_sm = ops::log_softmax(&amp;logits, D::Minus1)?;
let loss = loss::nll(&amp;log_sm, &amp;train_results)?;
sgd.backward_step(&amp;loss)?;
let test_logits = model.forward(&amp;test_votes)?;
let sum_ok = test_logits
.argmax(D::Minus1)?
.eq(&amp;test_results)?
.to_dtype(DType::F32)?
.sum_all()?
.to_scalar::&lt;f32&gt;()?;
let test_accuracy = sum_ok / test_results.dims1()? as f32;
final_accuracy = 100. * test_accuracy;
println!("Epoch: {epoch:3} Train loss: {:8.5} Test accuracy: {:5.2}%",
loss.to_scalar::&lt;f32&gt;()?,
final_accuracy
);
if final_accuracy == 100.0 {
break;
}
}
if final_accuracy &lt; 100.0 {
Err(anyhow::Error::msg("The model is not trained well enough."))
} else {
Ok(model)
}
}</code></pre>
<pre><code class="language-rust ignore">#[tokio::test]
async fn simplified() -&gt; anyhow::Result&lt;()&gt; {
let dev = Device::cuda_if_available(0)?;
let train_votes_vec: Vec&lt;u32&gt; = vec![
15, 10,
10, 15,
5, 12,
30, 20,
16, 12,
13, 25,
6, 14,
31, 21,
];
let train_votes_tensor = Tensor::from_vec(train_votes_vec.clone(), (train_votes_vec.len() / VOTE_DIM, VOTE_DIM), &amp;dev)?.to_dtype(DType::F32)?;
let train_results_vec: Vec&lt;u32&gt; = vec![
1,
0,
0,
1,
1,
0,
0,
1,
];
let train_results_tensor = Tensor::from_vec(train_results_vec, train_votes_vec.len() / VOTE_DIM, &amp;dev)?;
let test_votes_vec: Vec&lt;u32&gt; = vec![
13, 9,
8, 14,
3, 10,
];
let test_votes_tensor = Tensor::from_vec(test_votes_vec.clone(), (test_votes_vec.len() / VOTE_DIM, VOTE_DIM), &amp;dev)?.to_dtype(DType::F32)?;
let test_results_vec: Vec&lt;u32&gt; = vec![
1,
0,
0,
];
let test_results_tensor = Tensor::from_vec(test_results_vec.clone(), test_results_vec.len(), &amp;dev)?;
let m = Dataset {
train_votes: train_votes_tensor,
train_results: train_results_tensor,
test_votes: test_votes_tensor,
test_results: test_results_tensor,
};
let trained_model: MultiLevelPerceptron;
loop {
println!("Trying to train neural network.");
match train(m.clone(), &amp;dev) {
Ok(model) =&gt; {
trained_model = model;
break;
},
Err(e) =&gt; {
println!("Error: {}", e);
continue;
}
}
}
let real_world_votes: Vec&lt;u32&gt; = vec![
13, 22,
];
let tensor_test_votes = Tensor::from_vec(real_world_votes.clone(), (1, VOTE_DIM), &amp;dev)?.to_dtype(DType::F32)?;
let final_result = trained_model.forward(&amp;tensor_test_votes)?;
let result = final_result
.argmax(D::Minus1)?
.to_dtype(DType::F32)?
.get(0).map(|x| x.to_scalar::&lt;f32&gt;())??;
println!("real_life_votes: {:?}", real_world_votes);
println!("neural_network_prediction_result: {:?}", result);
Ok(())
}</code></pre>
<h2 id="example-output"><a class="header" href="#example-output">Example output</a></h2>
<pre><code class="language-bash">Trying to train neural network.
Epoch: 1 Train loss: 4.42555 Test accuracy: 0.00%
Epoch: 2 Train loss: 0.84677 Test accuracy: 33.33%
Epoch: 3 Train loss: 2.54335 Test accuracy: 33.33%
Epoch: 4 Train loss: 0.37806 Test accuracy: 33.33%
Epoch: 5 Train loss: 0.36647 Test accuracy: 100.00%
real_life_votes: [13, 22]
neural_network_prediction_result: 0.0
</code></pre>
<div style="break-before: page; page-break-before: always;"></div><h1 id="mnist"><a class="header" href="#mnist">MNIST</a></h1>
<p>So we now have downloaded the MNIST parquet files, let's put them in a simple struct.</p>
<pre><code class="language-rust ignore">
let test_samples = 10_000;
let mut test_buffer_images: Vec&lt;u8&gt; = Vec::with_capacity(test_samples * 784);
let mut test_buffer_labels: Vec&lt;u8&gt; = Vec::with_capacity(test_samples);
for row in test_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
test_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
test_buffer_labels.push(*label as u8);
}
}
}
let test_images = (Tensor::from_vec(test_buffer_images, (test_samples, 784), &amp;Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let test_labels = Tensor::from_vec(test_buffer_labels, (test_samples, ), &amp;Device::Cpu)?;
let train_samples = 60_000;
let mut train_buffer_images: Vec&lt;u8&gt; = Vec::with_capacity(train_samples * 784);
let mut train_buffer_labels: Vec&lt;u8&gt; = Vec::with_capacity(train_samples);
for row in train_parquet{
for (_name, field) in row?.get_column_iter() {
if let parquet::record::Field::Group(subrow) = field {
for (_name, field) in subrow.get_column_iter() {
if let parquet::record::Field::Bytes(value) = field {
let image = image::load_from_memory(value.data()).unwrap();
train_buffer_images.extend(image.to_luma8().as_raw());
}
}
}else if let parquet::record::Field::Long(label) = field {
train_buffer_labels.push(*label as u8);
}
}
}
let train_images = (Tensor::from_vec(train_buffer_images, (train_samples, 784), &amp;Device::Cpu)?.to_dtype(DType::F32)? / 255.)?;
let train_labels = Tensor::from_vec(train_buffer_labels, (train_samples, ), &amp;Device::Cpu)?;
let mnist = candle_datasets::vision::Dataset {
train_images,
train_labels,
test_images,
test_labels,
labels: 10,
};
</code></pre>
<p>The parsing of the file and putting it into single tensors requires the dataset to fit the entire memory.
It is quite rudimentary, but simple enough for a small dataset like MNIST.</p>
</main>
<nav class="nav-wrapper" aria-label="Page navigation">
<!-- Mobile navigation buttons -->
<div style="clear: both"></div>
</nav>
</div>
</div>
<nav class="nav-wide-wrapper" aria-label="Page navigation">
</nav>
</div>
<script>
window.playground_copyable = true;
</script>
<script src="elasticlunr.min.js"></script>
<script src="mark.min.js"></script>
<script src="searcher.js"></script>
<script src="clipboard.min.js"></script>
<script src="highlight.js"></script>
<script src="book.js"></script>
<!-- Custom JS scripts -->
<script>
window.addEventListener('load', function() {
window.setTimeout(window.print, 100);
});
</script>
</div>
</body>
</html>