mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
More general broadcast setup.
This commit is contained in:
@ -90,6 +90,9 @@ pub enum Error {
|
|||||||
/// I/O error.
|
/// I/O error.
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
|
|
||||||
|
#[error("cannot broadcast {src_shape:?} to {dst_shape:?}")]
|
||||||
|
BroadcastIncompatibleShapes { src_shape: Shape, dst_shape: Shape },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
@ -605,8 +605,8 @@ impl Tensor {
|
|||||||
storage: Arc::new(self.storage.try_clone()?),
|
storage: Arc::new(self.storage.try_clone()?),
|
||||||
shape: self.shape.clone(),
|
shape: self.shape.clone(),
|
||||||
stride: self.stride.clone(),
|
stride: self.stride.clone(),
|
||||||
op: self.op.clone(),
|
op: None, // TODO
|
||||||
is_variable: self.is_variable,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
@ -654,7 +654,7 @@ impl Tensor {
|
|||||||
shape: self.shape.clone(),
|
shape: self.shape.clone(),
|
||||||
stride: self.stride.clone(),
|
stride: self.stride.clone(),
|
||||||
op,
|
op,
|
||||||
is_variable: self.is_variable,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
@ -662,28 +662,60 @@ impl Tensor {
|
|||||||
|
|
||||||
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
|
/// Returns a new tensor duplicating data from the original tensor. New dimensions are inserted
|
||||||
/// on the left.
|
/// on the left.
|
||||||
pub fn broadcast<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
pub fn broadcast_left<S: Into<Shape>>(&self, left_shape: S) -> Result<Self> {
|
||||||
let left_shape = left_shape.into();
|
let left_shape = left_shape.into();
|
||||||
|
let mut dims = left_shape.into_dims();
|
||||||
|
dims.extend(self.shape.dims());
|
||||||
|
self.broadcast_as(dims)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||||
let op = if self.track_op() {
|
let op = if self.track_op() {
|
||||||
Some(Op::Broadcast(self.clone()))
|
Some(Op::Broadcast(self.clone()))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let mut stride = vec![0; left_shape.rank()];
|
let shape = shape.into();
|
||||||
stride.extend_from_slice(&self.stride);
|
if shape.rank() < self.rank() {
|
||||||
let mut dims = left_shape.into_dims();
|
return Err(Error::BroadcastIncompatibleShapes {
|
||||||
dims.extend(self.shape.dims());
|
src_shape: self.shape().clone(),
|
||||||
|
dst_shape: shape,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let added_dims = shape.rank() - self.rank();
|
||||||
|
let mut stride = vec![0; added_dims];
|
||||||
|
for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..]
|
||||||
|
.iter()
|
||||||
|
.zip(self.dims().iter().zip(self.stride()))
|
||||||
|
{
|
||||||
|
let s = if dst_dim == src_dim {
|
||||||
|
src_stride
|
||||||
|
} else if src_dim != 1 {
|
||||||
|
return Err(Error::BroadcastIncompatibleShapes {
|
||||||
|
src_shape: self.shape().clone(),
|
||||||
|
dst_shape: shape,
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
stride.push(s)
|
||||||
|
}
|
||||||
let tensor_ = Tensor_ {
|
let tensor_ = Tensor_ {
|
||||||
id: TensorId::new(),
|
id: TensorId::new(),
|
||||||
storage: self.storage.clone(),
|
storage: self.storage.clone(),
|
||||||
shape: Shape::from(dims),
|
shape,
|
||||||
stride,
|
stride,
|
||||||
op,
|
op,
|
||||||
is_variable: self.is_variable,
|
is_variable: false,
|
||||||
};
|
};
|
||||||
Ok(Tensor(Arc::new(tensor_)))
|
Ok(Tensor(Arc::new(tensor_)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An alias for broadcast_as.
|
||||||
|
pub fn expand<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||||
|
self.broadcast_as(shape)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
pub fn to_dtype(&self, dtype: DType) -> Result<Self> {
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
let storage = self.storage.to_dtype(shape, self.stride(), dtype)?;
|
||||||
@ -706,8 +738,8 @@ impl Tensor {
|
|||||||
Ok(from_storage(
|
Ok(from_storage(
|
||||||
storage,
|
storage,
|
||||||
shape.clone(),
|
shape.clone(),
|
||||||
self.op.clone(),
|
None, // TODO
|
||||||
self.is_variable,
|
false,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ fn broadcast() -> Result<()> {
|
|||||||
let data = &[3f32, 1., 4.];
|
let data = &[3f32, 1., 4.];
|
||||||
let tensor = Tensor::new(data, &Device::Cpu)?;
|
let tensor = Tensor::new(data, &Device::Cpu)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tensor.broadcast((3, 1))?.to_vec3::<f32>()?,
|
tensor.broadcast_left((3, 1))?.to_vec3::<f32>()?,
|
||||||
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
&[[[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]], [[3.0, 1.0, 4.0]]]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user