Beyond Gradient Descent
For finding local minima of differentiable functions, the work-horse algorithm of machine learning is gradient descent.
Gradient descent is really a special case of a wider family of non-linear optimization algorithms, and recognizing this fact allows us to address constraints, consider alternative metrics, and be more explicit about the approximations we use and the sequence of subproblems we solve.
The most basic form of gradient descent, where we seek to minimize an objective function over decision variable , is
where is some step-size that we must choose.
This equation is the solution that minimizes a local approximation of plus a regularization term that penalizes us for choosing that results in large updates. That is, gradient descent generates a sequence of candidate solutions by solving a sequence of subproblems of the form
Specifically, Eq. (1) solves
There are a number of generalizations we can make to Eq. (2). For example,
- We may choose non-Euclidean update metrics .
- We may choose based on multiple past iterates.
- We may choose surrogate estimates .
- We may replace with a Lagrangian to address constraints.
In the following sections, I give examples of each of these generalizations.
Pseudo-Riemannian Geometry
Let us consider linear approximations of the form
and distances of the form
where is a (pseudo)metric on the tangent space of .
Under these conditions, it follows that
where is the Moore-Penrose (pseudo)inverse of . This update rule is commonly known as “preconditioned” gradient descent.
Importantly, these conditions correspond to the limit of updates in continuous-time for the true objective and any with positive-semidefinite Hessian, wherein
(Amid and Warmuth (2020)). To see this, let
It follows that
Note that we use the dot (e.g., over and ) to represent a time-derivative.
Gauss-Newton
If our objective function is convex, we are guaranteed that the Hessian of , which we denote as , is positive semi-definite. That is,
When we substitute for in Eq. (3), then the resulting subproblem (Eq. (2)) is simply a minimization over a 2nd-order Taylor approximation of at . The resulting update is known as Gauss-Newton and is frequently used for solving least-squares problems (since, as a sum-of-squares, is convex).
Fisher-Rao
We may also choose values of that derive from spaces other than . For example, when parameterizes a probability distribution , from which the objective function derives, we may measure the magnitude of an update from to by the relative entropy from to .
This choice yields the update rule for Fisher-Rao natural gradient descent:
where is the Fisher metric tensor
One way to understand the effect of the metric is that it (un)warps the space of marginal updates , such that updates of the same induced magnitude contain the same amount of marginal information about .
In continuous time, Fisher-Rao natural gradient descent optimally approximates replicator dynamics (used to model evolution by natural selection) and continuous Bayesian inference (Raab et. al. (2022)). We may also choose different divergences in space, yielding other forms of “natural gradient descent” (Nurbekyan et. al. (2022)).
Projective Geometry
Vanilla gradient descent does not require a Euclidean metric.
As a counterexample, let where the numerator is an outer product and the denominator an inner product of with itself.
It follows that Therefore, the update rule that solves (Eq. (2)) is vanilla gradient descent:
Multi-iterate Methods
So far, we have only discussed updates that rely on at most one previous iterate of (or or ). We may also choose update penalties , that rely on additional history of the solution candidates .
Momentum
An example of a multi-iterate approach is given by the classical “Momentum” variant of gradient descent. Consider, for example, a penalty term that quadratically penalizes step-sizes with a linear term that encourages steps in same continued direction, as in
where . Alternatively, consider a quadratic penalty for the difference between the proposed update and the previous update with decay , as in
In either case,
If we solve Eq. (2) by setting for either choice of , the resulting update rule is
which can be decomposed, with an additional state variable , to
This update rule is the classical “Momentum” algorithm (Botev et. al. (2016)) with decay parameterized by .
Nth-Order Regularization
There is yet another update penalty that may be used to derive the Momentum update rule, up to a reparameterization of constants . Specifically, let
We can naturally extend this distance penalty to account for higher-order derivatives as
where we use the state variable to represent the th time-derivative of , defined empirically by the recursive formula
By this choice of , if we solve Eq. (2) by setting , we obtain the update rules
Surrogate Approximation
As varies from to , the minimum average error of a local approximation of along this path can be reduced by choosing an intermediate point , somewhere between and , instead of , about which to approximate the local behavior of .
Dealing with Constraints
Consider the basic (constrained) optimization problem
for vector-valued and .
In terms of the corresponding Lagrangian , the primal problem may be written
where
Intuitively, given that and are chosen adversarially (i.e., after is fixed), the problem is to choose such that the constraints on and are satisfied (otherwise, the objective can be made unboundedly positive by an adversary’s choice of ). Within this “feasible set” of values, should be minimized.
Augmented Lagrangian
Primal-Dual
For the simple choices
and linear local approximations
we have the update rules
where is taken element-wise and the gradients are taken with respect to (and these gradient components remain uncontracted with the dimensions of or ).
The solution to Eq. (12) is a fixed point of a recursive map . Denoting iterations of this map by the upper index , by change of variables , we define the piece-wise linear map
with fixed point , such that
Claim: The recursive mapping Eq. (13) converges, upon repeated iteration, when and are bounded as
The updates of Eq. (12) comprise a primal-dual algorithm, which maintains iterates for the dual variables and in addition to the primal variable . This basic approach may be extended by substituting different regularization terms or local function approximations.
References
- Conjugate Natural Selection. Raab et. al. (2022)
- Efficient Natural Gradient Descent Methods for Large-Scale PDE-Based Optimization Problems. Nurbekyan et. al. (2022)
- Reparameterizing Mirror Descent as Gradient Descent. Amid and Warmuth (2020)
- Advanced Algorithms (Lecture Notes). Anupam Gupta (2020)
- Nesterov’s accelerated gradient and momentum as approximations to regularised update descent. Botev et. al. (2016)
- Convex Optimization. Boyd and Vandenberghe (2009)
- Numerical Optimization. Nocedal and Wright (2006)