mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Add argsort. (#2132)
* Add the argsort cuda kernels. * CPU version of arg-sort. * Hook the cuda kernel + rework the cpu bits. * Add some dedicated test. * Working cuda kernel. * Metal kernel. * Metal adjustments. * Bugfix. * Use the fast rope in qwen. * Rework the expert selection in qwen.
This commit is contained in:
@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn asort(device: &Device) -> Result<()> {
|
||||
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
let indexes = tensor.arg_sort_last_dim(true)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
|
||||
);
|
||||
let indexes = tensor.arg_sort_last_dim(false)?;
|
||||
assert_eq!(
|
||||
indexes.to_vec2::<u32>()?,
|
||||
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn unary_op(device: &Device) -> Result<()> {
|
||||
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
|
||||
let tensor = Tensor::new(data, device)?;
|
||||
@ -1151,6 +1167,7 @@ test_device!(
|
||||
);
|
||||
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
|
||||
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
|
||||
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
|
||||
test_device!(var, var_cpu, var_gpu, var_metal);
|
||||
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);
|
||||
|
||||
|
Reference in New Issue
Block a user