mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Use arange in the examples. (#146)
This commit is contained in:
@ -217,9 +217,10 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let arange: Vec<_> = (0..length).map(|c| c as f32).collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::new(arange.as_slice(), &Device::Cpu)?.unsqueeze(1)?;
|
||||
let arange = Tensor::arange(0, length as u32, &Device::Cpu)?
|
||||
.to_dtype(candle::DType::F32)?
|
||||
.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||
|
Reference in New Issue
Block a user