mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
VarBuilder cleanup (#627)
* VarBuilder cleanup. * Implement the basic varbuilders. * Add the sharded code. * Proper support for tensor sharding.
This commit is contained in:
@ -14,8 +14,8 @@ const IMAGE_DIM: usize = 784;
|
||||
const LABELS: usize = 10;
|
||||
|
||||
fn linear_z(in_dim: usize, out_dim: usize, vs: VarBuilder) -> Result<Linear> {
|
||||
let ws = vs.get_or_init((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_or_init(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", candle_nn::init::ZERO)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user