Files
candle/README.md
2023-08-23 08:32:59 +00:00

12 KiB

candle

discord server Latest version Documentation License

Candle is a minimalist ML framework for Rust with a focus on performance (including GPU support) and ease of use. Try our online demos: whisper, LLaMA2, yolo.

Installation

  • With Cuda support:
  1. To install candle with Cuda support, first make sure that Cuda is correctly installed.
  • nvcc --version should print your information about your Cuda compiler driver.
  • nvidia-smi --query-gpu=compute_cap --format=csv should print your GPUs compute capability, e.g. something like:
compute_cap
8.9

If any of the above commands errors out, please make sure to update your CUDA version.

  1. Create a new app and add candle-core with Cuda support
cargo new myapp
cd myapp

Next make sure to add the candle-core crate with the cuda feature:

cargo add --git https://github.com/huggingface/candle.git candle-core --features "cuda"

Finally, run cargo build to make sure everything can be correctly built.

cargo run

Now you can run the example as shown in the next section!

  • Without Cuda support:

Create a new app and add candle-core as follows:

cargo new myapp
cd myapp
cargo add --git https://github.com/huggingface/candle.git candle-core

Finally, run cargo build to make sure everything can be correctly built.

cargo run

Get started

Having installed candle-core as described in Installation, we can now run a simple matrix multiplication.

First, let's add the anyhow package to our app.

cd myapp
cargo add anyhow

Next, write the following to your myapp/src/main.rs file:

use anyhow::Result;
use candle_core::{Device, Tensor};

fn main() -> Result<()> {
    let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
    let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;

    let c = a.matmul(&b)?;
    println!("{c}");
    Ok(())
}

cargo run should display a tensor of shape Tensor[[2, 4], f32]

Having installed candle with Cuda support, you can create the tensors on GPU instead as follows:

- let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
- let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
+ let a = Tensor::randn(0f32, 1., (2, 3), &Device::new_cuda(0)?)?;
+ let b = Tensor::randn(0f32, 1., (3, 4), &Device::new_cuda(0)?)?;

For more advanced examples, please have a look at the following sections.

Check out our examples

Check out our examples:

  • Whisper: speech recognition model.
  • LLaMA and LLaMA-v2: general LLM.
  • Falcon: general LLM.
  • Bert: useful for sentence embeddings.
  • StarCoder: LLM specialized to code generation.
  • Stable Diffusion: text to image generative model.
  • DINOv2: computer vision model trained using self-supervision (can be used for imagenet classification, depth evaluation, segmentation).
  • Quantized LLaMA: quantized version of the LLaMA model using the same quantization techniques as llama.cpp.
  • yolo-v3 and yolo-v8: object detection models. Run them using the following commands:
cargo run --example whisper --release
cargo run --example llama --release
cargo run --example falcon --release
cargo run --example bert --release
cargo run --example bigcode --release
cargo run --example stable-diffusion --release -- --prompt "a rusty robot holding a fire torch"
cargo run --example dinov2 --release -- --image path/to/myinput.jpg
cargo run --example quantized --release
cargo run --example yolo-v3 --release -- myimage.jpg
cargo run --example yolo-v8 --release -- myimage.jpg

In order to use CUDA add --features cuda to the example command line. If you have cuDNN installed, use --features cudnn for even more speedups.

There are also some wasm examples for whisper and llama2.c. You can either build them with trunk or try them online: whisper, llama2.

For LLaMA2, run the following command to retrieve the weight files and start a test server:

cd candle-wasm-examples/llama2-c
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/model.bin
wget https://huggingface.co/spaces/lmz/candle-llama2/resolve/main/tokenizer.json
trunk serve --release --public-url /candle-llama2/ --port 8081

And then head over to http://localhost:8081/candle-llama2.

Features

  • Simple syntax, looks and feels like PyTorch.
  • Backends.
    • Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
    • CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
    • WASM support, run your models in a browser.
  • Included models.
    • LLMs: LLaMA v1 and v2, Falcon, StarCoder.
    • Whisper (multi-lingual support).
    • Stable Diffusion.
    • Computer Vision: DINOv2.
  • File formats: load models from safetensors, npz, ggml, or PyTorch files.
  • Serverless (on CPU), small and fast deployments.
  • Quantization support using the llama.cpp quantized types.

How to use

Cheatsheet:

Using PyTorch Using Candle
Creation torch.Tensor([[1, 2], [3, 4]]) Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
Creation torch.zeros((2, 2)) Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
Indexing tensor[:, :4] tensor.i((.., ..4))?
Operations tensor.view((2, 2)) tensor.reshape((2, 2))?
Operations a.matmul(b) a.matmul(&b)?
Arithmetic a + b &a + &b
Device tensor.to(device="cuda") tensor.to_device(&Device::Cuda(0))?
Dtype tensor.to(dtype=torch.float16) tensor.to_dtype(&DType::F16)?
Saving torch.save({"A": A}, "model.bin") candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
Loading weights = torch.load("model.bin") candle::safetensors::load("model.safetensors", &device)

Structure

FAQ

Why should I use Candle?

Candle's core goal is to make serverless inference possible. Full machine learning frameworks like PyTorch are very large, which makes creating instances on a cluster slow. Candle allows deployment of lightweight binaries.

Secondly, Candle lets you remove Python from production workloads. Python overhead can seriously hurt performance, and the GIL is a notorious source of headaches.

Finally, Rust is cool! A lot of the HF ecosystem already has Rust crates, like safetensors and tokenizers.

Other ML frameworks

  • dfdx is a formidable crate, with shapes being included in types. This prevents a lot of headaches by getting the compiler to complain about shape mismatches right off the bat. However, we found that some features still require nightly, and writing code can be a bit daunting for non rust experts.

    We're leveraging and contributing to other core crates for the runtime so hopefully both crates can benefit from each other.

  • burn is a general crate that can leverage multiple backends so you can choose the best engine for your workload.

  • tch-rs Bindings to the torch library in Rust. Extremely versatile, but they bring in the entire torch library into the runtime. The main contributor of tch-rs is also involved in the development of candle.

Common Errors

Missing symbols when compiling with the mkl feature.

If you get some missing symbols when compiling binaries/tests using the mkl or accelerate features, e.g. for mkl you get:

  = note: /usr/bin/ld: (....o): in function `blas::sgemm':
          .../blas-0.22.0/src/lib.rs:1944: undefined reference to `sgemm_' collect2: error: ld returned 1 exit status

  = note: some `extern` functions couldn't be found; some native libraries may need to be installed or have their path specified
  = note: use the `-l` flag to specify native libraries to link
  = note: use the `cargo:rustc-link-lib` directive to specify the native libraries to link with Cargo

or for accelerate:

Undefined symbols for architecture arm64:
            "_dgemm_", referenced from:
                candle_core::accelerate::dgemm::h1b71a038552bcabe in libcandle_core...
            "_sgemm_", referenced from:
                candle_core::accelerate::sgemm::h2cf21c592cba3c47 in libcandle_core...
          ld: symbol(s) not found for architecture arm64

This is likely due to a missing linker flag that was needed to enable the mkl library. You can try adding the following for mkl at the top of your binary:

extern crate intel_mkl_src;

or for accelerate:

extern crate accelerate_src;

Cannot run the LLaMA examples: access to source requires login credentials

Error: request error: https://huggingface.co/meta-llama/Llama-2-7b-hf/resolve/main/tokenizer.json: status code 401

This is likely because you're not permissioned for the LLaMA-v2 model. To fix this, you have to register on the huggingface-hub, accept the LLaMA-v2 model conditions, and set up your authentication token. See issue #350 for more details.

Tracking down errors

You can set RUST_BACKTRACE=1 to be provided with backtraces when a candle error is generated.