Finished scaffolding, lots of TODOs

- Most kernels just copy themselfs to get the shapes correct
- Matmul works only in 1 case and simply empty allocates otherwise
- Logits and randomized to make the demo finish itself.

Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal.

Couln't get it super high by removing the obvious blockers (println + the actual running matmuls).

Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
This commit is contained in:
Nicolas Patry
2023-11-02 15:32:28 +01:00
parent 82cce52e73
commit 7161002a34
11 changed files with 212 additions and 52 deletions

View File

@ -13,7 +13,7 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true } byteorder = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true } candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.0.1", optional = true } candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true} metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
gemm = { workspace = true } gemm = { workspace = true }

View File

@ -1,15 +1,15 @@
use crate::backend::{BackendDevice, BackendStorage}; use crate::backend::{BackendDevice, BackendStorage};
use crate::bail;
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D}; use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape}; use crate::{CpuStorage, DType, Layout, Result, Shape};
pub use candle_metal_kernels; use candle_metal_kernels;
use core::mem; use core::mem;
use half::{bf16, f16}; use half::{bf16, f16};
use metal; use metal;
use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor}; use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
use metal::mps::{Float32, MPSDataType}; use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions; use metal::MTLResourceOptions;
use crate::bail;
/// Metal related errors /// Metal related errors
#[derive(thiserror::Error, Debug)] #[derive(thiserror::Error, Debug)]
@ -72,11 +72,16 @@ impl BackendStorage for MetalStorage {
} }
fn to_cpu_storage(&self) -> Result<CpuStorage> { fn to_cpu_storage(&self) -> Result<CpuStorage> {
todo!() match self.dtype{
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
dtype => todo!("Unsupported dtype {dtype:?}")
}
} }
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> { fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
todo!() println!("TODO Affine");
Ok(self.clone())
// todo!()
} }
fn powf(&self, _: &Layout, _: f64) -> Result<Self> { fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
@ -88,7 +93,9 @@ impl BackendStorage for MetalStorage {
} }
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> { fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
todo!() println!("TODO reduce_op");
Ok(self.clone())
// todo!()
} }
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
@ -100,15 +107,22 @@ impl BackendStorage for MetalStorage {
} }
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> { fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
todo!() // todo!()
// TODO
println!("TODO {:?}", B::NAME);
Ok(self.clone())
} }
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> { fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
todo!() println!("TODO Binary {:?}", B::NAME);
Ok(self.clone())
// todo!()
} }
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> { fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
todo!() println!("TODO where_cond");
Ok(rhs.clone())
// todo!()
} }
fn conv1d( fn conv1d(
@ -174,7 +188,9 @@ impl BackendStorage for MetalStorage {
} }
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> { fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
todo!() println!("TODO Index select");
Ok(self.clone())
// todo!()
} }
fn index_add( fn index_add(
@ -195,21 +211,96 @@ impl BackendStorage for MetalStorage {
(b, m, n, k): (usize, usize, usize, usize), (b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
println!("TODO Copy strided");
Ok(())
}
}
impl MetalStorage {
pub(crate) fn matmul_t(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = true;
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
}
pub(crate) fn matmul_generic(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
transpose_left: bool,
transpose_right: bool,
alpha: f64,
beta: f64,
) -> Result<Self> { ) -> Result<Self> {
let elem_count = b * m * n; let elem_count = b * m * n;
match (self.dtype, rhs.dtype) { match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => { (DType::F32, DType::F32) => {
if b != 1 { if b != 1 {
bail!("Didn't implemented strided matmul yet"); println!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
} }
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() { if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
bail!("Didn't implemented non contiguous matmul yet"); println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
} }
let out_buffer = self.device.new_buffer( let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64, (elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(), MTLResourceOptions::empty(),
); );
let m : u64 = m.try_into().expect("usize should fit u64"); let m: u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64"); let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64"); let k: u64 = k.try_into().expect("usize should fit u64");
// Create descriptors // Create descriptors
@ -220,6 +311,9 @@ impl BackendStorage for MetalStorage {
let result_descriptor = let result_descriptor =
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID); MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
println!("lhs {:?} {m} {k}", self.buffer.length());
println!("rhs {:?} {k} {n}", rhs.buffer.length());
println!("out {:?} {m} {n}", out_buffer.length());
// Create matrix objects // Create matrix objects
let left_matrix = let left_matrix =
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor) Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
@ -232,11 +326,7 @@ impl BackendStorage for MetalStorage {
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor) Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.expect("Failed to create left matrix"); .expect("Failed to create left matrix");
let transpose_left = false; println!("lhs {:?}", lhs_l.shape());
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;
// Create kernel // Create kernel
let matrix_multiplication = MatrixMultiplication::init( let matrix_multiplication = MatrixMultiplication::init(
@ -258,20 +348,15 @@ impl BackendStorage for MetalStorage {
&right_matrix, &right_matrix,
&result_matrix, &result_matrix,
); );
Ok(Self{ Ok(Self {
buffer: out_buffer, buffer: out_buffer,
device: self.device.clone(), device: self.device.clone(),
dtype: self.dtype(), dtype: self.dtype(),
}) })
} }
_ => todo!("Unimplemented matmul for this pair"), _ => todo!("Unimplemented matmul for this pair"),
} }
} }
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
todo!()
}
} }
impl BackendDevice for MetalDevice { impl BackendDevice for MetalDevice {
@ -281,7 +366,11 @@ impl BackendDevice for MetalDevice {
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let _command_queue = device.new_command_queue(); let _command_queue = device.new_command_queue();
let command_buffer = _command_queue.new_owned_command_buffer(); let command_buffer = _command_queue.new_owned_command_buffer();
Ok(Self { device, _command_queue, command_buffer }) Ok(Self {
device,
_command_queue,
command_buffer,
})
} }
fn set_seed(&self, _seed: u64) -> Result<()> { fn set_seed(&self, _seed: u64) -> Result<()> {
@ -296,12 +385,16 @@ impl BackendDevice for MetalDevice {
self.device.registry_id() == rhs.device.registry_id() self.device.registry_id() == rhs.device.registry_id()
} }
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<MetalStorage> { fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
todo!() // TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
} }
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> { fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
todo!() // TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
} }
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> { fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
@ -350,11 +443,15 @@ impl BackendDevice for MetalDevice {
}) })
} }
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
todo!() // TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
} }
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> { fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
todo!() // TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
} }
} }

View File

@ -182,7 +182,7 @@ pub trait CustomOp1 {
_layout: &Layout, _layout: &Layout,
) -> Result<(MetalStorage, Shape)> { ) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal( Err(crate::Error::Metal(
format!("no cuda implementation for {}", self.name()).into(), format!("no metal implementation for {}", self.name()).into(),
)) ))
} }

View File

@ -315,6 +315,49 @@ impl crate::CustomOp1 for QTensor {
)?; )?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape)) Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
} }
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
println!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device{
Ok((
device.storage_from_cpu_storage(&cpu_storage)?,
dst_shape,
))
}else{
crate::bail!("qtensor not on metal device")
}
}
} }
impl QMatMul { impl QMatMul {

View File

@ -51,6 +51,7 @@ anyhow = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"] cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle/cudnn"] cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"] flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file}; use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor}; use candle::{Tensor};
use candle_transformers::generation::LogitsProcessor; use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model; use candle_transformers::models::quantized_llama as model;
@ -367,9 +367,11 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now(); let start_prompt_processing = std::time::Instant::now();
let mut next_token = { let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?; let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
logits_processor.sample(&logits)? logits_processor.sample(&logits)?
}; };
let prompt_dt = start_prompt_processing.elapsed(); let prompt_dt = start_prompt_processing.elapsed();
@ -380,7 +382,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now(); let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample { for index in 0..to_sample {
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?; let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?; let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?; let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. { let logits = if args.repeat_penalty == 1. {
@ -393,6 +395,8 @@ fn main() -> anyhow::Result<()> {
&all_tokens[start_at..], &all_tokens[start_at..],
)? )?
}; };
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
next_token = logits_processor.sample(&logits)?; next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token); all_tokens.push(next_token);
print_token(next_token, &tokenizer); print_token(next_token, &tokenizer);

View File

@ -1,13 +1,12 @@
[package] [package]
name = "candle-metal-kernels" name = "candle-metal-kernels"
version = "0.0.1" version.workspace = true
edition = "2021" edition.workspace = true
description.workspace = true
description = "Metal kernels for Candle" repository.workspace = true
repository = "https://github.com/huggingface/candle" keywords.workspace = true
keywords = ["blas", "tensor", "machine-learning"] categories.workspace = true
categories = ["science"] license.workspace = true
license = "MIT OR Apache-2.0"
[dependencies] [dependencies]
metal = { workspace = true, optional = true} metal = { workspace = true }

View File

@ -28,4 +28,5 @@ clap = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"] cuda = ["candle/cuda"]
metal = ["candle/metal"]
mkl = ["dep:intel-mkl-src", "candle/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl"]

View File

@ -191,6 +191,16 @@ impl candle::CustomOp1 for SoftmaxLastDim {
}; };
Ok((dst, layout.shape().clone())) Ok((dst, layout.shape().clone()))
} }
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
println!("TODO softmax-last-dim");
Ok((storage.clone(), layout.shape().clone()))
}
} }
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> { pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {

View File

@ -28,5 +28,6 @@ wav = { workspace = true }
default = [] default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"] accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda"] cuda = ["candle/cuda", "candle-nn/cuda"]
metal = ["candle/metal", "candle-nn/metal"]
flash-attn = ["cuda", "dep:candle-flash-attn"] flash-attn = ["cuda", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"] mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]

View File

@ -112,6 +112,7 @@ impl LayerWeights {
let q = self.attention_wq.forward(x)?; let q = self.attention_wq.forward(x)?;
let k = self.attention_wk.forward(x)?; let k = self.attention_wk.forward(x)?;
let v = self.attention_wv.forward(x)?; let v = self.attention_wv.forward(x)?;
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
let q = q let q = q
.reshape((b_sz, seq_len, self.n_head, self.head_dim))? .reshape((b_sz, seq_len, self.n_head, self.head_dim))?
@ -145,9 +146,12 @@ impl LayerWeights {
let v = self.repeat_kv(v)?; let v = self.repeat_kv(v)?;
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
// println!("att {:?}", att.dtype());
let mask = mask.broadcast_as(att.shape())?; let mask = mask.broadcast_as(att.shape())?;
// println!("mask {:?}", mask.dtype());
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?; let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
let att = candle_nn::ops::softmax_last_dim(&att)?; let att = candle_nn::ops::softmax_last_dim(&att)?;
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
// Convert to contiguous as matmul doesn't support strided vs for now. // Convert to contiguous as matmul doesn't support strided vs for now.
let y = att.matmul(&v.contiguous()?)?; let y = att.matmul(&v.contiguous()?)?;
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?; let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;