mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add the sequential layer. (#1136)
This commit is contained in:
@ -11,6 +11,7 @@ pub mod loss;
|
|||||||
pub mod ops;
|
pub mod ops;
|
||||||
pub mod optim;
|
pub mod optim;
|
||||||
pub mod rnn;
|
pub mod rnn;
|
||||||
|
pub mod sequential;
|
||||||
pub mod var_builder;
|
pub mod var_builder;
|
||||||
pub mod var_map;
|
pub mod var_map;
|
||||||
|
|
||||||
@ -29,6 +30,7 @@ pub use linear::{linear, linear_no_bias, Linear};
|
|||||||
pub use ops::Dropout;
|
pub use ops::Dropout;
|
||||||
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||||
|
pub use sequential::{seq, Sequential};
|
||||||
pub use var_builder::VarBuilder;
|
pub use var_builder::VarBuilder;
|
||||||
pub use var_map::VarMap;
|
pub use var_map::VarMap;
|
||||||
|
|
||||||
|
62
candle-nn/src/sequential.rs
Normal file
62
candle-nn/src/sequential.rs
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
//! A sequential layer used to chain multiple layers and closures.
|
||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
|
||||||
|
/// A sequential layer combining multiple other layers.
|
||||||
|
pub struct Sequential {
|
||||||
|
layers: Vec<Box<dyn Module>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new empty sequential layer.
|
||||||
|
pub fn seq() -> Sequential {
|
||||||
|
Sequential { layers: vec![] }
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sequential {
|
||||||
|
/// The number of sub-layers embedded in this layer.
|
||||||
|
pub fn len(&self) -> i64 {
|
||||||
|
self.layers.len() as i64
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this layer does not have any sub-layer.
|
||||||
|
pub fn is_empty(&self) -> bool {
|
||||||
|
self.layers.is_empty()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Sequential {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs)?
|
||||||
|
}
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Sequential {
|
||||||
|
/// Appends a layer after all the current layers.
|
||||||
|
#[allow(clippy::should_implement_trait)]
|
||||||
|
pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
|
||||||
|
self.layers.push(Box::new(layer));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Appends a closure after all the current layers.
|
||||||
|
pub fn add_fn<F>(self, f: F) -> Self
|
||||||
|
where
|
||||||
|
F: 'static + Fn(&Tensor) -> Result<Tensor> + Send,
|
||||||
|
{
|
||||||
|
self.add(super::func(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies the forward pass and returns the output for each layer.
|
||||||
|
pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
|
||||||
|
let mut vec = Vec::with_capacity(self.layers.len());
|
||||||
|
let mut xs = xs.clone();
|
||||||
|
for layer in self.layers.iter() {
|
||||||
|
xs = layer.forward(&xs)?;
|
||||||
|
vec.push(xs.clone())
|
||||||
|
}
|
||||||
|
Ok(vec)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user