Então, meu problema é que tenho uma característica de camada com tipos de entrada e saída como segue:
pub trait Layer {
type Input: Dimension;
type Output: Dimension;
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}
Com esta função de avanço:
impl<A: Activation> Layer for DenseLayer<A> {
type Input = Ix2;
type Output = Ix2;
fn forward(&mut self, input: &Array2<f32>) -> Array2<f32> {
assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
let z = input.dot(&self.weights) + &self.biases;
self.activation.activate(&z)
}
}
Eu tenho isso para que minhas funções forward ou backward possam receber, por exemplo, um array de 2 dimensões, mas ainda produzir um com apenas 1 dimensão. Então eu tenho uma implementação para um tipo de wrapper dessa característica de camada onde eu quero encaminhar por todas as camadas:
pub struct NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
layers: Vec<L>,
loss_function: &'a dyn Cost,
}
impl<'a, L> NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
pub fn new(layers: Vec<L>, loss_function: &'a dyn Cost) -> Self {
NeuralNetwork { layers, loss_function }
}
pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
let mut output = input.clone();
// todo fix the layer forward changing input to output
// causing mismatch in the input and output dimensions of forward
for layer in &mut self.layers {
output = layer.forward(&output);
}
output
}
}
Agora, porque no loop for eu primeiro insiro o tipo input, então recebo a saída de layer.forward. Na próxima iteração, ele pega o tipo output, mas o layer.forward aceita apenas o tipo input. Pelo menos é isso que eu acho que está acontecendo. Isso pode parecer um problema muito simples, mas estou genuinamente inseguro sobre como consertar isso.
Edição 1:
Exemplo reproduzível:
use ndarray::{Array, Array2, ArrayBase, Dimension, OwnedRepr};
pub trait Layer {
type Input: Dimension;
type Output: Dimension;
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output>;
}
// A Dense Layer struct
pub struct DenseLayer {
weights: Array2<f32>,
biases: Array2<f32>,
}
impl DenseLayer {
pub fn new(input_size: usize, output_size: usize) -> Self {
let weights = Array::random((input_size, output_size), rand::distributions::Uniform::new(-0.5, 0.5));
let biases = Array::zeros((1, output_size));
DenseLayer { weights, biases }
}
}
impl Layer for DenseLayer {
type Input = ndarray::Ix2; // Two-dimensional input
type Output = ndarray::Ix2; // Two-dimensional output
fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, Self::Input>) -> ArrayBase<OwnedRepr<f32>, Self::Output> {
assert_eq!(input.shape()[1], self.weights.shape()[0], "Input width must match weight height.");
let z = input.dot(&self.weights) + &self.biases;
z // Return the output directly without activation
}
}
// Neural Network struct
pub struct NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
layers: Vec<L>,
}
impl<'a, L> NeuralNetwork<'a, L>
where
L: Layer + 'a,
{
pub fn new(layers: Vec<L>) -> Self {
NeuralNetwork { layers }
}
pub fn forward(&mut self, input: &ArrayBase<OwnedRepr<f32>, L::Input>) -> ArrayBase<OwnedRepr<f32>, L::Output> {
let mut output = input.clone();
for layer in &mut self.layers {
output = layer.forward(&output);
}
output
}
}
fn main() {
// Create a neural network with one Dense Layer
let mut dense_layer = DenseLayer::new(3, 2);
let mut nn = NeuralNetwork::new(vec![dense_layer]);
// Create an example input (1 batch, 3 features)
let input = Array::from_shape_vec((1, 3), vec![1.0, 2.0, 3.0]).unwrap();
// Forward pass
let output = nn.forward(&input);
println!("Output: {:?}", output);
}
Há duas coisas que você precisa fazer
NeuralNetwork::forward
para compilar.Layer
limite para que os tiposInput
eOutput
associados sejam do mesmo tipo.Clone
de forma queinput.clone()
seja possível clonar o array subjacente em vez de clonar a referência.Esses limites comunicarão essas restrições ao compilador (observe a introdução de um novo parâmetro genérico
T
noimpl
bloco):Observe que você deve considerar mudar
NeuralNetwork::new
para umimpl
bloco com restrições mínimas, pois não há motivo para que essas restrições sejam aplicadas a ele.Há alguns outros erros de tempo de compilação, mas presumo que eles não estejam relacionados ao problema que você está tentando resolver. Em particular, não está claro para mim por que você tem um
'a
lifetime emNeuralNetwork
; você pode removê-lo completamente e o código ainda compila.