mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Improve the handling of matmul with squeezed layouts. (#1998)
* Improve the handling of matmul with squeezed layouts. * Fix for the cuda backend. * Revert the temporary fix.
This commit is contained in:
@ -938,74 +938,6 @@ fn gather(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn matmul(device: &Device) -> Result<()> {
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let a = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
let data = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let b = Tensor::from_slice(&data, (2, 2), device)?;
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[[7.0f32, 10.0], [15.0, 22.0]]);
|
||||
|
||||
let data = vec![1.0f32, 2.0];
|
||||
let a = Tensor::from_slice(&data, (2, 1), device)?;
|
||||
let data = vec![3.0f32, 4.0];
|
||||
let b = Tensor::from_slice(&data, (1, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[3.0, 4.0], &[6.0, 8.0]]);
|
||||
|
||||
let data: Vec<_> = (0..6).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 3), device)?;
|
||||
let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (3, 2), device)?;
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec2::<f32>()?, &[&[16., 19.], &[52., 64.]]);
|
||||
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let a = Tensor::from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect();
|
||||
let b = Tensor::from_slice(&data, (2, 3, 2), device)?;
|
||||
let expected = [[[16., 19.], [52., 64.]], [[214., 235.], [304., 334.]]];
|
||||
|
||||
let c = a.matmul(&b)?;
|
||||
assert_eq!(c.to_vec3::<f32>()?, &expected);
|
||||
|
||||
// Also perform the matmul on contiguous transposed versions.
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a_tt.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcast_matmul(device: &Device) -> Result<()> {
|
||||
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
|
||||
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
|
||||
let out = lhs.broadcast_matmul(&rhs)?;
|
||||
assert_eq!(out.dims(), &[3, 6, 4, 2]);
|
||||
for idx1 in 0..3 {
|
||||
for idx2 in 0..6 {
|
||||
let out = out.i((idx1, idx2))?;
|
||||
let lhs = lhs.i((idx1, 0))?;
|
||||
let rhs = rhs.i(idx2)?;
|
||||
let out2 = lhs.matmul(&rhs);
|
||||
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
|
||||
// With cuda, we see errors of up to ~1e-12.
|
||||
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn broadcasting(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
|
||||
let t2 = Tensor::new(&[100f32, 200f32], device)?;
|
||||
@ -1140,30 +1072,6 @@ fn randn(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/candle/issues/1948
|
||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||
let seq_len = 8_usize;
|
||||
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
|
||||
let x = a.i((.., seq_len - 1, ..))?;
|
||||
println!(
|
||||
"x shape:{:?}, stride:{:?}, is_contiguous:{}",
|
||||
x.shape(),
|
||||
x.stride(),
|
||||
x.is_contiguous()
|
||||
);
|
||||
|
||||
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
|
||||
println!(
|
||||
"w shape:{:?}, stride:{:?}, is_contiguous:{}",
|
||||
w.shape(),
|
||||
w.stride(),
|
||||
w.is_contiguous()
|
||||
);
|
||||
let x = x.matmul(&w)?;
|
||||
assert_eq!(x.dims(), &[1, 32]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
@ -1183,13 +1091,6 @@ test_device!(unary_op, unary_op_cpu, unary_op_gpu, unary_op_metal);
|
||||
test_device!(binary_op, binary_op_cpu, binary_op_gpu, binary_op_metal);
|
||||
test_device!(embeddings, embeddings_cpu, embeddings_gpu, embeddings_metal);
|
||||
test_device!(cmp, cmp_cpu, cmp_gpu, cmp_metal);
|
||||
test_device!(matmul, matmul_cpu, matmul_gpu, matmul_metal);
|
||||
test_device!(
|
||||
broadcast_matmul,
|
||||
broadcast_matmul_cpu,
|
||||
broadcast_matmul_gpu,
|
||||
broadcast_matmul_metal
|
||||
);
|
||||
test_device!(
|
||||
broadcasting,
|
||||
broadcasting_cpu,
|
||||
@ -1219,7 +1120,6 @@ test_device!(
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
|
||||
|
||||
// There was originally a bug on the CPU implementation for randn
|
||||
// https://github.com/huggingface/candle/issues/381
|
||||
|
Reference in New Issue
Block a user