Tweak the basic example to show how to implement sort.

This commit is contained in:
laurent
2023-11-30 08:01:42 +00:00
parent 7c3cfd1086
commit 03ad494fcd

View File

@ -5,13 +5,43 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use anyhow::Result; use anyhow::Result;
use candle_core::{Device, Tensor}; use candle::{CpuStorage, Device, Layout, Shape, Tensor};
use candle_core as candle;
struct ArgSort;
impl candle::CustomOp1 for ArgSort {
fn name(&self) -> &'static str {
"arg-sort"
}
fn cpu_fwd(
&self,
storage: &CpuStorage,
layout: &Layout,
) -> candle::Result<(CpuStorage, Shape)> {
if layout.shape().rank() != 1 {
candle::bail!(
"input should have a single dimension, got {:?}",
layout.shape()
)
}
let slice = storage.as_slice::<f32>()?;
let src = match layout.contiguous_offsets() {
None => candle::bail!("input has to be contiguous"),
Some((o1, o2)) => &slice[o1..o2],
};
let mut dst = (0..src.len() as u32).collect::<Vec<u32>>();
dst.sort_by(|&i, &j| src[i as usize].total_cmp(&src[j as usize]));
let storage = candle::WithDType::to_cpu_storage_owned(dst);
Ok((storage, layout.shape().clone()))
}
}
fn main() -> Result<()> { fn main() -> Result<()> {
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?; let a = Tensor::new(&[0.0f32, 1.0, 3.0, 2.0, -12.0, 4.0, 3.5], &Device::Cpu)?;
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?; let indices = a.apply_op1(ArgSort)?;
let new_a = a.slice_scatter(&b, 1, 2)?; let a_sorted = a.gather(&indices, 0)?;
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); println!("{indices}");
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); println!("{a_sorted}");
Ok(()) Ok(())
} }