mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Finished scaffolding, lots of TODOs
- Most kernels just copy themselfs to get the shapes correct - Matmul works only in 1 case and simply empty allocates otherwise - Logits and randomized to make the demo finish itself. Performance is quite bad (30ms/token), but lot's of prints and allocs and some actual sending to metal. Couln't get it super high by removing the obvious blockers (println + the actual running matmuls). Allocations takes between 1us and 100us and seems very stable, Maybe metal doesn't really have a smart allocator and we'll need to own it.
This commit is contained in:
@ -13,7 +13,7 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
byteorder = { workspace = true }
|
||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.0.1", optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
metal = { workspace = true, optional = true}
|
||||
cudarc = { workspace = true, optional = true }
|
||||
gemm = { workspace = true }
|
||||
|
@ -1,15 +1,15 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::bail;
|
||||
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||
pub use candle_metal_kernels;
|
||||
use candle_metal_kernels;
|
||||
use core::mem;
|
||||
use half::{bf16, f16};
|
||||
use metal;
|
||||
use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor};
|
||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||
use metal::mps::{Float32, MPSDataType};
|
||||
use metal::MTLResourceOptions;
|
||||
use crate::bail;
|
||||
|
||||
/// Metal related errors
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
@ -72,11 +72,16 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
todo!()
|
||||
match self.dtype{
|
||||
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
|
||||
dtype => todo!("Unsupported dtype {dtype:?}")
|
||||
}
|
||||
}
|
||||
|
||||
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
|
||||
todo!()
|
||||
println!("TODO Affine");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||
@ -88,7 +93,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
||||
todo!()
|
||||
println!("TODO reduce_op");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
@ -100,15 +107,22 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
// todo!()
|
||||
// TODO
|
||||
println!("TODO {:?}", B::NAME);
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
println!("TODO Binary {:?}", B::NAME);
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
todo!()
|
||||
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||
println!("TODO where_cond");
|
||||
Ok(rhs.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -174,7 +188,9 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
||||
todo!()
|
||||
println!("TODO Index select");
|
||||
Ok(self.clone())
|
||||
// todo!()
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -195,21 +211,96 @@ impl BackendStorage for MetalStorage {
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = false;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
println!("TODO Copy strided");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub(crate) fn matmul_t(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
let transpose_left = false;
|
||||
let transpose_right = true;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
self.matmul_generic(
|
||||
rhs,
|
||||
(b, m, n, k),
|
||||
lhs_l,
|
||||
rhs_l,
|
||||
transpose_left,
|
||||
transpose_right,
|
||||
alpha,
|
||||
beta,
|
||||
)
|
||||
}
|
||||
pub(crate) fn matmul_generic(
|
||||
&self,
|
||||
rhs: &Self,
|
||||
(b, m, n, k): (usize, usize, usize, usize),
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
transpose_left: bool,
|
||||
transpose_right: bool,
|
||||
alpha: f64,
|
||||
beta: f64,
|
||||
) -> Result<Self> {
|
||||
let elem_count = b * m * n;
|
||||
match (self.dtype, rhs.dtype) {
|
||||
(DType::F32, DType::F32) => {
|
||||
if b != 1 {
|
||||
bail!("Didn't implemented strided matmul yet");
|
||||
println!("TODO implement batched matmul for B={b}");
|
||||
// bail!("Didn't implemented strided matmul yet");
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||
bail!("Didn't implemented non contiguous matmul yet");
|
||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
return Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
});
|
||||
}
|
||||
let out_buffer = self.device.new_buffer(
|
||||
(elem_count * mem::size_of::<f32>()) as u64,
|
||||
MTLResourceOptions::empty(),
|
||||
);
|
||||
let m : u64 = m.try_into().expect("usize should fit u64");
|
||||
let m: u64 = m.try_into().expect("usize should fit u64");
|
||||
let n: u64 = n.try_into().expect("usize should fit u64");
|
||||
let k: u64 = k.try_into().expect("usize should fit u64");
|
||||
// Create descriptors
|
||||
@ -220,6 +311,9 @@ impl BackendStorage for MetalStorage {
|
||||
let result_descriptor =
|
||||
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
|
||||
|
||||
println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||
println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||
println!("out {:?} {m} {n}", out_buffer.length());
|
||||
// Create matrix objects
|
||||
let left_matrix =
|
||||
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||
@ -232,11 +326,7 @@ impl BackendStorage for MetalStorage {
|
||||
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||
.expect("Failed to create left matrix");
|
||||
|
||||
let transpose_left = false;
|
||||
let transpose_right = false;
|
||||
let alpha = 1.0;
|
||||
let beta = 0.0;
|
||||
|
||||
println!("lhs {:?}", lhs_l.shape());
|
||||
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
@ -258,20 +348,15 @@ impl BackendStorage for MetalStorage {
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
Ok(Self{
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
|
||||
Ok(Self {
|
||||
buffer: out_buffer,
|
||||
device: self.device.clone(),
|
||||
dtype: self.dtype(),
|
||||
})
|
||||
}
|
||||
_ => todo!("Unimplemented matmul for this pair"),
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for MetalDevice {
|
||||
@ -281,7 +366,11 @@ impl BackendDevice for MetalDevice {
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
let _command_queue = device.new_command_queue();
|
||||
let command_buffer = _command_queue.new_owned_command_buffer();
|
||||
Ok(Self { device, _command_queue, command_buffer })
|
||||
Ok(Self {
|
||||
device,
|
||||
_command_queue,
|
||||
command_buffer,
|
||||
})
|
||||
}
|
||||
|
||||
fn set_seed(&self, _seed: u64) -> Result<()> {
|
||||
@ -296,12 +385,16 @@ impl BackendDevice for MetalDevice {
|
||||
self.device.registry_id() == rhs.device.registry_id()
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<MetalStorage> {
|
||||
todo!()
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
todo!()
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
// TODO Is there a faster way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
@ -350,11 +443,15 @@ impl BackendDevice for MetalDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
todo!()
|
||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
|
||||
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
|
||||
todo!()
|
||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
||||
// TODO is there a better way ?
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||
self.storage_from_cpu_storage(&cpu_storage)
|
||||
}
|
||||
}
|
||||
|
@ -182,7 +182,7 @@ pub trait CustomOp1 {
|
||||
_layout: &Layout,
|
||||
) -> Result<(MetalStorage, Shape)> {
|
||||
Err(crate::Error::Metal(
|
||||
format!("no cuda implementation for {}", self.name()).into(),
|
||||
format!("no metal implementation for {}", self.name()).into(),
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -315,6 +315,49 @@ impl crate::CustomOp1 for QTensor {
|
||||
)?;
|
||||
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
|
||||
}
|
||||
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &crate::MetalStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::MetalStorage, Shape)> {
|
||||
println!("TODO qmatmul");
|
||||
if !layout.is_contiguous() {
|
||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||
}
|
||||
let src_shape = layout.shape();
|
||||
// self is transposed so n is first then k.
|
||||
let (n, k) = self.shape.dims2()?;
|
||||
if src_shape.rank() < 2 {
|
||||
crate::bail!("input tensor has only one dimension {layout:?}")
|
||||
}
|
||||
let mut dst_shape = src_shape.dims().to_vec();
|
||||
let last_k = dst_shape.pop().unwrap();
|
||||
if last_k != k {
|
||||
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
|
||||
}
|
||||
dst_shape.push(n);
|
||||
let dst_shape = Shape::from(dst_shape);
|
||||
// let storage = storage.as_slice::<f32>()?;
|
||||
// let storage =
|
||||
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
let dst_storage = vec![0f32; dst_shape.elem_count()];
|
||||
// self.matmul_t(
|
||||
// (dst_shape.elem_count() / n, k, n),
|
||||
// storage,
|
||||
// &mut dst_storage,
|
||||
// )?;
|
||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
if let Device::Metal(device) = &self.device{
|
||||
Ok((
|
||||
device.storage_from_cpu_storage(&cpu_storage)?,
|
||||
dst_shape,
|
||||
))
|
||||
}else{
|
||||
crate::bail!("qtensor not on metal device")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
|
@ -51,6 +51,7 @@ anyhow = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
|
||||
cudnn = ["candle/cudnn"]
|
||||
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::{ggml_file, gguf_file};
|
||||
use candle::{Device, Tensor};
|
||||
use candle::{Tensor};
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
|
||||
use candle_transformers::models::quantized_llama as model;
|
||||
@ -367,9 +367,11 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
logits_processor.sample(&logits)?
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
@ -380,7 +382,7 @@ fn main() -> anyhow::Result<()> {
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
@ -393,6 +395,8 @@ fn main() -> anyhow::Result<()> {
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
// TODO Remove this once implementation is finished.
|
||||
let logits = logits.ones_like()?;
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
print_token(next_token, &tokenizer);
|
||||
|
@ -1,13 +1,12 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.0.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
keywords = ["blas", "tensor", "machine-learning"]
|
||||
categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
description.workspace = true
|
||||
repository.workspace = true
|
||||
keywords.workspace = true
|
||||
categories.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
metal = { workspace = true, optional = true}
|
||||
metal = { workspace = true }
|
||||
|
@ -28,4 +28,5 @@ clap = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
metal = ["candle/metal"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
|
@ -191,6 +191,16 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
println!("TODO softmax-last-dim");
|
||||
Ok((storage.clone(), layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
@ -28,5 +28,6 @@ wav = { workspace = true }
|
||||
default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate"]
|
||||
cuda = ["candle/cuda", "candle-nn/cuda"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
flash-attn = ["cuda", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl"]
|
||||
|
@ -112,6 +112,7 @@ impl LayerWeights {
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
// println!("Q {:?} K {:?} V {:?}", q.dtype(), k.dtype(), v.dtype());
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
@ -145,9 +146,12 @@ impl LayerWeights {
|
||||
let v = self.repeat_kv(v)?;
|
||||
|
||||
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
|
||||
// println!("att {:?}", att.dtype());
|
||||
let mask = mask.broadcast_as(att.shape())?;
|
||||
// println!("mask {:?}", mask.dtype());
|
||||
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)?;
|
||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||
// println!("att {:?} v {:?}", att.dtype(), v.dtype());
|
||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||
let y = att.matmul(&v.contiguous()?)?;
|
||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||
|
Reference in New Issue
Block a user