Add clone to various nn layers. (#910)

This commit is contained in:
Laurent Mazare
2023-09-20 11:33:51 +01:00
committed by GitHub
parent f685b2231c
commit 7b1ddcff47
7 changed files with 11 additions and 11 deletions

View File

@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct BatchNorm { pub struct BatchNorm {
running_mean: Tensor, running_mean: Tensor,
running_var: Tensor, running_var: Tensor,

View File

@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Conv1d { pub struct Conv1d {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -88,7 +88,7 @@ impl Default for Conv2dConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Conv2d { pub struct Conv2d {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig {
} }
} }
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct ConvTranspose2d { pub struct ConvTranspose2d {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,

View File

@ -1,7 +1,7 @@
//! Embedding Layer. //! Embedding Layer.
use candle::{Result, Tensor}; use candle::{Result, Tensor};
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Embedding { pub struct Embedding {
embeddings: Tensor, embeddings: Tensor,
hidden_size: usize, hidden_size: usize,

View File

@ -4,7 +4,7 @@
use candle::{DType, Result, Tensor}; use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean. // This group norm version handles both weight and bias so removes the mean.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct GroupNorm { pub struct GroupNorm {
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,

View File

@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
} }
// This layer norm version handles both weight and bias so removes the mean. // This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct LayerNorm { pub struct LayerNorm {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
} }
/// RmsNorm is a specialized version of the LayerNorm module. /// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm); pub struct RmsNorm(LayerNorm);
impl RmsNorm { impl RmsNorm {

View File

@ -19,7 +19,7 @@
//! ``` //! ```
use candle::{Result, Tensor}; use candle::{Result, Tensor};
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct Linear { pub struct Linear {
weight: Tensor, weight: Tensor,
bias: Option<Tensor>, bias: Option<Tensor>,

View File

@ -85,7 +85,7 @@ impl LSTMConfig {
/// ///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory> /// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)] #[allow(clippy::upper_case_acronyms, unused)]
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct LSTM { pub struct LSTM {
w_ih: Tensor, w_ih: Tensor,
w_hh: Tensor, w_hh: Tensor,
@ -235,7 +235,7 @@ impl GRUConfig {
/// ///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit> /// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)] #[allow(clippy::upper_case_acronyms, unused)]
#[derive(Debug)] #[derive(Clone, Debug)]
pub struct GRU { pub struct GRU {
w_ih: Tensor, w_ih: Tensor,
w_hh: Tensor, w_hh: Tensor,