mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add clone to various nn layers. (#910)
This commit is contained in:
@ -38,7 +38,7 @@ impl From<f64> for BatchNormConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct BatchNorm {
|
||||
running_mean: Tensor,
|
||||
running_var: Tensor,
|
||||
|
@ -20,7 +20,7 @@ impl Default for Conv1dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv1d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -88,7 +88,7 @@ impl Default for Conv2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Conv2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -157,7 +157,7 @@ impl Default for ConvTranspose2dConfig {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConvTranspose2d {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
|
@ -1,7 +1,7 @@
|
||||
//! Embedding Layer.
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
|
@ -4,7 +4,7 @@
|
||||
use candle::{DType, Result, Tensor};
|
||||
|
||||
// This group norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GroupNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
|
@ -60,7 +60,7 @@ impl From<f64> for LayerNormConfig {
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
@ -143,7 +143,7 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
|
||||
}
|
||||
|
||||
/// RmsNorm is a specialized version of the LayerNorm module.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RmsNorm(LayerNorm);
|
||||
|
||||
impl RmsNorm {
|
||||
|
@ -19,7 +19,7 @@
|
||||
//! ```
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Linear {
|
||||
weight: Tensor,
|
||||
bias: Option<Tensor>,
|
||||
|
@ -85,7 +85,7 @@ impl LSTMConfig {
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LSTM {
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
@ -235,7 +235,7 @@ impl GRUConfig {
|
||||
///
|
||||
/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
|
||||
#[allow(clippy::upper_case_acronyms, unused)]
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GRU {
|
||||
w_ih: Tensor,
|
||||
w_hh: Tensor,
|
||||
|
Reference in New Issue
Block a user