mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
3rd phase.
This commit is contained in:
@ -12,11 +12,11 @@
|
|||||||
|
|
||||||
- [Running a model](inference/README.md)
|
- [Running a model](inference/README.md)
|
||||||
- [Using the hub](inference/hub.md)
|
- [Using the hub](inference/hub.md)
|
||||||
- [Serialization](inference/serialization.md)
|
|
||||||
- [Advanced Cuda usage](inference/cuda/README.md)
|
|
||||||
- [Writing a custom kernel](inference/cuda/writing.md)
|
|
||||||
- [Porting a custom kernel](inference/cuda/porting.md)
|
|
||||||
- [Error management](error_manage.md)
|
- [Error management](error_manage.md)
|
||||||
|
- [Advanced Cuda usage](cuda/README.md)
|
||||||
|
- [Writing a custom kernel](cuda/writing.md)
|
||||||
|
- [Porting a custom kernel](cuda/porting.md)
|
||||||
|
- [Using MKL](advanced/mkl.md)
|
||||||
- [Creating apps](apps/README.md)
|
- [Creating apps](apps/README.md)
|
||||||
- [Creating a WASM app](apps/wasm.md)
|
- [Creating a WASM app](apps/wasm.md)
|
||||||
- [Creating a REST api webserver](apps/rest.md)
|
- [Creating a REST api webserver](apps/rest.md)
|
||||||
@ -24,4 +24,4 @@
|
|||||||
- [Training](training/README.md)
|
- [Training](training/README.md)
|
||||||
- [MNIST](training/mnist.md)
|
- [MNIST](training/mnist.md)
|
||||||
- [Fine-tuning](training/finetuning.md)
|
- [Fine-tuning](training/finetuning.md)
|
||||||
- [Using MKL](advanced/mkl.md)
|
- [Serialization](training/serialization.md)
|
||||||
|
1
candle-book/src/cuda/README.md
Normal file
1
candle-book/src/cuda/README.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Advanced Cuda usage
|
1
candle-book/src/cuda/porting.md
Normal file
1
candle-book/src/cuda/porting.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Porting a custom kernel
|
1
candle-book/src/cuda/writing.md
Normal file
1
candle-book/src/cuda/writing.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Writing a custom kernel
|
@ -1 +1,39 @@
|
|||||||
# Error management
|
# Error management
|
||||||
|
|
||||||
|
You might have seen in the code base a lot of `.unwrap()` or `?`.
|
||||||
|
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
|
||||||
|
for more information.
|
||||||
|
|
||||||
|
What's important to know though, is that if you want to know *where* a particular operation failed
|
||||||
|
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
|
||||||
|
|
||||||
|
Let's see on failing code:
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||||
|
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
|
||||||
|
let z = x.matmul(&y)?;
|
||||||
|
```
|
||||||
|
|
||||||
|
Will print at runtime:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
After adding `RUST_BACKTRACE=1`:
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
|
||||||
|
```
|
||||||
|
|
||||||
|
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
|
||||||
|
|
||||||
|
|
||||||
|
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
|
||||||
|
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
|
||||||
|
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.
|
||||||
|
|
||||||
|
|
||||||
|
@ -1 +1,7 @@
|
|||||||
# Running a model
|
# Running a model
|
||||||
|
|
||||||
|
|
||||||
|
In order to run an existing model, you will need to download and use existing weights.
|
||||||
|
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
|
||||||
|
|
||||||
|
Let's get started by running an old model : `bert-base-uncased`.
|
||||||
|
@ -1 +1,80 @@
|
|||||||
# Using the hub
|
# Using the hub
|
||||||
|
|
||||||
|
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo add hf-hub
|
||||||
|
```
|
||||||
|
|
||||||
|
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
|
||||||
|
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
# extern crate hf_hub;
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use candle::Device;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
|
||||||
|
let weights = repo.get("model.safetensors").unwrap();
|
||||||
|
|
||||||
|
let weights = candle::safetensors::load(weights, &Device::Cpu);
|
||||||
|
```
|
||||||
|
|
||||||
|
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
|
||||||
|
|
||||||
|
|
||||||
|
## Using async
|
||||||
|
|
||||||
|
`hf-hub` comes with an async API.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo add hf-hub --features tokio
|
||||||
|
```
|
||||||
|
|
||||||
|
```rust,ignore
|
||||||
|
# extern crate candle;
|
||||||
|
# extern crate hf_hub;
|
||||||
|
use hf_hub::api::tokio::Api;
|
||||||
|
use candle::Device;
|
||||||
|
|
||||||
|
let api = Api::new().unwrap();
|
||||||
|
let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
|
||||||
|
let weights = repo.get("model.safetensors").await.unwrap();
|
||||||
|
|
||||||
|
let weights = candle::safetensors::load(weights, &Device::Cpu);
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Using in a real model.
|
||||||
|
|
||||||
|
Now that we have our weights, we can use them in our bert architecture:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
# extern crate candle;
|
||||||
|
# extern crate candle_nn;
|
||||||
|
# extern crate hf_hub;
|
||||||
|
# use hf_hub::api::sync::Api;
|
||||||
|
# use candle::Device;
|
||||||
|
#
|
||||||
|
# let api = Api::new().unwrap();
|
||||||
|
# let repo = api.model("bert-base-uncased".to_string());
|
||||||
|
#
|
||||||
|
# let weights = repo.get("model.safetensors").unwrap();
|
||||||
|
use candle_nn::Linear;
|
||||||
|
|
||||||
|
let weights = candle::safetensors::load(weights, &Device::Cpu);
|
||||||
|
|
||||||
|
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
|
||||||
|
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
|
||||||
|
|
||||||
|
let linear = Linear::new(weight, Some(bias));
|
||||||
|
|
||||||
|
let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
|
||||||
|
let output = linear.forward(&input_ids);
|
||||||
|
```
|
||||||
|
|
||||||
|
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.
|
||||||
|
@ -1 +1,3 @@
|
|||||||
# Serialization
|
# Serialization
|
||||||
|
|
||||||
|
Once you have a r
|
||||||
|
1
candle-book/src/training/serialization.md
Normal file
1
candle-book/src/training/serialization.md
Normal file
@ -0,0 +1 @@
|
|||||||
|
# Serialization
|
Reference in New Issue
Block a user