mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Metal operational.
This commit is contained in:
@ -58,6 +58,10 @@ impl MetalDevice {
|
|||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn metal_device(&self) -> &metal::Device {
|
||||||
|
&self.device
|
||||||
|
}
|
||||||
|
|
||||||
pub fn command_queue(&self) -> &CommandQueue {
|
pub fn command_queue(&self) -> &CommandQueue {
|
||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
@ -215,8 +219,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
assert!(sum_dims.len() == 1);
|
assert!(sum_dims.len() == 1);
|
||||||
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
assert!(sum_dims[0] == layout.shape().rank() - 1);
|
||||||
assert!(layout.is_contiguous());
|
assert!(layout.stride()[sum_dims[0]] == 1);
|
||||||
assert!(layout.start_offset() == 0);
|
// assert!(layout.is_contiguous());
|
||||||
|
// assert!(layout.start_offset() == 0);
|
||||||
let device = self.device.clone();
|
let device = self.device.clone();
|
||||||
let src_stride = layout.stride();
|
let src_stride = layout.stride();
|
||||||
let src_dims = layout.shape().dims();
|
let src_dims = layout.shape().dims();
|
||||||
@ -251,6 +256,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
}
|
}
|
||||||
let dtype = if return_index { DType::U32 } else { self.dtype };
|
let dtype = if return_index { DType::U32 } else { self.dtype };
|
||||||
|
if dtype == DType::U32{
|
||||||
|
todo!("Implement this");
|
||||||
|
}
|
||||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer();
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
@ -261,6 +269,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
src_el,
|
src_el,
|
||||||
dst_el,
|
dst_el,
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
|
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&mut buffer,
|
&mut buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
@ -727,26 +736,26 @@ impl BackendStorage for MetalStorage {
|
|||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
// let stride_left: u64 = match lhs_stride[..lhs_stride.len() - 2] {
|
||||||
[s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
// [s1, stride] if s1 == stride * lhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
// [stride] => stride,
|
||||||
[] => m * k,
|
// [] => m * k,
|
||||||
_ => Err(MetalError::MatMulNonContiguous {
|
// _ => Err(MetalError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
// lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
// rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
// mnk: (m, n, k),
|
||||||
})?,
|
// })?,
|
||||||
} as u64;
|
// } as u64;
|
||||||
let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
// let stride_right: u64 = match rhs_stride[..rhs_stride.len() - 2] {
|
||||||
[s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
// [s1, stride] if s1 == stride * rhs_l.dims()[1] => stride,
|
||||||
[stride] => stride,
|
// [stride] => stride,
|
||||||
[] => n * k,
|
// [] => n * k,
|
||||||
_ => Err(MetalError::MatMulNonContiguous {
|
// _ => Err(MetalError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
// lhs_stride: lhs_stride.to_vec(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
// rhs_stride: rhs_stride.to_vec(),
|
||||||
mnk: (m, n, k),
|
// mnk: (m, n, k),
|
||||||
})?,
|
// })?,
|
||||||
} as u64;
|
// } as u64;
|
||||||
|
|
||||||
let b = b as NSUInteger;
|
let b = b as NSUInteger;
|
||||||
let m = m as NSUInteger;
|
let m = m as NSUInteger;
|
||||||
|
@ -57,6 +57,7 @@ flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
|
|||||||
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
|
||||||
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
nccl = ["cuda", "cudarc/nccl", "dep:half"]
|
||||||
onnx = ["candle-onnx"]
|
onnx = ["candle-onnx"]
|
||||||
|
metal = ["candle/metal", "candle-nn/metal"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "llama_multiprocess"
|
name = "llama_multiprocess"
|
||||||
|
@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
|
|||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
// println!("Kernel {:?}", kernel_name.0);
|
|
||||||
// assert_eq!(input.length(), output.length());
|
|
||||||
|
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
@ -484,6 +481,7 @@ pub fn call_reduce_contiguous(
|
|||||||
length: usize,
|
length: usize,
|
||||||
out_length: usize,
|
out_length: usize,
|
||||||
input: &Buffer,
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
output: &mut Buffer,
|
output: &mut Buffer,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
@ -492,7 +490,7 @@ pub fn call_reduce_contiguous(
|
|||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
||||||
|
|
||||||
let thread_group_count = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: out_length as u64,
|
width: out_length as u64,
|
||||||
@ -1228,6 +1226,7 @@ mod tests {
|
|||||||
v.len(),
|
v.len(),
|
||||||
out_length,
|
out_length,
|
||||||
&input,
|
&input,
|
||||||
|
0,
|
||||||
&mut output,
|
&mut output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -16,7 +16,7 @@ METAL_FUNC uint get_strided_index(
|
|||||||
return strided_i;
|
return strided_i;
|
||||||
}
|
}
|
||||||
|
|
||||||
constant int THREADGROUP_SIZE = 256;
|
constant int THREADGROUP_SIZE = 1024;
|
||||||
|
|
||||||
# define REDUCE(FN, NAME, TYPENAME) \
|
# define REDUCE(FN, NAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
|
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
|||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
safetensors = { workspace = true }
|
safetensors = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
@ -29,3 +30,4 @@ default = []
|
|||||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||||
cuda = ["candle/cuda"]
|
cuda = ["candle/cuda"]
|
||||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||||
|
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||||
|
@ -201,6 +201,37 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
};
|
};
|
||||||
Ok((dst, layout.shape().clone()))
|
Ok((dst, layout.shape().clone()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
storage: &candle::MetalStorage,
|
||||||
|
layout: &Layout,
|
||||||
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
|
use candle::backend::{BackendStorage};
|
||||||
|
let device = storage.device();
|
||||||
|
let command_buffer = device.command_buffer();
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let name = "softmax_float";
|
||||||
|
assert!(layout.is_contiguous());
|
||||||
|
assert!(layout.start_offset() == 0);
|
||||||
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||||
|
let elem_count = layout.shape().elem_count();
|
||||||
|
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||||
|
candle_metal_kernels::call_last_softmax(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
&kernels,
|
||||||
|
name,
|
||||||
|
elem_count,
|
||||||
|
last_dim,
|
||||||
|
storage.buffer(),
|
||||||
|
&mut output,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||||
|
Ok((newstorage, layout.shape().clone()))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user