mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Use the module trait in stable-diffusion. (#817)
This commit is contained in:
@ -17,7 +17,7 @@ impl GeGlu {
|
||||
}
|
||||
}
|
||||
|
||||
impl GeGlu {
|
||||
impl Module for GeGlu {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
|
||||
@ -53,7 +53,7 @@ impl FeedForward {
|
||||
}
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
impl Module for FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = self.project_in.forward(xs)?;
|
||||
@ -501,8 +501,10 @@ impl AttentionBlock {
|
||||
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
impl Module for AttentionBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let in_dtype = xs.dtype();
|
||||
let residual = xs;
|
||||
|
Reference in New Issue
Block a user