mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Fix the tests for mkl. (#437)
This commit is contained in:
17
README.md
17
README.md
@ -64,14 +64,17 @@ And then head over to
|
||||
## Features
|
||||
|
||||
- Simple syntax, looks and feels like PyTorch.
|
||||
- CPU and Cuda backends, m1, f16, bf16.
|
||||
- Serverless (on CPU), small and fast deployments
|
||||
- WASM support, run your models in a browser.
|
||||
- Model training.
|
||||
- Distributed computing using NCCL.
|
||||
- Model support out of the box: Llama, Whisper, Falcon, StarCoder...
|
||||
- Embed user-defined ops/kernels, such as [flash-attention
|
||||
v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
|
||||
- Backends.
|
||||
- Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
|
||||
- CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
|
||||
- WASM support, run your models in a browser.
|
||||
- Model support out of the box.
|
||||
- LLMs: Llama v1 and v2, Falcon, StarCoder.
|
||||
- Whisper.
|
||||
- Stable Diffusion.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
|
||||
<!--- ANCHOR_END: features --->
|
||||
|
||||
|
@ -1,4 +1,6 @@
|
||||
use candle::{Device, Result, Tensor};
|
||||
mod test_utils;
|
||||
use test_utils::to_vec0_round;
|
||||
|
||||
/* Equivalent python code:
|
||||
import torch
|
||||
@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> {
|
||||
|
||||
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
|
||||
let loss = candle_nn::loss::nll(&log_softmax, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||
let loss = candle_nn::loss::cross_entropy(&input, &target)?;
|
||||
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335);
|
||||
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user