From 03ad494fcdb4c628dc59fa5886619f3d6defc697 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 30 Nov 2023 08:01:42 +0000 Subject: [PATCH] Tweak the basic example to show how to implement sort. --- candle-core/examples/basics.rs | 42 +++++++++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index fe15187b..ca06eebe 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -5,13 +5,43 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::Result; -use candle_core::{Device, Tensor}; +use candle::{CpuStorage, Device, Layout, Shape, Tensor}; +use candle_core as candle; + +struct ArgSort; +impl candle::CustomOp1 for ArgSort { + fn name(&self) -> &'static str { + "arg-sort" + } + + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> candle::Result<(CpuStorage, Shape)> { + if layout.shape().rank() != 1 { + candle::bail!( + "input should have a single dimension, got {:?}", + layout.shape() + ) + } + let slice = storage.as_slice::()?; + let src = match layout.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => &slice[o1..o2], + }; + let mut dst = (0..src.len() as u32).collect::>(); + dst.sort_by(|&i, &j| src[i as usize].total_cmp(&src[j as usize])); + let storage = candle::WithDType::to_cpu_storage_owned(dst); + Ok((storage, layout.shape().clone())) + } +} fn main() -> Result<()> { - let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; - let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; - let new_a = a.slice_scatter(&b, 1, 2)?; - assert_eq!(a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); - assert_eq!(new_a.to_vec2::()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let a = Tensor::new(&[0.0f32, 1.0, 3.0, 2.0, -12.0, 4.0, 3.5], &Device::Cpu)?; + let indices = a.apply_op1(ArgSort)?; + let a_sorted = a.gather(&indices, 0)?; + println!("{indices}"); + println!("{a_sorted}"); Ok(()) }