diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 56ff08ae..53c7ecf1 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -158,6 +158,71 @@ fn copy_strided_src_( } } +fn matmul_impl( + lhs: &[T], + rhs: &[T], + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], +) -> Result> { + let a_skip: usize = m * k; + let b_skip: usize = n * k; + let c_skip: usize = m * n; + + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; + + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + if lhs_stride.len() > 2 { + let lhs_batch_stride = &lhs_stride[..rank - 2]; + let rhs_batch_stride = &rhs_stride[..rank - 2]; + + if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { + // Temporary error before we support abitrary striding. + return Err(Error::UnexpectedStriding); + } + } + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + let mut dst = vec![T::zero(); b * m * n]; + for step in 0..b { + let lhs_p = &lhs[step * a_skip..]; + let rhs_p = &rhs[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + /* m: usize = */ m, + /* n: usize = */ n, + /* k: usize = */ k, + /* dst: *mut T = */ dst_p.as_mut_ptr(), + /* dst_cs: isize = */ dst_cs as isize, + /* dst_rs: isize = */ dst_rs as isize, + /* read_dst: bool = */ false, + /* lhs: *const T = */ lhs_p.as_ptr(), + /* lhs_cs: isize = */ lhs_cs as isize, + /* lhs_rs: isize = */ lhs_rs as isize, + /* rhs: *const T = */ rhs_p.as_ptr(), + /* rhs_cs: isize = */ rhs_cs as isize, + /* rhs_rs: isize = */ rhs_rs as isize, + /* alpha: T = */ T::zero(), + /* beta: T = */ T::one(), + /* conj_dst: bool = */ false, + /* conj_lhs: bool = */ false, + /* conj_rhs: bool = */ false, + Parallelism::Rayon(crate::utils::get_num_threads()), + ) + } + } + Ok(dst) +} + impl CpuStorage { pub fn dtype(&self) -> DType { match self { @@ -593,199 +658,28 @@ impl CpuStorage { pub(crate) fn matmul_impl( &self, rhs: &Self, - (b, m, n, k): (usize, usize, usize, usize), + bmnk: (usize, usize, usize, usize), lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - let a_skip: usize = m * k; - let b_skip: usize = n * k; - let c_skip: usize = m * n; - - let rank = lhs_stride.len(); - let lhs_cs = lhs_stride[rank - 1]; - let lhs_rs = lhs_stride[rank - 2]; - - let rhs_cs = rhs_stride[rank - 1]; - let rhs_rs = rhs_stride[rank - 2]; - - if lhs_stride.len() > 2 { - let lhs_batch_stride = &lhs_stride[..rank - 2]; - let rhs_batch_stride = &rhs_stride[..rank - 2]; - - if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] { - // Temporary error before we support abitrary striding. - return Err(Error::UnexpectedStriding); - } - } - - let dst_shape: Shape = (m, n).into(); - let dst_strides = dst_shape.stride_contiguous(); - let dst_rs = dst_strides[0]; - let dst_cs = dst_strides[1]; - match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let mut dst = vec![f16::ZERO; b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - gemm( - // m: usize, - m, - // n: usize, - n, - // k: usize, - k, - // dst: *mut T, - dst_p.as_mut_ptr(), - // dst_cs: isize, - dst_cs as isize, - // dst_rs: isize, - dst_rs as isize, - // read_dst: bool, - false, - // lhs: *const T, - lhs_p.as_ptr(), - // lhs_cs: isize, - lhs_cs as isize, - // lhs_rs: isize, - lhs_rs as isize, - // rhs: *const T, - rhs_p.as_ptr(), - // rhs_cs: isize, - rhs_cs as isize, - // rhs_rs: isize, - rhs_rs as isize, - // alpha: T, - f16::ONE, - // beta: T, - f16::ONE, - // conj_dst: bool, - false, - // conj_lhs: bool, - false, - // conj_rhs: bool, - true, - // parallelism: Parallelism - Parallelism::Rayon(crate::utils::get_num_threads()), - ) - } - } - + let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let mut dst = vec![0f32; b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - gemm( - // m: usize, - m, - // n: usize, - n, - // k: usize, - k, - // dst: *mut T, - dst_p.as_mut_ptr(), - // dst_cs: isize, - dst_cs as isize, - // dst_rs: isize, - dst_rs as isize, - // read_dst: bool, - false, - // lhs: *const T, - lhs_p.as_ptr(), - // lhs_cs: isize, - lhs_cs as isize, - // lhs_rs: isize, - lhs_rs as isize, - // rhs: *const T, - rhs_p.as_ptr(), - // rhs_cs: isize, - rhs_cs as isize, - // rhs_rs: isize, - rhs_rs as isize, - // alpha: T, - 1f32, - // beta: T, - 1f32, - // conj_dst: bool, - false, - // conj_lhs: bool, - false, - // conj_rhs: bool, - true, - // parallelism: Parallelism - Parallelism::Rayon(crate::utils::get_num_threads()), - ) - } - } + let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let mut dst = vec![0f64; b * m * n]; - for step in 0..b { - let lhs_p = &lhs[step * a_skip..]; - let rhs_p = &rhs[step * b_skip..]; - let dst_p = &mut dst[step * c_skip..]; - unsafe { - gemm( - // m: usize, - m, - // n: usize, - n, - // k: usize, - k, - // dst: *mut T, - dst_p.as_mut_ptr(), - // dst_cs: isize, - dst_cs as isize, - // dst_rs: isize, - dst_rs as isize, - // read_dst: bool, - false, - // lhs: *const T, - lhs_p.as_ptr(), - // lhs_cs: isize, - lhs_cs as isize, - // lhs_rs: isize, - lhs_rs as isize, - // rhs: *const T, - rhs_p.as_ptr(), - // rhs_cs: isize, - rhs_cs as isize, - // rhs_rs: isize, - rhs_rs as isize, - // alpha: T, - 1f64, - // beta: T, - 1f64, - // conj_dst: bool, - false, - // conj_lhs: bool, - false, - // conj_rhs: bool, - true, - // parallelism: Parallelism - Parallelism::Rayon(crate::utils::get_num_threads()), - ) - } - } + let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; Ok(Self::F64(dst)) } - _ => { - // This should be covered by the dtype check above. - Err(Error::DTypeMismatchBinaryOp { - lhs: self.dtype(), - rhs: rhs.dtype(), - op: "matmul", - }) - } + _ => Err(Error::DTypeMismatchBinaryOp { + lhs: self.dtype(), + rhs: rhs.dtype(), + op: "matmul", + }), } } @@ -841,45 +735,3 @@ impl CpuStorage { } } } - -#[cfg(test)] -mod tests { - use super::*; - use crate::{Device, Tensor}; - - #[test] - fn simple_matmul() -> Result<()> { - let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; - let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; - - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); - - let data = vec![1.0f32, 2.0]; - let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?; - let data = vec![3.0f32, 4.0]; - let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?; - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); - - let data: Vec<_> = (0..6).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?; - let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?; - let c = a.matmul(&b)?; - assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); - - let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?; - let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?; - let c = a.matmul(&b)?; - assert_eq!( - c.to_vec3::()?, - &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]] - ); - Ok(()) - } -} diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 286c12e3..78ca4b05 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -219,6 +219,63 @@ fn cat(device: &Device) -> Result<()> { Ok(()) } +fn embeddings(device: &Device) -> Result<()> { + let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?; + let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?; + let hs = Tensor::embedding(&ids, &t)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + Ok(()) +} + +fn matmul(device: &Device) -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), device)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), device)?; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[[7.0f32, 10.0], [15.0, 22.0]]); + + let data = vec![1.0f32, 2.0]; + let a = Tensor::from_slice(&data, (2, 1), device)?; + let data = vec![3.0f32, 4.0]; + let b = Tensor::from_slice(&data, (1, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); + + let data: Vec<_> = (0..6).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 3), device)?; + let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (3, 2), device)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); + + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 2, 3), device)?; + let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (2, 3, 2), device)?; + let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]]; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec3::()?, &expected); + + // Also perform the matmul on contiguous transposed versions. + let a_tt = a.t()?.contiguous()?.t()?; + assert!(!a_tt.is_contiguous()); + assert_eq!(a.dims(), a_tt.dims()); + assert_eq!(a_tt.stride(), &[6, 1, 2]); + + let b_tt = b.t()?.contiguous()?.t()?; + assert!(!b_tt.is_contiguous()); + assert_eq!(b.dims(), b_tt.dims()); + assert_eq!(b_tt.stride(), &[6, 1, 3]); + + assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); + assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected); + assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::()?, &expected); + Ok(()) +} + test_device!(zeros, zeros_cpu, zeros_gpu); test_device!(add_mul, add_mul_cpu, add_mul_gpu); test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu); @@ -229,3 +286,5 @@ test_device!(sum, sum_cpu, sum_gpu); test_device!(transpose, transpose_cpu, transpose_gpu); test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(softmax, softmax_cpu, softmax_gpu); +test_device!(embeddings, embeddings_cpu, embeddings_gpu); +test_device!(matmul, matmul_cpu, matmul_gpu);