From 59e63d690c95e58526bd41144823c6e63f9f2916 Mon Sep 17 00:00:00 2001 From: Eric Buehler <65165915+EricLBuehler@users.noreply.github.com> Date: Mon, 11 Sep 2023 11:01:11 -0400 Subject: [PATCH] Add weight, bias, and hidden_size methods (#816) * Add weight, bias methods to Conv(1|2) * Add hidden_size method to Embedding * Expose hidden_size --- candle-nn/src/conv.rs | 16 ++++++++++++++++ candle-nn/src/embedding.rs | 5 +++++ 2 files changed, 21 insertions(+) diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index cfe86bfa..309a5f37 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -39,6 +39,14 @@ impl Conv1d { pub fn config(&self) -> &Conv1dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv1d { @@ -99,6 +107,14 @@ impl Conv2d { pub fn config(&self) -> &Conv2dConfig { &self.config } + + pub fn weight(&self) -> &Tensor { + &self.weight + } + + pub fn bias(&self) -> Option<&Tensor> { + self.bias.as_ref() + } } impl crate::Module for Conv2d { diff --git a/candle-nn/src/embedding.rs b/candle-nn/src/embedding.rs index d84f9f53..fccc8a17 100644 --- a/candle-nn/src/embedding.rs +++ b/candle-nn/src/embedding.rs @@ -18,6 +18,11 @@ impl Embedding { pub fn embeddings(&self) -> &Tensor { &self.embeddings } + + /// Get the hidden size of the embedding matrix + pub fn hidden_size(&self) -> usize { + self.hidden_size + } } impl crate::Module for Embedding {