Fix the tests for mkl. (#437)

This commit is contained in:
Laurent Mazare
2023-08-14 08:09:27 +01:00
committed by GitHub
parent 9e7e6e0288
commit eab54e4490
2 changed files with 15 additions and 10 deletions

View File

@ -64,14 +64,17 @@ And then head over to
## Features ## Features
- Simple syntax, looks and feels like PyTorch. - Simple syntax, looks and feels like PyTorch.
- CPU and Cuda backends, m1, f16, bf16. - Model training.
- Serverless (on CPU), small and fast deployments - Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152).
- WASM support, run your models in a browser. - Backends.
- Model training. - Optimized CPU backend with optional MKL support for x86 and Accelerate for macs.
- Distributed computing using NCCL. - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL.
- Model support out of the box: Llama, Whisper, Falcon, StarCoder... - WASM support, run your models in a browser.
- Embed user-defined ops/kernels, such as [flash-attention - Model support out of the box.
v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152). - LLMs: Llama v1 and v2, Falcon, StarCoder.
- Whisper.
- Stable Diffusion.
- Serverless (on CPU), small and fast deployments.
<!--- ANCHOR_END: features ---> <!--- ANCHOR_END: features --->

View File

@ -1,4 +1,6 @@
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};
mod test_utils;
use test_utils::to_vec0_round;
/* Equivalent python code: /* Equivalent python code:
import torch import torch
@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> {
let log_softmax = candle_nn::ops::log_softmax(&input, 1)?; let log_softmax = candle_nn::ops::log_softmax(&input, 1)?;
let loss = candle_nn::loss::nll(&log_softmax, &target)?; let loss = candle_nn::loss::nll(&log_softmax, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335); assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
let loss = candle_nn::loss::cross_entropy(&input, &target)?; let loss = candle_nn::loss::cross_entropy(&input, &target)?;
assert_eq!(loss.to_vec0::<f32>()?, 1.1312335); assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(()) Ok(())
} }