mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Improve the error message on shape mismatch for cat. (#897)
* Improve the error message on shape mismatch for cat. * Cosmetic tweak.
This commit is contained in:
@ -1907,6 +1907,34 @@ impl Tensor {
|
|||||||
for arg in args {
|
for arg in args {
|
||||||
arg.as_ref().check_dim(dim, "cat")?;
|
arg.as_ref().check_dim(dim, "cat")?;
|
||||||
}
|
}
|
||||||
|
for (arg_idx, arg) in args.iter().enumerate() {
|
||||||
|
let arg = arg.as_ref();
|
||||||
|
if arg0.rank() != arg.rank() {
|
||||||
|
Err(Error::UnexpectedNumberOfDims {
|
||||||
|
expected: arg0.rank(),
|
||||||
|
got: arg.rank(),
|
||||||
|
shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
for (dim_idx, (v1, v2)) in arg0
|
||||||
|
.shape()
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.zip(arg.shape().dims().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
if dim_idx != dim && v1 != v2 {
|
||||||
|
Err(Error::ShapeMismatchCat {
|
||||||
|
dim: dim_idx,
|
||||||
|
first_shape: arg0.shape().clone(),
|
||||||
|
n: arg_idx + 1,
|
||||||
|
nth_shape: arg.shape().clone(),
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
Self::cat0(args)
|
Self::cat0(args)
|
||||||
} else {
|
} else {
|
||||||
|
@ -295,12 +295,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
let prior_scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
let timesteps = prior_scheduler.timesteps();
|
let timesteps = prior_scheduler.timesteps();
|
||||||
|
let timesteps = ×teps[..timesteps.len() - 1];
|
||||||
println!("prior denoising");
|
println!("prior denoising");
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
if index == timesteps.len() - 1 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
let latent_model_input = Tensor::cat(&[&latents, &latents], 0)?;
|
||||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
||||||
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
let noise_pred = prior.forward(&latent_model_input, &ratio, &prior_text_embeddings)?;
|
||||||
@ -358,11 +356,9 @@ fn run(args: Args) -> Result<()> {
|
|||||||
println!("diffusion process with prior {image_embeddings:?}");
|
println!("diffusion process with prior {image_embeddings:?}");
|
||||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
let timesteps = scheduler.timesteps();
|
let timesteps = scheduler.timesteps();
|
||||||
|
let timesteps = ×teps[..timesteps.len() - 1];
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
if index == timesteps.len() - 1 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||||
let noise_pred =
|
let noise_pred =
|
||||||
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||||
|
@ -41,6 +41,7 @@ impl ResBlockStageB {
|
|||||||
};
|
};
|
||||||
let xs = xs
|
let xs = xs
|
||||||
.permute((0, 2, 3, 1))?
|
.permute((0, 2, 3, 1))?
|
||||||
|
.contiguous()?
|
||||||
.apply(&self.channelwise_lin1)?
|
.apply(&self.channelwise_lin1)?
|
||||||
.gelu()?
|
.gelu()?
|
||||||
.apply(&self.channelwise_grn)?
|
.apply(&self.channelwise_grn)?
|
||||||
@ -374,7 +375,7 @@ impl WDiffNeXt {
|
|||||||
.apply(&self.clf_ln)?
|
.apply(&self.clf_ln)?
|
||||||
.apply(&self.clf_conv)?
|
.apply(&self.clf_conv)?
|
||||||
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
.apply(&|xs: &_| candle_nn::ops::pixel_shuffle(xs, self.patch_size))?
|
||||||
.chunk(1, 2)?;
|
.chunk(2, 1)?;
|
||||||
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
let b = ((candle_nn::ops::sigmoid(&ab[1])? * (1. - EPS * 2.))? + EPS)?;
|
||||||
(x_in - &ab[0])? / b
|
(x_in - &ab[0])? / b
|
||||||
}
|
}
|
||||||
|
@ -52,6 +52,7 @@ impl Module for MixingResidualBlock {
|
|||||||
.affine(1. + mods[3] as f64, mods[4] as f64)?;
|
.affine(1. + mods[3] as f64, mods[4] as f64)?;
|
||||||
let x_temp = x_temp
|
let x_temp = x_temp
|
||||||
.permute((0, 2, 3, 1))?
|
.permute((0, 2, 3, 1))?
|
||||||
|
.contiguous()?
|
||||||
.apply(&self.channelwise_lin1)?
|
.apply(&self.channelwise_lin1)?
|
||||||
.gelu()?
|
.gelu()?
|
||||||
.apply(&self.channelwise_lin2)?
|
.apply(&self.channelwise_lin2)?
|
||||||
|
Reference in New Issue
Block a user