diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6857341c..1813f276 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -790,6 +790,7 @@ impl BackendStorage for MetalStorage { let device = self.device().clone(); let shape = layout.shape(); let dims = shape.dims(); + let strides = layout.stride(); let stride = params.stride; let dilation = params.dilation; @@ -811,7 +812,7 @@ impl BackendStorage for MetalStorage { &self.device.kernels, name, layout.shape().dims(), - layout.stride(), + strides, (k_size, stride, padding, dilation), &self.buffer, layout.start_offset() * self.dtype.size_in_bytes(), diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 52cdce04..d126aa42 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1301,7 +1301,7 @@ pub fn call_gemm( let fused_activation = false; let fused_bias = false; let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { - let m_simd = 16; + let m_simd = 8; let n_simd = 8; let k_simd = 64; let m_splits = 1; @@ -1310,7 +1310,7 @@ pub fn call_gemm( } else { let m_simd = 40; let n_simd = 40; - let k_simd = 8; + let k_simd = 32; let m_splits = 1; let n_splits = 1; (m_simd, n_simd, k_simd, m_splits, n_splits)