mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Relax the requirements on CustomOp. (#486)
* Relax the requirements on CustomOp. * Simplify the custom-ops when no backward is required.
This commit is contained in:
@ -118,13 +118,22 @@ pub enum Op {
|
|||||||
ToDevice(Tensor),
|
ToDevice(Tensor),
|
||||||
Transpose(Tensor, usize, usize),
|
Transpose(Tensor, usize, usize),
|
||||||
Elu(Tensor, f64),
|
Elu(Tensor, f64),
|
||||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1>>),
|
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||||
CustomOp2(Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp2>>),
|
CustomOp2(
|
||||||
CustomOp3(Tensor, Tensor, Tensor, std::sync::Arc<Box<dyn CustomOp3>>),
|
Tensor,
|
||||||
|
Tensor,
|
||||||
|
std::sync::Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
|
),
|
||||||
|
CustomOp3(
|
||||||
|
Tensor,
|
||||||
|
Tensor,
|
||||||
|
Tensor,
|
||||||
|
std::sync::Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Unary ops that can be defined in user-land.
|
/// Unary ops that can be defined in user-land.
|
||||||
pub trait CustomOp1: Send + Sync {
|
pub trait CustomOp1 {
|
||||||
// Box<dyn> does not support const yet, so use a function to get the name.
|
// Box<dyn> does not support const yet, so use a function to get the name.
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
@ -148,7 +157,7 @@ pub trait CustomOp1: Send + Sync {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CustomOp2: Send + Sync {
|
pub trait CustomOp2 {
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
@ -186,7 +195,7 @@ pub trait CustomOp2: Send + Sync {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait CustomOp3: Send + Sync {
|
pub trait CustomOp3 {
|
||||||
fn name(&self) -> &'static str;
|
fn name(&self) -> &'static str;
|
||||||
|
|
||||||
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
|
||||||
|
@ -147,11 +147,11 @@ impl QTensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct QMatMul(std::sync::Arc<Box<dyn crate::CustomOp1>>);
|
pub struct QMatMul(QTensor);
|
||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
pub fn from_qtensor(qtensor: QTensor) -> Self {
|
||||||
Self(std::sync::Arc::new(Box::new(qtensor)))
|
Self(qtensor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -196,6 +196,6 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
|
|
||||||
impl QMatMul {
|
impl QMatMul {
|
||||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
xs.custom_op1_arc(self.0.clone())
|
xs.apply_op1_no_bwd(&self.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,7 +138,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn custom_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
pub(crate) fn apply_op1(&self, l: &Layout, c: &dyn CustomOp1) -> Result<(Self, Shape)> {
|
||||||
match self {
|
match self {
|
||||||
Self::Cpu(storage) => {
|
Self::Cpu(storage) => {
|
||||||
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
let (storage, shape) = c.cpu_fwd(storage, l)?;
|
||||||
@ -151,7 +151,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn custom_op2(
|
pub(crate) fn apply_op2(
|
||||||
&self,
|
&self,
|
||||||
l1: &Layout,
|
l1: &Layout,
|
||||||
t2: &Self,
|
t2: &Self,
|
||||||
@ -172,7 +172,7 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn custom_op3(
|
pub(crate) fn apply_op3(
|
||||||
&self,
|
&self,
|
||||||
l1: &Layout,
|
l1: &Layout,
|
||||||
t2: &Self,
|
t2: &Self,
|
||||||
|
@ -1870,22 +1870,53 @@ impl Tensor {
|
|||||||
std::ptr::eq(lhs, rhs)
|
std::ptr::eq(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies a unary custom op without backward support
|
||||||
|
pub fn apply_op1_no_bwd<C: CustomOp1>(&self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op1(self.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a binary custom op without backward support
|
||||||
|
pub fn apply_op2_no_bwd<C: CustomOp2>(&self, rhs: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) =
|
||||||
|
self.storage()
|
||||||
|
.apply_op2(self.layout(), &rhs.storage(), rhs.layout(), c)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies a ternary custom op without backward support
|
||||||
|
pub fn apply_op3_no_bwd<C: CustomOp3>(&self, t2: &Self, t3: &Self, c: &C) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
|
self.layout(),
|
||||||
|
&t2.storage(),
|
||||||
|
t2.layout(),
|
||||||
|
&t3.storage(),
|
||||||
|
t3.layout(),
|
||||||
|
c,
|
||||||
|
)?;
|
||||||
|
Ok(from_storage(storage, shape, BackpropOp::none(), false))
|
||||||
|
}
|
||||||
|
|
||||||
/// Applies a unary custom op.
|
/// Applies a unary custom op.
|
||||||
pub fn custom_op1_arc(&self, c: Arc<Box<dyn CustomOp1>>) -> Result<Self> {
|
pub fn apply_op1_arc(&self, c: Arc<Box<dyn CustomOp1 + Send + Sync>>) -> Result<Self> {
|
||||||
let (storage, shape) = self
|
let (storage, shape) = self
|
||||||
.storage()
|
.storage()
|
||||||
.custom_op1(self.layout(), c.as_ref().as_ref())?;
|
.apply_op1(self.layout(), c.as_ref().as_ref())?;
|
||||||
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
let op = BackpropOp::new1(self, |s| Op::CustomOp1(s, c.clone()));
|
||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn custom_op1<C: 'static + CustomOp1>(&self, c: C) -> Result<Self> {
|
pub fn apply_op1<C: 'static + CustomOp1 + Send + Sync>(&self, c: C) -> Result<Self> {
|
||||||
self.custom_op1_arc(Arc::new(Box::new(c)))
|
self.apply_op1_arc(Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a binary custom op.
|
/// Applies a binary custom op.
|
||||||
pub fn custom_op2_arc(&self, rhs: &Self, c: Arc<Box<dyn CustomOp2>>) -> Result<Self> {
|
pub fn apply_op2_arc(
|
||||||
let (storage, shape) = self.storage().custom_op2(
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp2 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op2(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
&rhs.storage(),
|
&rhs.storage(),
|
||||||
rhs.layout(),
|
rhs.layout(),
|
||||||
@ -1895,13 +1926,18 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn custom_op2<C: 'static + CustomOp2>(&self, r: &Self, c: C) -> Result<Self> {
|
pub fn apply_op2<C: 'static + CustomOp2 + Send + Sync>(&self, r: &Self, c: C) -> Result<Self> {
|
||||||
self.custom_op2_arc(r, Arc::new(Box::new(c)))
|
self.apply_op2_arc(r, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies a ternary custom op.
|
/// Applies a ternary custom op.
|
||||||
pub fn custom_op3_arc(&self, t2: &Self, t3: &Self, c: Arc<Box<dyn CustomOp3>>) -> Result<Self> {
|
pub fn apply_op3_arc(
|
||||||
let (storage, shape) = self.storage().custom_op3(
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: Arc<Box<dyn CustomOp3 + Send + Sync>>,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let (storage, shape) = self.storage().apply_op3(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
&t2.storage(),
|
&t2.storage(),
|
||||||
t2.layout(),
|
t2.layout(),
|
||||||
@ -1915,8 +1951,13 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, shape, op, false))
|
Ok(from_storage(storage, shape, op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn custom_op3<C: 'static + CustomOp3>(&self, t2: &Self, t3: &Self, c: C) -> Result<Self> {
|
pub fn apply_op3<C: 'static + CustomOp3 + Send + Sync>(
|
||||||
self.custom_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
&self,
|
||||||
|
t2: &Self,
|
||||||
|
t3: &Self,
|
||||||
|
c: C,
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.apply_op3_arc(t2, t3, Arc::new(Box::new(c)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@ fn custom_op1_no_backward() -> Result<()> {
|
|||||||
let cpu = &Device::Cpu;
|
let cpu = &Device::Cpu;
|
||||||
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
let t = Tensor::arange(0u32, 12u32, cpu)?.to_dtype(DType::F32)?;
|
||||||
let t = (t - 5.)?;
|
let t = (t - 5.)?;
|
||||||
let elu_t = t.custom_op1(Elu { alpha: 1. })?;
|
let elu_t = t.apply_op1_no_bwd(&Elu { alpha: 1. })?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
to_vec1_round(&elu_t, 4)?,
|
to_vec1_round(&elu_t, 4)?,
|
||||||
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
&[-0.9933, -0.9817, -0.9502, -0.8647, -0.6321, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||||
@ -96,7 +96,7 @@ impl CustomOp1 for EluWithBackward {
|
|||||||
|
|
||||||
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
fn bwd(&self, arg: &Tensor, _res: &Tensor, grad_res: &Tensor) -> Result<Option<Tensor>> {
|
||||||
let alpha = self.0.alpha;
|
let alpha = self.0.alpha;
|
||||||
let bwd = arg.custom_op1(EluBackward { alpha })?;
|
let bwd = arg.apply_op1(EluBackward { alpha })?;
|
||||||
Ok(Some(grad_res.mul(&bwd)?))
|
Ok(Some(grad_res.mul(&bwd)?))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -105,7 +105,7 @@ impl CustomOp1 for EluWithBackward {
|
|||||||
fn custom_op1_with_backward() -> Result<()> {
|
fn custom_op1_with_backward() -> Result<()> {
|
||||||
let cpu = &Device::Cpu;
|
let cpu = &Device::Cpu;
|
||||||
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
let t = candle_core::Var::new(&[-2f32, 0f32, 2f32], cpu)?;
|
||||||
let elu_t = t.custom_op1(EluWithBackward::new(2.))?;
|
let elu_t = t.apply_op1(EluWithBackward::new(2.))?;
|
||||||
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
assert_eq!(to_vec1_round(&elu_t, 4)?, &[-1.7293, 0.0, 2.0]);
|
||||||
|
|
||||||
let grads = elu_t.backward()?;
|
let grads = elu_t.backward()?;
|
||||||
|
@ -89,7 +89,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
let t = Tensor::arange(0f32, 14f32, &device)?.reshape((2, 7))?;
|
||||||
println!("{t}");
|
println!("{t}");
|
||||||
let t = t.custom_op1(LayerNorm { eps: 1e-5 })?;
|
let t = t.apply_op1(LayerNorm { eps: 1e-5 })?;
|
||||||
println!("{t}");
|
println!("{t}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -68,7 +68,7 @@ impl CustomOp1 for AllReduce {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
fn all_reduce_sum(x: &Tensor, comm: &Rc<Comm>) -> Result<Tensor> {
|
||||||
x.custom_op1(AllReduce { comm: comm.clone() })
|
x.apply_op1(AllReduce { comm: comm.clone() })
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TensorParallelRowLinear {
|
impl TensorParallelRowLinear {
|
||||||
|
@ -178,7 +178,7 @@ pub fn flash_attn(
|
|||||||
softmax_scale,
|
softmax_scale,
|
||||||
causal,
|
causal,
|
||||||
};
|
};
|
||||||
q.custom_op3(k, v, op)
|
q.apply_op3(k, v, op)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct FlashAttnVarLen {
|
struct FlashAttnVarLen {
|
||||||
@ -402,5 +402,5 @@ pub fn flash_attn_varlen(
|
|||||||
seqlens_q: seqlens_q.clone(),
|
seqlens_q: seqlens_q.clone(),
|
||||||
seqlens_k: seqlens_k.clone(),
|
seqlens_k: seqlens_k.clone(),
|
||||||
};
|
};
|
||||||
q.custom_op3(k, v, op)
|
q.apply_op3(k, v, op)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user