3rd phase.

This commit is contained in:
Nicolas Patry
2023-07-28 12:07:39 +02:00
parent 52414ba5c8
commit 82464166e4
9 changed files with 134 additions and 5 deletions

View File

@ -12,11 +12,11 @@
- [Running a model](inference/README.md) - [Running a model](inference/README.md)
- [Using the hub](inference/hub.md) - [Using the hub](inference/hub.md)
- [Serialization](inference/serialization.md)
- [Advanced Cuda usage](inference/cuda/README.md)
- [Writing a custom kernel](inference/cuda/writing.md)
- [Porting a custom kernel](inference/cuda/porting.md)
- [Error management](error_manage.md) - [Error management](error_manage.md)
- [Advanced Cuda usage](cuda/README.md)
- [Writing a custom kernel](cuda/writing.md)
- [Porting a custom kernel](cuda/porting.md)
- [Using MKL](advanced/mkl.md)
- [Creating apps](apps/README.md) - [Creating apps](apps/README.md)
- [Creating a WASM app](apps/wasm.md) - [Creating a WASM app](apps/wasm.md)
- [Creating a REST api webserver](apps/rest.md) - [Creating a REST api webserver](apps/rest.md)
@ -24,4 +24,4 @@
- [Training](training/README.md) - [Training](training/README.md)
- [MNIST](training/mnist.md) - [MNIST](training/mnist.md)
- [Fine-tuning](training/finetuning.md) - [Fine-tuning](training/finetuning.md)
- [Using MKL](advanced/mkl.md) - [Serialization](training/serialization.md)

View File

@ -0,0 +1 @@
# Advanced Cuda usage

View File

@ -0,0 +1 @@
# Porting a custom kernel

View File

@ -0,0 +1 @@
# Writing a custom kernel

View File

@ -1 +1,39 @@
# Error management # Error management
You might have seen in the code base a lot of `.unwrap()` or `?`.
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
for more information.
What's important to know though, is that if you want to know *where* a particular operation failed
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.
Let's see on failing code:
```rust,ignore
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
let z = x.matmul(&y)?;
```
Will print at runtime:
```bash
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
```
After adding `RUST_BACKTRACE=1`:
```bash
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```
Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`
Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.

View File

@ -1 +1,7 @@
# Running a model # Running a model
In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.
Let's get started by running an old model : `bert-base-uncased`.

View File

@ -1 +1,80 @@
# Using the hub # Using the hub
Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:
```bash
cargo add hf-hub
```
Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).
```rust
# extern crate candle;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").unwrap();
let weights = candle::safetensors::load(weights, &Device::Cpu);
```
We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.
## Using async
`hf-hub` comes with an async API.
```bash
cargo add hf-hub --features tokio
```
```rust,ignore
# extern crate candle;
# extern crate hf_hub;
use hf_hub::api::tokio::Api;
use candle::Device;
let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights = repo.get("model.safetensors").await.unwrap();
let weights = candle::safetensors::load(weights, &Device::Cpu);
```
## Using in a real model.
Now that we have our weights, we can use them in our bert architecture:
```rust
# extern crate candle;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
# use candle::Device;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_nn::Linear;
let weights = candle::safetensors::load(weights, &Device::Cpu);
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
let linear = Linear::new(weight, Some(bias));
let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids);
```
For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.

View File

@ -1 +1,3 @@
# Serialization # Serialization
Once you have a r

View File

@ -0,0 +1 @@
# Serialization