Add clone to various nn layers. (#910)

This commit is contained in:
Laurent Mazare
2023-09-20 11:33:51 +01:00
committed by GitHub
parent f685b2231c
commit 7b1ddcff47
7 changed files with 11 additions and 11 deletions

View File

@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct BatchNorm {
running_mean: Tensor,
running_var: Tensor,

View File

@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Conv1d {
weight: Tensor,
bias: Option<Tensor>,
@ -88,7 +88,7 @@ impl Default for Conv2dConfig {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Conv2d {
weight: Tensor,
bias: Option<Tensor>,
@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ConvTranspose2d {
weight: Tensor,
bias: Option<Tensor>,

View File

@ -1,7 +1,7 @@
//! Embedding Layer.
use candle::{Result, Tensor};
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Embedding {
embeddings: Tensor,
hidden_size: usize,

View File

@ -4,7 +4,7 @@
use candle::{DType, Result, Tensor};
// This group norm version handles both weight and bias so removes the mean.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct GroupNorm {
weight: Tensor,
bias: Tensor,

View File

@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
}
// This layer norm version handles both weight and bias so removes the mean.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Option<Tensor>,
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
}
/// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm);
impl RmsNorm {

View File

@ -19,7 +19,7 @@
//! ```
use candle::{Result, Tensor};
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,

View File

@ -85,7 +85,7 @@ impl LSTMConfig {
///
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct LSTM {
w_ih: Tensor,
w_hh: Tensor,
@ -235,7 +235,7 @@ impl GRUConfig {
///
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
#[allow(clippy::upper_case_acronyms, unused)]
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct GRU {
w_ih: Tensor,
w_hh: Tensor,