Add weight, bias, and hidden_size methods (#816)

* Add weight, bias methods to Conv(1|2)

* Add hidden_size method to Embedding

* Expose hidden_size
This commit is contained in:
Eric Buehler
2023-09-11 11:01:11 -04:00
committed by GitHub
parent dbd4561416
commit 59e63d690c
2 changed files with 21 additions and 0 deletions

View File

@ -39,6 +39,14 @@ impl Conv1d {
pub fn config(&self) -> &Conv1dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for Conv1d {
@ -99,6 +107,14 @@ impl Conv2d {
pub fn config(&self) -> &Conv2dConfig {
&self.config
}
pub fn weight(&self) -> &Tensor {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
}
impl crate::Module for Conv2d {

View File

@ -18,6 +18,11 @@ impl Embedding {
pub fn embeddings(&self) -> &Tensor {
&self.embeddings
}
/// Get the hidden size of the embedding matrix
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
impl crate::Module for Embedding {