Skip to content

Commit

Permalink
regression experiments with epochs and min train steps
Browse files Browse the repository at this point in the history
  • Loading branch information
sukhijab committed Jan 23, 2024
1 parent 21d75e1 commit e390fe3
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 38 deletions.
87 changes: 60 additions & 27 deletions experiments/lf_hf_transfer_exp/run_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import os
import argparse
import jax
import jax.numpy as jnp
import sys
import copy
import jax.numpy as jnp

BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, BASE_DIR)

import datetime
import wandb
from typing import Dict, List, Tuple, Union
from experiments.util import Logger, hash_dict, NumpyArrayEncoder

from typing import List, Union
from experiments.util import hash_dict, NumpyArrayEncoder
from experiments.data_provider import provide_data_and_sim, DATASET_CONFIGS
from sim_transfer.models import BNN_SVGD, BNN_FSVGD, BNN_FSVGD_SimPrior, BNN_MMD_SimPrior, BNN_SVGD_DistillPrior
from sim_transfer.models import (BNN_SVGD, BNN_FSVGD, BNN_FSVGD_SimPrior, BNN_MMD_SimPrior, BNN_SVGD_DistillPrior,
BNNGreyBox)

from sim_transfer.sims.simulators import AdditiveSim, GaussianProcessSim, PredictStateChangeWrapper

BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, BASE_DIR)

ACTIVATION_DICT = {
'relu': jax.nn.relu,
'leaky_relu': jax.nn.leaky_relu,
Expand All @@ -28,7 +30,8 @@
'swish': jax.nn.swish,
}

OUTPUTSCALES_RCCAR = [0.0075, 0.0075, 0.012, 0.012, 0.23, 0.23, 0.62]
OUTPUTSCALES_RCCAR = [0.008, 0.008, 0.01, 0.01, 0.08, 0.08, 0.5]


def regression_experiment(
# data parameters
Expand All @@ -45,15 +48,16 @@ def regression_experiment(
model_seed: int = 892616,
likelihood_std: Union[List[float], float] = 0.1,
data_batch_size: int = 8,
num_train_steps: int = 20000,
min_train_steps: int = 2500,
num_epochs: int = 60,
lr: float = 1e-3,
hidden_activation: str = 'leaky_relu',
num_layers: int = 3,
layer_size: int = 64,
normalize_likelihood_std: bool = False,
learn_likelihood_std: bool = False,
likelihood_exponent: float = 1.0,

likelihood_reg: float = 0.0,
# SVGD parameters
num_particles: int = 20,
bandwidth_svgd: float = 10.0,
Expand All @@ -66,16 +70,17 @@ def regression_experiment(

# FSVGD_Sim_Prior parameters
bandwidth_score_estim: float = None,
ssge_kernel_type: str = 'SE',
ssge_kernel_type: str = 'IMQ',
num_f_samples: int = 128,

switch_score_estimator_frac: float = 0.6667,
switch_score_estimator_frac: float = 0.75,
added_gp_lengthscale: float = 5.,
added_gp_outputscale: Union[List[float], float] = 0.05,

# BNN_SVGD_DistillPrior
num_distill_steps: int = 500000,
):
num_train_steps = num_epochs // data_batch_size * num_samples_train + min_train_steps
# provide data and sim
x_train, y_train, x_test, y_test, sim_lf = provide_data_and_sim(
data_source=data_source,
Expand All @@ -93,6 +98,8 @@ def regression_experiment(
no_added_gp = True
model = model.replace('_no_add_gp', '')
added_gp_outputscale = 0.
elif model in ['GreyBox', 'SysID']:
no_added_gp = True
else:
no_added_gp = False

Expand Down Expand Up @@ -130,12 +137,14 @@ def regression_experiment(
bandwidth_svgd=bandwidth_svgd,
weight_prior_std=weight_prior_std,
bias_prior_std=bias_prior_std,
likelihood_reg=likelihood_reg,
**standard_model_params)
elif model == 'BNN_FSVGD':
model = BNN_FSVGD(domain=sim.domain,
num_particles=num_particles,
bandwidth_svgd=bandwidth_svgd,
bandwidth_gp_prior=bandwidth_gp_prior,
likelihood_reg=likelihood_reg,
num_measurement_points=num_measurement_points,
**standard_model_params)
elif 'BNN_FSVGD_SimPrior' in model:
Expand All @@ -152,6 +161,20 @@ def regression_experiment(
score_estimator=score_estimator,
switch_score_estimator_frac=switch_score_estimator_frac,
**standard_model_params)
elif model in ['GreyBox', 'SysID']:
base_bnn = BNN_FSVGD(domain=sim.domain,
num_particles=num_particles,
bandwidth_svgd=bandwidth_svgd,
bandwidth_gp_prior=bandwidth_gp_prior,
likelihood_reg=likelihood_reg,
num_measurement_points=num_measurement_points,
**standard_model_params)
model = BNNGreyBox(
base_bnn=base_bnn,
sim=sim,
use_base_bnn=(model == 'GreyBox'),
num_sim_model_train_steps=5_000,
)
elif model == 'BNN_MMD_SimPrior':
model = BNN_MMD_SimPrior(domain=sim.domain,
function_sim=sim,
Expand All @@ -173,10 +196,10 @@ def regression_experiment(
raise NotImplementedError('Model {model} not implemented')

# train model
model.fit(x_train, y_train, x_test, y_test, log_to_wandb=use_wandb, log_period=1000)
model.fit_with_scan(x_train, y_train, x_test, y_test, log_to_wandb=use_wandb, log_period=1000)

# eval model
eval_metrics = model.eval(x_test, y_test)
eval_metrics = model.eval(x_test, y_test, per_dim_metrics=True)
return eval_metrics


Expand All @@ -194,6 +217,15 @@ def main(args):
os.makedirs(exp_result_folder, exist_ok=True)

# set likelihood_std to default value if not specified
if 'added_gp_outputscale' in exp_params:
if exp_params['added_gp_outputscale'] < 0:
if 'racecar' in exp_params['data_source']:
exp_params['added_gp_outputscale'] = OUTPUTSCALES_RCCAR
print(f"Setting added_gp_outputscale to data_source default value from DATASET_CONFIGS "
f"which is {exp_params['added_gp_outputscale']}")
else:
raise AssertionError('passed negative value for added_gp_outputscale')

if exp_params['likelihood_std'] is None:
likelihood_std = DATASET_CONFIGS[args.data_source]['likelihood_std']['value']
if 'no_angvel' in exp_params['data_source']:
Expand All @@ -205,7 +237,7 @@ def main(args):
f"which is {exp_params['likelihood_std']}")

# custom gp outputscale for racecar_hf
if 'racecar_hf' in exp_params['data_source']:
if 'real_racecar' in exp_params['data_source']:
outputscales_racecar = exp_params['added_gp_outputscale'] * jnp.array(OUTPUTSCALES_RCCAR)
if 'no_angvel' in exp_params['data_source']:
outputscales_racecar = outputscales_racecar[:-1]
Expand All @@ -215,7 +247,6 @@ def main(args):
print(f'For {exp_params["data_source"]}, multiplying likelihood_std by OUTPUTSCALES_RCCAR. '
f'Resulting added_gp_outputscale parameter: {exp_params["added_gp_outputscale"]}')


from pprint import pprint
print('\nExperiment parameters:')
pprint(exp_params)
Expand Down Expand Up @@ -255,10 +286,10 @@ def main(args):
from pprint import pprint
pprint(results_dict)
else:
exp_result_file = os.path.join(exp_result_folder, '%s.json'%exp_hash)
exp_result_file = os.path.join(exp_result_folder, f'{exp_hash}.json')
with open(exp_result_file, 'w') as f:
json.dump(results_dict, f, indent=4, cls=NumpyArrayEncoder)
print('Dumped results to %s'%exp_result_file)
print(f'Dumped results to {exp_result_file}')

if use_wandb:
wandb.finish()
Expand All @@ -274,18 +305,20 @@ def main(args):
parser.add_argument('--use_wandb', type=bool, default=False)

# data parameters
parser.add_argument('--data_source', type=str, default='pendulum_hf')
parser.add_argument('--num_samples_train', type=int, default=20)
parser.add_argument('--data_source', type=str, default='racecar_hf')
parser.add_argument('--pred_diff', type=int, default=1)
parser.add_argument('--num_samples_train', type=int, default=5000)
parser.add_argument('--data_seed', type=int, default=77698)
parser.add_argument('--pred_diff', type=int, default=0)

# standard BNN parameters
parser.add_argument('--model', type=str, default='BNN_SVGD')
parser.add_argument('--model', type=str, default='BNN_FSVGD')
parser.add_argument('--model_seed', type=int, default=892616)
parser.add_argument('--likelihood_std', type=float, default=None)
parser.add_argument('--learn_likelihood_std', type=int, default=0)
parser.add_argument('--data_batch_size', type=int, default=16)
parser.add_argument('--num_train_steps', type=int, default=20000)
parser.add_argument('--likelihood_reg', type=float, default=-1.0)
parser.add_argument('--data_batch_size', type=int, default=8)
parser.add_argument('--min_train_steps', type=int, default=2500)
parser.add_argument('--num_epochs', type=int, default=60)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--hidden_activation', type=str, default='leaky_relu')
parser.add_argument('--num_layers', type=int, default=3)
Expand All @@ -306,8 +339,8 @@ def main(args):
# FSVGD_SimPrior parameters
parser.add_argument('--bandwidth_score_estim', type=float, default=None)
parser.add_argument('--ssge_kernel_type', type=str, default='IMQ')
parser.add_argument('--num_f_samples', type=int, default=1024)
parser.add_argument('--switch_score_estimator_frac', type=float, default=0.6667)
parser.add_argument('--num_f_samples', type=int, default=128)
parser.add_argument('--switch_score_estimator_frac', type=float, default=0.75)

# Additive SimPrior GP parameters
parser.add_argument('--added_gp_lengthscale', type=float, default=5.)
Expand Down
31 changes: 20 additions & 11 deletions experiments/lf_hf_transfer_exp/sweep_regression_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
MODEL_SPECIFIC_CONFIG = {
'BNN_SVGD': {
'bandwidth_svgd': {'distribution': 'log_uniform', 'min': -1., 'max': 4.},
'num_train_steps': {'values': [5000, 10000, 20000]}
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
},
'BNN_FSVGD': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'bandwidth_gp_prior': {'distribution': 'log_uniform', 'min': -2., 'max': 0.},
'num_train_steps': {'values': [5000, 10000, 20000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [16, 32, 64, 128]},
},
'BNN_FSVGD_SimPrior_gp': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [32]},
'num_f_samples': {'values': [1024]},
'added_gp_lengthscale': {'distribution': 'uniform', 'min': 2., 'max': 10.}, # racecar: 4 - 8 # pendulum (diff): > 8 - 15.
Expand All @@ -30,7 +33,8 @@
},
'BNN_FSVGD_SimPrior_nu-method': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [32]},
'num_f_samples': {'values': [512]},
'bandwidth_score_estim': {'distribution': 'uniform', 'min': 0.8, 'max': 2.0},
Expand All @@ -39,33 +43,38 @@
},
'BNN_FSVGD_SimPrior_ssge': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [8, 16, 32]},
'num_f_samples': {'values': [512]},
'bandwidth_score_estim': {'distribution': 'log_uniform_10', 'min': -0.5, 'max': 1.},
},
'BNN_FSVGD_SimPrior_gp+nu-method': {
'bandwidth_svgd': {'distribution': 'log_uniform_10', 'min': -1.0, 'max': 0.0},
'num_train_steps': {'values': [40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [16, 32]},
'num_f_samples': {'values': [512]},
'switch_score_estimator_frac': {'values': [0.6667]},
'bandwidth_score_estim': {'distribution': 'log_uniform_10', 'min': 0.0, 'max': 0.5},
},
'BNN_FSVGD_SimPrior_kde': {
'bandwidth_svgd': {'distribution': 'log_uniform', 'min': -2., 'max': 2.},
'num_train_steps': {'values': [40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [16, 32]},
'num_f_samples': {'values': [512, 1024, 2056]},
},
'BNN_MMD_SimPrior': {
'num_train_steps': {'values': [20000, 40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [8, 16, 32, 64]},
'num_f_samples': {'values': [64, 128, 256, 512]},
},
'BNN_SVGD_DistillPrior': {
'bandwidth_svgd': {'distribution': 'log_uniform', 'min': -2., 'max': 2.},
'num_train_steps': {'values': [20000, 40000]},
'min_train_steps': {'values': [2500]},
'num_epochs': {'values': [60]},
'num_measurement_points': {'values': [8, 16, 32]},
'num_f_samples': {'values': [64, 128, 256]},
'num_distill_steps': {'values': [30000, 60000]},
Expand Down Expand Up @@ -124,7 +133,7 @@ def main(args):
# sweep args
parser.add_argument('--num_hparam_samples', type=int, default=20)
parser.add_argument('--num_model_seeds', type=int, default=3, help='number of model seeds per hparam')
parser.add_argument('--num_data_seeds', type=int, default=4, help='number of model seeds per hparam')
parser.add_argument('--num_data_seeds', type=int, default=3, help='number of model seeds per hparam')
parser.add_argument('--num_cpus', type=int, default=1, help='number of cpus to use')
parser.add_argument('--run_mode', type=str, default='euler')

Expand All @@ -136,7 +145,7 @@ def main(args):

# data parameters
parser.add_argument('--data_source', type=str, default='pendulum_hf')
parser.add_argument('--pred_diff', type=int, default=0)
parser.add_argument('--pred_diff', type=int, default=1)

# # standard BNN parameters
parser.add_argument('--model', type=str, default='BNN_SVGD')
Expand Down
Loading

0 comments on commit e390fe3

Please sign in to comment.