mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Starting to fix some tests.
Few fixes.
Going back on remote metal-rs.
Reusing a single buffer (for now) to speed things up.
Adding some half kernels.
All tests are panicking instead of random failure.
Putting back f16 index select.
Add erf.
Working version for llama2-c.
Fixes + cache compute_pipeline_state.
BF16 metal fix.
Remove some prints.
new_owned -> new()..to_owned().
Better batched matmul.
Metal operational.
Reuse buffers on our own reference counts.
Tmp gemm.
Revert "Tmp gemm."
This reverts commit c65f68e988
.
Interleave committing.
Speeding up copies using blit.
Fmt.
Fmt.
Remove the assert!
Fmt all.
Fixes after big rebase.
Add softmax for half and bfloat + tests
Fixing Llama example + accumulate softmax in float.
This commit is contained in:
@ -19,6 +19,7 @@ num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
@ -29,3 +30,4 @@ default = []
|
||||
accelerate = ["dep:accelerate-src", "candle/accelerate"]
|
||||
cuda = ["candle/cuda"]
|
||||
mkl = ["dep:intel-mkl-src", "candle/mkl"]
|
||||
metal = ["candle/metal", "dep:candle-metal-kernels"]
|
||||
|
@ -201,6 +201,46 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
};
|
||||
Ok((dst, layout.shape().clone()))
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
fn metal_fwd(
|
||||
&self,
|
||||
storage: &candle::MetalStorage,
|
||||
layout: &Layout,
|
||||
) -> Result<(candle::MetalStorage, Shape)> {
|
||||
use candle::{backend::BackendStorage, DType};
|
||||
let device = storage.device();
|
||||
let command_buffer = device.command_buffer();
|
||||
let kernels = device.kernels();
|
||||
let name = match storage.dtype() {
|
||||
DType::F32 => "softmax_float",
|
||||
DType::F16 => "softmax_half",
|
||||
DType::BF16 => "softmax_bfloat",
|
||||
dtype => candle::bail!("softmax-last-dim is not implemented for {dtype:?}"),
|
||||
};
|
||||
|
||||
let n = layout.stride().len();
|
||||
if !(layout.stride()[n - 1] == 1 && layout.start_offset() == 0) {
|
||||
candle::bail!("Non contiguous softmax-last-dim is not implemented");
|
||||
}
|
||||
|
||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||
let elem_count = layout.shape().elem_count();
|
||||
let mut output = device.new_buffer(elem_count, storage.dtype());
|
||||
candle_metal_kernels::call_last_softmax(
|
||||
device.metal_device(),
|
||||
&command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
elem_count,
|
||||
last_dim,
|
||||
storage.buffer(),
|
||||
&mut output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {
|
||||
|
Reference in New Issue
Block a user