mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Lots of updates including some stack of command buffers.
This commit is contained in:
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
output.did_modify_range(metal::NSRange::new(0, output.length()));
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
|
Reference in New Issue
Block a user