mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add a simple Module trait and implement it for the various nn layers (#500)
* Start adding the module trait. * Use the module trait. * Implement module for qmatmul.
This commit is contained in:
@ -35,8 +35,10 @@ impl Conv1d {
|
||||
pub fn config(&self) -> &Conv1dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
@ -84,8 +86,10 @@ impl Conv2d {
|
||||
pub fn config(&self) -> &Conv2dConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
|
Reference in New Issue
Block a user