Add a simple Module trait and implement it for the various nn layers (#500)

* Start adding the module trait.

* Use the module trait.

* Implement module for qmatmul.
This commit is contained in:
Laurent Mazare
2023-08-18 09:38:22 +01:00
committed by GitHub
parent 13401df4d1
commit c78ce76501
33 changed files with 70 additions and 28 deletions

View File

@ -35,8 +35,10 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
@ -84,8 +86,10 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),