mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Faster matmul when we can fall back to gemv.
This commit is contained in:
@ -5,6 +5,8 @@ use anyhow::Result;
|
|||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
let mut file = std::fs::File::open("ggml.bin")?;
|
||||||
|
let data = candle_core::ggml::Content::read(&mut file, &Device::Cpu)?;
|
||||||
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
let a = Tensor::randn(0f32, 1., (2, 3), &Device::Cpu)?;
|
||||||
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
let b = Tensor::randn(0f32, 1., (3, 4), &Device::Cpu)?;
|
||||||
let c = a.matmul(&b)?;
|
let c = a.matmul(&b)?;
|
||||||
|
@ -1010,12 +1010,18 @@ impl Map2 for MatMul {
|
|||||||
};
|
};
|
||||||
let c_skip: usize = m * n;
|
let c_skip: usize = m * n;
|
||||||
|
|
||||||
let dst_shape: Shape = (m, n).into();
|
|
||||||
let dst_strides = dst_shape.stride_contiguous();
|
|
||||||
let dst_rs = dst_strides[0];
|
|
||||||
let dst_cs = dst_strides[1];
|
|
||||||
|
|
||||||
let mut dst = vec![T::zero(); b * m * n];
|
let mut dst = vec![T::zero(); b * m * n];
|
||||||
|
|
||||||
|
let (dst_rs, dst_cs) = if m == 1 {
|
||||||
|
(1, 1)
|
||||||
|
} else if n == 1 {
|
||||||
|
(1, 1)
|
||||||
|
} else {
|
||||||
|
let dst_shape: Shape = (m, n).into();
|
||||||
|
let dst_strides = dst_shape.stride_contiguous();
|
||||||
|
(dst_strides[0], dst_strides[1])
|
||||||
|
};
|
||||||
|
|
||||||
let num_threads = crate::utils::get_num_threads();
|
let num_threads = crate::utils::get_num_threads();
|
||||||
let parallelism = if num_threads > 1 {
|
let parallelism = if num_threads > 1 {
|
||||||
Parallelism::Rayon(num_threads)
|
Parallelism::Rayon(num_threads)
|
||||||
|
@ -111,6 +111,7 @@ impl TransformerWeights {
|
|||||||
// matrix column major rather than row major. This ends up speeding up text generation from
|
// matrix column major rather than row major. This ends up speeding up text generation from
|
||||||
// 120 token/s to 220 token/s on a Ryzen 2600X.
|
// 120 token/s to 220 token/s on a Ryzen 2600X.
|
||||||
let tr = device.is_cpu() && !candle::utils::has_mkl();
|
let tr = device.is_cpu() && !candle::utils::has_mkl();
|
||||||
|
let tr = false;
|
||||||
let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
|
let tr = |x: Tensor| if tr { x.t()?.contiguous()?.t() } else { Ok(x) };
|
||||||
let mut ws = std::collections::HashMap::new();
|
let mut ws = std::collections::HashMap::new();
|
||||||
let mut insert = |name: &str, t: Tensor| {
|
let mut insert = |name: &str, t: Tensor| {
|
||||||
|
Reference in New Issue
Block a user