Metal operational.

This commit is contained in:
Nicolas Patry
2023-11-18 00:52:38 +01:00
parent a0010898cc
commit 251c65f9f1
6 changed files with 69 additions and 27 deletions

View File

@ -58,6 +58,10 @@ impl MetalDevice {
self.registry_id()
}
pub fn metal_device(&self) -> &metal::Device {
&self.device
}
pub fn command_queue(&self) -> &CommandQueue {
&self.command_queue
}
@ -215,8 +219,9 @@ impl BackendStorage for MetalStorage {
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
assert!(sum_dims.len() == 1);
assert!(sum_dims[0] == layout.shape().rank() - 1);
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
assert!(layout.stride()[sum_dims[0]] == 1);
// assert!(layout.is_contiguous());
// assert!(layout.start_offset() == 0);
let device = self.device.clone();
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
@ -251,6 +256,9 @@ impl BackendStorage for MetalStorage {
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
}
let dtype = if return_index { DType::U32 } else { self.dtype };
if dtype == DType::U32{
todo!("Implement this");
}
let mut buffer = device.new_buffer(dst_el, dtype);
let command_buffer = self.device.command_buffer();
candle_metal_kernels::call_reduce_contiguous(
@ -261,6 +269,7 @@ impl BackendStorage for MetalStorage {
src_el,
dst_el,
&self.buffer,
layout.start_offset() * self.dtype.size_in_bytes(),
&mut buffer,
)
.map_err(MetalError::from)?;
@ -727,26 +736,26 @@ impl BackendStorage for MetalStorage {
mnk: (m, n, k),
})?
};
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
[stride] => stride,
[] => m * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
[stride] => stride,
[] => n * k,
_ => Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?,
} as u64;
// let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
// [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
// [stride] => stride,
// [] => m * k,
// _ => Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?,
// } as u64;
// let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
// [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
// [stride] => stride,
// [] => n * k,
// _ => Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?,
// } as u64;
let b = b as NSUInteger;
let m = m as NSUInteger;

View File

@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
[[example]]
name = "llama_multiprocess"

View File

@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
@ -484,6 +481,7 @@ pub fn call_reduce_contiguous(
length: usize,
out_length: usize,
input: &Buffer,
input_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
@ -492,7 +490,7 @@ pub fn call_reduce_contiguous(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
let thread_group_count = MTLSize {
width: out_length as u64,
@ -1228,6 +1226,7 @@ mod tests {
v.len(),
out_length,
&input,
0,
&mut output,
)
.unwrap();

View File

@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
return strided_i;
}
constant int THREADGROUP_SIZE = 256;
constant int THREADGROUP_SIZE = 1024;
# define REDUCE(FN, NAME, TYPENAME) \
kernel void NAME( \

View File

@ -19,6 +19,7 @@ num-traits = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
[dev-dependencies]
anyhow = { workspace = true }
@ -29,3 +30,4 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]
metal = ["candle/metal", "dep:candle-metal-kernels"]

View File

@ -201,6 +201,37 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
use candle::backend::{BackendStorage};
let device = storage.device();
let command_buffer = device.command_buffer();
let kernels = device.kernels();
let name = "softmax_float";
assert!(layout.is_contiguous());
assert!(layout.start_offset() == 0);
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
&kernels,
name,
elem_count,
last_dim,
storage.buffer(),
&mut output,
)
.unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {