Fixing examples.

This commit is contained in:
Nicolas Patry
2023-08-01 15:04:41 +02:00
parent 82464166e4
commit 45642a8530

View File

@ -58,20 +58,20 @@ Now that we have our weights, we can use them in our bert architecture:
# extern crate candle_nn; # extern crate candle_nn;
# extern crate hf_hub; # extern crate hf_hub;
# use hf_hub::api::sync::Api; # use hf_hub::api::sync::Api;
# use candle::Device;
# #
# let api = Api::new().unwrap(); # let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string()); # let repo = api.model("bert-base-uncased".to_string());
# #
# let weights = repo.get("model.safetensors").unwrap(); # let weights = repo.get("model.safetensors").unwrap();
use candle::{Device, Tensor, DType};
use candle_nn::Linear; use candle_nn::Linear;
let weights = candle::safetensors::load(weights, &Device::Cpu); let weights = candle::safetensors::load(weights, &Device::Cpu).unwrap();
let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap(); let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap(); let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();
let linear = Linear::new(weight, Some(bias)); let linear = Linear::new(weight.clone(), Some(bias.clone()));
let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap(); let input_ids = Tensor::zeros((3, 7680), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids); let output = linear.forward(&input_ids);