mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the broadcast operator.
This commit is contained in:
@ -22,6 +22,7 @@ pub(crate) enum Op {
|
|||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
ToDType(Tensor),
|
ToDType(Tensor),
|
||||||
|
Broadcast(Tensor),
|
||||||
Exp(Tensor),
|
Exp(Tensor),
|
||||||
Log(Tensor),
|
Log(Tensor),
|
||||||
Sin(Tensor),
|
Sin(Tensor),
|
||||||
|
@ -94,6 +94,10 @@ impl Shape {
|
|||||||
self.0.len()
|
self.0.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_dims(self) -> Vec<usize> {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dims(&self) -> &[usize] {
|
pub fn dims(&self) -> &[usize] {
|
||||||
&self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
@ -653,6 +653,30 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
|
||||||
|
/// on the left.
|
||||||
|
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||||
|
let left_shape = left_shape.into();
|
||||||
|
let op = if self.track_op() {
|
||||||
|
Some(Op::Broadcast(self.clone()))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
let mut stride = vec![0; left_shape.rank()];
|
||||||
|
stride.extend_from_slice(&self.stride);
|
||||||
|
let mut dims = left_shape.into_dims();
|
||||||
|
dims.extend(self.shape.dims());
|
||||||
|
let tensor_ = Tensor_ {
|
||||||
|
id: TensorId::new(),
|
||||||
|
storage: self.storage.clone(),
|
||||||
|
shape: Shape::from(dims),
|
||||||
|
stride,
|
||||||
|
op,
|
||||||
|
is_variable: self.is_variable,
|
||||||
|
};
|
||||||
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||||
@ -849,6 +873,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Reshape(node)
|
Op::Reshape(node)
|
||||||
|
| Op::Broadcast(node)
|
||||||
| Op::ToDType(node)
|
| Op::ToDType(node)
|
||||||
| Op::ToDevice(node)
|
| Op::ToDevice(node)
|
||||||
| Op::Transpose(node, _, _)
|
| Op::Transpose(node, _, _)
|
||||||
@ -978,6 +1003,9 @@ impl Tensor {
|
|||||||
start_idx += len;
|
start_idx += len;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Op::Broadcast(_arg) => {
|
||||||
|
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||||
|
}
|
||||||
Op::ToDType(arg) => {
|
Op::ToDType(arg) => {
|
||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||||
|
@ -138,3 +138,14 @@ fn narrow() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn broadcast() -> Result<()> {
|
||||||
|
let data = &[3f32, 1., 4.];
|
||||||
|
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
|
||||||
|
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user