Lots of updates including some stack of command buffers.

This commit is contained in:
nicolas
2023-12-12 17:41:56 +01:00
parent da0af3cb3e
commit 87dc559817
10 changed files with 537 additions and 117 deletions

View File

@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype());
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,
@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
&mut output,
)
.unwrap();
command_buffer.commit();
output.did_modify_range(metal::NSRange::new(0, output.length()));
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
Ok((newstorage, layout.shape().clone()))
}