Fixed matmul (display still broken without casting back to CPU first? )

This commit is contained in:
Nicolas Patry
2023-11-10 20:09:25 +01:00
committed by Nicolas Patry
parent d46670f7c0
commit 38de52bc4b
4 changed files with 127 additions and 111 deletions

View File

@ -61,7 +61,8 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../metal-rs", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -19,6 +19,13 @@ pub enum MetalError {
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
mnk: (usize, usize, usize),
},
}
impl From<String> for MetalError {
@ -53,7 +60,7 @@ impl MetalDevice {
// self.device.as_ref()
// }
pub fn id(&self) -> u64 {
pub fn id(&self) -> NSUInteger {
self.registry_id()
}
@ -70,7 +77,7 @@ impl MetalDevice {
}
pub fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
@ -561,20 +568,116 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = !rhs_l.is_contiguous();
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
// Create descriptors
use metal::mps::matrix::*;
let type_id = metal::mps::MPS_FLOATBIT_ENCODING | 32;
let size = core::mem::size_of::<f32>() as NSUInteger;
let elem_count = b * m * n;
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_descriptor = if transpose_left {
MatrixDescriptor::init_single(k, m, m * size, type_id)
} else {
MatrixDescriptor::init_single(m, k, k * size, type_id)
};
let right_descriptor = if transpose_right {
MatrixDescriptor::init_single(n, k, k * size, type_id)
} else {
MatrixDescriptor::init_single(k, n, n * size, type_id)
};
let result_descriptor = MatrixDescriptor::init_single(m, n, n * size, type_id);
// Create matrix objects
let left_matrix = Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let right_matrix = Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let out_buffer = self.device.new_buffer(elem_count, self.dtype);
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
// Encode kernel to command buffer
let command_buffer = self.device.command_queue.new_command_buffer();
matrix_multiplication.encode_to_command_buffer(
command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.commit();
command_buffer.wait_until_completed();
// let left = self.buffer.read_to_vec::<f32>(10);
// let right = rhs.buffer.read_to_vec::<f32>(10);
// let out = out_buffer.read_to_vec::<f32>(40);
// todo!("Out {left:?} {right:?} {out:?}");
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
@ -583,18 +686,6 @@ impl BackendStorage for MetalStorage {
if el_count == 0 {
return Ok(());
}
// todo!("Copy strided {:?}", src_l.is_contiguous());
// if src_l.is_contiguous() {
// let command_buffer = self.device.command_queue.new_command_buffer();
// let blip = command_buffer.new_blit_command_encoder();
// blip.copy_from_buffer(
// &self.buffer,
// src_l.start_offset() as u64,
// &dst.buffer,
// dst_offset as u64,
// self.buffer.length(),
// );
// } else {
let command_buffer = self.device.command_queue.new_command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
@ -631,84 +722,6 @@ impl MetalStorage {
dtype,
}
}
pub(crate) fn matmul_generic(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
transpose_left: bool,
transpose_right: bool,
alpha: f64,
beta: f64,
) -> Result<Self> {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
// if b != 1 {
// // debug!("TODO implement batched matmul for B={b}");
// crate::bail!("Didn't implemented strided matmul yet");
// return Ok(Self {
// buffer: out_buffer,
// device: self.device.clone(),
// dtype: self.dtype(),
// });
//}
// if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
// // debug!(
// // "TODO non contiguous matmul yet {:?} {:?} - {:?} - {transpose_right}",
// // lhs_l.is_contiguous(),
// // rhs_l.is_contiguous(),
// // rhs_l
// // );
// crate::bail!("No not contiguous matmul");
// return Ok(Self {
// buffer: out_buffer,
// device: self.device.clone(),
// dtype: self.dtype(),
// });
// }
// debug!("TODO GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&command_buffer,
transpose_left,
transpose_right,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha as f32,
beta as f32,
Some(b as NSUInteger),
)
.map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.wait_until_scheduled();
//
let left = self.buffer.read_to_vec::<f32>(10);
let right = rhs.buffer.read_to_vec::<f32>(10);
let out = out_buffer.read_to_vec::<f32>(10);
println!("{b} {m} {n} {k} ");
println!("{left:?} {right:?} {out:?}");
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
_ => todo!("Unimplemented matmul for this pair"),
}
}
pub fn buffer(&self) -> &Buffer {
&self.buffer
@ -774,37 +787,37 @@ impl BackendDevice for MetalDevice {
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
(storage.len() * mem::size_of::<u8>()) as NSUInteger,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
(storage.len() * mem::size_of::<u32>()) as NSUInteger,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
(storage.len() * mem::size_of::<i64>()) as NSUInteger,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
(storage.len() * mem::size_of::<bf16>()) as NSUInteger,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
(storage.len() * mem::size_of::<f16>()) as NSUInteger,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
(storage.len() * mem::size_of::<f32>()) as NSUInteger,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
(storage.len() * mem::size_of::<f64>()) as NSUInteger,
option,
),
};

View File

@ -10,7 +10,8 @@ categories = ["science"]
license = "MIT OR Apache-2.0"
[dependencies]
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
# metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
metal = { path = "../../metal-rs", features = ["mps"] }
once_cell = "1.18.0"
thiserror = "1"
tracing = "0.1.37"

View File

@ -156,6 +156,7 @@ impl CausalSelfAttention {
let x = x.reshape((b_sz, seq_len, h, n_embd / 2, 2))?;
let x0 = x.narrow(D::Minus1, 0, 1)?;
let x1 = x.narrow(D::Minus1, 1, 1)?;
todo!("X {x1}");
let dst0 = (x0.broadcast_mul(&cos)? - x1.broadcast_mul(&sin)?)?;
let dst1 = (x0.broadcast_mul(&sin)? + x1.broadcast_mul(&cos)?)?;
let rope = Tensor::cat(&[&dst0, &dst1], D::Minus1)?.reshape((b_sz, seq_len, h, n_embd))?;
@ -165,7 +166,6 @@ impl CausalSelfAttention {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?;
todo!("X {q}");
let k = self.k_proj.forward(x)?;
let v = self.v_proj.forward(x)?;
@ -174,6 +174,7 @@ impl CausalSelfAttention {
let mut v = v.reshape((b_sz, seq_len, self.n_key_value_head, self.head_dim))?;
let q = self.apply_rotary_emb(&q, index_pos)?;
todo!("X {q}");
let mut k = self.apply_rotary_emb(&k, index_pos)?;
if self.cache.use_kv_cache {