mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More general broadcast setup.
This commit is contained in:
@ -90,6 +90,9 @@ pub enum Error {
|
||||
/// I/O error.
|
||||
#[error(transparent)]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
@ -605,8 +605,8 @@ impl Tensor {
|
||||
storage: Arc::new(self.storage.try_clone()?),
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
op: self.op.clone(),
|
||||
is_variable: self.is_variable,
|
||||
op: None, // TODO
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -654,7 +654,7 @@ impl Tensor {
|
||||
shape: self.shape.clone(),
|
||||
stride: self.stride.clone(),
|
||||
op,
|
||||
is_variable: self.is_variable,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
@ -662,28 +662,60 @@ impl Tensor {
|
||||
|
||||
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
|
||||
/// on the left.
|
||||
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||
let left_shape = left_shape.into();
|
||||
let mut dims = left_shape.into_dims();
|
||||
dims.extend(self.shape.dims());
|
||||
self.broadcast_as(dims)
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let op = if self.track_op() {
|
||||
Some(Op::Broadcast(self.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let mut stride = vec![0; left_shape.rank()];
|
||||
stride.extend_from_slice(&self.stride);
|
||||
let mut dims = left_shape.into_dims();
|
||||
dims.extend(self.shape.dims());
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.rank() {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
}
|
||||
let added_dims = shape.rank() - self.rank();
|
||||
let mut stride = vec![0; added_dims];
|
||||
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||
.iter()
|
||||
.zip(self.dims().iter().zip(self.stride()))
|
||||
{
|
||||
let s = if dst_dim == src_dim {
|
||||
src_stride
|
||||
} else if src_dim != 1 {
|
||||
return Err(Error::BroadcastIncompatibleShapes {
|
||||
src_shape: self.shape().clone(),
|
||||
dst_shape: shape,
|
||||
});
|
||||
} else {
|
||||
0
|
||||
};
|
||||
stride.push(s)
|
||||
}
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
shape: Shape::from(dims),
|
||||
shape,
|
||||
stride,
|
||||
op,
|
||||
is_variable: self.is_variable,
|
||||
is_variable: false,
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// An alias for broadcast_as.
|
||||
pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
self.broadcast_as(shape)
|
||||
}
|
||||
|
||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||
let shape = self.shape();
|
||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||
@ -706,8 +738,8 @@ impl Tensor {
|
||||
Ok(from_storage(
|
||||
storage,
|
||||
shape.clone(),
|
||||
self.op.clone(),
|
||||
self.is_variable,
|
||||
None, // TODO
|
||||
false,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
@ -144,7 +144,7 @@ fn broadcast() -> Result<()> {
|
||||
let data = &[3f32, 1., 4.];
|
||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||
assert_eq!(
|
||||
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
|
||||
tensor.broadcast_left((3, 1))?.to_vec3::<f32>()?,
|
||||
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||
);
|
||||
Ok(())
|
||||
|
Reference in New Issue
Block a user