mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Fixed matmul (display still broken without casting back to CPU first? )
This commit is contained in:
@ -60,7 +60,8 @@ tracing-subscriber = "0.3.7"
|
|||||||
wav = "1.0.0"
|
wav = "1.0.0"
|
||||||
yoke = { version = "0.7.2", features = ["derive"] }
|
yoke = { version = "0.7.2", features = ["derive"] }
|
||||||
zip = { version = "0.6.6", default-features = false }
|
zip = { version = "0.6.6", default-features = false }
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../metal-rs", features = ["mps"] }
|
||||||
|
|
||||||
[profile.release-with-debug]
|
[profile.release-with-debug]
|
||||||
inherits = "release"
|
inherits = "release"
|
||||||
|
@ -19,6 +19,13 @@ pub enum MetalError {
|
|||||||
Message(String),
|
Message(String),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
KernelError(#[from] candle_metal_kernels::MetalKernelError),
|
||||||
|
|
||||||
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
|
MatMulNonContiguous {
|
||||||
|
lhs_stride: Vec<usize>,
|
||||||
|
rhs_stride: Vec<usize>,
|
||||||
|
mnk: (usize, usize, usize),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for MetalError {
|
impl From<String> for MetalError {
|
||||||
@ -53,7 +60,7 @@ impl MetalDevice {
|
|||||||
// self.device.as_ref()
|
// self.device.as_ref()
|
||||||
// }
|
// }
|
||||||
|
|
||||||
pub fn id(&self) -> u64 {
|
pub fn id(&self) -> NSUInteger {
|
||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -70,7 +77,7 @@ impl MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
// debug!("Allocate 1 - buffer size {size}");
|
// debug!("Allocate 1 - buffer size {size}");
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
@ -520,6 +527,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
// println!("INDEX SELECT");
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -561,20 +569,117 @@ impl BackendStorage for MetalStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let transpose_left = false;
|
// Create descriptors
|
||||||
let transpose_right = !rhs_l.is_contiguous();
|
use metal::mps::matrix::*;
|
||||||
let alpha = 1.0;
|
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
|
||||||
let beta = 0.0;
|
let size = core::mem::size_of::<f32>() as NSUInteger;
|
||||||
self.matmul_generic(
|
|
||||||
rhs,
|
let elem_count = b * m * n;
|
||||||
(b, m, n, k),
|
|
||||||
lhs_l,
|
let lhs_stride = lhs_l.stride();
|
||||||
rhs_l,
|
let rhs_stride = rhs_l.stride();
|
||||||
|
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
|
||||||
|
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
|
||||||
|
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
|
||||||
|
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
|
||||||
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
|
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
|
||||||
|
false
|
||||||
|
} else if lhs_m1 == m && lhs_m2 == 1 {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
|
||||||
|
false
|
||||||
|
} else if rhs_m1 == k && rhs_m2 == 1 {
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
Err(MetalError::MatMulNonContiguous {
|
||||||
|
lhs_stride: lhs_stride.to_vec(),
|
||||||
|
rhs_stride: rhs_stride.to_vec(),
|
||||||
|
mnk: (m, n, k),
|
||||||
|
})?
|
||||||
|
};
|
||||||
|
// println!("{transpose_left} {transpose_right}");
|
||||||
|
|
||||||
|
let b = b as NSUInteger;
|
||||||
|
let m = m as NSUInteger;
|
||||||
|
let n = n as NSUInteger;
|
||||||
|
let k = k as NSUInteger;
|
||||||
|
|
||||||
|
let left_descriptor = if transpose_left {
|
||||||
|
MatrixDescriptor::init_single(k, m, m * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(m, k, k * size, type_id)
|
||||||
|
};
|
||||||
|
let right_descriptor = if transpose_right {
|
||||||
|
MatrixDescriptor::init_single(n, k, k * size, type_id)
|
||||||
|
} else {
|
||||||
|
MatrixDescriptor::init_single(k, n, n * size, type_id)
|
||||||
|
};
|
||||||
|
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
|
||||||
|
|
||||||
|
// Create matrix objects
|
||||||
|
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
|
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let alpha = 1.0f64;
|
||||||
|
let beta = 0.0f64;
|
||||||
|
// Create kernel
|
||||||
|
let matrix_multiplication = MatrixMultiplication::init(
|
||||||
|
&self.device,
|
||||||
transpose_left,
|
transpose_left,
|
||||||
transpose_right,
|
transpose_right,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
alpha,
|
alpha,
|
||||||
beta,
|
beta,
|
||||||
)
|
)
|
||||||
|
.ok_or_else(|| {
|
||||||
|
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
matrix_multiplication.set_batch_size(b);
|
||||||
|
|
||||||
|
// Encode kernel to command buffer
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
matrix_multiplication.encode_to_command_buffer(
|
||||||
|
command_buffer,
|
||||||
|
&left_matrix,
|
||||||
|
&right_matrix,
|
||||||
|
&result_matrix,
|
||||||
|
);
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
// let left = self.buffer.read_to_vec::<f32>(10);
|
||||||
|
// let right = rhs.buffer.read_to_vec::<f32>(10);
|
||||||
|
// let out = out_buffer.read_to_vec::<f32>(40);
|
||||||
|
// todo!("Out {left:?} {right:?} {out:?}");
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer: out_buffer,
|
||||||
|
device: self.device.clone(),
|
||||||
|
dtype: self.dtype(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
@ -583,18 +688,6 @@ impl BackendStorage for MetalStorage {
|
|||||||
if el_count == 0 {
|
if el_count == 0 {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
// todo!("Copy strided {:?}", src_l.is_contiguous());
|
|
||||||
// if src_l.is_contiguous() {
|
|
||||||
// let command_buffer = self.device.command_queue.new_command_buffer();
|
|
||||||
// let blip = command_buffer.new_blit_command_encoder();
|
|
||||||
// blip.copy_from_buffer(
|
|
||||||
// &self.buffer,
|
|
||||||
// src_l.start_offset() as u64,
|
|
||||||
// &dst.buffer,
|
|
||||||
// dst_offset as u64,
|
|
||||||
// self.buffer.length(),
|
|
||||||
// );
|
|
||||||
// } else {
|
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
let kernel_name = match self.dtype {
|
let kernel_name = match self.dtype {
|
||||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||||
@ -631,84 +724,6 @@ impl MetalStorage {
|
|||||||
dtype,
|
dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pub(crate) fn matmul_generic(
|
|
||||||
&self,
|
|
||||||
rhs: &Self,
|
|
||||||
(b, m, n, k): (usize, usize, usize, usize),
|
|
||||||
lhs_l: &Layout,
|
|
||||||
rhs_l: &Layout,
|
|
||||||
transpose_left: bool,
|
|
||||||
transpose_right: bool,
|
|
||||||
alpha: f64,
|
|
||||||
beta: f64,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let elem_count = b * m * n;
|
|
||||||
match (self.dtype, rhs.dtype) {
|
|
||||||
(DType::F32, DType::F32) => {
|
|
||||||
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
|
||||||
// if b != 1 {
|
|
||||||
// // debug!("TODO implement batched matmul for B={b}");
|
|
||||||
// crate::bail!("Didn't implemented strided matmul yet");
|
|
||||||
// return Ok(Self {
|
|
||||||
// buffer: out_buffer,
|
|
||||||
// device: self.device.clone(),
|
|
||||||
// dtype: self.dtype(),
|
|
||||||
// });
|
|
||||||
//}
|
|
||||||
// if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
|
||||||
// // debug!(
|
|
||||||
// // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
|
|
||||||
// // lhs_l.is_contiguous(),
|
|
||||||
// // rhs_l.is_contiguous(),
|
|
||||||
// // rhs_l
|
|
||||||
// // );
|
|
||||||
// crate::bail!("No not contiguous matmul");
|
|
||||||
// return Ok(Self {
|
|
||||||
// buffer: out_buffer,
|
|
||||||
// device: self.device.clone(),
|
|
||||||
// dtype: self.dtype(),
|
|
||||||
// });
|
|
||||||
// }
|
|
||||||
|
|
||||||
// debug!("TODO GEMM");
|
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
|
||||||
encode_gemm::<Float32, Float32, Float32>(
|
|
||||||
&self.device,
|
|
||||||
&command_buffer,
|
|
||||||
transpose_left,
|
|
||||||
transpose_right,
|
|
||||||
&self.buffer,
|
|
||||||
&rhs.buffer,
|
|
||||||
&mut out_buffer,
|
|
||||||
m as NSUInteger,
|
|
||||||
n as NSUInteger,
|
|
||||||
k as NSUInteger,
|
|
||||||
alpha as f32,
|
|
||||||
beta as f32,
|
|
||||||
Some(b as NSUInteger),
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
|
|
||||||
command_buffer.commit();
|
|
||||||
command_buffer.wait_until_completed();
|
|
||||||
// command_buffer.wait_until_scheduled();
|
|
||||||
//
|
|
||||||
let left = self.buffer.read_to_vec::<f32>(10);
|
|
||||||
let right = rhs.buffer.read_to_vec::<f32>(10);
|
|
||||||
let out = out_buffer.read_to_vec::<f32>(10);
|
|
||||||
|
|
||||||
println!("{b} {m} {n} {k} ");
|
|
||||||
println!("{left:?} {right:?} {out:?}");
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
buffer: out_buffer,
|
|
||||||
device: self.device.clone(),
|
|
||||||
dtype: self.dtype(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
_ => todo!("Unimplemented matmul for this pair"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn buffer(&self) -> &Buffer {
|
pub fn buffer(&self) -> &Buffer {
|
||||||
&self.buffer
|
&self.buffer
|
||||||
@ -774,37 +789,37 @@ impl BackendDevice for MetalDevice {
|
|||||||
let buffer = match storage {
|
let buffer = match storage {
|
||||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<u8>()) as u64,
|
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<u32>()) as u64,
|
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<i64>()) as u64,
|
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<bf16>()) as u64,
|
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f16>()) as u64,
|
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f32>()) as u64,
|
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
(storage.len() * mem::size_of::<f64>()) as u64,
|
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
|
||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
@ -10,7 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
|
||||||
|
metal = { path = "../../metal-rs", features = ["mps"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -156,6 +156,7 @@ impl CausalSelfAttention {
|
|||||||
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
|
||||||
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
let x0 = x.narrow(D::Minus1, 0, 1)?;
|
||||||
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
let x1 = x.narrow(D::Minus1, 1, 1)?;
|
||||||
|
todo!("X {x1}");
|
||||||
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
|
||||||
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
|
||||||
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
|
||||||
@ -165,7 +166,6 @@ impl CausalSelfAttention {
|
|||||||
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.q_proj.forward(x)?;
|
let q = self.q_proj.forward(x)?;
|
||||||
todo!("X {q}");
|
|
||||||
let k = self.k_proj.forward(x)?;
|
let k = self.k_proj.forward(x)?;
|
||||||
let v = self.v_proj.forward(x)?;
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
@ -174,6 +174,7 @@ impl CausalSelfAttention {
|
|||||||
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
|
||||||
|
|
||||||
let q = self.apply_rotary_emb(&q, index_pos)?;
|
let q = self.apply_rotary_emb(&q, index_pos)?;
|
||||||
|
todo!("X {q}");
|
||||||
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
let mut k = self.apply_rotary_emb(&k, index_pos)?;
|
||||||
|
|
||||||
if self.cache.use_kv_cache {
|
if self.cache.use_kv_cache {
|
||||||
|
Reference in New Issue
Block a user