mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More generic sampling.
This commit is contained in:
@ -1,3 +1,19 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub trait WithForward {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
txt_ids: &Tensor,
|
||||
timesteps: &Tensor,
|
||||
y: &Tensor,
|
||||
guidance: Option<&Tensor>,
|
||||
) -> Result<Tensor>;
|
||||
}
|
||||
|
||||
pub mod autoencoder;
|
||||
pub mod model;
|
||||
pub mod quantized_model;
|
||||
|
@ -575,9 +575,11 @@ impl Flux {
|
||||
final_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl super::WithForward for Flux {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
|
@ -577,9 +577,11 @@ impl Flux {
|
||||
final_layer,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl super::WithForward for Flux {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn forward(
|
||||
fn forward(
|
||||
&self,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
|
@ -92,8 +92,8 @@ pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn denoise(
|
||||
model: &super::model::Flux,
|
||||
pub fn denoise<M: super::WithForward>(
|
||||
model: &M,
|
||||
img: &Tensor,
|
||||
img_ids: &Tensor,
|
||||
txt: &Tensor,
|
||||
|
Reference in New Issue
Block a user