Skip to content

Commit

Permalink
sgd cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
davidge807 committed Dec 23, 2022
1 parent 007bd62 commit e5f7fae
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 21 deletions.
19 changes: 1 addition & 18 deletions opennn/stochastic_gradient_descent.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,7 @@ void StochasticGradientDescent::update_parameters(LossIndexBackPropagation& back
}
else
{
optimization_data.nesterov_increment.device(*thread_pool_device)
= optimization_data.parameters_increment*momentum - back_propagation.gradient*learning_rate;

back_propagation.parameters.device(*thread_pool_device) += optimization_data.nesterov_increment;
back_propagation.parameters.device(*thread_pool_device) += optimization_data.parameters_increment*momentum - back_propagation.gradient*learning_rate;;
}
}
else
Expand All @@ -301,20 +298,6 @@ void StochasticGradientDescent::update_parameters(LossIndexBackPropagation& back

optimization_data.last_parameters_increment = optimization_data.parameters_increment;

/// @todo check if the following is equivalent
/*
if(momentum > type(0))
{
back_propagation.parameters.device(*thread_pool_device) += momentum*optimization_data.last_parameters_increment;
if(nesterov)
{
back_propagation.parameters.device(*thread_pool_device) += optimization_data.parameters_increment*momentum;
}
optimization_data.last_parameters_increment = optimization_data.parameters_increment;
}
*/
optimization_data.iteration++;

// Update parameters
Expand Down
3 changes: 0 additions & 3 deletions opennn/stochastic_gradient_descent.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,9 @@ struct StochasticGradientDescentData : public OptimizationAlgorithmData
const Index parameters_number = neural_network_pointer->get_parameters_number();

parameters_increment.resize(parameters_number);
nesterov_increment.resize(parameters_number);
last_parameters_increment.resize(parameters_number);

parameters_increment.setZero();
nesterov_increment.setZero();
last_parameters_increment.setZero();
}

Expand All @@ -196,7 +194,6 @@ struct StochasticGradientDescentData : public OptimizationAlgorithmData
Index iteration = 0;

Tensor<type, 1> parameters_increment;
Tensor<type, 1> nesterov_increment;
Tensor<type, 1> last_parameters_increment;
};

Expand Down

0 comments on commit e5f7fae

Please sign in to comment.