mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Merge pull request #22 from LaurentMazare/more-cuda-testing2
Again more cuda testing.
This commit is contained in:
@ -158,6 +158,71 @@ fn copy_strided_src_<T: Copy + std::fmt::Display>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matmul_impl<T: 'static + num_traits::Num + Copy>(
|
||||||
|
lhs: &[T],
|
||||||
|
rhs: &[T],
|
||||||
|
(b, m, n, k): (usize, usize, usize, usize),
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Vec<T>> {
|
||||||
|
let a_skip: usize = m * k;
|
||||||
|
let b_skip: usize = n * k;
|
||||||
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
|
let rank = lhs_stride.len();
|
||||||
|
let lhs_cs = lhs_stride[rank - 1];
|
||||||
|
let lhs_rs = lhs_stride[rank - 2];
|
||||||
|
|
||||||
|
let rhs_cs = rhs_stride[rank - 1];
|
||||||
|
let rhs_rs = rhs_stride[rank - 2];
|
||||||
|
|
||||||
|
if lhs_stride.len() > 2 {
|
||||||
|
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
||||||
|
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
||||||
|
|
||||||
|
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
||||||
|
// Temporary error before we support abitrary striding.
|
||||||
|
return Err(Error::UnexpectedStriding);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let dst_shape: Shape = (m, n).into();
|
||||||
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
|
let dst_rs = dst_strides[0];
|
||||||
|
let dst_cs = dst_strides[1];
|
||||||
|
|
||||||
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
|
for step in 0..b {
|
||||||
|
let lhs_p = &lhs[step * a_skip..];
|
||||||
|
let rhs_p = &rhs[step * b_skip..];
|
||||||
|
let dst_p = &mut dst[step * c_skip..];
|
||||||
|
unsafe {
|
||||||
|
gemm(
|
||||||
|
/* m: usize = */ m,
|
||||||
|
/* n: usize = */ n,
|
||||||
|
/* k: usize = */ k,
|
||||||
|
/* dst: *mut T = */ dst_p.as_mut_ptr(),
|
||||||
|
/* dst_cs: isize = */ dst_cs as isize,
|
||||||
|
/* dst_rs: isize = */ dst_rs as isize,
|
||||||
|
/* read_dst: bool = */ false,
|
||||||
|
/* lhs: *const T = */ lhs_p.as_ptr(),
|
||||||
|
/* lhs_cs: isize = */ lhs_cs as isize,
|
||||||
|
/* lhs_rs: isize = */ lhs_rs as isize,
|
||||||
|
/* rhs: *const T = */ rhs_p.as_ptr(),
|
||||||
|
/* rhs_cs: isize = */ rhs_cs as isize,
|
||||||
|
/* rhs_rs: isize = */ rhs_rs as isize,
|
||||||
|
/* alpha: T = */ T::zero(),
|
||||||
|
/* beta: T = */ T::one(),
|
||||||
|
/* conj_dst: bool = */ false,
|
||||||
|
/* conj_lhs: bool = */ false,
|
||||||
|
/* conj_rhs: bool = */ false,
|
||||||
|
Parallelism::Rayon(crate::utils::get_num_threads()),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
|
||||||
impl CpuStorage {
|
impl CpuStorage {
|
||||||
pub fn dtype(&self) -> DType {
|
pub fn dtype(&self) -> DType {
|
||||||
match self {
|
match self {
|
||||||
@ -593,199 +658,28 @@ impl CpuStorage {
|
|||||||
pub(crate) fn matmul_impl(
|
pub(crate) fn matmul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
bmnk: (usize, usize, usize, usize),
|
||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
rhs_stride: &[usize],
|
rhs_stride: &[usize],
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let a_skip: usize = m * k;
|
|
||||||
let b_skip: usize = n * k;
|
|
||||||
let c_skip: usize = m * n;
|
|
||||||
|
|
||||||
let rank = lhs_stride.len();
|
|
||||||
let lhs_cs = lhs_stride[rank - 1];
|
|
||||||
let lhs_rs = lhs_stride[rank - 2];
|
|
||||||
|
|
||||||
let rhs_cs = rhs_stride[rank - 1];
|
|
||||||
let rhs_rs = rhs_stride[rank - 2];
|
|
||||||
|
|
||||||
if lhs_stride.len() > 2 {
|
|
||||||
let lhs_batch_stride = &lhs_stride[..rank - 2];
|
|
||||||
let rhs_batch_stride = &rhs_stride[..rank - 2];
|
|
||||||
|
|
||||||
if lhs_batch_stride != [a_skip] || rhs_batch_stride != [b_skip] {
|
|
||||||
// Temporary error before we support abitrary striding.
|
|
||||||
return Err(Error::UnexpectedStriding);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
|
||||||
let dst_rs = dst_strides[0];
|
|
||||||
let dst_cs = dst_strides[1];
|
|
||||||
|
|
||||||
match (self, rhs) {
|
match (self, rhs) {
|
||||||
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
(CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => {
|
||||||
let mut dst = vec![f16::ZERO; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
f16::ONE,
|
|
||||||
// beta: T,
|
|
||||||
f16::ONE,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self::F16(dst))
|
Ok(Self::F16(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
(CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => {
|
||||||
let mut dst = vec![0f32; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
1f32,
|
|
||||||
// beta: T,
|
|
||||||
1f32,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Self::F32(dst))
|
Ok(Self::F32(dst))
|
||||||
}
|
}
|
||||||
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
(CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => {
|
||||||
let mut dst = vec![0f64; b * m * n];
|
let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?;
|
||||||
for step in 0..b {
|
|
||||||
let lhs_p = &lhs[step * a_skip..];
|
|
||||||
let rhs_p = &rhs[step * b_skip..];
|
|
||||||
let dst_p = &mut dst[step * c_skip..];
|
|
||||||
unsafe {
|
|
||||||
gemm(
|
|
||||||
// m: usize,
|
|
||||||
m,
|
|
||||||
// n: usize,
|
|
||||||
n,
|
|
||||||
// k: usize,
|
|
||||||
k,
|
|
||||||
// dst: *mut T,
|
|
||||||
dst_p.as_mut_ptr(),
|
|
||||||
// dst_cs: isize,
|
|
||||||
dst_cs as isize,
|
|
||||||
// dst_rs: isize,
|
|
||||||
dst_rs as isize,
|
|
||||||
// read_dst: bool,
|
|
||||||
false,
|
|
||||||
// lhs: *const T,
|
|
||||||
lhs_p.as_ptr(),
|
|
||||||
// lhs_cs: isize,
|
|
||||||
lhs_cs as isize,
|
|
||||||
// lhs_rs: isize,
|
|
||||||
lhs_rs as isize,
|
|
||||||
// rhs: *const T,
|
|
||||||
rhs_p.as_ptr(),
|
|
||||||
// rhs_cs: isize,
|
|
||||||
rhs_cs as isize,
|
|
||||||
// rhs_rs: isize,
|
|
||||||
rhs_rs as isize,
|
|
||||||
// alpha: T,
|
|
||||||
1f64,
|
|
||||||
// beta: T,
|
|
||||||
1f64,
|
|
||||||
// conj_dst: bool,
|
|
||||||
false,
|
|
||||||
// conj_lhs: bool,
|
|
||||||
false,
|
|
||||||
// conj_rhs: bool,
|
|
||||||
true,
|
|
||||||
// parallelism: Parallelism
|
|
||||||
Parallelism::Rayon(crate::utils::get_num_threads()),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(Self::F64(dst))
|
Ok(Self::F64(dst))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => Err(Error::DTypeMismatchBinaryOp {
|
||||||
// This should be covered by the dtype check above.
|
lhs: self.dtype(),
|
||||||
Err(Error::DTypeMismatchBinaryOp {
|
rhs: rhs.dtype(),
|
||||||
lhs: self.dtype(),
|
op: "matmul",
|
||||||
rhs: rhs.dtype(),
|
}),
|
||||||
op: "matmul",
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -841,45 +735,3 @@ impl CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use crate::{Device, Tensor};
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn simple_matmul() -> Result<()> {
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
|
||||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?;
|
|
||||||
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]);
|
|
||||||
|
|
||||||
let data = vec![1.0f32, 2.0];
|
|
||||||
let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?;
|
|
||||||
let data = vec![3.0f32, 4.0];
|
|
||||||
let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?;
|
|
||||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
|
||||||
|
|
||||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
|
||||||
let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?;
|
|
||||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
|
||||||
let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?;
|
|
||||||
let c = a.matmul(&b)?;
|
|
||||||
assert_eq!(
|
|
||||||
c.to_vec3::<f32>()?,
|
|
||||||
&[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]]
|
|
||||||
);
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -219,6 +219,63 @@ fn cat(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn embeddings(device: &Device) -> Result<()> {
|
||||||
|
let ids = Tensor::new(&[0u32, 2u32, 1u32], device)?;
|
||||||
|
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||||
|
let hs = Tensor::embedding(&ids, &t)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn matmul(device: &Device) -> Result<()> {
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||||
|
|
||||||
|
let data = vec![1.0f32, 2.0];
|
||||||
|
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||||
|
let data = vec![3.0f32, 4.0];
|
||||||
|
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||||
|
|
||||||
|
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||||
|
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||||
|
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||||
|
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||||
|
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||||
|
|
||||||
|
let c = a.matmul(&b)?;
|
||||||
|
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||||
|
|
||||||
|
// Also perform the matmul on contiguous transposed versions.
|
||||||
|
let a_tt = a.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!a_tt.is_contiguous());
|
||||||
|
assert_eq!(a.dims(), a_tt.dims());
|
||||||
|
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||||
|
|
||||||
|
let b_tt = b.t()?.contiguous()?.t()?;
|
||||||
|
assert!(!b_tt.is_contiguous());
|
||||||
|
assert_eq!(b.dims(), b_tt.dims());
|
||||||
|
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||||
|
|
||||||
|
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
test_device!(zeros, zeros_cpu, zeros_gpu);
|
test_device!(zeros, zeros_cpu, zeros_gpu);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
|
||||||
@ -229,3 +286,5 @@ test_device!(sum, sum_cpu, sum_gpu);
|
|||||||
test_device!(transpose, transpose_cpu, transpose_gpu);
|
test_device!(transpose, transpose_cpu, transpose_gpu);
|
||||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
test_device!(binary_op, binary_op_cpu, binary_op_gpu);
|
||||||
test_device!(softmax, softmax_cpu, softmax_gpu);
|
test_device!(softmax, softmax_cpu, softmax_gpu);
|
||||||
|
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
|
||||||
|
test_device!(matmul, matmul_cpu, matmul_gpu);
|
||||||
|
Reference in New Issue
Block a user