Add argsort. (#2132)

* Add the argsort cuda kernels.

* CPU version of arg-sort.

* Hook the cuda kernel + rework the cpu bits.

* Add some dedicated test.

* Working cuda kernel.

* Metal kernel.

* Metal adjustments.

* Bugfix.

* Use the fast rope in qwen.

* Rework the expert selection in qwen.
This commit is contained in:
Laurent Mazare
2024-04-27 20:17:35 +02:00
committed by GitHub
parent 6cf82fd7a3
commit 96a48e5cc4
11 changed files with 489 additions and 44 deletions

View File

@ -11,7 +11,7 @@ use std::sync::{Arc, Mutex, RwLock, TryLockError};
mod device;
pub use device::{DeviceId, MetalDevice};
fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
pub fn buffer_o<'a>(buffer: &'a Buffer, l: &Layout, dtype: DType) -> BufferOffset<'a> {
BufferOffset {
buffer,
offset_in_bytes: l.start_offset() * dtype.size_in_bytes(),