mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Fixing matmul for convolutions.
This commit is contained in:
@ -790,6 +790,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
|
let strides = layout.stride();
|
||||||
|
|
||||||
let stride = params.stride;
|
let stride = params.stride;
|
||||||
let dilation = params.dilation;
|
let dilation = params.dilation;
|
||||||
@ -811,7 +812,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.device.kernels,
|
&self.device.kernels,
|
||||||
name,
|
name,
|
||||||
layout.shape().dims(),
|
layout.shape().dims(),
|
||||||
layout.stride(),
|
strides,
|
||||||
(k_size, stride, padding, dilation),
|
(k_size, stride, padding, dilation),
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
|
@ -1301,7 +1301,7 @@ pub fn call_gemm(
|
|||||||
let fused_activation = false;
|
let fused_activation = false;
|
||||||
let fused_bias = false;
|
let fused_bias = false;
|
||||||
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 {
|
||||||
let m_simd = 16;
|
let m_simd = 8;
|
||||||
let n_simd = 8;
|
let n_simd = 8;
|
||||||
let k_simd = 64;
|
let k_simd = 64;
|
||||||
let m_splits = 1;
|
let m_splits = 1;
|
||||||
@ -1310,7 +1310,7 @@ pub fn call_gemm(
|
|||||||
} else {
|
} else {
|
||||||
let m_simd = 40;
|
let m_simd = 40;
|
||||||
let n_simd = 40;
|
let n_simd = 40;
|
||||||
let k_simd = 8;
|
let k_simd = 32;
|
||||||
let m_splits = 1;
|
let m_splits = 1;
|
||||||
let n_splits = 1;
|
let n_splits = 1;
|
||||||
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
(m_simd, n_simd, k_simd, m_splits, n_splits)
|
||||||
|
Reference in New Issue
Block a user