Skip to content

Commit

Permalink
calculate real traj cost
Browse files Browse the repository at this point in the history
  • Loading branch information
acxz committed Dec 16, 2019
1 parent a4a2d0c commit 98bf3d1
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 12 deletions.
5 changes: 3 additions & 2 deletions examples/cart_pole/cartpole_main.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ctrl_noise_covar = [5e-1]; % ctrl_dim by ctrl_dim
learning_rate = 0.01;
per_ctrl_based_ctrl_noise = 0.999;
real_traj_cost = true;
plot_traj = true;
print_sim = true;
print_mppi = true;
Expand All @@ -28,8 +29,8 @@
@cartpole_state_est, @cartpole_apply_ctrl, @cartpole_g, @cartpole_F, ...
@cartpole_state_transform, @cartpole_control_transform, ...
@cartpole_filter_du, num_samples, learning_rate, init_state, init_ctrl_seq, ...
ctrl_noise_covar, time_horizon, per_ctrl_based_ctrl_noise, plot_traj, ...
print_sim, print_mppi, save_sampling, sampling_filename);
ctrl_noise_covar, time_horizon, per_ctrl_based_ctrl_noise, real_traj_cost, ...
plot_traj, print_sim, print_mppi, save_sampling, sampling_filename);

all_figures = findobj('type', 'figure');
num_figures = length(all_figures);
Expand Down
5 changes: 3 additions & 2 deletions examples/inv_pen/inv_pen_main.m
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ctrl_noise_covar = [5e-1]; % ctrl_dim by ctrl_dim
learning_rate = 0.01;
per_ctrl_based_ctrl_noise = 0.999;
real_traj_cost = true;
plot_traj = true;
print_sim = true;
print_mppi = true;
Expand All @@ -28,8 +29,8 @@
@inv_pen_state_est, @inv_pen_apply_ctrl, @inv_pen_g, @inv_pen_F, ...
@inv_pen_state_transform, @inv_pen_control_transform, @inv_pen_filter_du, ...
num_samples, learning_rate, init_state, init_ctrl_seq, ctrl_noise_covar, ...
time_horizon, per_ctrl_based_ctrl_noise, plot_traj, print_sim, ...
print_mppi, save_sampling, sampling_filename);
time_horizon, per_ctrl_based_ctrl_noise, real_traj_cost, plot_traj, ...
print_sim, print_mppi, save_sampling, sampling_filename);

all_figures = findobj('type', 'figure');
num_figures = length(all_figures);
Expand Down
33 changes: 27 additions & 6 deletions mppi.m
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
func_comp_weights, func_term_cost, func_run_cost, func_g, func_F, ...
func_state_transform, func_filter_du, num_samples, learning_rate, ...
init_state, init_ctrl_seq, ctrl_noise_covar, time_horizon, ...
per_ctrl_based_ctrl_noise, print_mppi, save_sampling, sampling_filename)
per_ctrl_based_ctrl_noise, real_traj_cost, print_mppi, save_sampling, ...
sampling_filename)

% time stuff
num_timesteps = size(init_ctrl_seq, 2);
Expand All @@ -15,6 +16,8 @@
sample_state_dim = size(sample_init_state,1);

% state trajectories
real_x_traj = zeros(sample_state_dim, num_timesteps + 1);
real_x_traj(:,1) = sample_init_state;
x_traj = zeros(sample_state_dim, num_samples, num_timesteps + 1);
x_traj(:,:,1) = repmat(sample_init_state,[1, num_samples]);

Expand All @@ -31,6 +34,8 @@
% Begin mppi
iteration = 1;
while(func_control_update_converged(du, iteration) == false)

last_sample_u_traj = sample_u_traj;

% Noise generation
flat_distribution = randn(control_dim, num_samples * num_timesteps);
Expand Down Expand Up @@ -82,11 +87,27 @@

end

% normalize weights, in case they are not normalized
normalized_w = w / sum(w);
if (real_traj_cost == true)
% Loop through the dynamics again to recalcuate traj_cost
rep_traj_cost = 0;

for timestep_num = 1:num_timesteps

% Forward propagation
real_x_traj(:,timestep_num+1) = func_F(real_x_traj(:,timestep_num),func_g(sample_u_traj(:,timestep_num)),dt);

% Compute the representative trajectory cost of what actually happens
% another way to think about this is weighted average of sample trajectory costs
rep_traj_cost = sum(normalized_w .* traj_cost);
rep_traj_cost = rep_traj_cost + func_run_cost(real_x_traj(:,timestep_num)) + learning_rate * sample_u_traj(:,timestep_num)' * inv(ctrl_noise_covar) * (last_sample_u_traj(:,timestep_num) - sample_u_traj(:,timestep_num));

end

rep_traj_cost = rep_traj_cost + func_term_cost(real_x_traj(:,timestep_num+1));
else
% normalize weights, in case they are not normalized
normalized_w = w / sum(w);

% Compute the representative trajectory cost of what actually happens
% another way to think about this is weighted average of sample trajectory costs
rep_traj_cost = sum(normalized_w .* traj_cost);
end

end
5 changes: 3 additions & 2 deletions mppisim.m
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
func_state_est, func_apply_ctrl, func_g, func_F, func_state_transform, ...
func_control_transform, func_filter_du, num_samples, learning_rate, ...
init_state, init_ctrl_seq, ctrl_noise_covar, time_horizon, ...
per_ctrl_based_ctrl_noise, plot_traj, print_sim, print_mppi, ...
per_ctrl_based_ctrl_noise, real_traj_cost, plot_traj, print_sim, print_mppi, ...
save_sampling, sampling_filename)

% time stuff
Expand Down Expand Up @@ -89,7 +89,8 @@
func_comp_weights, func_term_cost, func_run_cost, func_g, func_F, ...
func_state_transform, func_filter_du, num_samples, learning_rate, ...
curr_x, sample_u_traj, ctrl_noise_covar, time_horizon, ...
per_ctrl_based_ctrl_noise, print_mppi, save_sampling, sampling_filename);
per_ctrl_based_ctrl_noise, real_traj_cost, print_mppi, save_sampling, ...
sampling_filename);

% Transform from sample_u to u
u = func_control_transform(sample_x_hist(:,total_timestep_num), sample_u_traj(:,1), dt);
Expand Down

0 comments on commit 98bf3d1

Please sign in to comment.