Merge pull request #16 from LaurentMazare/cuda-tests

Run the tensor tests for the cuda backend too.
This commit is contained in:
Laurent Mazare
2023-06-27 15:51:28 +01:00
committed by GitHub
3 changed files with 85 additions and 46 deletions

2
README.md Normal file
View File

@ -0,0 +1,2 @@
# candle
Minimalist ML framework for Rust

View File

@ -1,7 +1,7 @@
use crate::{CpuStorage, DType, Shape};
use candle_kernels as kernels;
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig};
use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig};
use half::{bf16, f16};
use std::sync::Arc;
@ -240,6 +240,24 @@ enum CudaStorageSlice {
F64(CudaSlice<f64>),
}
fn slice_src_and_dst<'a, T>(
src: &'a CudaSlice<T>,
src_offset: usize,
dst: &'a mut CudaSlice<T>,
dst_offset: usize,
) -> (
cudarc::driver::CudaView<'a, T>,
cudarc::driver::CudaViewMut<'a, T>,
) {
let to_copy = dst
.len()
.saturating_sub(dst_offset)
.min(src.len().saturating_sub(src_offset));
let src = src.slice(src_offset..src_offset + to_copy);
let dst = dst.slice_mut(dst_offset..dst_offset + to_copy);
(src, dst)
}
#[derive(Debug)]
pub struct CudaStorage {
slice: CudaStorageSlice,
@ -903,8 +921,7 @@ impl CudaStorage {
let ds = dev.htod_copy([dims, src_stride].concat())?;
match (&self.slice, &mut dst.slice) {
(CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -916,8 +933,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -929,8 +945,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -942,8 +957,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {
@ -955,8 +969,7 @@ impl CudaStorage {
}
}
(CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => {
let src = src.slice(src_offset..);
let mut dst = dst.slice_mut(dst_offset..);
let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset);
if src_shape.is_contiguous(src_stride) {
dev.dtod_copy(&src, &mut dst)?
} else {

View File

@ -1,18 +1,16 @@
// TODO: Also test the cuda backend.
use candle::{DType, Device, Result, Tensor};
#[test]
fn zeros() -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, &Device::Cpu)?;
fn zeros(device: &Device) -> Result<()> {
let tensor = Tensor::zeros((5, 2), DType::F32, device)?;
let (dim1, dim2) = tensor.shape().r2()?;
assert_eq!(dim1, 5);
assert_eq!(dim2, 2);
Ok(())
}
#[test]
fn add_mul() -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
fn add_mul(device: &Device) -> Result<()> {
let tensor = Tensor::new(&[3f32, 1., 4.], device)?;
let dim1 = tensor.shape().r1()?;
assert_eq!(dim1, 3);
let content: Vec<f32> = tensor.to_vec1()?;
@ -26,10 +24,9 @@ fn add_mul() -> Result<()> {
Ok(())
}
#[test]
fn tensor_2d() -> Result<()> {
fn tensor_2d(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
let dims = tensor.shape().r2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
@ -37,28 +34,27 @@ fn tensor_2d() -> Result<()> {
Ok(())
}
#[test]
fn binary_op() -> Result<()> {
fn binary_op(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]];
let tensor2 = Tensor::new(data2, &Device::Cpu)?;
let tensor2 = Tensor::new(data2, device)?;
let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?;
let dims = tensor.shape().r2()?;
assert_eq!(dims, (2, 5));
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]);
assert_eq!(content[1], [3.0, 1.5, 10.5, 12.0, 3.0]);
#[allow(clippy::eq_op)]
let tensor = (&tensor - &tensor)?;
let content: Vec<Vec<f32>> = tensor.to_vec2()?;
assert_eq!(content[0], [0., 0., 0., 0., 0.]);
Ok(())
}
#[test]
fn tensor_2d_transpose() -> Result<()> {
fn transpose(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]];
let tensor = Tensor::new(data, &Device::Cpu)?.t()?;
let tensor = Tensor::new(data, device)?.t()?;
let dims = tensor.shape().r2()?;
assert_eq!(dims, (5, 2));
assert_eq!(
@ -71,10 +67,9 @@ fn tensor_2d_transpose() -> Result<()> {
Ok(())
}
#[test]
fn softmax() -> Result<()> {
fn softmax(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
let t0 = tensor.log()?.softmax(0)?;
let t1 = tensor.log()?.softmax(1)?;
let t2 = tensor.log()?.softmax(2)?;
@ -108,10 +103,9 @@ fn softmax() -> Result<()> {
Ok(())
}
#[test]
fn sum() -> Result<()> {
fn sum(device: &Device) -> Result<()> {
let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.sum(&[2])?.to_vec3::<u32>()?,
&[[[8], [15]], [[10], [18]]]
@ -132,10 +126,9 @@ fn sum() -> Result<()> {
Ok(())
}
#[test]
fn narrow() -> Result<()> {
fn narrow(device: &Device) -> Result<()> {
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.narrow(2, 1, 2)?.to_vec3::<f32>()?,
&[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]],
@ -163,10 +156,9 @@ fn narrow() -> Result<()> {
Ok(())
}
#[test]
fn broadcast() -> Result<()> {
fn broadcast(device: &Device) -> Result<()> {
let data = &[3f32, 1., 4.];
let tensor = Tensor::new(data, &Device::Cpu)?;
let tensor = Tensor::new(data, device)?;
assert_eq!(
tensor.broadcast_left((3, 1))?.to_vec3::<f32>()?,
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
@ -174,12 +166,11 @@ fn broadcast() -> Result<()> {
Ok(())
}
#[test]
fn cat() -> Result<()> {
fn cat(device: &Device) -> Result<()> {
// 1D
let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?;
let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?;
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?;
let t1 = Tensor::new(&[3f32, 1., 4.], device)?;
let t2 = Tensor::new(&[1f32, 5., 9., 2.], device)?;
let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], device)?;
assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::<f32>()?, [3f32, 1., 4.],);
assert_eq!(
Tensor::cat(&[&t1, &t2], 0)?.to_vec1::<f32>()?,
@ -192,9 +183,9 @@ fn cat() -> Result<()> {
// 2D
let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]];
let t1 = Tensor::new(data, &Device::Cpu)?;
let t1 = Tensor::new(data, device)?;
let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]];
let t2 = Tensor::new(data2, &Device::Cpu)?;
let t2 = Tensor::new(data2, device)?;
assert_eq!(
Tensor::cat(&[&t1, &t2], 0)?.to_vec2::<f32>()?,
[
@ -226,3 +217,36 @@ fn cat() -> Result<()> {
);
Ok(())
}
macro_rules! test {
// TODO: Switch to generating the two last arguments automatically once concat_idents is
// stable. https://github.com/rust-lang/rust/issues/29599
($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => {
#[test]
fn $test_cpu() -> Result<()> {
$fn_name(&Device::Cpu)
}
#[cfg(feature = "cuda")]
#[test]
fn $test_cuda() -> Result<()> {
$fn_name(&Device::new_cuda(0)?)
}
};
}
test!(zeros, zeros_cpu, zeros_gpu);
test!(add_mul, add_mul_cpu, add_mul_gpu);
test!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test!(narrow, narrow_cpu, narrow_gpu);
test!(broadcast, broadcast_cpu, broadcast_gpu);
test!(cat, cat_cpu, cat_gpu);
test!(sum, sum_cpu, sum_gpu);
test!(transpose, transpose_cpu, transpose_gpu);
test!(binary_op, binary_op_cpu, binary_op_gpu);
// TODO: Make the test less sensitive to numerical precision and enable on the gpu.
#[test]
fn softmax_cpu() -> Result<()> {
softmax(&Device::Cpu)
}