More generic sampling.

This commit is contained in:
Laurent
2024-09-25 11:15:37 +02:00
parent fa1e0e438e
commit 0bd61bae29
4 changed files with 24 additions and 4 deletions

View File

@ -1,3 +1,19 @@
use candle::{Result, Tensor};
pub trait WithForward {
#[allow(clippy::too_many_arguments)]
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,
txt_ids: &Tensor,
timesteps: &Tensor,
y: &Tensor,
guidance: Option<&Tensor>,
) -> Result<Tensor>;
}
pub mod autoencoder;
pub mod model;
pub mod quantized_model;

View File

@ -575,9 +575,11 @@ impl Flux {
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
pub fn forward(
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,

View File

@ -577,9 +577,11 @@ impl Flux {
final_layer,
})
}
}
impl super::WithForward for Flux {
#[allow(clippy::too_many_arguments)]
pub fn forward(
fn forward(
&self,
img: &Tensor,
img_ids: &Tensor,

View File

@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
}
#[allow(clippy::too_many_arguments)]
pub fn denoise(
model: &super::model::Flux,
pub fn denoise<M: super::WithForward>(
model: &M,
img: &Tensor,
img_ids: &Tensor,
txt: &Tensor,