Beyond Gradient Descent

For finding local minima of differentiable functions, the work-horse algorithm of machine learning is gradient descent.

Gradient descent is really a special case of a wider family of non-linear optimization algorithms, and recognizing this fact allows us to address constraints, consider alternative metrics, and be more explicit about the approximations we use and the sequence of subproblems we solve.

The most basic form of gradient descent, where we seek to minimize an objective function f ⁣:XRf \colon \mathcal{X} \to \mathbb{R} over decision variable xXx \in \mathcal{X}, is

xt+1=xtηf(xt),\begin{equation} \label{eqn-1}\tag{1} x_{t+1} = x_{t} - \eta \nabla f(x_{t}), \end{equation}

where η>0\eta > 0 is some step-size that we must choose.

This equation is the solution that minimizes a local approximation of ff plus a regularization term DD that penalizes us for choosing xx that results in large updates. That is, gradient descent generates a sequence of candidate solutions xtx_{t} by solving a sequence of subproblems of the form

xt+1=argminx[f^t(x)+Dt(x)].\begin{equation} \label{eqn-2}\tag{2} x_{t+1} = \underset{x}{\rm argmin}\bigg[\hat{f}_{t}(x) + D_{t}(x) \bigg]. \end{equation}

Specifically, Eq. (1) solves

xt+1=argminx[f(xt)+f(xt),xxt1ηf^t(x)+12ηxxt22Dt(x)].\begin{equation*} x_{t+1} = \underset{x}{\rm argmin}\bigg[\underbrace{f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle \vphantom{\frac{1}{\eta}}}_{\hat{f}_{t}(x)} + \underbrace{\frac{1}{2 \eta} \big\| x - x_{t}\big\|^{2}_{2}}_{D_{t}(x)} \bigg]. \end{equation*}

There are a number of generalizations we can make to Eq. (2). For example,

In the following sections, I give examples of each of these generalizations.

Pseudo-Riemannian Geometry

Let us consider linear approximations of the form

f^t(x)=f(xt)+f(xt),xxt,\begin{equation*} \hat{f}_{t}(x) = f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle, \end{equation*}

and distances DtD_{t} of the form

Dt(x)=12(xxt)Gt(xxt),\begin{equation} \label{eqn-3}\tag{3} D_{t}(x) = \frac{1}{2} (x - x_{t})^{\top} G_{t} (x - x_{t}), \end{equation}

where GtG_{t} is a (pseudo)metric on the tangent space of xx.

Under these conditions, it follows that

xt+1=xtGtf(xt),\begin{equation} \label{eqn-4}\tag{4} x_{t+1} = x_{t} - G^{\dagger}_{t} \nabla f(x_{t}), \end{equation}

where GtG^{\dagger}_{t} is the Moore-Penrose (pseudo)inverse of GtG_{t}. This update rule is commonly known as “preconditioned” gradient descent.

Importantly, these conditions correspond to the limit of updates in continuous-time for the true objective ff and any DtD_{t} with positive-semidefinite Hessian, wherein

[f(x)+D˙t(x)](x=xt)=0\begin{equation*} \nabla \bigg[f(x) + \dot{D}_{t}(x) \bigg]_{(x=x_{t})} = 0 \end{equation*}

(Amid and Warmuth (2020)). To see this, let

Gt=2x2Dt(x)    D˙t(x)=x˙2x2Dt(x)=Gtx˙.\begin{equation*} G_{t} = \frac{\partial^{2}}{\partial x^{2}} D_{t}(x) \quad \implies \quad \nabla \dot{D}_{t}(x) = \dot{x} \frac{\partial^{2}}{\partial x^{2}} D_{t}(x) = G_{t} \dot{x}. \end{equation*}

It follows that

x˙t=Gtf(xt).\begin{align*} \dot{x}_{t} = - G_{t}^{\dagger} \nabla f(x_{t}). \end{align*}

Note that we use the dot (e.g., over xtx_{t} and DtD_{t}) to represent a time-derivative.


If our objective function ff is convex, we are guaranteed that the Hessian of ff, which we denote as HH, is positive semi-definite. That is,

Ht=2x2f(xt)0.\begin{equation*} H_{t} = \frac{\partial^{2}}{\partial x^{2}} f(x_{t}) \succeq 0. \end{equation*}

When we substitute HtH_{t} for GtG_{t} in Eq. (3), then the resulting subproblem (Eq. (2)) is simply a minimization over a 2nd-order Taylor approximation of ff at xtx_{t}. The resulting update is known as Gauss-Newton and is frequently used for solving least-squares problems (since, as a sum-of-squares, ff is convex).


We may also choose values of GG that derive from spaces other than X\mathcal{X}. For example, when xx parameterizes a probability distribution ρ(y;x)\rho(y ; x), from which the objective function ff derives, we may measure the magnitude of an update from xtx_{t} to xx by the relative entropy from ρt\rho_{t} to ρ\rho.

Dt(x)=1ηDKL(ρ(y;x)ρ(y;xt))=1ηYρ(y;x)logρ(y;x)ρ(y;xt)dy.\begin{equation*} D_{t}(x) = \frac{1}{\eta} D_{\rm KL}\Big(\rho(y ; x) \parallel \rho(y ; x_{t})\Big) = \frac{1}{\eta} \int_{\mathcal{Y}} \rho(y ; x) \log \frac{\rho(y; x)}{\rho(y ; x_{t})} {\rm d}y. \end{equation*}

This choice yields the update rule for Fisher-Rao natural gradient descent:

xt+1=xtηFtf(xt),\begin{equation} \label{eqn-5}\tag{5} x_{t+1} = x_{t} - \eta F^{\dagger}_{t} \nabla f(x_{t}), \end{equation}

where FtF_{t} is the Fisher metric tensor

Ft=Yρ(y;xt)2x2logρ(y;x)dy.\begin{equation*} F_{t} = \int_{\mathcal{Y}} \rho(y ; x_{t}) \frac{\partial^{2}}{\partial x^{2}} \log \rho(y ; x) {\rm d}y. \end{equation*}

One way to understand the effect of the metric G=FG = F is that it (un)warps the space of marginal updates dx{\rm d}x, such that updates of the same induced magnitude contain the same amount of marginal information about ρ(xt+dx)\rho(x_{t} + {\rm d}x).

In continuous time, Fisher-Rao natural gradient descent optimally approximates replicator dynamics (used to model evolution by natural selection) and continuous Bayesian inference (Raab et. al. (2022)). We may also choose different divergences in ρ\rho space, yielding other forms of “natural gradient descent” (Nurbekyan et. al. (2022)).

Projective Geometry

Vanilla gradient descent does not require a Euclidean metric.

As a counterexample, let G=f fff,G = \frac{\nabla f ~ \nabla^{\top} f}{\nabla^{\top} f \cdot \nabla f}, where the numerator is an outer product and the denominator an inner product of f\nabla f with itself.

It follows that G=GandGf=f.G^{\dagger} = G \quad \text{and} \quad G^{\dagger} \nabla f = \nabla f. Therefore, the update rule that solves (f^t+Dt)=0\nabla(\hat{f}_{t} + D_{t}) = 0 (Eq. (2)) is vanilla gradient descent: xt+1xt=ηGf(xt)=ηf(xt).x_{t+1} - x_{t} = -\eta G^{\dagger} \nabla f(x_{t}) = -\eta \nabla f(x_{t}).

Multi-iterate Methods

So far, we have only discussed updates that rely on at most one previous iterate of xtx_{t} (or λt\lambda_{t} or μt\mu_{t}). We may also choose update penalties DtD_{t}, that rely on additional history of the solution candidates (xt,xt1,...)(x_{t}, x_{t-1}, ...).


An example of a multi-iterate approach is given by the classical “Momentum” variant of gradient descent. Consider, for example, a penalty term DtD_{t} that quadratically penalizes step-sizes with a linear term that encourages steps in same continued direction, as in

Dt(x)=1η(12xxt22ξxxt,xtxt1),\begin{equation*} D_{t}(x) = \frac{1}{\eta} \bigg( \frac{1}{2} \big\| x - x_{t} \big\|_{2}^{2} - \xi \big\langle x - x_{t}, x_{t} - x_{t-1} \big\rangle \bigg), \end{equation*}

where ξ(0,1)\xi \in (0, 1). Alternatively, consider a quadratic penalty for the difference between the proposed update and the previous update with decay ξ\xi, as in

Dt(x)=12η(xxt)ξ(xtxt1)22.\begin{equation*} D_{t}(x) = \frac{1}{2 \eta} \big\| (x - x_{t}) - \xi (x_{t} - x_{t-1}) \big\|_{2}^{2}. \end{equation*}

In either case,

Dt(x)=1η((xxt)ξ(xtxt1)).\begin{equation*} \nabla D_{t}(x) = \frac{1}{\eta} \bigg((x - x_{t}) - \xi (x_{t} - x_{t-1})\bigg). \end{equation*}

If we solve Eq. (2) by setting (f^t+Dt)=0\nabla (\hat{f}_{t} + D_{t}) = 0 for either choice of DtD_{t}, the resulting update rule is

xt+1=xt+ξ(xtxt1)ηf^t,\begin{equation*} x_{t+1} = x_{t} + \xi(x_{t} - x_{t-1}) - \eta \nabla \hat{f}_{t}, \end{equation*}

which can be decomposed, with an additional state variable vv, to

vt+1=ξvtηf^t,xt+1=xt+vt+1.\begin{equation} \label{eqn-6}\tag{6} \begin{aligned} v_{t+1} &= \xi v_{t} - \eta \nabla \hat{f}_{t}, \\ x_{t+1} &= x_{t} + v_{t+1}. \end{aligned} \end{equation}

This update rule is the classical “Momentum” algorithm (Botev et. al. (2016)) with decay parameterized by ξ\xi.

Nth-Order Regularization

There is yet another update penalty DtD_{t} that may be used to derive the Momentum update rule, up to a reparameterization of constants (ξ1,ξ2)ξ(\xi_{1}, \xi_{2}) \mapsto \xi. Specifically, let

Dt(x)=ξ12ηxxt2+ξ22η(xxt)v(xtxt1)vt2.\begin{equation*} D_{t}(x) = \frac{\xi_{1}}{2\eta} \big\| x - x_{t} \big\|^{2} + \frac{\xi_{2}}{2\eta} \big\| \underbrace{(x - x_{t})}_{v} - \underbrace{(x_{t} - x_{t-1})}_{v_{t}} \big\|^{2}. \end{equation*}

We can naturally extend this distance penalty to account for higher-order derivatives as

Dt(x)=ξ12ηxxt2+ξ22ηvvt2+...+ξn2ηx(n)2,\begin{equation*} D_{t}(x) = \frac{\xi_{1}}{2\eta} \big\|x - x_{t} \big\|^{2} + \frac{\xi_{2}}{2\eta} \big\| v - v_{t} \big\|^{2} + ... + \frac{\xi_{n}}{2\eta} \big\| x^{(n)} \big\|^{2}, \end{equation*}

where we use the state variable x(n)x^{(n)} to represent the nnth time-derivative of xx, defined empirically by the recursive formula

x(n)=x(n1)xt(n1).\begin{equation*} x^{(n)} = x^{(n-1)} - x_{t}^{(n-1)}. \end{equation*}

By this choice of DtD_{t}, if we solve Eq. (2) by setting (f^t+Dt)=0\nabla (\hat{f}_{t} + D_{t}) = 0, we obtain the update rules

xt+1(n1)=1ξn+ξn1(ξnxt(n1)k=1n2ξkx(k)ηf^t),xt+1(n2)=xt(n2)+xt+1(n1),xt+1=xt+vt+1.\begin{equation} \label{eqn-7}\tag{7} \begin{aligned} x^{(n-1)}_{t+1} &= \frac{1}{\xi_{n} + \xi_{n-1}} \bigg( \xi_{n} x^{(n-1)}_{t} - \sum_{k=1}^{n-2} \xi_{k} x^{(k)} - \eta \nabla \hat{f}_{t} \bigg), \\ x^{(n-2)}_{t+1} &= x^{(n-2)}_{t} + x^{(n-1)}_{t+1}, \\ &\vdots \\ x_{t+1} &= x_{t} + v_{t+1}. \end{aligned} \end{equation}

Surrogate Approximation

As xx varies from xtx_{t} to xt+1x_{t+1}, the minimum average error of a local approximation of f(x)f(x) along this path can be reduced by choosing an intermediate point x~t\tilde{x}_{t}, somewhere between xtx_{t} and xtx_{t}, instead of xtx_{t}, about which to approximate the local behavior of ff.

Nesterov Acceleration

When using Eq. (6), one “zero-cost” (i.e., if we decouple candidate solutions xtx_{t} from the points at which we query f^\hat{f}) choice of x~t\tilde{x}_{t} is given by

x~t=xt+ξvt,\begin{equation*} \tilde{x}_{t} = x_{t} + \xi v_{t}, \end{equation*}

such that

f^t(x)=f(x~t)+f(x~t),xx~t.\begin{equation} \label{eqn-8}\tag{8} \hat{f}_{t}(x) = f(\tilde{x}_{t}) + \big\langle \nabla f(\tilde{x}_{t}), x - \tilde{x}_{t} \big\rangle. \end{equation}

Substituting Eq. (8) into Eq. (6), we obtain Nesterov’s variant of gradient descent with momentum:

vt+1=ξvtηf(xt+ξvt),xt+1=xt+vt+1.\begin{equation*} \begin{aligned} v_{t+1} &= \xi v_{t} - \eta \nabla f(x_{t} + \xi v_{t}), \\ x_{t+1} &= x_{t} + v_{t+1}. \end{aligned} \end{equation*}

Dealing with Constraints

Consider the basic (constrained) optimization problem

minimizef(x)subject tog(x)0,h(x)=0,\begin{equation} \begin{array}{lll} {\rm minimize} &f(x) \\ {\rm subject~to} &g(x) \preceq 0, \\ &h(x) = 0, \end{array} \end{equation}

for vector-valued gg and hh.

In terms of the corresponding Lagrangian L\mathcal{L}, the primal problem may be written

minimizex(maxλ,μL(x,λ,μ))subject toλ0,\begin{equation} \label{eqn-10}\tag{10} \begin{array}{lll} \underset{x}{\rm minimize} \bigg( \underset{\lambda, \mu}{\rm max}& \mathcal{L}(x, \lambda, \mu) \bigg) \\ {\rm subject~to} & \lambda \succeq 0, \end{array} \end{equation}


L(x,λ,μ)=f(x)+λg(x)+μh(x).\begin{equation} \mathcal{L}(x, \lambda, \mu) = f(x) + \lambda^{\top} g(x) + \mu^{\top} h(x). \end{equation}

Intuitively, given that λ\lambda and μ\mu are chosen adversarially (i.e., after xx is fixed), the problem is to choose xx such that the constraints on g(x)g(x) and h(x)h(x) are satisfied (otherwise, the objective can be made unboundedly positive by an adversary’s choice of λ,μ\lambda, \mu). Within this “feasible set” of xx values, ff should be minimized.

Augmented Lagrangian

We may iteratively approximate Eq. (10) with a sequence of subproblems in the form of Eq. (2):

xt+1=argminx[maxλ0,μ (L^t(x,λ,μ)+Dt(x)Rt(λ)St(μ))],\begin{equation*} x_{t+1} = \underset{x}{\rm argmin}\bigg[ \underset{\lambda \succ 0, \mu}{\rm max} ~ \bigg( \hat{\mathcal{L}}_{t}(x, \lambda, \mu) + D_{t}(x) - R_{t}(\lambda) - S_{t}(\mu) \bigg) \bigg], \end{equation*}

introducing update penalties for state variables xx, λ\lambda, and μ\mu.

This approach is known as an “Augmented Lagrangian Method” (Nocedal & Wright (2006); ch. 17).


For the simple choices

Dt(x)=12ηxxt22,η>0,Rt(λ)=12αλλt22,α>0,St(μ)=12βμμt22,β>0,\begin{align*} D_{t}(x) &= \frac{1}{2\eta}\big\|x - x_{t}\big\|_{2}^{2}, \quad \eta > 0, \\ R_{t}(\lambda) &= \frac{1}{2\alpha}\big\|\lambda - \lambda_{t}\big\|_{2}^{2}, \quad \alpha > 0, \\ S_{t}(\mu) &= \frac{1}{2\beta}\big\|\mu - \mu_{t}\big\|_{2}^{2}, \quad \beta > 0, \end{align*}

and linear local approximations

f^t(x)=f(xt)+f(xt),xxt,g^t(x)=g(xt)+g(xt),xxt,h^t(x)=h(xt)+h(xt),xxt,L^t(x,λ,μ)=f^t(x)+λg^t(x)+μh^t(x),\begin{align*} \hat{f}_{t}(x) &= f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle, \\ \hat{g}_{t}(x) &= g(x_{t}) + \big\langle \nabla g(x_{t}), x - x_{t} \big\rangle, \\ \hat{h}_{t}(x) &= h(x_{t}) + \big\langle \nabla h(x_{t}), x - x_{t} \big\rangle, \\ \hat{\mathcal{L}}_{t}(x, \lambda, \mu) &= \hat{f}_{t}(x) + \lambda^{\top} \hat{g}_{t}(x) + \mu^{\top} \hat{h}_{t}(x), \end{align*}

we have the update rules

xt+1=xtη(f(xt)+λt+1g(xt)+μt+1h(xt)),λt+1=max(λt+αg^(xt+1),0),μt+1=μt+βh^(xt+1),\begin{equation} \label{eqn-12}\tag{12} \begin{aligned} x_{t+1} &= x_{t} - \eta \bigg( \nabla f(x_{t}) + \lambda_{t+1}^{\top} \nabla g(x_{t}) + \mu^{\top}_{t+1} \nabla h(x_{t})\bigg), \\ \lambda_{t+1} &= \max\bigg( \lambda_{t} + \alpha \hat{g}(x_{t+1}), 0\bigg), \\ \mu_{t+1} &= \mu_{t} + \beta \hat{h}(x_{t+1}), \end{aligned} \end{equation}

where max\max is taken element-wise and the gradients are taken with respect to xx (and these gradient components remain uncontracted with the dimensions of λ\lambda or μ\mu).

The solution (xt+1,λt+1,μt+1)(x_{t+1}, \lambda_{t+1}, \mu_{t+1}) to Eq. (12) is a fixed point of a recursive map xt+1(λt+1,μt+1)xt+1x_{t+1} \mapsto (\lambda_{t+1}, \mu_{t+1}) \mapsto x_{t+1}. Denoting iterations of this map by the upper index kk, by change of variables vt+1=xt+1xtv_{t+1} = x_{t+1} - x_{t}, we define the piece-wise linear map

vt+1k+1η(f(xt) +max(0,λt+α(g(xt)+g(xt),vt+1k))g(xt) +(μt+β(h(xt)+h(xt),vt+1k))h(xt))\begin{equation} \label{eqn-13}\tag{13} v^{k+1}_{t+1} \leftarrow -\eta \left( \begin{aligned} \nabla f(x_{t}) &~+ \\ \max\Big(0, \lambda_{t} + \alpha \big( g(x_{t}) + \langle \nabla g(x_{t}), v^{k}_{t+1} \rangle \big)\Big) \nabla g(x_{t}) &~+ \\ \Big( \mu_{t} + \beta \big( h(x_{t}) + \langle \nabla h(x_{t}), v^{k}_{t+1} \rangle \big)\Big) \nabla h(x_{t})& \end{aligned} \right) \end{equation}

with fixed point vt+1v_{t+1}^{\star}, such that

λt+1=max(λt+α(g(xt)+g(xt),vt+1));μt+1=μt+β(h(xt+h(xt),vt+1).\begin{align*} \lambda_{t+1} &= \max\Big( \lambda_{t} + \alpha \big( g(x_{t}) + \langle \nabla g(x_{t}), v_{t+1}^{\star} \rangle \big) \Big); \\ \mu_{t+1} &= \mu_{t} + \beta \big( h(x_{t} + \langle \nabla h(x_{t}), v_{t+1}^{\star} \rangle \big). \end{align*}

Claim: The recursive mapping Eq. (13) converges, upon repeated iteration, when α\alpha and β\beta are bounded as

ηα<g(xt)2andηβ<h(xt)2.\begin{equation*} \eta \alpha < |\nabla g(x_{t}) |^{2} \quad \text{and} \quad \eta \beta < |\nabla h(x_{t}) |^{2}. \end{equation*}

The updates of Eq. (12) comprise a primal-dual algorithm, which maintains iterates for the dual variables λ\lambda and μ\mu in addition to the primal variable xx. This basic approach may be extended by substituting different regularization terms or local function approximations.