diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cfe86bfa..309a5f37 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -39,6 +39,14 @@ impl Conv1d { pub fn config(&self) -> &Conv1dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv1d { @@ -99,6 +107,14 @@ impl Conv2d { pub fn config(&self) -> &Conv2dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv2d { diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index d84f9f53..fccc8a17 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -18,6 +18,11 @@ impl Embedding { pub fn embeddings(&self) -> &Tensor { &self.embeddings } + + /// Get the hidden size of the embedding matrix + pub fn hidden_size(&self) -> usize { + self.hidden_size + } } impl crate::Module for Embedding {