Add the upblocks. (#853)

This commit is contained in:
Laurent Mazare
2023-09-14 23:24:56 +02:00
committed by GitHub
parent 91ec546feb
commit 130fe5a087
4 changed files with 63 additions and 5 deletions

View File

@ -110,7 +110,7 @@ impl ToUsize2 for (usize, usize) {
}
// A simple trait defining a module with forward method using a single argument.
pub trait Module: std::fmt::Debug {
pub trait Module {
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}
@ -119,3 +119,9 @@ impl Module for quantized::QMatMul {
self.forward(xs)
}
}
impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
self(xs)
}
}