Gradient-based Hyperparameter Optimization with Reversible

Gradient-based Hyperparameter Optimization with Reversible ... fft2 transpose tensordot triu ... Gradient-based Hyperparameter Optimization with Rever...

84 downloads 463 Views 2MB Size
Gradient-based Hyperparameter Optimization with Reversible Learning

Dougal Maclaurin, David Duvenaud, Ryan Adams

Motivation

• Hyperparameters are everywhere • sometimes hidden! • Gradient-free optimization is hard • Validation loss is a function of hyperparameters • Why not take gradients?

Optimizing optimization xfinal = SGD (xinit , learn rate, momentum, ∇Loss(x, reg, Data))

Initial weights Meta-iteration 1

Weigh t1

ht

ig We

2

Optimizing optimization xfinal = SGD (xinit , learn rate, momentum, ∇Loss(x, reg, Data))

Initial weights Meta-iteration 1 Meta-iteration 2

Weigh t1

ht

ig We

2

Optimizing optimization xfinal = SGD (xinit , learn rate, momentum, ∇Loss(x, reg, Data))

Initial weights Meta-iteration 1 Meta-iteration 2 Meta-iteration 3

Weigh t1

ht

ig We

2

A pretty scary function to differentiate J = Loss (Dval , SGD (xinit , α, β, ∇Loss(Dtrain , x, reg))) Stochastic Gradient Descent 1: input: initial x1 , decay β, learning rate α, reg-

ularization params θ, loss function L(x, θ, t) 2: initialize v1 = 0 3: for t = 1 to T do 4: gt = ∇x L(xt , θ, t) . evaluate gradient 5: vt+1 = βt vt − (1 − βt )gt . update velocity 6: xt+1 = xt + αt vt . update position 7: output trained parameters xT

• Each gradient evaluation in SGD requires forward

and backprop through model • Entire learning procedure looks like a 1000-layer deep net

Autograd: Automatic Differentiation

• github.com/HIPS/autograd • Works with (almost) arbitrary Python/Numpy code • Can take gradients of gradients (of gradients...)

Autograd Example

import autograd.numpy as np import matplotlib.pyplot as plt from autograd import grad # Taylor approximation to sin function def fun(x): curr = x ans = curr for i in xrange(1000): curr = - curr * x**2 / ((2*i+3)*(2*i+2)) ans = ans + curr if np.abs(curr) < 0.2: break return ans d_fun = grad(fun) dd_fun = grad(d_fun) x = np.linspace(-10, 10, 100) plt.plot(x, map(fun, x), x, map(d_fun, x), x, map(dd_fun, x))

Most Numpy functions implemented Complex & Fourier

Array

Misc

imag conjugate angle real_if_close real fabs fft fftshift fft2 ifftn ifftshift ifft2 ifft

atleast_1d logsumexp atleast_2d where atleast_3d einsum full sort repeat partition split clip concatenate outer roll dot transpose tensordot reshape rot90 squeeze ravel expand_dims

Linear Algebra

Stats

inv norm det eigh solve trace diag tril triu

std mean var prod sum cumsum

Technical Challenge: Memory

• Reverse-mode differentiation needs access to entire

learning trajectory • i.e. 107 parameters × 105 training iterations • Only need access in reverse order... • Could we recompute the learning trajectory backwards by running reverse SGD?

SGD with momentum is reversible Forward update rule: xt+1 ← xt + αvt vt+1 ← βvt − ∇L (xt+1 ) Reverse update rule: vt ← (vt+1 + ∇L (xt+1 )) /β xt ← xt+1 − αvt

Reverse-mode differentiation of SGD Stochastic Gradient Descent Reverse-Mode Gradient of SGD 1: input:

initial x1 , decays β, learning rates α, loss function L(x, θ, t)

2: initialize v1 = 0 3: for t = 1 to T do 4: gt = ∇x L(xt , θ, t) 5: vt+1 = βt vt − (1 − βt )gt 6: xt+1 = xt + αt vt 7: output trained parameters xT

. evaluate gradient . update velocity . update position

1: 2: 3: 4: 5: 6: 7: 8: 9: 10: 11: 12: 13: 14:

input: xT , vT , β, α, train loss L(x, θ, t), loss f (x) initialize dv = 0, dθ = 0, dαt = 0, dβ = 0 initialize dx = ∇x f (xT ) for t = T counting down to 1 do dαt = dxT vt xt−1 = xt − αt vt

. downdate position

gt = ∇x L(xt , θ, t) vt−1 = [vt + (1 − βt )gt ]/βt

. evaluate gradient . downdate velocity

dv = dv + αt dx dβt = dvT (vt + gt ) dx = dx − (1 − βt )dv∇x ∇x L(xt , θ, t) dθ = dθ − (1 − βt )dv∇θ ∇x L(xt , θ, t) dv = βt dv output gradient of f (xT ) w.r.t x1 , v1 , β, α and θ

• Outputs gradients with respect to all hypers. • Reversing SGD avoids storing learning trajectory

Naive reversal

t2

Weigh t1

W

h eig

Naive reversal ... Fails!

t2

Weigh t1

W

h eig

A closer look at reverse SGD Forward update rule: xt+1 ← xt + αvt vt+1 ← βvt − ∇L (xt+1 ) Destroys log2 β bits per parameter per iteration Reverse update rule: vt ← (vt+1 + ∇L (xt+1 )) /β xt ← xt+1 − αvt Needs log2 β bits per parameter per iteration

How to store the lost bits? • Switch to fixed-precision for exact addition • Express β as a rational number • push/pop remainders from an information buffer def rational_multiply(x, n, d, bitbuffer): bitbuffer.push(x % d, d) x /= d x *= n x += bitbuffer.pop(n) return x • 200X memory savings when β = 0.9 • Now we have scalable gradients of hypers, only twice

as slow as original!

Part 2: A Garden of Delights

Learning rate gradients 7 Layer Layer Layer Layer

Learning rate

6 5

Learning rate radient

∂Loss (Dval , xinit , α, β, Dtrain , reg) 4 ∂α 3 0

0

1 2 3 4

2 1 0 0

20

20

40 60 Schedule index

40 60 Schedule index

80

100

80

100

Optimized learning rates • Used SGD to optimize SGD • 4-layer NN on MNIST • Top layer learns early on; slowdown at end 7 Layer Layer Layer Layer

Learning rate

6 5 4 3

1 2 3 4

2 1 0 0

20

40 60 Schedule index

80

100

Optimizing initialization scales ∂Loss (Dval , xinit , α, β, Dtrain , reg) ∂xinit

Initial scale

Biases 0.8 0.7 0.6 0.5 0.4 0.3 0.2 0.1 0.0 0

Weights 0.25

Layer Layer Layer Layer

1 2 3 4

0.20 p

1= 50 0.10 p

1= 784 10 20 30 40 Meta iteration

50

0.00 0

10 20 30 40 50 Meta iteration

Optimizing regularization ∂Loss (Dval , xinit , α, β, Dtrain , reg) ∂reg Optimized L2 hypers for each weight in logistic regression: 0.03

0

Optimizing architecture • Architecture = tying weights or setting to them zero • i.e. convnets, recurrent nets, multi-task • Trying be enforced by L2 regularization

Rotated

Original

• L2 regularization is differentiable

Optimizing regularization Matrices enforce weight sharing between tasks Input weights Separate networks Tied weights Learned sharing

Middle weights

Output weights

Train Test error error 0.61

1.34

0.90

1.25

0.60

1.13

Optimizing training data ∂Loss (Dval , xinit , α, β, Dtrain , reg) ∂Dtrain • Training set of size 10 with fixed labels on MNIST • Started from blank images

0.33

0 -0.30

Training loss

Limitations: Chaotic learning dynamics

Gradient

0

1.0

1.5

2.0

1.0

1.5

2.0

0

Learning rate

Summary

• Can compute gradients of learning procedures • Reversing learning saves memory • Can optimize thousands of hyperparameters

Summary

• Can compute gradients of learning procedures • Reversing learning saves memory • Can optimize thousands of hyperparameters

Thanks!