Track the conv2d operations in stable-diffusion. (#431)

* Track the conv2d operations in stable-diffusion.

* Add more tracing to stable-diffusion.

* Also trace the resnet bits.

* Trace the attention blocks.

* Also trace the attention inner part.

* Small tweak.
This commit is contained in:
Laurent Mazare
2023-08-13 16:58:26 +02:00
committed by GitHub
parent b1ff78f762
commit 9af438ac1b
7 changed files with 146 additions and 25 deletions

View File

@ -40,6 +40,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// The height in pixels of the generated image.
#[arg(long)]
height: Option<usize>,
@ -183,6 +187,9 @@ fn output_filename(
}
fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let Args {
prompt,
uncond_prompt,
@ -198,8 +205,18 @@ fn run(args: Args) -> Result<()> {
clip_weights,
vae_weights,
unet_weights,
tracing,
..
} = args;
let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let sd_config = match sd_version {
StableDiffusionVersion::V1_5 => {
stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width)