mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 04:10:46 +00:00
Support for UG kernels. (#2579)
* Support for UG kernels. * Add a dedicated test.
This commit is contained in:
@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "cuda")]
|
||||
#[allow(clippy::approx_constant)]
|
||||
#[test]
|
||||
fn ug_op() -> Result<()> {
|
||||
let kernel = {
|
||||
use ug::lang::op;
|
||||
|
||||
let layout = ug::Layout::from_shape(&[12]);
|
||||
let ptr = op::Arg::ptr(ug::DType::F32);
|
||||
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
|
||||
let src = op::unary(op::UnaryOp::Exp, src)?;
|
||||
let st = op::store(ptr.id(), layout, src)?;
|
||||
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
|
||||
let opts: ug::lower_op::Opts = Default::default();
|
||||
kernel.lower(&opts.with_global(0, 12))?
|
||||
};
|
||||
let device = Device::new_cuda(0)?;
|
||||
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
|
||||
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
|
||||
t.inplace_op1(&op)?;
|
||||
assert_eq!(
|
||||
to_vec1_round(&t, 4)?,
|
||||
&[
|
||||
1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
|
||||
8103.0806, 22026.469, 59874.133
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user