mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Tweak the basic example to show how to implement sort.
This commit is contained in:
@ -5,13 +5,43 @@ extern crate intel_mkl_src;
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
use candle::{CpuStorage, Device, Layout, Shape, Tensor};
|
||||
use candle_core as candle;
|
||||
|
||||
struct ArgSort;
|
||||
impl candle::CustomOp1 for ArgSort {
|
||||
fn name(&self) -> &'static str {
|
||||
"arg-sort"
|
||||
}
|
||||
|
||||
fn cpu_fwd(
|
||||
&self,
|
||||
storage: &CpuStorage,
|
||||
layout: &Layout,
|
||||
) -> candle::Result<(CpuStorage, Shape)> {
|
||||
if layout.shape().rank() != 1 {
|
||||
candle::bail!(
|
||||
"input should have a single dimension, got {:?}",
|
||||
layout.shape()
|
||||
)
|
||||
}
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let src = match layout.contiguous_offsets() {
|
||||
None => candle::bail!("input has to be contiguous"),
|
||||
Some((o1, o2)) => &slice[o1..o2],
|
||||
};
|
||||
let mut dst = (0..src.len() as u32).collect::<Vec<u32>>();
|
||||
dst.sort_by(|&i, &j| src[i as usize].total_cmp(&src[j as usize]));
|
||||
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||
Ok((storage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let a = Tensor::new(&[[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]], &Device::Cpu)?;
|
||||
let b = Tensor::new(&[[88.0f32, 99.0]], &Device::Cpu)?;
|
||||
let new_a = a.slice_scatter(&b, 1, 2)?;
|
||||
assert_eq!(a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
assert_eq!(new_a.to_vec2::<f32>()?, [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let a = Tensor::new(&[0.0f32, 1.0, 3.0, 2.0, -12.0, 4.0, 3.5], &Device::Cpu)?;
|
||||
let indices = a.apply_op1(ArgSort)?;
|
||||
let a_sorted = a.gather(&indices, 0)?;
|
||||
println!("{indices}");
|
||||
println!("{a_sorted}");
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user