mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 20:38:06 +00:00
Fix the tests for mkl. (#437)
This commit is contained in:
19
README.md
19
README.md
@ -64,14 +64,17 @@ And then head over to
|
|||||||
## Features
|
## Features
|
||||||
|
|
||||||
- Simple syntax, looks and feels like PyTorch.
|
- Simple syntax, looks and feels like PyTorch.
|
||||||
- CPU and Cuda backends, m1, f16, bf16.
|
- Model training.
|
||||||
- Serverless (on CPU), small and fast deployments
|
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||||
- WASM support, run your models in a browser.
|
- Backends.
|
||||||
- Model training.
|
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||||
- Distributed computing using NCCL.
|
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||||
- Model support out of the box: Llama, Whisper, Falcon, StarCoder...
|
- WASM support, run your models in a browser.
|
||||||
- Embed user-defined ops/kernels, such as [flash-attention
|
- Model support out of the box.
|
||||||
v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||||
|
- Whisper.
|
||||||
|
- Stable Diffusion.
|
||||||
|
- Serverless (on CPU), small and fast deployments.
|
||||||
|
|
||||||
<!--- ANCHOR_END: features --->
|
<!--- ANCHOR_END: features --->
|
||||||
|
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
mod test_utils;
|
||||||
|
use test_utils::to_vec0_round;
|
||||||
|
|
||||||
/* Equivalent python code:
|
/* Equivalent python code:
|
||||||
import torch
|
import torch
|
||||||
@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> {
|
|||||||
|
|
||||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||||
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
||||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user