Use broadcasted scalars for const tensors.

This commit is contained in:
laurent
2023-06-29 11:56:40 +01:00
parent 3872dc4751
commit 2741b39ad3
5 changed files with 12 additions and 14 deletions

View File

@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> {
let a_tt = a.t()?.contiguous()?.t()?;
assert!(!a_tt.is_contiguous());
assert_eq!(a.dims(), a_tt.dims());
assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]);
assert_eq!(a_tt.stride(), &[6, 1, 2]);
let b_tt = b.t()?.contiguous()?.t()?;
assert!(!b_tt.is_contiguous());
assert_eq!(b.dims(), b_tt.dims());
assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]);
assert_eq!(b_tt.stride(), &[6, 1, 3]);
assert_eq!(a_tt.matmul(&b)?.to_vec3::<f32>()?, &expected);
assert_eq!(a.matmul(&b_tt)?.to_vec3::<f32>()?, &expected);