From eab54e449004bf0d5544ee0261dc89d2c3062b77 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 14 Aug 2023 08:09:27 +0100 Subject: [PATCH] Fix the tests for mkl. (#437) --- README.md | 19 +++++++++++-------- candle-nn/tests/loss.rs | 6 ++++-- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 9644b15c..02c46c01 100644 --- a/README.md +++ b/README.md @@ -64,14 +64,17 @@ And then head over to ## Features - Simple syntax, looks and feels like PyTorch. -- CPU and Cuda backends, m1, f16, bf16. -- Serverless (on CPU), small and fast deployments -- WASM support, run your models in a browser. -- Model training. -- Distributed computing using NCCL. -- Model support out of the box: Llama, Whisper, Falcon, StarCoder... -- Embed user-defined ops/kernels, such as [flash-attention - v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152). + - Model training. + - Embed user-defined ops/kernels, such as [flash-attention v2](https://github.com/huggingface/candle/blob/89ba005962495f2bfbda286e185e9c3c7f5300a3/candle-flash-attn/src/lib.rs#L152). +- Backends. + - Optimized CPU backend with optional MKL support for x86 and Accelerate for macs. + - CUDA backend for efficiently running on GPUs, multiple GPU distribution via NCCL. + - WASM support, run your models in a browser. +- Model support out of the box. + - LLMs: Llama v1 and v2, Falcon, StarCoder. + - Whisper. + - Stable Diffusion. +- Serverless (on CPU), small and fast deployments. diff --git a/candle-nn/tests/loss.rs b/candle-nn/tests/loss.rs index 0811fa39..9df73c94 100644 --- a/candle-nn/tests/loss.rs +++ b/candle-nn/tests/loss.rs @@ -1,4 +1,6 @@ use candle::{Device, Result, Tensor}; +mod test_utils; +use test_utils::to_vec0_round; /* Equivalent python code: import torch @@ -27,8 +29,8 @@ fn nll_and_cross_entropy() -> Result<()> { let log_softmax = candle_nn::ops::log_softmax(&input, 1)?; let loss = candle_nn::loss::nll(&log_softmax, &target)?; - assert_eq!(loss.to_vec0::()?, 1.1312335); + assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); let loss = candle_nn::loss::cross_entropy(&input, &target)?; - assert_eq!(loss.to_vec0::()?, 1.1312335); + assert_eq!(to_vec0_round(&loss, 4)?, 1.1312); Ok(()) }