Skip to content

add total gradient norm to VI #2257

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 4, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions pymc3/tests/test_variational_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TestApproximates:
class Base(SeededTest):
inference = None
NITER = 12000
optimizer = pm.adagrad_window(learning_rate=0.01)
optimizer = pm.adagrad_window(learning_rate=0.01, n_win=50)
conv_cb = property(lambda self: [
pm.callbacks.CheckParametersConvergence(
every=500,
Expand Down Expand Up @@ -152,7 +152,6 @@ def test_optimizer_with_full_data(self):
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
Normal('x', mu=mu_, sd=sd, observed=data)
inf = self.inference(start={})
inf.fit(10)
approx = inf.fit(self.NITER,
obj_optimizer=self.optimizer,
callbacks=self.conv_cb,)
Expand Down Expand Up @@ -295,11 +294,9 @@ class TestSVGD(TestApproximates.Base):


class TestASVGD(TestApproximates.Base):
NITER = 15000
inference = ASVGD
NITER = 5000
inference = functools.partial(ASVGD, temperature=1.5)
test_aevb = _test_aevb
optimizer = pm.adagrad_window(learning_rate=0.002)
conv_cb = []


class TestEmpirical(SeededTest):
Expand Down Expand Up @@ -366,12 +363,13 @@ def test_init_from_noize(self):
(_advi, dict(start={}), None),
(_fullrank_advi, dict(), None),
(_svgd, dict(), None),
('advi', dict(), None),
('advi', dict(total_grad_norm_constraint=10), None),
('advi->fullrank_advi', dict(frac=.1), None),
('advi->fullrank_advi', dict(frac=1), ValueError),
('fullrank_advi', dict(), None),
('svgd', dict(), None),
('svgd', dict(total_grad_norm_constraint=10), None),
('svgd', dict(start={}), None),
('asvgd', dict(start={}, total_grad_norm_constraint=10), None),
('svgd', dict(local_rv={_model.free_RVs[0]: (0, 1)}), ValueError)
]
)
Expand Down
33 changes: 27 additions & 6 deletions pymc3/variational/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ class Inference(object):
See (AEVB; Kingma and Welling, 2014) for details
model : Model
PyMC3 Model
op_kwargs : dict
kwargs passed to :class:`Operator`
kwargs : kwargs
additional kwargs for :class:`Approximation`
"""

def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
def __init__(self, op, approx, tf, local_rv=None, model=None, op_kwargs=None, **kwargs):
if op_kwargs is None:
op_kwargs = dict()
self.hist = np.asarray(())
if isinstance(approx, type) and issubclass(approx, Approximation):
approx = approx(
Expand All @@ -56,7 +60,7 @@ def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
else: # pragma: no cover
raise TypeError(
'approx should be Approximation instance or Approximation subclass')
self.objective = op(approx)(tf)
self.objective = op(approx, **op_kwargs)(tf)

approx = property(lambda self: self.objective.approx)

Expand Down Expand Up @@ -146,7 +150,11 @@ def _iterate_without_loss(self, _, step_func, progress, callbacks):
def _iterate_with_loss(self, n, step_func, progress, callbacks):
def _infmean(input_array):
"""Return the mean of the finite values of the array"""
return np.mean(np.asarray(input_array)[np.isfinite(input_array)])
input_array = input_array[np.isfinite(input_array)].astype('float64')
if len(input_array) == 0:
return np.nan
else:
return np.mean(input_array)
scores = np.empty(n)
scores[:] = np.nan
i = 0
Expand Down Expand Up @@ -531,6 +539,8 @@ class SVGD(Inference):
PyMC3 model for inference
kernel : `callable`
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
temperature : float
parameter responsible for exploration, higher temperature gives more broad posterior estimate
scale_cost_to_minibatch : bool, default False
Scale cost to minibatch instead of full dataset
start : `dict`
Expand All @@ -548,10 +558,14 @@ class SVGD(Inference):
- Qiang Liu, Dilin Wang (2016)
Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
arXiv:1608.04471

- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
Stein Variational Policy Gradient
arXiv:1704.02399
"""

def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_functions.rbf,
scale_cost_to_minibatch=False, start=None, histogram=None,
temperature=1, scale_cost_to_minibatch=False, start=None, histogram=None,
random_seed=None, local_rv=None):
if histogram is None:
histogram = Empirical.from_noise(
Expand Down Expand Up @@ -593,6 +607,8 @@ class ASVGD(Inference):
See (AEVB; Kingma and Welling, 2014) for details
kernel : `callable`
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
temperature : float
parameter responsible for exploration, higher temperature gives more broad posterior estimate
model : :class:`Model`
kwargs : kwargs for :class:`Approximation`

Expand All @@ -604,17 +620,22 @@ class ASVGD(Inference):

- Dilin Wang, Qiang Liu (2016)
Learning to Draw Samples: With Application to Amortized MLE for Generative Adversarial Learning
https://arxiv.org/abs/1611.01722
arXiv:1611.01722

- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
Stein Variational Policy Gradient
arXiv:1704.02399
"""

def __init__(self, approx=FullRank, local_rv=None,
kernel=test_functions.rbf, model=None, **kwargs):
kernel=test_functions.rbf, temperature=1, model=None, **kwargs):
super(ASVGD, self).__init__(
op=AKSD,
approx=approx,
local_rv=local_rv,
tf=kernel,
model=model,
op_kwargs=dict(temperature=temperature),
**kwargs
)

Expand Down
21 changes: 17 additions & 4 deletions pymc3/variational/operators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from theano import theano, tensor as tt
from pymc3.variational.opvi import Operator, ObjectiveFunction, _warn_not_used
from pymc3.variational.stein import Stein
from pymc3.variational import updates
import pymc3 as pm

__all__ = [
Expand Down Expand Up @@ -63,7 +63,6 @@ def __call__(self, z, **kwargs):
grad *= pm.floatX(-1)
grad = theano.clone(grad, {op.input_matrix: z})
grad = tt.grad(None, params, known_grads={z: grad})
grad = updates.total_norm_constraint(grad, 10)
return grad


Expand Down Expand Up @@ -97,15 +96,29 @@ class KSD(Operator):
SUPPORT_AEVB = False
OBJECTIVE = KSDObjective

def __init__(self, approx):
def __init__(self, approx, temperature=1):
Operator.__init__(self, approx)
self.temperature = temperature
self.input_matrix = tt.matrix('KSD input matrix')

def apply(self, f):
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
stein = Stein(self.approx, f, self.input_matrix)
stein = Stein(
approx=self.approx,
kernel=f,
input_matrix=self.input_matrix,
temperature=self.temperature)
return pm.floatX(-1) * stein.grad


class AKSD(KSD):
def __init__(self, approx, temperature=1):
warnings.warn('You are using experimental inference Operator. '
'It requires careful choice of temperature, default is 1. '
'Default temperature works well for low dimensional problems and '
'for significant `n_obj_mc`. Temperature > 1 gives more exploration '
'power to algorithm, < 1 leads to undesirable results. Please take '
'it in account when looking at inference result. Posterior variance '
'is often **underestimated** when using temperature = 1.', stacklevel=2)
super(AKSD, self).__init__(approx, temperature)
SUPPORT_AEVB = True
51 changes: 35 additions & 16 deletions pymc3/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def random(self, size=None):
return self.op.approx.random(size)

def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
more_obj_params=None, more_tf_params=None, more_updates=None, more_replacements=None):
more_obj_params=None, more_tf_params=None, more_updates=None,
more_replacements=None, total_grad_norm_constraint=None):
"""Calculates gradients for objective function, test function and then
constructs updates for optimization step

Expand All @@ -120,27 +121,24 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
Add custom updates to resulting updates
more_replacements : `dict`
Apply custom replacements before calculating gradients
total_grad_norm_constraint : `float`
Bounds gradient norm, prevents exploding gradient problem

Returns
-------
:class:`ObjectiveUpdates`
"""
if more_obj_params is None:
more_obj_params = []
if more_tf_params is None:
more_tf_params = []
if more_updates is None:
more_updates = dict()
if more_replacements is None:
more_replacements = dict()
resulting_updates = ObjectiveUpdates()
if self.test_params:
self.add_test_updates(
resulting_updates,
tf_n_mc=tf_n_mc,
test_optimizer=test_optimizer,
more_tf_params=more_tf_params,
more_replacements=more_replacements
more_replacements=more_replacements,
total_grad_norm_constraint=total_grad_norm_constraint
)
else:
if tf_n_mc is not None:
Expand All @@ -152,30 +150,47 @@ def updates(self, obj_n_mc=None, tf_n_mc=None, obj_optimizer=adagrad_window, tes
obj_n_mc=obj_n_mc,
obj_optimizer=obj_optimizer,
more_obj_params=more_obj_params,
more_replacements=more_replacements
more_replacements=more_replacements,
total_grad_norm_constraint=total_grad_norm_constraint
)
resulting_updates.update(more_updates)
return resulting_updates

def add_test_updates(self, updates, tf_n_mc=None, test_optimizer=adagrad_window,
more_tf_params=None, more_replacements=None):
more_tf_params=None, more_replacements=None,
total_grad_norm_constraint=None):
if more_tf_params is None:
more_tf_params = []
if more_replacements is None:
more_replacements = dict()
tf_z = self.get_input(tf_n_mc)
tf_target = self(tf_z, more_tf_params=more_tf_params)
tf_target = theano.clone(tf_target, more_replacements, strict=False)
grads = pm.updates.get_or_compute_grads(tf_target, self.obj_params + more_tf_params)
if total_grad_norm_constraint is not None:
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
updates.update(
test_optimizer(
tf_target,
grads,
self.test_params +
more_tf_params))

def add_obj_updates(self, updates, obj_n_mc=None, obj_optimizer=adagrad_window,
more_obj_params=None, more_replacements=None):
more_obj_params=None, more_replacements=None,
total_grad_norm_constraint=None):
if more_obj_params is None:
more_obj_params = []
if more_replacements is None:
more_replacements = dict()
obj_z = self.get_input(obj_n_mc)
obj_target = self(obj_z, more_obj_params=more_obj_params)
obj_target = theano.clone(obj_target, more_replacements, strict=False)
grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
if total_grad_norm_constraint is not None:
grads = pm.total_norm_constraint(grads, total_grad_norm_constraint)
updates.update(
obj_optimizer(
obj_target,
grads,
self.obj_params +
more_obj_params))
if self.op.RETURNS_LOSS:
Expand All @@ -189,8 +204,9 @@ def get_input(self, n_mc):
def step_function(self, obj_n_mc=None, tf_n_mc=None,
obj_optimizer=adagrad_window, test_optimizer=adagrad_window,
more_obj_params=None, more_tf_params=None,
more_updates=None, more_replacements=None, score=False,
fn_kwargs=None):
more_updates=None, more_replacements=None,
total_grad_norm_constraint=None,
score=False, fn_kwargs=None):
R"""Step function that should be called on each optimization step.

Generally it solves the following problem:
Expand All @@ -215,6 +231,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
Add custom params for test function optimizer
more_updates : `dict`
Add custom updates to resulting updates
total_grad_norm_constraint : `float`
Bounds gradient norm, prevents exploding gradient problem
score : `bool`
calculate loss on each step? Defaults to False for speed
fn_kwargs : `dict`
Expand All @@ -236,7 +254,8 @@ def step_function(self, obj_n_mc=None, tf_n_mc=None,
more_obj_params=more_obj_params,
more_tf_params=more_tf_params,
more_updates=more_updates,
more_replacements=more_replacements)
more_replacements=more_replacements,
total_grad_norm_constraint=total_grad_norm_constraint)
if score:
step_fn = theano.function(
[], updates.loss, updates=updates, **fn_kwargs)
Expand Down
10 changes: 6 additions & 4 deletions pymc3/variational/stein.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from theano import theano, tensor as tt
from pymc3.variational.test_functions import rbf
from pymc3.theanof import memoize
from pymc3.theanof import memoize, floatX

__all__ = [
'Stein'
]


class Stein(object):
def __init__(self, approx, kernel=rbf, input_matrix=None):
def __init__(self, approx, kernel=rbf, input_matrix=None, temperature=1):
self.approx = approx
self.temperature = floatX(temperature)
self._kernel_f = kernel
if input_matrix is None:
input_matrix = tt.matrix('stein_input_matrix')
Expand All @@ -22,8 +23,9 @@ def grad(self):
t = self.approx.normalizing_constant
Kxy, dxkxy = self.Kxy, self.dxkxy
dlogpdx = self.dlogp # Normalized
n = self.input_matrix.shape[0].astype('float32')
svgd_grad = (tt.dot(Kxy, dlogpdx) + dxkxy/t) / n
n = floatX(self.input_matrix.shape[0])
temperature = self.temperature
svgd_grad = (tt.dot(Kxy, dlogpdx)/temperature + dxkxy/t) / n
return svgd_grad

@property
Expand Down