Home

Beyond Gradient Descent

For finding local minima of differentiable functions, the work-horse algorithm of machine learning is gradient descent.

Gradient descent is really a special case of a wider family of non-linear optimization algorithms, and recognizing this fact allows us to address constraints, consider alternative metrics, and be more explicit about the approximations we use and the sequence of subproblems we solve.

The most basic form of gradient descent, where we seek to minimize an objective function f ⁣:XRf \colon \mathcal{X} \to \mathbb{R} over decision variable xXx \in \mathcal{X}, is

xt+1=xtηf(xt),\begin{equation} \label{eqn-1}\tag{1} x_{t+1} = x_{t} - \eta \nabla f(x_{t}), \end{equation}

where η>0\eta > 0 is some step-size that we must choose.

This equation is the solution that minimizes a local approximation of ff plus a regularization term DD that penalizes us for choosing xx that results in large updates. That is, gradient descent generates a sequence of candidate solutions xtx_{t} by solving a sequence of subproblems of the form

xt+1=argminx[f^t(x)+Dt(x)].\begin{equation} \label{eqn-2}\tag{2} x_{t+1} = \underset{x}{\rm argmin}\bigg[\hat{f}_{t}(x) + D_{t}(x) \bigg]. \end{equation}

Specifically, Eq. (1) solves

xt+1=argminx[f(xt)+f(xt),xxt1ηf^t(x)+12ηxxt22Dt(x)].\begin{equation} \label{eqn-3}\tag{3} x_{t+1} = \underset{x}{\rm argmin}\bigg[\underbrace{f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle \vphantom{\frac{1}{\eta}}}_{\hat{f}_{t}(x)} + \underbrace{\frac{1}{2 \eta} \big\| x - x_{t}\big\|^{2}_{2}}_{D_{t}(x)} \bigg]. \end{equation}

Exercise: Prove this by solving (f^t(x)+Dt(x))=0\nabla \big(\hat{f}_{t}(x) + D_{t}(x)\big) = 0.

There are a number of generalizations of Eq. (2) beyond Eq. (3). For example,

In the following sections, I give examples of each of these generalizations.

Pseudo-Riemannian Geometry

Let us consider linear approximations of the form

f^t(x)=f(xt)+f(xt),xxt,\begin{equation} \label{eqn-4}\tag{4} \hat{f}_{t}(x) = f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle, \end{equation}

and distances DtD_{t} of the form

Dt(x)=12(xxt)Gt(xxt),\begin{equation} \label{eqn-5}\tag{5} D_{t}(x) = \frac{1}{2} (x - x_{t})^{\top} G_{t} (x - x_{t}), \end{equation}

where GtG_{t} is a (pseudo)metric on the tangent space of xx.

Under these conditions, it follows that

xt+1=xtGtf(xt),\begin{equation} \label{eqn-6}\tag{6} x_{t+1} = x_{t} - G^{\dagger}_{t} \nabla f(x_{t}), \end{equation}

where GtG^{\dagger}_{t} is the Moore-Penrose (pseudo)inverse of GtG_{t}. This update rule is commonly known as “preconditioned” gradient descent.

Importantly, these conditions correspond to the limit of updates in continuous-time for the true objective ff and any DtD_{t} with positive-semidefinite Hessian, wherein

[f(x)+D˙t(x)](x=xt)=0\begin{equation*} \nabla \bigg[f(x) + \dot{D}_{t}(x) \bigg]_{(x=x_{t})} = 0 \end{equation*}

(Amid and Warmuth (2020)). To see this, let

Gt=2x2Dt(x)    D˙t(x)=x˙2x2Dt(x)=Gtx˙.\begin{equation*} G_{t} = \frac{\partial^{2}}{\partial x^{2}} D_{t}(x) \quad \implies \quad \nabla \dot{D}_{t}(x) = \dot{x} \frac{\partial^{2}}{\partial x^{2}} D_{t}(x) = G_{t} \dot{x}. \end{equation*}

It follows that

x˙t=Gtf(xt).\begin{align*} \dot{x}_{t} = - G_{t}^{\dagger} \nabla f(x_{t}). \end{align*}

Note that we use the dot (e.g., over xtx_{t} and DtD_{t}) to represent a time-derivative.

Gauss-Newton

If our objective function ff is convex, we are guaranteed that the Hessian of ff, which we denote as HH, is positive semi-definite. That is,

Ht=2x2f(xt)0.\begin{equation*} H_{t} = \frac{\partial^{2}}{\partial x^{2}} f(x_{t}) \succeq 0. \end{equation*}

When we substitute HtH_{t} for GtG_{t} in Eq. (5) (i.e., when we choose DtD_{t} based on the Hessian of ff), then the resulting subproblem (Eq. (2)) is simply a minimization over a 2nd-order Taylor approximation of ff at xtx_{t}. The resulting update is known as Newton’s Method.

When used for solving least-squares problems (for residuals ri(x)r_{i}(x)), the Hessian may be calculated as

f(x)=iri(x)ri(x)xf(x)=2iri(x)ri(x)xHt=2x2f(x)=2(iri(x)xri(x)x+iri(x)2ri(x)x2)\begin{align*} f(x) &= \sum_{i} r_{i}(x) r_{i}(x) \\ \frac{\partial}{\partial x} f(x) &= 2 \sum_{i} r_{i}(x) \frac{\partial r_{i}(x)}{\partial x} \\ H_{t} = \frac{\partial^{2}}{\partial x^{2}} f(x) &= 2 \bigg( \sum_{i} \frac{\partial r_{i}(x)}{\partial x} \frac{\partial r_{i}(x)}{\partial x} + \sum_{i} r_{i}(x) \frac{\partial^{2} r_{i}(x)}{\partial x^{2}} \bigg) \end{align*}

Ignoring the second term on the last line and retaining only the first as an approximation for HtH_{t} yields the Gauss-Newton algorithm. For linear models, the second term is identically zero.

Fisher-Rao

We may also choose values of GG that derive from spaces other than X\mathcal{X}. For example, when xx parameterizes a probability distribution ρ(y;x)\rho(y ; x), from which the objective function ff derives, we may measure the magnitude of an update from xtx_{t} to xx by the relative entropy from ρt\rho_{t} to ρ\rho.

Dt(x)=1ηDKL(ρ(y;x)ρ(y;xt))=1ηYρ(y;x)logρ(y;x)ρ(y;xt)dy.\begin{equation*} D_{t}(x) = \frac{1}{\eta} D_{\rm KL}\Big(\rho(y ; x) \parallel \rho(y ; x_{t})\Big) = \frac{1}{\eta} \int_{\mathcal{Y}} \rho(y ; x) \log \frac{\rho(y; x)}{\rho(y ; x_{t})} {\rm d}y. \end{equation*}

This choice yields the update rule for Fisher-Rao natural gradient descent:

xt+1=xtηFtf(xt),\begin{equation} \label{eqn-7}\tag{7} x_{t+1} = x_{t} - \eta F^{\dagger}_{t} \nabla f(x_{t}), \end{equation}

where FtF_{t} is the Fisher metric tensor

Ft=Yρ(y;xt)2x2logρ(y;x)dy.\begin{equation*} F_{t} = \int_{\mathcal{Y}} \rho(y ; x_{t}) \frac{\partial^{2}}{\partial x^{2}} \log \rho(y ; x) {\rm d}y. \end{equation*}

One way to understand the effect of the metric G=FG = F is that it (un)warps the space of marginal updates dx{\rm d}x, such that updates of the same induced magnitude contain the same amount of marginal information about ρ(xt+dx)\rho(x_{t} + {\rm d}x).

In continuous time, Fisher-Rao natural gradient descent optimally approximates replicator dynamics (used to model evolution by natural selection) and continuous Bayesian inference (Raab et. al. (2022)). We may also choose different divergences in ρ\rho space, yielding other forms of “natural gradient descent” (Nurbekyan et. al. (2022)).

Projective Geometry

Vanilla gradient descent does not require a Euclidean metric.

As a counterexample, let G=f fff,G = \frac{\nabla f ~ \nabla^{\top} f}{\nabla^{\top} f \cdot \nabla f}, where the numerator is an outer product and the denominator an inner product of f\nabla f with itself.

It follows that G=GandGf=f.G^{\dagger} = G \quad \text{and} \quad G^{\dagger} \nabla f = \nabla f. Therefore, the update rule that solves (f^t+Dt)=0\nabla(\hat{f}_{t} + D_{t}) = 0 (Eq. (2)) is vanilla gradient descent: xt+1xt=ηGf(xt)=ηf(xt).x_{t+1} - x_{t} = -\eta G^{\dagger} \nabla f(x_{t}) = -\eta \nabla f(x_{t}).

Multi-iterate Methods

So far, we have only discussed updates that rely on at most one previous iterate of xtx_{t} (or λt\lambda_{t} or μt\mu_{t}) to regularize updates with a penalty term DD, thus defining “trust-regions” in which the approximation f^f\hat{f} \approx f is assumed to be valid. We may also choose update penalties that rely on additional history of the solution candidates (xt,xt1,...)(x_{t}, x_{t-1}, ...).

Momentum

An example of a multi-iterate approach is given by the classical “Momentum” variant of gradient descent. Consider, for example, a penalty term DtD_{t} that quadratically penalizes step-sizes with a linear term that encourages steps in same continued direction, as in

Dt(x)=1η(12xxt22ξxxt,xtxt1),\begin{equation*} D_{t}(x) = \frac{1}{\eta} \bigg( \frac{1}{2} \big\| x - x_{t} \big\|_{2}^{2} - \xi \big\langle x - x_{t}, x_{t} - x_{t-1} \big\rangle \bigg), \end{equation*}

where ξ(0,1)\xi \in (0, 1). Alternatively, consider a quadratic penalty for the difference between the proposed update and the previous update with decay ξ\xi, as in

Dt(x)=12η(xxt)ξ(xtxt1)22.\begin{equation*} D_{t}(x) = \frac{1}{2 \eta} \big\| (x - x_{t}) - \xi (x_{t} - x_{t-1}) \big\|_{2}^{2}. \end{equation*}

In either case,

Dt(x)=1η((xxt)ξ(xtxt1)).\begin{equation*} \nabla D_{t}(x) = \frac{1}{\eta} \bigg((x - x_{t}) - \xi (x_{t} - x_{t-1})\bigg). \end{equation*}

If we solve Eq. (2) by setting (f^t+Dt)=0\nabla (\hat{f}_{t} + D_{t}) = 0 for either choice of DtD_{t}, the resulting update rule is

xt+1=xt+ξ(xtxt1)ηf^t,\begin{equation*} x_{t+1} = x_{t} + \xi(x_{t} - x_{t-1}) - \eta \nabla \hat{f}_{t}, \end{equation*}

which can be decomposed, with an additional state variable vv, to

vt+1=ξvtηf^t,xt+1=xt+vt+1.\begin{equation} \label{eqn-8}\tag{8} \begin{aligned} v_{t+1} &= \xi v_{t} - \eta \nabla \hat{f}_{t}, \\ x_{t+1} &= x_{t} + v_{t+1}. \end{aligned} \end{equation}

This update rule is the classical “Momentum” algorithm (Botev et. al. (2016)) with decay parameterized by ξ\xi.

Nth-Order Regularization

There is yet another update penalty DtD_{t} that may be used to derive the Momentum update rule. Specifically, let

Dt(x)=ξ12ηxxt2+ξ22η(xxt)v(xtxt1)vt2.\begin{equation*} D_{t}(x) = \frac{\xi_{1}}{2\eta} \big\| x - x_{t} \big\|^{2} + \frac{\xi_{2}}{2\eta} \big\| \underbrace{(x - x_{t})}_{v} - \underbrace{(x_{t} - x_{t-1})}_{v_{t}} \big\|^{2}. \end{equation*}

We can naturally extend this distance penalty to account for higher-order derivatives as

Dt(x)=ξ12ηxxt2+ξ22ηvvt2+...+ξn2ηx(n1)xt(n1)2=ξ12ηx(1)2+ξ22ηx(2)2+...+ξn2ηx(n)2,\begin{align*} D_{t}(x) &= \frac{\xi_{1}}{2\eta} \big\|x - x_{t} \big\|^{2} + \frac{\xi_{2}}{2\eta} \big\| v - v_{t} \big\|^{2} + ... + \frac{\xi_{n}}{2\eta} \big\| x^{(n-1)} - x^{(n-1)}_{t} \big\|^{2} \\ &= \frac{\xi_{1}}{2\eta} \big\|x^{(1)} \big\|^{2} + \frac{\xi_{2}}{2\eta} \big\| x^{(2)} \big\|^{2} + ... + \frac{\xi_{n}}{2\eta} \big\| x^{(n)} \big\|^{2}, \end{align*}

where we use the state variable x(n)x^{(n)} to represent the nnth time-derivative of xx, defined empirically by the recursive formula

x(n)=x(n1)xt(n1).\begin{equation*} x^{(n)} = x^{(n-1)} - x_{t}^{(n-1)}. \end{equation*}

By this choice of DtD_{t}, if we solve Eq. (2) by setting (f^t+Dt)=0\nabla (\hat{f}_{t} + D_{t}) = 0 at xt+1x_{t+1}, we obtain the equation

ηf^t=k=1n2ξkxt+1(k)+ξn1xt+1(n1)+ξn(xt+1(n1)xt(n1))\begin{equation*} \eta \nabla \hat{f}_{t} = \sum_{k=1}^{n-2} \xi_{k} x^{(k)}_{t+1} + \xi_{n-1} x^{(n-1)}_{t+1} + \xi_{n} \big(x^{(n-1)}_{t+1} - x^{(n-1)}_{t}\big) \end{equation*}

and therefore, the update rules

xt+1(n1)=1ξn+ξn1(ξnxt(n1)k=1n2ξkxt+1(k)ηf^t),xt+1(n2)=xt(n2)+xt+1(n1),xt+1=xt+xt+1(1).\begin{equation} \label{eqn-9}\tag{9} \begin{aligned} x^{(n-1)}_{t+1} &= \frac{1}{\xi_{n} + \xi_{n-1}} \bigg( \xi_{n} x^{(n-1)}_{t} - \sum_{k=1}^{n-2} \xi_{k} x^{(k)}_{t+1} - \eta \nabla \hat{f}_{t} \bigg), \\ x^{(n-2)}_{t+1} &= x^{(n-2)}_{t} + x^{(n-1)}_{t+1}, \\ &\vdots \\ x_{t+1} &= x_{t} + x^{(1)}_{t+1}. \end{aligned} \end{equation}

For n=2n = 2, Letting ξ1=(1ξ)\xi_{1} = (1 - \xi) and ξ2=ξ\xi_{2} = \xi, we recover the standard momentum update.

Surrogate Approximation

As xx varies from xtx_{t} to xt+1x_{t+1}, the minimum average error of a local approximation of f(x)f(x) along this path can be reduced by choosing an intermediate point x~t\tilde{x}_{t}, somewhere between xtx_{t} and xt+1x_{t+1}, instead of xtx_{t}, about which to approximate the local behavior of ff.

Nesterov Acceleration

When using Eq. (8), one “zero-cost” (i.e., if we decouple candidate solutions xtx_{t} from the points at which we query f^\hat{f}) choice of x~t\tilde{x}_{t} is given by

x~t=xt+ξvt,\begin{equation*} \tilde{x}_{t} = x_{t} + \xi v_{t}, \end{equation*}

such that

f^t(x)=f(x~t)+f(x~t),xx~t.\begin{equation} \label{eqn-10}\tag{10} \hat{f}_{t}(x) = f(\tilde{x}_{t}) + \big\langle \nabla f(\tilde{x}_{t}), x - \tilde{x}_{t} \big\rangle. \end{equation}

Substituting Eq. (10) into Eq. (8), we obtain Nesterov’s variant of gradient descent with momentum:

vt+1=ξvtηf(xt+ξvt),xt+1=xt+vt+1.\begin{equation*} \begin{aligned} v_{t+1} &= \xi v_{t} - \eta \nabla f(x_{t} + \xi v_{t}), \\ x_{t+1} &= x_{t} + v_{t+1}. \end{aligned} \end{equation*}

Dealing with Constraints

Consider the basic (constrained) optimization problem

minimizef(x)subject tog(x)0,h(x)=0,\begin{equation} \begin{array}{lll} {\rm minimize} &f(x) \\ {\rm subject~to} &g(x) \preceq 0, \\ &h(x) = 0, \end{array} \end{equation}

for vector-valued gg and hh.

In terms of the corresponding Lagrangian L\mathcal{L}, the primal problem may be written

minimizex(maxλ,μL(x,λ,μ))subject toλ0,\begin{equation} \label{eqn-12}\tag{12} \begin{array}{lll} \underset{x}{\rm minimize} \bigg( \underset{\lambda, \mu}{\rm max}& \mathcal{L}(x, \lambda, \mu) \bigg) \\ {\rm subject~to} & \lambda \succeq 0, \end{array} \end{equation}

where

L(x,λ,μ)=f(x)+λg(x)+μh(x).\begin{equation} \mathcal{L}(x, \lambda, \mu) = f(x) + \lambda^{\top} g(x) + \mu^{\top} h(x). \end{equation}

Intuitively, given that λ\lambda and μ\mu are chosen adversarially (i.e., after xx is fixed), the problem is to choose xx such that the constraints on g(x)g(x) and h(x)h(x) are satisfied (otherwise, the objective can be made unboundedly positive by an adversary’s choice of λ,μ\lambda, \mu). Within this “feasible set” of xx values, ff should be minimized.

Primal-Dual

We may iteratively approximate Eq. (12) with a sequence of subproblems in the form of Eq. (2):

xt+1=argminx[maxλ0,μ (L^t(x,λ,μ)+Dt(x)Rt(λ)St(μ))],\begin{equation*} x_{t+1} = \underset{x}{\rm argmin}\bigg[ \underset{\lambda \succ 0, \mu}{\rm max} ~ \bigg( \hat{\mathcal{L}}_{t}(x, \lambda, \mu) + D_{t}(x) - R_{t}(\lambda) - S_{t}(\mu) \bigg) \bigg], \end{equation*}

introducing update penalties for state variables xx, λ\lambda, and μ\mu.

For the simple choices

Dt(x)=12ηxxt22,η>0,Rt(λ)=12αλλt22,α>0,St(μ)=12βμμt22,β>0,\begin{align*} D_{t}(x) &= \frac{1}{2\eta}\big\|x - x_{t}\big\|_{2}^{2}, \quad \eta > 0, \\ R_{t}(\lambda) &= \frac{1}{2\alpha}\big\|\lambda - \lambda_{t}\big\|_{2}^{2}, \quad \alpha > 0, \\ S_{t}(\mu) &= \frac{1}{2\beta}\big\|\mu - \mu_{t}\big\|_{2}^{2}, \quad \beta > 0, \end{align*}

and local, linear approximations

f^t(x)=f(xt)+f(xt),xxt;g^t(x)=g(xt)+g(xt),xxt;h^t(x)=h(xt)+h(xt),xxt;L^t(x,λ,μ)=f^t(x)+λg^t(x)+μh^t(x),\begin{align*} \hat{f}_{t}(x) &= f(x_{t}) + \big\langle \nabla f(x_{t}), x - x_{t} \big\rangle; \\ \hat{g}_{t}(x) &= g(x_{t}) + \big\langle \nabla g(x_{t}), x - x_{t} \big\rangle; \\ \hat{h}_{t}(x) &= h(x_{t}) + \big\langle \nabla h(x_{t}), x - x_{t} \big\rangle; \\ \hat{\mathcal{L}}_{t}(x, \lambda, \mu) &= \hat{f}_{t}(x) + \lambda^{\top} \hat{g}_{t}(x) + \mu^{\top} \hat{h}_{t}(x), \end{align*}

we have the update rules

xt+1=xtη(f(xt)+λt+1g(xt)+μt+1h(xt)),λt+1=max(λt+αg^(xt+1),0),μt+1=μt+βh^(xt+1),\begin{equation} \label{eqn-14}\tag{14} \begin{aligned} x_{t+1} &= x_{t} - \eta \bigg( \nabla f(x_{t}) + \lambda_{t+1}^{\top} \nabla g(x_{t}) + \mu^{\top}_{t+1} \nabla h(x_{t})\bigg), \\ \lambda_{t+1} &= \max\bigg( \lambda_{t} + \alpha \hat{g}(x_{t+1}), 0\bigg), \\ \mu_{t+1} &= \mu_{t} + \beta \hat{h}(x_{t+1}), \end{aligned} \end{equation}

where max\max is taken element-wise and the gradients are taken with respect to xx (and these gradient components remain uncontracted with the dimensions of λ\lambda or μ\mu).

In practice, it is common to eliminate the mutual dependence between variables in Eq. (14) by replacing xt+1xtx_{t+1} \mapsto x_{t}, λt+1λt\lambda_{t+1} \mapsto \lambda_{t}, and μt+1μt\mu_{t+1} \mapsto \mu_{t} on the right-hand side of each equation, thus yielding the standard primal-dual algorithm, which maintains iterates for the dual variables λ\lambda and μ\mu in addition to the primal variable xx. The error of this approximation has order O(η(α+β))\mathcal{O}(\eta( \alpha + \beta)), though this is not strictly necessary.

When ηα<g(xt)2\eta \alpha < |\nabla g(x_{t}) |^{2} and ηβ<h(xt)2,\eta \beta < |\nabla h(x_{t}) |^{2}, the true solution to Eq. (14) may be found by iterating the recursive map

vt+1k+1η(f(xt) +max(0,λt+α(g(xt)+g(xt),vt+1k))g(xt) +(μt+β(h(xt)+h(xt),vt+1k))h(xt)).\begin{equation*} v^{k+1}_{t+1} \leftarrow -\eta \left( \begin{aligned} \nabla f(x_{t}) &~+ \\ \max\Big(0, \lambda_{t} + \alpha \big( g(x_{t}) + \langle \nabla g(x_{t}), v^{k}_{t+1} \rangle \big)\Big) \nabla g(x_{t}) &~+ \\ \Big( \mu_{t} + \beta \big( h(x_{t}) + \langle \nabla h(x_{t}), v^{k}_{t+1} \rangle \big)\Big) \nabla h(x_{t})& \end{aligned} \right). \end{equation*}

such that xt+1=xt+vt+1x_{t+1} = x_{t} + v_{t+1}^{\infty}.

For convex functions ff, gg, and hh, and sufficiently small η,α,β\eta, \alpha, \beta, iterating subproblem Eq. (14) will cause xx, λ\lambda, and μ\mu to converge to finite values and solve the target constrained optimization problem Eq. (12).

Generalizing Fletcher’s Method

In this section, we show how a rather interesting penalty choice for the dual variable λ\lambda, as used above, completely eliminates the need to track λ\lambda as an independent variable all and provides a straight-forward method for constrained optimization based on local gradients of the objective and the constraint.

Consider the problem

minimizef(x)subject tog(x)0,\begin{equation} \begin{array}{lll} {\rm minimize} &f(x) \\ {\rm subject~to} &g(x) \leq 0, \end{array} \end{equation}

for vector xx and scalar-valued ff and gg.

Given standard assumptions such as convexity, this problem may be solved the sequence of subproblems

xt+1= argminxf^(x)+λg^(x)+12ηxxt2,λ argmaxv0f(xt)+vg(xt)12f(xt)+vg(xt)2,\begin{align*} x_{t+1} =&~\underset{x}{\rm argmin}\quad \hat{f}(x) + \lambda \hat{g}(x) + \frac{1}{2\eta} \| x - x_{t} \|^{2}, \\ \lambda \equiv&~ \underset{v \geq 0}{\rm argmax}\quad f(x_{t}) + v g(x_{t}) - \frac{1}{2} \Big\| \nabla f(x_{t}) + v \nabla g(x_{t}) \Big\|^{2}, \end{align*}

which, as in Eq. (4), we express using the local, linear approximations f^\hat{f} and g^\hat{g}. Note that we have no need to track the value of λ\lambda between iterates of the above subproblems. This is because, instead of penalizing the update magnitude λλt\|\lambda - \lambda_{t}\|, we penalize the distance of λ\lambda away from the value that renders xtx_{t} a critical point of the Lagrangian. We have omitted an explicit time-index on λ\lambda despite the fact that its value may change with each iteration.

Solving the above subproblem, xt+1x_{t+1} has the explicit solution given by

xt+1=xtηΛ;Λf(xt)+λg(xt);λmax[0,g(xt)f(xt),g(xt),g(xt),g(xt)],\begin{align*} x_{t+1} &= x_{t} - \eta \nabla \Lambda; \\ \nabla \Lambda &\equiv \nabla f(x_{t}) + \lambda \nabla g(x_{t}); \\ \lambda &\equiv \max\bigg[0, \frac{ g(x_{t}) - \big\langle \nabla f(x_{t}), \nabla g(x_{t}), \big\rangle}{\big\langle \nabla g(x_{t}), \nabla g(x_{t}) \big\rangle} \bigg], \end{align*}

where Λ\nabla \Lambda represents an estimate for the gradient of the Lagrangian at xtx_{t}.

Manipulating the above equations, we see that

Λ,g(xt)=1ηxt+1xt,g(xt).Λ,g(xt)=f(xt),g(xt)+λg(xt),g(xt).=f(xt),g(xt)+max[0,g(xt)f(xt),g(xt)]=max[f(xt),g(xt),g(xt)].\begin{align*} \big\langle \nabla \Lambda, \nabla g(x_{t}) \big\rangle &= - \frac{1}{\eta} \big\langle x_{t+1} - x_{t}, \nabla g(x_{t}) \big\rangle. \\ \big\langle \nabla \Lambda, \nabla g(x_{t}) \big\rangle &= \big\langle \nabla f(x_{t}), \nabla g(x_{t}) \big\rangle + \lambda \big\langle \nabla g(x_{t}), \nabla g(x_{t}) \big\rangle. \\ &= \big\langle \nabla f(x_{t}), \nabla g(x_{t}) \big\rangle + \max\bigg[ 0, g(x_{t}) - \big\langle \nabla f(x_{t}), \nabla g(x_{t}) \big\rangle \bigg] \\ &= \max\bigg[ \big\langle \nabla f(x_{t}), \nabla g(x_{t}) \big\rangle, g(x_{t}) \bigg]. \end{align*}

From which it follows

1ηxt+1xt,g(xt)g(xt).\begin{equation*} -\frac{1}{\eta} \big\langle x_{t+1} - x_{t}, \nabla g(x_{t}) \big\rangle \geq g(x_{t}). \end{equation*}

By similar logic,

1ηxt+1xt,f(xt)f(xt)2.\begin{equation*} -\frac{1}{\eta} \big\langle x_{t+1} - x_{t}, \nabla f(x_{t}) \big\rangle \geq \|\nabla f(x_{t}) \|^{2}. \end{equation*}

This suggests the following, alternative expression for our subproblem:

xt+1= argminxx,f(xt)+12ηxxt2, subject toxxt,g(xt)ηg(xt).\begin{align*} x_{t+1} =~\underset{x}{\rm argmin}&\quad \big\langle x, \nabla f(x_{t}) \big\rangle + \frac{1}{2\eta} \| x - x_{t} \|^{2}, \\ ~{\rm subject~to}&\quad \big\langle x - x_{t}, -\nabla g(x_{t}) \big\rangle \geq \eta g(x_{t}). \end{align*}

Intuitively, when g(xt)g(x_{t}) is positive, in order to make progress towards the constraint g0g \leq 0, we ensure that the update xt+1xtx_{t+1} - x_{t} is aligned with the negative gradient of gg. When gg is negative, and the constraint already satisfied, the update cannot be too aligned with increasing gg. Subject to these constraints xt+1x_{t+1} may be chosen to make progress towards decreasing the objective ff.

References