mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp. * Simplify the custom-ops when no backward is required.
This commit is contained in:
@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||
println!("{t}");
|
||||
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
|
||||
let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
|
||||
println!("{t}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -68,7 +68,7 @@ impl CustomOp1 for AllReduce {
|
||||
}
|
||||
|
||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
||||
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||
}
|
||||
|
||||
impl TensorParallelRowLinear {
|
||||
|
Reference in New Issue
Block a user