From 0bd61bae293c0c3cbbf8040b802cd02440fd2c6d Mon Sep 17 00:00:00 2001 From: Laurent Date: Wed, 25 Sep 2024 11:15:37 +0200 Subject: [PATCH] More generic sampling. --- candle-transformers/src/models/flux/mod.rs | 16 ++++++++++++++++ candle-transformers/src/models/flux/model.rs | 4 +++- .../src/models/flux/quantized_model.rs | 4 +++- candle-transformers/src/models/flux/sampling.rs | 4 ++-- 4 files changed, 24 insertions(+), 4 deletions(-) diff --git a/candle-transformers/src/models/flux/mod.rs b/candle-transformers/src/models/flux/mod.rs index 340d23c2..b0c8a693 100644 --- a/candle-transformers/src/models/flux/mod.rs +++ b/candle-transformers/src/models/flux/mod.rs @@ -1,3 +1,19 @@ +use candle::{Result, Tensor}; + +pub trait WithForward { + #[allow(clippy::too_many_arguments)] + fn forward( + &self, + img: &Tensor, + img_ids: &Tensor, + txt: &Tensor, + txt_ids: &Tensor, + timesteps: &Tensor, + y: &Tensor, + guidance: Option<&Tensor>, + ) -> Result; +} + pub mod autoencoder; pub mod model; pub mod quantized_model; diff --git a/candle-transformers/src/models/flux/model.rs b/candle-transformers/src/models/flux/model.rs index 4e47873f..02835be5 100644 --- a/candle-transformers/src/models/flux/model.rs +++ b/candle-transformers/src/models/flux/model.rs @@ -575,9 +575,11 @@ impl Flux { final_layer, }) } +} +impl super::WithForward for Flux { #[allow(clippy::too_many_arguments)] - pub fn forward( + fn forward( &self, img: &Tensor, img_ids: &Tensor, diff --git a/candle-transformers/src/models/flux/quantized_model.rs b/candle-transformers/src/models/flux/quantized_model.rs index aba9ad13..366182eb 100644 --- a/candle-transformers/src/models/flux/quantized_model.rs +++ b/candle-transformers/src/models/flux/quantized_model.rs @@ -577,9 +577,11 @@ impl Flux { final_layer, }) } +} +impl super::WithForward for Flux { #[allow(clippy::too_many_arguments)] - pub fn forward( + fn forward( &self, img: &Tensor, img_ids: &Tensor, diff --git a/candle-transformers/src/models/flux/sampling.rs b/candle-transformers/src/models/flux/sampling.rs index 89b9a953..f3f0eafd 100644 --- a/candle-transformers/src/models/flux/sampling.rs +++ b/candle-transformers/src/models/flux/sampling.rs @@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result { } #[allow(clippy::too_many_arguments)] -pub fn denoise( - model: &super::model::Flux, +pub fn denoise( + model: &M, img: &Tensor, img_ids: &Tensor, txt: &Tensor,