diff --git a/README.md b/README.md new file mode 100644 index 00000000..246f5bae --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# candle +Minimalist ML framework for Rust diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 57ea9b3e..9cbf82be 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,7 +1,7 @@ use crate::{CpuStorage, DType, Shape}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; -use cudarc::driver::{CudaFunction, CudaSlice, LaunchAsync, LaunchConfig}; +use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; use half::{bf16, f16}; use std::sync::Arc; @@ -240,6 +240,24 @@ enum CudaStorageSlice { F64(CudaSlice), } +fn slice_src_and_dst<'a, T>( + src: &'a CudaSlice, + src_offset: usize, + dst: &'a mut CudaSlice, + dst_offset: usize, +) -> ( + cudarc::driver::CudaView<'a, T>, + cudarc::driver::CudaViewMut<'a, T>, +) { + let to_copy = dst + .len() + .saturating_sub(dst_offset) + .min(src.len().saturating_sub(src_offset)); + let src = src.slice(src_offset..src_offset + to_copy); + let dst = dst.slice_mut(dst_offset..dst_offset + to_copy); + (src, dst) +} + #[derive(Debug)] pub struct CudaStorage { slice: CudaStorageSlice, @@ -903,8 +921,7 @@ impl CudaStorage { let ds = dev.htod_copy([dims, src_stride].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { - let src = src.slice(src_offset..); - let mut dst = dst.slice_mut(dst_offset..); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -916,8 +933,7 @@ impl CudaStorage { } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { - let src = src.slice(src_offset..); - let mut dst = dst.slice_mut(dst_offset..); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -929,8 +945,7 @@ impl CudaStorage { } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let src = src.slice(src_offset..); - let mut dst = dst.slice_mut(dst_offset..); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -942,8 +957,7 @@ impl CudaStorage { } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { - let src = src.slice(src_offset..); - let mut dst = dst.slice_mut(dst_offset..); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { @@ -955,8 +969,7 @@ impl CudaStorage { } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let src = src.slice(src_offset..); - let mut dst = dst.slice_mut(dst_offset..); + let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); if src_shape.is_contiguous(src_stride) { dev.dtod_copy(&src, &mut dst)? } else { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 8269ace1..1ad83368 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1,18 +1,16 @@ // TODO: Also test the cuda backend. use candle::{DType, Device, Result, Tensor}; -#[test] -fn zeros() -> Result<()> { - let tensor = Tensor::zeros((5, 2), DType::F32, &Device::Cpu)?; +fn zeros(device: &Device) -> Result<()> { + let tensor = Tensor::zeros((5, 2), DType::F32, device)?; let (dim1, dim2) = tensor.shape().r2()?; assert_eq!(dim1, 5); assert_eq!(dim2, 2); Ok(()) } -#[test] -fn add_mul() -> Result<()> { - let tensor = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?; +fn add_mul(device: &Device) -> Result<()> { + let tensor = Tensor::new(&[3f32, 1., 4.], device)?; let dim1 = tensor.shape().r1()?; assert_eq!(dim1, 3); let content: Vec = tensor.to_vec1()?; @@ -26,10 +24,9 @@ fn add_mul() -> Result<()> { Ok(()) } -#[test] -fn tensor_2d() -> Result<()> { +fn tensor_2d(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; let dims = tensor.shape().r2()?; assert_eq!(dims, (2, 5)); let content: Vec> = tensor.to_vec2()?; @@ -37,28 +34,27 @@ fn tensor_2d() -> Result<()> { Ok(()) } -#[test] -fn binary_op() -> Result<()> { +fn binary_op(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; let data2 = &[[5f32, 5., 5., 5., 5.], [2., 1., 7., 8., 2.]]; - let tensor2 = Tensor::new(data2, &Device::Cpu)?; + let tensor2 = Tensor::new(data2, device)?; let tensor = (&tensor + (&tensor * &tensor)? / (&tensor + &tensor2))?; let dims = tensor.shape().r2()?; assert_eq!(dims, (2, 5)); let content: Vec> = tensor.to_vec2()?; assert_eq!(content[0], [4.125, 1.1666666, 5.7777777, 1.1666666, 7.5]); assert_eq!(content[1], [3.0, 1.5, 10.5, 12.0, 3.0]); + #[allow(clippy::eq_op)] let tensor = (&tensor - &tensor)?; let content: Vec> = tensor.to_vec2()?; assert_eq!(content[0], [0., 0., 0., 0., 0.]); Ok(()) } -#[test] -fn tensor_2d_transpose() -> Result<()> { +fn transpose(device: &Device) -> Result<()> { let data = &[[3f32, 1., 4., 1., 5.], [2., 1., 7., 8., 2.]]; - let tensor = Tensor::new(data, &Device::Cpu)?.t()?; + let tensor = Tensor::new(data, device)?.t()?; let dims = tensor.shape().r2()?; assert_eq!(dims, (5, 2)); assert_eq!( @@ -71,10 +67,9 @@ fn tensor_2d_transpose() -> Result<()> { Ok(()) } -#[test] -fn softmax() -> Result<()> { +fn softmax(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; let t0 = tensor.log()?.softmax(0)?; let t1 = tensor.log()?.softmax(1)?; let t2 = tensor.log()?.softmax(2)?; @@ -108,10 +103,9 @@ fn softmax() -> Result<()> { Ok(()) } -#[test] -fn sum() -> Result<()> { +fn sum(device: &Device) -> Result<()> { let data = &[[[3u32, 1, 4], [1, 5, 9]], [[2, 1, 7], [8, 2, 8]]]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; assert_eq!( tensor.sum(&[2])?.to_vec3::()?, &[[[8], [15]], [[10], [18]]] @@ -132,10 +126,9 @@ fn sum() -> Result<()> { Ok(()) } -#[test] -fn narrow() -> Result<()> { +fn narrow(device: &Device) -> Result<()> { let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; assert_eq!( tensor.narrow(2, 1, 2)?.to_vec3::()?, &[[[1.0, 4.0], [5.0, 9.0]], [[1.0, 7.0], [2.0, 8.0]]], @@ -163,10 +156,9 @@ fn narrow() -> Result<()> { Ok(()) } -#[test] -fn broadcast() -> Result<()> { +fn broadcast(device: &Device) -> Result<()> { let data = &[3f32, 1., 4.]; - let tensor = Tensor::new(data, &Device::Cpu)?; + let tensor = Tensor::new(data, device)?; assert_eq!( tensor.broadcast_left((3, 1))?.to_vec3::()?, &[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]] @@ -174,12 +166,11 @@ fn broadcast() -> Result<()> { Ok(()) } -#[test] -fn cat() -> Result<()> { +fn cat(device: &Device) -> Result<()> { // 1D - let t1 = Tensor::new(&[3f32, 1., 4.], &Device::Cpu)?; - let t2 = Tensor::new(&[1f32, 5., 9., 2.], &Device::Cpu)?; - let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], &Device::Cpu)?; + let t1 = Tensor::new(&[3f32, 1., 4.], device)?; + let t2 = Tensor::new(&[1f32, 5., 9., 2.], device)?; + let t3 = Tensor::new(&[6f32, 5., 3., 5., 8., 9.], device)?; assert_eq!(Tensor::cat(&[&t1], 0)?.to_vec1::()?, [3f32, 1., 4.],); assert_eq!( Tensor::cat(&[&t1, &t2], 0)?.to_vec1::()?, @@ -192,9 +183,9 @@ fn cat() -> Result<()> { // 2D let data = &[[3f32, 1., 4., 1., 5.], [2., 7., 1., 8., 2.]]; - let t1 = Tensor::new(data, &Device::Cpu)?; + let t1 = Tensor::new(data, device)?; let data2 = &[[5f32, 5., 5., 5., 5.], [2., 7., 1., 8., 2.]]; - let t2 = Tensor::new(data2, &Device::Cpu)?; + let t2 = Tensor::new(data2, device)?; assert_eq!( Tensor::cat(&[&t1, &t2], 0)?.to_vec2::()?, [ @@ -226,3 +217,36 @@ fn cat() -> Result<()> { ); Ok(()) } + +macro_rules! test { + // TODO: Switch to generating the two last arguments automatically once concat_idents is + // stable. https://github.com/rust-lang/rust/issues/29599 + ($fn_name: ident, $test_cpu: ident, $test_cuda: ident) => { + #[test] + fn $test_cpu() -> Result<()> { + $fn_name(&Device::Cpu) + } + + #[cfg(feature = "cuda")] + #[test] + fn $test_cuda() -> Result<()> { + $fn_name(&Device::new_cuda(0)?) + } + }; +} + +test!(zeros, zeros_cpu, zeros_gpu); +test!(add_mul, add_mul_cpu, add_mul_gpu); +test!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); +test!(narrow, narrow_cpu, narrow_gpu); +test!(broadcast, broadcast_cpu, broadcast_gpu); +test!(cat, cat_cpu, cat_gpu); +test!(sum, sum_cpu, sum_gpu); +test!(transpose, transpose_cpu, transpose_gpu); +test!(binary_op, binary_op_cpu, binary_op_gpu); + +// TODO: Make the test less sensitive to numerical precision and enable on the gpu. +#[test] +fn softmax_cpu() -> Result<()> { + softmax(&Device::Cpu) +}