mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the permute op (similar to pytorch). (#504)
* Add the permute op (similar to pytorch). * Add the backprop for dimension permutation.
This commit is contained in:
@ -96,6 +96,7 @@ impl Tensor {
|
||||
| Op::ToDType(node)
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Permute(node, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
@ -403,6 +404,15 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Permute(arg, dims) => {
|
||||
let mut inv_dims = vec![0; dims.len()];
|
||||
for (i, &dim_idx) in dims.iter().enumerate() {
|
||||
inv_dims[dim_idx] = i
|
||||
}
|
||||
let arg_grad = grad.permute(inv_dims)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
@ -112,6 +112,31 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
idxs
|
||||
)
|
||||
}
|
||||
let stride = self.stride();
|
||||
let dims = self.shape().dims();
|
||||
let mut perm_stride = stride.to_vec();
|
||||
let mut perm_dims = dims.to_vec();
|
||||
for (i, &idx) in idxs.iter().enumerate() {
|
||||
perm_stride[i] = stride[idx];
|
||||
perm_dims[i] = dims[idx];
|
||||
}
|
||||
Ok(Self {
|
||||
shape: Shape::from(perm_dims),
|
||||
stride: perm_stride,
|
||||
start_offset: self.start_offset,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn broadcast_as<S: Into<Shape>>(&self, shape: S) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
if shape.rank() < self.shape().rank() {
|
||||
|
@ -117,6 +117,7 @@ pub enum Op {
|
||||
Reshape(Tensor),
|
||||
ToDevice(Tensor),
|
||||
Transpose(Tensor, usize, usize),
|
||||
Permute(Tensor, Vec<usize>),
|
||||
Elu(Tensor, f64),
|
||||
CustomOp1(Tensor, std::sync::Arc<Box<dyn CustomOp1 + Send + Sync>>),
|
||||
CustomOp2(
|
||||
|
@ -345,6 +345,16 @@ impl<D1: Dim, D2: Dim, D3: Dim> Dims for (D1, D2, D3) {
|
||||
}
|
||||
}
|
||||
|
||||
impl<D1: Dim, D2: Dim, D3: Dim, D4: Dim> Dims for (D1, D2, D3, D4) {
|
||||
fn to_indexes_internal(self, shape: &Shape, op: &'static str) -> Result<Vec<usize>> {
|
||||
let d0 = self.0.to_index(shape, op)?;
|
||||
let d1 = self.1.to_index(shape, op)?;
|
||||
let d2 = self.2.to_index(shape, op)?;
|
||||
let d3 = self.3.to_index(shape, op)?;
|
||||
Ok(vec![d0, d1, d2, d3])
|
||||
}
|
||||
}
|
||||
|
||||
extract_dims!(dims0, 0, |_: &[usize]| (), ());
|
||||
extract_dims!(dims1, 1, |d: &[usize]| d[0], usize);
|
||||
extract_dims!(dims2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize));
|
||||
|
@ -1459,6 +1459,42 @@ impl Tensor {
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns a tensor with the same data as the input where the dimensions have been permuted.
|
||||
/// dims must be a permutation, i.e. include each dimension index exactly once.
|
||||
///
|
||||
/// ```rust
|
||||
/// use candle_core::{Tensor, Device};
|
||||
/// let tensor = Tensor::arange(0u32, 120u32, &Device::Cpu)?.reshape((2, 3, 4, 5))?;
|
||||
/// assert_eq!(tensor.dims(), &[2, 3, 4, 5]);
|
||||
/// let tensor = tensor.permute((2, 3, 1, 0))?;
|
||||
/// assert_eq!(tensor.dims(), &[4, 5, 3, 2]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn permute<D: Dims>(&self, dims: D) -> Result<Tensor> {
|
||||
let dims = dims.to_indexes(self.shape(), "permute")?;
|
||||
// O(n^2) permutation check but these arrays are small.
|
||||
let is_permutation =
|
||||
dims.len() == self.rank() && (0..dims.len()).all(|i| dims.contains(&i));
|
||||
if !is_permutation {
|
||||
crate::bail!(
|
||||
"dimension mismatch in permute, tensor {:?}, dims: {:?}",
|
||||
self.dims(),
|
||||
dims
|
||||
)
|
||||
}
|
||||
let op = BackpropOp::new1(self, |t| Op::Permute(t, dims.clone()));
|
||||
let tensor_ = Tensor_ {
|
||||
id: TensorId::new(),
|
||||
storage: self.storage.clone(),
|
||||
layout: self.layout.permute(&dims)?,
|
||||
op,
|
||||
is_variable: false,
|
||||
dtype: self.dtype,
|
||||
device: self.device.clone(),
|
||||
};
|
||||
Ok(Tensor(Arc::new(tensor_)))
|
||||
}
|
||||
|
||||
/// Returns true if the data is stored in a C contiguous (aka row major) way.
|
||||
pub fn is_contiguous(&self) -> bool {
|
||||
self.layout.is_contiguous()
|
||||
|
@ -306,8 +306,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
// TODO: apply imagenet normalization.
|
||||
let image = candle_examples::load_image(args.image)?;
|
||||
let image = candle_examples::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.model)? };
|
||||
|
@ -13,8 +13,8 @@ pub fn device(cpu: bool) -> Result<Device> {
|
||||
}
|
||||
|
||||
/// Loads an image from disk using the image crate, this returns a tensor with shape
|
||||
/// (3, 224, 224). imagenet normaliation is applied.
|
||||
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
/// (3, 224, 224). imagenet normalization is applied.
|
||||
pub fn load_image224<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
|
||||
let img = image::io::Reader::open(p)?
|
||||
.decode()
|
||||
.map_err(candle::Error::wrap)?
|
||||
|
Reference in New Issue
Block a user