mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Fix the cat implementation + more testing.
This commit is contained in:
@ -817,11 +817,34 @@ impl Tensor {
|
|||||||
shape: args[0].shape().clone(),
|
shape: args[0].shape().clone(),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
if dim == 0 {
|
||||||
|
Self::cat0(args)
|
||||||
|
} else {
|
||||||
|
// TODO: Avoid these transpositions and have an implementation that works
|
||||||
|
// for dim != 0...
|
||||||
|
let args: Vec<Tensor> = args
|
||||||
|
.iter()
|
||||||
|
.map(|a| a.transpose(0, dim))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
let args: Vec<&Tensor> = args.iter().collect();
|
||||||
|
let cat = Self::cat0(&args)?;
|
||||||
|
cat.transpose(0, dim)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cat0(args: &[&Self]) -> Result<Self> {
|
||||||
|
if args.is_empty() {
|
||||||
|
return Err(Error::OpRequiresAtLeastOneTensor { op: "cat" });
|
||||||
|
}
|
||||||
|
if args.len() == 1 {
|
||||||
|
return Ok(args[0].clone());
|
||||||
|
}
|
||||||
|
let rank = args[0].rank();
|
||||||
let device = args[0].device();
|
let device = args[0].device();
|
||||||
let dtype = args[0].dtype();
|
let dtype = args[0].dtype();
|
||||||
let first_dims = args[0].shape().dims();
|
let first_dims = args[0].shape().dims();
|
||||||
let mut cat_dims = first_dims.to_vec();
|
let mut cat_dims = first_dims.to_vec();
|
||||||
cat_dims[dim] = 0;
|
cat_dims[0] = 0;
|
||||||
let mut offsets = vec![0usize];
|
let mut offsets = vec![0usize];
|
||||||
for (arg_idx, arg) in args.iter().enumerate() {
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
if arg.dtype() != dtype {
|
if arg.dtype() != dtype {
|
||||||
@ -848,10 +871,10 @@ impl Tensor {
|
|||||||
.zip(arg.shape().dims().iter())
|
.zip(arg.shape().dims().iter())
|
||||||
.enumerate()
|
.enumerate()
|
||||||
{
|
{
|
||||||
if dim == dim_idx {
|
if dim_idx == 0 {
|
||||||
cat_dims[dim] += v2;
|
cat_dims[0] += v2;
|
||||||
}
|
}
|
||||||
if dim != dim_idx && v1 != v2 {
|
if dim_idx != 0 && v1 != v2 {
|
||||||
// TODO: It would probably be good to have a nicer error message here, i.e.
|
// TODO: It would probably be good to have a nicer error message here, i.e.
|
||||||
// mention the problematic dimension and the values.
|
// mention the problematic dimension and the values.
|
||||||
mismatch = true;
|
mismatch = true;
|
||||||
@ -859,7 +882,7 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
if mismatch {
|
if mismatch {
|
||||||
return Err(Error::ShapeMismatchCat {
|
return Err(Error::ShapeMismatchCat {
|
||||||
dim,
|
dim: 0, // TODO: not the appropriate error message
|
||||||
first_shape: args[0].shape().clone(),
|
first_shape: args[0].shape().clone(),
|
||||||
n: arg_idx + 1,
|
n: arg_idx + 1,
|
||||||
nth_shape: arg.shape().clone(),
|
nth_shape: arg.shape().clone(),
|
||||||
@ -871,7 +894,7 @@ impl Tensor {
|
|||||||
let shape = Shape::from(cat_dims);
|
let shape = Shape::from(cat_dims);
|
||||||
let op = if args.iter().any(|arg| arg.track_op()) {
|
let op = if args.iter().any(|arg| arg.track_op()) {
|
||||||
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
|
let args: Vec<Tensor> = args.iter().map(|&arg| arg.clone()).collect();
|
||||||
Some(Op::Cat(args, dim))
|
Some(Op::Cat(args, 0))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
@ -210,18 +210,18 @@ fn cat() -> Result<()> {
|
|||||||
.t()?
|
.t()?
|
||||||
.to_vec2::<f32>()?,
|
.to_vec2::<f32>()?,
|
||||||
[
|
[
|
||||||
[3.0, 4.0, 5.0, 5.0, 5.0],
|
[3.0, 1.0, 4.0, 1.0, 5.0],
|
||||||
[2.0, 1.0, 2.0, 7.0, 8.0],
|
[2.0, 7.0, 1.0, 8.0, 2.0],
|
||||||
[1.0, 1.0, 5.0, 5.0, 5.0],
|
[5.0, 5.0, 5.0, 5.0, 5.0],
|
||||||
[7.0, 8.0, 2.0, 1.0, 2.0]
|
[2.0, 7.0, 1.0, 8.0, 2.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
// TODO: This is not the expected answer, to be fixed!
|
// TODO: This is not the expected answer, to be fixed!
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
|
Tensor::cat(&[&t1, &t2], 1)?.to_vec2::<f32>()?,
|
||||||
[
|
[
|
||||||
[3.0, 1.0, 4.0, 1.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0],
|
[3.0, 1.0, 4.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0],
|
||||||
[5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
Reference in New Issue
Block a user