mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Use broadcasted scalars for const tensors.
This commit is contained in:
@ -20,7 +20,6 @@ fn matmul_grad(device: &Device) -> Result<()> {
|
||||
let x = Tensor::var_from_slice(&data, (2, 2, 3), device)?;
|
||||
let data: Vec<_> = (0..12).map(|i| i as f32).collect();
|
||||
let y = Tensor::var_from_slice(&data, (2, 3, 2), device)?;
|
||||
|
||||
let c = x.matmul(&y)?;
|
||||
let grads = c.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
|
||||
let a_tt = a.t()?.contiguous()?.t()?;
|
||||
assert!(!a_tt.is_contiguous());
|
||||
assert_eq!(a.dims(), a_tt.dims());
|
||||
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
|
||||
assert_eq!(a_tt.stride(), &[6, 1, 2]);
|
||||
|
||||
let b_tt = b.t()?.contiguous()?.t()?;
|
||||
assert!(!b_tt.is_contiguous());
|
||||
assert_eq!(b.dims(), b_tt.dims());
|
||||
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
|
||||
assert_eq!(b_tt.stride(), &[6, 1, 3]);
|
||||
|
||||
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
|
||||
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);
|
||||
|
Reference in New Issue
Block a user