More flexible matmul contiguity checks. (#1949)

* More flexible matmul contiguity checks.

* Also relax the checks on the metal side.
This commit is contained in:
Laurent Mazare
2024-03-27 10:59:05 +01:00
committed by GitHub
parent 75b6d4b0da
commit a9abde5f93
4 changed files with 51 additions and 8 deletions

View File

@ -1135,6 +1135,30 @@ fn randn(device: &Device) -> Result<()> {
Ok(())
}
// https://github.com/huggingface/candle/issues/1948
fn squeeze_mm(device: &Device) -> Result<()> {
let seq_len = 8_usize;
let a = Tensor::zeros((1, seq_len, 16), DType::F32, device)?;
let x = a.i((.., seq_len - 1, ..))?;
println!(
"x shape:{:?}, stride:{:?}, is_contiguous:{}",
x.shape(),
x.stride(),
x.is_contiguous()
);
let w = Tensor::zeros((32, 16), DType::F32, device)?.t()?;
println!(
"w shape:{:?}, stride:{:?}, is_contiguous:{}",
w.shape(),
w.stride(),
w.is_contiguous()
);
let x = x.matmul(&w)?;
assert_eq!(x.dims(), &[1, 32]);
Ok(())
}
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
test_device!(full, full_cpu, full_gpu, full_metal);
@ -1190,6 +1214,7 @@ test_device!(
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(squeeze_mm, squeeze_mm_cpu, squeeze_mm_gpu, squeeze_mm_metal);
// There was originally a bug on the CPU implementation for randn
// https://github.com/huggingface/candle/issues/381