mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Metal operational.
This commit is contained in:
@ -58,6 +58,10 @@ impl MetalDevice {
|
||||
self.registry_id()
|
||||
}
|
||||
|
||||
pub fn metal_device(&self) -> &metal::Device {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn command_queue(&self) -> &CommandQueue {
|
||||
&self.command_queue
|
||||
}
|
||||
@ -215,8 +219,9 @@ impl BackendStorage for MetalStorage {
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
assert!(sum_dims.len() == 1);
|
||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
assert!(layout.stride()[sum_dims[0]] == 1);
|
||||
// assert!(layout.is_contiguous());
|
||||
// assert!(layout.start_offset() == 0);
|
||||
let device = self.device.clone();
|
||||
let src_stride = layout.stride();
|
||||
let src_dims = layout.shape().dims();
|
||||
@ -251,6 +256,9 @@ impl BackendStorage for MetalStorage {
|
||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||
}
|
||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||
if dtype == DType::U32{
|
||||
todo!("Implement this");
|
||||
}
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let command_buffer = self.device.command_buffer();
|
||||
candle_metal_kernels::call_reduce_contiguous(
|
||||
@ -261,6 +269,7 @@ impl BackendStorage for MetalStorage {
|
||||
src_el,
|
||||
dst_el,
|
||||
&self.buffer,
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
&mut buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
@ -727,26 +736,26 @@ impl BackendStorage for MetalStorage {
|
||||
mnk: (m, n, k),
|
||||
})?
|
||||
};
|
||||
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => m * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
[stride] => stride,
|
||||
[] => n * k,
|
||||
_ => Err(MetalError::MatMulNonContiguous {
|
||||
lhs_stride: lhs_stride.to_vec(),
|
||||
rhs_stride: rhs_stride.to_vec(),
|
||||
mnk: (m, n, k),
|
||||
})?,
|
||||
} as u64;
|
||||
// let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||
// [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||
// [stride] => stride,
|
||||
// [] => m * k,
|
||||
// _ => Err(MetalError::MatMulNonContiguous {
|
||||
// lhs_stride: lhs_stride.to_vec(),
|
||||
// rhs_stride: rhs_stride.to_vec(),
|
||||
// mnk: (m, n, k),
|
||||
// })?,
|
||||
// } as u64;
|
||||
// let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||
// [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||
// [stride] => stride,
|
||||
// [] => n * k,
|
||||
// _ => Err(MetalError::MatMulNonContiguous {
|
||||
// lhs_stride: lhs_stride.to_vec(),
|
||||
// rhs_stride: rhs_stride.to_vec(),
|
||||
// mnk: (m, n, k),
|
||||
// })?,
|
||||
// } as u64;
|
||||
|
||||
let b = b as NSUInteger;
|
||||
let m = m as NSUInteger;
|
||||
|
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||
onnx = ["candle-onnx"]
|
||||
metal = ["candle/metal", "candle-nn/metal"]
|
||||
|
||||
[[example]]
|
||||
name = "llama_multiprocess"
|
||||
|
@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -484,6 +481,7 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -492,7 +490,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -1228,6 +1226,7 @@ mod tests {
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
|
@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
|
||||
return strided_i;
|
||||
}
|
||||
|
||||
constant int THREADGROUP_SIZE = 256;
|
||||
constant int THREADGROUP_SIZE = 1024;
|
||||
|
||||
# define REDUCE(FN, NAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
|
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -29,3 +30,4 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
|
@ -201,6 +201,37 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::backend::{BackendStorage};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = "softmax_float";
|
||||
assert!(layout.is_contiguous());
|
||||
assert!(layout.start_offset() == 0);
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user