mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Metal operational.
This commit is contained in:
@ -300,9 +300,6 @@ pub fn call_unary_contiguous(
|
||||
input: &Buffer,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
// println!("Kernel {:?}", kernel_name.0);
|
||||
// assert_eq!(input.length(), output.length());
|
||||
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
@ -484,6 +481,7 @@ pub fn call_reduce_contiguous(
|
||||
length: usize,
|
||||
out_length: usize,
|
||||
input: &Buffer,
|
||||
input_offset: usize,
|
||||
output: &mut Buffer,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||
@ -492,7 +490,7 @@ pub fn call_reduce_contiguous(
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, elements_to_sum, input, output));
|
||||
set_params!(encoder, (length, elements_to_sum, (input,input_offset), output));
|
||||
|
||||
let thread_group_count = MTLSize {
|
||||
width: out_length as u64,
|
||||
@ -1228,6 +1226,7 @@ mod tests {
|
||||
v.len(),
|
||||
out_length,
|
||||
&input,
|
||||
0,
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
|
Reference in New Issue
Block a user