mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
More shape fixes.
This commit is contained in:
@ -16,6 +16,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
const PRIOR_GUIDANCE_SCALE: f64 = 8.0;
|
||||||
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
const RESOLUTION_MULTIPLE: f64 = 42.67;
|
||||||
|
const LATENT_DIM_SCALE: f64 = 10.67;
|
||||||
const PRIOR_CIN: usize = 16;
|
const PRIOR_CIN: usize = 16;
|
||||||
const DECODER_CIN: usize = 4;
|
const DECODER_CIN: usize = 4;
|
||||||
|
|
||||||
@ -270,11 +271,11 @@ fn run(args: Args) -> Result<()> {
|
|||||||
println!("generated text embeddings {text_embeddings:?}");
|
println!("generated text embeddings {text_embeddings:?}");
|
||||||
|
|
||||||
println!("Building the prior.");
|
println!("Building the prior.");
|
||||||
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
|
||||||
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
|
||||||
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
|
||||||
let b_size = 1;
|
let b_size = 1;
|
||||||
let effnet = {
|
let image_embeddings = {
|
||||||
|
// https://huggingface.co/warp-ai/wuerstchen-prior/blob/main/prior/config.json
|
||||||
|
let latent_height = (height as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||||
|
let latent_width = (width as f64 / RESOLUTION_MULTIPLE).ceil() as usize;
|
||||||
let mut latents = Tensor::randn(
|
let mut latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
@ -296,6 +297,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
let timesteps = prior_scheduler.timesteps();
|
let timesteps = prior_scheduler.timesteps();
|
||||||
println!("prior denoising");
|
println!("prior denoising");
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
|
continue;
|
||||||
let start_time = std::time::Instant::now();
|
let start_time = std::time::Instant::now();
|
||||||
if index == timesteps.len() - 1 {
|
if index == timesteps.len() - 1 {
|
||||||
continue;
|
continue;
|
||||||
@ -343,6 +345,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
for idx in 0..num_samples {
|
for idx in 0..num_samples {
|
||||||
|
// https://huggingface.co/warp-ai/wuerstchen/blob/main/model_index.json
|
||||||
|
let latent_height = (image_embeddings.dim(2)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||||
|
let latent_width = (image_embeddings.dim(3)? as f64 * LATENT_DIM_SCALE) as usize;
|
||||||
|
|
||||||
let mut latents = Tensor::randn(
|
let mut latents = Tensor::randn(
|
||||||
0f32,
|
0f32,
|
||||||
1f32,
|
1f32,
|
||||||
@ -350,7 +356,7 @@ fn run(args: Args) -> Result<()> {
|
|||||||
&device,
|
&device,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
println!("diffusion process with prior {effnet:?}");
|
println!("diffusion process with prior {image_embeddings:?}");
|
||||||
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
let scheduler = wuerstchen::ddpm::DDPMWScheduler::new(60, Default::default())?;
|
||||||
let timesteps = scheduler.timesteps();
|
let timesteps = scheduler.timesteps();
|
||||||
for (index, &t) in timesteps.iter().enumerate() {
|
for (index, &t) in timesteps.iter().enumerate() {
|
||||||
@ -358,8 +364,9 @@ fn run(args: Args) -> Result<()> {
|
|||||||
if index == timesteps.len() - 1 {
|
if index == timesteps.len() - 1 {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let ratio = (Tensor::ones(2, DType::F32, &device)? * t)?;
|
let ratio = (Tensor::ones(1, DType::F32, &device)? * t)?;
|
||||||
let noise_pred = decoder.forward(&latents, &ratio, &effnet, Some(&text_embeddings))?;
|
let noise_pred =
|
||||||
|
decoder.forward(&latents, &ratio, &image_embeddings, Some(&text_embeddings))?;
|
||||||
latents = scheduler.step(&noise_pred, t, &latents)?;
|
latents = scheduler.step(&noise_pred, t, &latents)?;
|
||||||
let dt = start_time.elapsed().as_secs_f32();
|
let dt = start_time.elapsed().as_secs_f32();
|
||||||
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
println!("step {}/{} done, {:.2}s", index + 1, timesteps.len(), dt);
|
||||||
|
@ -335,6 +335,7 @@ impl WDiffNeXt {
|
|||||||
level_outputs.push(xs.clone())
|
level_outputs.push(xs.clone())
|
||||||
}
|
}
|
||||||
level_outputs.reverse();
|
level_outputs.reverse();
|
||||||
|
let mut xs = level_outputs[0].clone();
|
||||||
|
|
||||||
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
for (i, up_block) in self.up_blocks.iter().enumerate() {
|
||||||
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
let effnet_c = match &self.effnet_mappers[self.down_blocks.len() + i] {
|
||||||
|
Reference in New Issue
Block a user