mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add the broadcast operator.
This commit is contained in:
@ -22,6 +22,7 @@ pub(crate) enum Op {
|
||||
add: f64,
|
||||
},
|
||||
ToDType(Tensor),
|
||||
Broadcast(Tensor),
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
|
@ -94,6 +94,10 @@ impl Shape {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
pub fn into_dims(self) -> Vec<usize> {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn dims(&self) -> &[usize] {
|
||||
&self.0
|
||||
}
|
||||
|
@ -653,6 +653,30 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
|
||||
/// on the left.
|
||||
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||
let left_shape = left_shape.into();
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Broadcast(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut stride = vec![0; left_shape.rank()];
|
||||
stride.extend_from_slice(&self.stride);
|
||||
let mut dims = left_shape.into_dims();
|
||||
dims.extend(self.shape.dims());
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: Shape::from(dims),
|
||||
stride,
|
||||
op,
|
||||
is_variable: self.is_variable,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||
@ -849,6 +873,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
Op::Reshape(node)
|
||||
| Op::Broadcast(node)
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
@ -978,6 +1003,9 @@ impl Tensor {
|
||||
start_idx += len;
|
||||
}
|
||||
}
|
||||
Op::Broadcast(_arg) => {
|
||||
return Err(Error::BackwardNotSupported { op: "broadcast" })
|
||||
}
|
||||
Op::ToDType(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
|
||||
|
@ -138,3 +138,14 @@ fn narrow() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn broadcast() -> Result<()> {
|
||||
let data = &[3f32, 1., 4.];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
|
||||
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user