Use the module trait in stable-diffusion. (#817)

This commit is contained in:
Laurent Mazare
2023-09-11 20:40:07 +01:00
committed by GitHub
parent 59e63d690c
commit c5a058b169
7 changed files with 30 additions and 31 deletions

View File

@ -17,7 +17,7 @@ impl GeGlu {
}
}
impl GeGlu {
impl Module for GeGlu {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let hidden_states_and_gate = self.proj.forward(xs)?.chunk(2, D::Minus1)?;
@ -53,7 +53,7 @@ impl FeedForward {
}
}
impl FeedForward {
impl Module for FeedForward {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.project_in.forward(xs)?;
@ -501,8 +501,10 @@ impl AttentionBlock {
xs.reshape((batch, t, self.num_heads, h_times_d / self.num_heads))?
.transpose(1, 2)
}
}
pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
impl Module for AttentionBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = xs.dtype();
let residual = xs;