mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Lots of updates including some stack of command buffers.
This commit is contained in:
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
metal = { workspace = true, optional = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
@ -30,4 +31,4 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels", "dep:metal"]
|
||||
|
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
@ -238,6 +238,8 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
output.did_modify_range(metal::NSRange::new(0, output.length()));
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
|
Reference in New Issue
Block a user