我是 Rust 的新手,我正在研究使用Burn移植一些 Python/Torch 代码用于新的统计参数方法。
第一步:我想生成一个 (10, 1) 张量,其随机值来自具有已知参数的柯西分布。Burn 中的分布非常有限,因此我使用statrs。通过使用statrs
,我可以得到Vec<f64>
,然后我可以将其包装到TensorData
Burn 中的 中,从而生成Tensor
。
我添加了一些类型签名,但 Burn 有Float
而不是特定的f64
,我对此有点困惑。事实上,只是为了调试目的,我想从 Burn 张量中提取数据作为Vec<f64>
来查看它(我应该从中看到相同的值vec: Vec<f64>
),但我得到了运行时类型不兼容。
use rand::prelude::Distribution;
use statrs::distribution::Cauchy;
use rand_chacha::ChaCha8Rng;
use rand_core::SeedableRng;
use burn::tensor::{Tensor, TensorData, Float};
use burn::backend::Wgpu;
type Backend = Wgpu;
fn main() {
// some global refs
let device = Default::default();
let mut rng: ChaCha8Rng = ChaCha8Rng::seed_from_u64(2);
// create random vec using statrs, store in a Vec<f64>
let dist: Cauchy = Cauchy::new(5.0, 2.0).unwrap();
let vec: Vec<f64> = dist.sample_iter(&mut rng).take(10).collect();
// wrap this into a Burn tensor
let td: TensorData = TensorData::new(vec, [10, 1]);
let tensor: Tensor<Backend, 2, Float> = Tensor::<Backend, 2, Float>::from_data(td, &device);
print!("{:?}\n", tensor.to_data().to_vec::<f64>().unwrap());
}
当在上面运行时,我得到
thread 'main' panicked at src/main.rs:23:55:
called `Result::unwrap()` on an `Err` value: TypeMismatch("Invalid target element type
(expected F32, got F64)")
使用to_vec::<f32>
有效,但我希望 Burn 张量具有 f64 值(torch 有这个),因为错误似乎意味着我在某个时候失去了精度——不太好。
是否可以存储f64
在 Burn 张量中?