Add argsort. (#2132)

* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
This commit is contained in:
Laurent Mazare
2024-04-27 20:17:35 +02:00
committed by GitHub
parent 6cf82fd7a3
commit 96a48e5cc4
11 changed files with 489 additions and 44 deletions

View File

@ -96,6 +96,22 @@ fn clamp(device: &Device) -> Result<()> {
Ok(())
}
fn asort(device: &Device) -> Result<()> {
let data = &[[3f32, 1., 4., 1.1, 5.], [2.1, 1., 7., 8., 2.]];
let tensor = Tensor::new(data, device)?;
let indexes = tensor.arg_sort_last_dim(true)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[1, 3, 0, 2, 4], [1, 4, 0, 2, 3]],
);
let indexes = tensor.arg_sort_last_dim(false)?;
assert_eq!(
indexes.to_vec2::<u32>()?,
[[4, 2, 0, 3, 1], [3, 2, 0, 4, 1]],
);
Ok(())
}
fn unary_op(device: &Device) -> Result<()> {
let data = &[[-3f32, 1., 4., -0.1, 0.5], [2.7, -1.8, -0.28, 1.8, 2.8]];
let tensor = Tensor::new(data, device)?;
@ -1151,6 +1167,7 @@ test_device!(
);
test_device!(randn, randn_cpu, randn_gpu, randn_metal);
test_device!(clamp, clamp_cpu, clamp_gpu, clamp_metal);
test_device!(asort, asort_cpu, asort_gpu, asort_metal);
test_device!(var, var_cpu, var_gpu, var_metal);
test_device!(zero_dim, zero_dim_cpu, zero_dim_gpu, zero_dim_metal);