Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LRS wrapper to optimizer #1170

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 157 additions & 6 deletions python/paddle/trainer_config_helpers/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
default_gradient_clipping_threshold, default_momentum

from .default_decorators import wrap_param_default
import collections
import cStringIO

__all__ = [
'Optimizer', 'BaseSGDOptimizer', 'MomentumOptimizer', 'AdamaxOptimizer',
'AdamOptimizer', 'AdaGradOptimizer', 'RMSPropOptimizer',
'DecayedAdaGradOptimizer', 'AdaDeltaOptimizer', 'BaseRegularization',
'L2Regularization', 'settings', 'ModelAverage'
'L2Regularization', 'settings', 'ModelAverage', 'PolyLRS', 'ConstantLRS',
'ExpLRS', 'DiscreteExpLRS', 'LinearLRS', 'ManualLRS', 'PassManualLRS'
]


Expand Down Expand Up @@ -351,15 +354,141 @@ def __extends__(dict1, dict2):
return dict1


class BaseLRS(Optimizer):
def __init__(self, a, b, scheduler_name):
self.__a__ = float(a)
self.__b__ = float(b)
self.__scheduler_name__ = scheduler_name

def to_setting_kwargs(self):
return {
'learning_rate_schedule': self.__scheduler_name__,
'learning_rate_decay_a': self.__a__,
'learning_rate_decay_b': self.__b__
}


class PolyLRS(BaseLRS):
"""
Poly Learning Rate Scheduler.

lr = learning_rate * pow(1 + a * num_samples_processed, -b)
"""

def __init__(self, a, b):
super(PolyLRS, self).__init__(a=a, b=b, scheduler_name='poly')


class ConstantLRS(Optimizer):
"""
Constant Learning Rate Scheduler. Learning rate will not be changed.
"""

def to_setting_kwargs(self):
return {'learning_rate_schedule': 'constant'}


class ExpLRS(BaseLRS):
"""
Exp Learning Rate Scheduler.

lr = learning_rate * pow(a, num_samples_processed/b)
"""

def __init__(self, a, b):
super(ExpLRS, self).__init__(a=a, b=b, scheduler_name='exp')


class DiscreteExpLRS(BaseLRS):
"""
Discrete Exp Learning Rate Scheduler.

lr = learning_rate * pow(a, floor(num_samples_processed / b))
"""

def __init__(self, a, b):
super(DiscreteExpLRS, self).__init__(a=a, b=b, scheduler_name='discexp')


class LinearLRS(BaseLRS):
"""
Linear Learning Rate Scheduler.

lr = max(learning_rate - a, b)
"""

def __init__(self, a, b):
super(LinearLRS, self).__init__(a=a, b=b, scheduler_name='linear')


class ManualLRS(Optimizer):
"""
specify learning rate through explicit pass all learning_rates.

:param learning_rates: list of learning rates. Each item contains two field.
First is a int value, as segmentation. Second is the
learning rate.

The real learning rate is:

if seg_{i-1} <= numSamples <= seg_i,
return lr_{i}

:type learning_rates: list of list. Each element should be (int, float)
"""

def __init__(self, learning_rates):
assert isinstance(learning_rates, collections.Sequence)
with cStringIO.StringIO() as buf:
for i, each in enumerate(learning_rates):
assert isinstance(each, collections.Sequence)
assert len(each) == 2
buf.write("{0}:{1:.5f}".format(int(each[0]), float(each[1])))
if i + 1 != len(learning_rates): # not at end
buf.write(",")
self.__args__ = buf.getvalue()

def to_setting_kwargs(self):
return {
'learning_rate_schedule': 'manual',
'learning_rate_args': self.__args__
}


class PassManualLRS(ManualLRS):
"""
Pass Manual Learning Rate Scheduler.

Basically same as manual learning rate scheduler, except pass manual LRS use
pass number as segment number.

The real learning rate is:

if seg_{i-1} <= pass_id <= seg_i:
return lr_{i}
"""

def __init__(self, learning_rates):
super(PassManualLRS, self).__init__(learning_rates=learning_rates)

def to_setting_kwargs(self):
return {
'learning_rate_schedule': 'pass_manual',
'learning_rate_args': self.__args__
}


@wrap_param_default(
['learning_method'], default_factory=lambda _: MomentumOptimizer())
@wrap_param_default(
['regularization'], default_factory=lambda _: BaseRegularization())
@wrap_param_default(
['learning_rate_schedule'], default_factory=lambda _: ConstantLRS())
def settings(batch_size,
learning_rate=1e-3,
learning_rate_decay_a=0.,
learning_rate_decay_b=0.,
learning_rate_schedule='poly',
learning_rate_schedule=None,
learning_rate_args='',
learning_method=None,
regularization=None,
Expand Down Expand Up @@ -396,6 +525,19 @@ def settings(batch_size,
value larger than some value, will be
clipped.
:type gradient_clipping_threshold: float

:param learning_rate_schedule: A Learning Rate Scheduler object or basestr.
It is recommend to pass a LRS object.
If you set learning_rate_schedule as basestr,
you should manually set learning_rate_decay_a
learning_rate_decay_b and learning_rate_args.

Check LRS.to_setting_kwargs to figure out
how to set these arguments.
:type learning_rate_schedule: basestring|Optimizer
:param learning_rate_decay_a: See learning_rate_schedule.
:param learning_rate_decay_b: See learning_rate_schedule.
:param learning_rate_args: See learning_rate_schedule.
"""
if isinstance(regularization, BaseRegularization):
regularization = [regularization]
Expand All @@ -406,15 +548,24 @@ def settings(batch_size,
else:
algorithm = 'owlqn'

args = [
'batch_size', 'learning_rate', 'learning_rate_decay_a',
'learning_rate_decay_b', 'learning_rate_schedule', 'learning_rate_args'
]
args = ['batch_size', 'learning_rate']
kwargs = dict()
kwargs['algorithm'] = algorithm

for arg in args:
kwargs[arg] = locals()[arg]

if isinstance(learning_rate_schedule, Optimizer):
kwargs = __extends__(kwargs, learning_rate_schedule.to_setting_kwargs())
elif isinstance(learning_rate_schedule, basestring):
for arg in [
'learning_rate_decay_a', 'learning_rate_decay_b',
'learning_rate_schedule', 'learning_rate_args'
]:
kwargs[arg] = locals()[arg]
else:
raise RuntimeWarning("Unexcepted branch")

kwargs = __extends__(kwargs, learning_method.to_setting_kwargs())
learning_method.extra_settings()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ opt_config {
ada_epsilon: 1e-06
do_average_in_cpu: false
ada_rou: 0.95
learning_rate_schedule: "poly"
learning_rate_schedule: "constant"
delta_add_rate: 1.0
shrink_parameter_value: 0
adam_beta1: 0.9
Expand Down