mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Merge pull request #16 from LaurentMazare/cuda-tests
Run the tensor tests for the cuda backend too.
This commit is contained in:
@ -1,7 +1,7 @@
|
||||
use crate::{CpuStorage, DType, Shape};
|
||||
use candle_kernels as kernels;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
|
||||
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
|
||||
use half::{bf16, f16};
|
||||
use std::sync::Arc;
|
||||
|
||||
@ -240,6 +240,24 @@ enum CudaStorageSlice {
|
||||
F64(CudaSlice<f64>),
|
||||
}
|
||||
|
||||
fn slice_src_and_dst<'a, T>(
|
||||
src: &'a CudaSlice<T>,
|
||||
src_offset: usize,
|
||||
dst: &'a mut CudaSlice<T>,
|
||||
dst_offset: usize,
|
||||
) -> (
|
||||
cudarc::driver::CudaView<'a, T>,
|
||||
cudarc::driver::CudaViewMut<'a, T>,
|
||||
) {
|
||||
let to_copy = dst
|
||||
.len()
|
||||
.saturating_sub(dst_offset)
|
||||
.min(src.len().saturating_sub(src_offset));
|
||||
let src = src.slice(src_offset..src_offset + to_copy);
|
||||
let dst = dst.slice_mut(dst_offset..dst_offset + to_copy);
|
||||
(src, dst)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CudaStorage {
|
||||
slice: CudaStorageSlice,
|
||||
@ -903,8 +921,7 @@ impl CudaStorage {
|
||||
let ds = dev.htod_copy([dims, src_stride].concat())?;
|
||||
match (&self.slice, &mut dst.slice) {
|
||||
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -916,8 +933,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -929,8 +945,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -942,8 +957,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
@ -955,8 +969,7 @@ impl CudaStorage {
|
||||
}
|
||||
}
|
||||
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
|
||||
let src = src.slice(src_offset..);
|
||||
let mut dst = dst.slice_mut(dst_offset..);
|
||||
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
|
||||
if src_shape.is_contiguous(src_stride) {
|
||||
dev.dtod_copy(&src, &mut dst)?
|
||||
} else {
|
||||
|
@ -1,18 +1,16 @@
|
||||
// TODO: Also test the cuda backend.
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
|
||||
#[test]
|
||||
fn zeros() -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, &Device::Cpu)?;
|
||||
fn zeros(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
|
||||
let (dim1, dim2) = tensor.shape().r2()?;
|
||||
assert_eq!(dim1, 5);
|
||||
assert_eq!(dim2, 2);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn add_mul() -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
|
||||
fn add_mul(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let dim1 = tensor.shape().r1()?;
|
||||
assert_eq!(dim1, 3);
|
||||
let content: Vec<f32> = tensor.to_vec1()?;
|
||||
@ -26,10 +24,9 @@ fn add_mul() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_2d() -> Result<()> {
|
||||
fn tensor_2d(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
@ -37,28 +34,27 @@ fn tensor_2d() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn binary_op() -> Result<()> {
|
||||
fn binary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor2 = Tensor::new(data2, &Device::Cpu)?;
|
||||
let tensor2 = Tensor::new(data2, device)?;
|
||||
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
assert_eq!(dims, (2, 5));
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
|
||||
assert_eq!(content[1], [3.0, 1.5, 10.5, 12.0, 3.0]);
|
||||
#[allow(clippy::eq_op)]
|
||||
let tensor = (&tensor - &tensor)?;
|
||||
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
|
||||
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_2d_transpose() -> Result<()> {
|
||||
fn transpose(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?.t()?;
|
||||
let tensor = Tensor::new(data, device)?.t()?;
|
||||
let dims = tensor.shape().r2()?;
|
||||
assert_eq!(dims, (5, 2));
|
||||
assert_eq!(
|
||||
@ -71,10 +67,9 @@ fn tensor_2d_transpose() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn softmax() -> Result<()> {
|
||||
fn softmax(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let t0 = tensor.log()?.softmax(0)?;
|
||||
let t1 = tensor.log()?.softmax(1)?;
|
||||
let t2 = tensor.log()?.softmax(2)?;
|
||||
@ -108,10 +103,9 @@ fn softmax() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sum() -> Result<()> {
|
||||
fn sum(device: &Device) -> Result<()> {
|
||||
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.sum(&[2])?.to_vec3::<u32>()?,
|
||||
&[[[8], [15]], [[10], [18]]]
|
||||
@ -132,10 +126,9 @@ fn sum() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn narrow() -> Result<()> {
|
||||
fn narrow(device: &Device) -> Result<()> {
|
||||
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
|
||||
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
|
||||
@ -163,10 +156,9 @@ fn narrow() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn broadcast() -> Result<()> {
|
||||
fn broadcast(device: &Device) -> Result<()> {
|
||||
let data = &[3f32, 1., 4.];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
assert_eq!(
|
||||
tensor.broadcast_left((3, 1))?.to_vec3::<f32>()?,
|
||||
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||
@ -174,12 +166,11 @@ fn broadcast() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cat() -> Result<()> {
|
||||
fn cat(device: &Device) -> Result<()> {
|
||||
// 1D
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
|
||||
let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?;
|
||||
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?;
|
||||
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
|
||||
let t2 = Tensor::new(&[1f32, 5., 9., 2.], device)?;
|
||||
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], device)?;
|
||||
assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::<f32>()?, [3f32, 1., 4.],);
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1, &t2], 0)?.to_vec1::<f32>()?,
|
||||
@ -192,9 +183,9 @@ fn cat() -> Result<()> {
|
||||
|
||||
// 2D
|
||||
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t1 = Tensor::new(data, &Device::Cpu)?;
|
||||
let t1 = Tensor::new(data, device)?;
|
||||
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
|
||||
let t2 = Tensor::new(data2, &Device::Cpu)?;
|
||||
let t2 = Tensor::new(data2, device)?;
|
||||
assert_eq!(
|
||||
Tensor::cat(&[&t1, &t2], 0)?.to_vec2::<f32>()?,
|
||||
[
|
||||
@ -226,3 +217,36 @@ fn cat() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
macro_rules! test {
|
||||
// TODO: Switch to generating the two last arguments automatically once concat_idents is
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
|
||||
#[test]
|
||||
fn $test_cpu() -> Result<()> {
|
||||
$fn_name(&Device::Cpu)
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
#[test]
|
||||
fn $test_cuda() -> Result<()> {
|
||||
$fn_name(&Device::new_cuda(0)?)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
test!(zeros, zeros_cpu, zeros_gpu);
|
||||
test!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||
test!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||
test!(narrow, narrow_cpu, narrow_gpu);
|
||||
test!(broadcast, broadcast_cpu, broadcast_gpu);
|
||||
test!(cat, cat_cpu, cat_gpu);
|
||||
test!(sum, sum_cpu, sum_gpu);
|
||||
test!(transpose, transpose_cpu, transpose_gpu);
|
||||
test!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||
|
||||
// TODO: Make the test less sensitive to numerical precision and enable on the gpu.
|
||||
#[test]
|
||||
fn softmax_cpu() -> Result<()> {
|
||||
softmax(&Device::Cpu)
|
||||
}
|
||||
|
Reference in New Issue
Block a user