From 18707891b7266e0695dd74147e6c1ec456aacd15 Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 27 Jun 2023 09:45:38 +0100 Subject: [PATCH] Fix an error message. --- examples/cuda_basics.rs | 10 ++++------ src/cuda_backend.rs | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/examples/cuda_basics.rs b/examples/cuda_basics.rs index f288cb40..2c790ed1 100644 --- a/examples/cuda_basics.rs +++ b/examples/cuda_basics.rs @@ -3,12 +3,10 @@ use candle::{Device, Tensor}; fn main() -> Result<()> { let device = Device::new_cuda(0)?; - let x = Tensor::new(&[[11f32, 22.], [33., 44.], [55., 66.], [77., 78.]], &device)?; - println!("> {:?}", x.sum(&[0])?.to_vec2::()?); - println!("> {:?}", x.sum(&[1])?.to_vec2::()?); - println!("> {:?}", x.sum(&[0, 1])?.to_vec2::()?); - let x = x.to_dtype(candle::DType::F16)?; - println!("> {:?}", x.sum(&[0])?.to_vec2::()?); + let ids = Tensor::new(&[0u32, 3u32, 1u32], &device)?; + let t = Tensor::new(&[[0f32, 1f32], [1f32, 2f32], [2f32, 3f32]], &device)?; + let hs = Tensor::embedding(&ids, &t)?; + println!("> {:?}", hs.to_vec2::()); let x = Tensor::new(&[3f32, 1., 4., 1., 5.], &device)?; println!("{:?}", x.to_vec1::()?); diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index 2410f1d7..caaa64b8 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -700,7 +700,7 @@ impl CudaStorage { let ids = match &self.slice { CudaStorageSlice::U32(slice) => slice, _ => Err(CudaError::UnexpectedDType { - msg: "embedding ids should be u32", + msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?,