Fixing matmul for convolutions.

This commit is contained in:
Nicolas Patry
2023-12-25 12:29:34 +01:00
parent 10d94659c3
commit 95e18ef675
2 changed files with 4 additions and 3 deletions

View File

@ -790,6 +790,7 @@ impl BackendStorage for MetalStorage {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let strides = layout.stride();
let stride = params.stride;
let dilation = params.dilation;
@ -811,7 +812,7 @@ impl BackendStorage for MetalStorage {
&self.device.kernels,
name,
layout.shape().dims(),
layout.stride(),
strides,
(k_size, stride, padding, dilation),
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),