mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add clone to various nn layers. (#910)
This commit is contained in:
@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct BatchNorm {
|
pub struct BatchNorm {
|
||||||
running_mean: Tensor,
|
running_mean: Tensor,
|
||||||
running_var: Tensor,
|
running_var: Tensor,
|
||||||
|
@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Conv1d {
|
pub struct Conv1d {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
@ -88,7 +88,7 @@ impl Default for Conv2dConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Conv2d {
|
pub struct Conv2d {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ConvTranspose2d {
|
pub struct ConvTranspose2d {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
//! Embedding Layer.
|
//! Embedding Layer.
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Embedding {
|
pub struct Embedding {
|
||||||
embeddings: Tensor,
|
embeddings: Tensor,
|
||||||
hidden_size: usize,
|
hidden_size: usize,
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
use candle::{DType, Result, Tensor};
|
use candle::{DType, Result, Tensor};
|
||||||
|
|
||||||
// This group norm version handles both weight and bias so removes the mean.
|
// This group norm version handles both weight and bias so removes the mean.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GroupNorm {
|
pub struct GroupNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
|
@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// This layer norm version handles both weight and bias so removes the mean.
|
// This layer norm version handles both weight and bias so removes the mean.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LayerNorm {
|
pub struct LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct RmsNorm(LayerNorm);
|
pub struct RmsNorm(LayerNorm);
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
//! ```
|
//! ```
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Linear {
|
pub struct Linear {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Option<Tensor>,
|
bias: Option<Tensor>,
|
||||||
|
@ -85,7 +85,7 @@ impl LSTMConfig {
|
|||||||
///
|
///
|
||||||
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
||||||
#[allow(clippy::upper_case_acronyms, unused)]
|
#[allow(clippy::upper_case_acronyms, unused)]
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct LSTM {
|
pub struct LSTM {
|
||||||
w_ih: Tensor,
|
w_ih: Tensor,
|
||||||
w_hh: Tensor,
|
w_hh: Tensor,
|
||||||
@ -235,7 +235,7 @@ impl GRUConfig {
|
|||||||
///
|
///
|
||||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||||
#[allow(clippy::upper_case_acronyms, unused)]
|
#[allow(clippy::upper_case_acronyms, unused)]
|
||||||
#[derive(Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct GRU {
|
pub struct GRU {
|
||||||
w_ih: Tensor,
|
w_ih: Tensor,
|
||||||
w_hh: Tensor,
|
w_hh: Tensor,
|
||||||
|
Reference in New Issue
Block a user