Stochastic gradient descent

From Infogalactic: the planetary knowledge core
Jump to: navigation, search

Stochastic gradient descent is a gradient descent optimization method for minimizing an objective function that is written as a sum of differentiable functions.

Background

<templatestyles src="Module:Hatnote/styles.css"></templatestyles>

<templatestyles src="Module:Hatnote/styles.css"></templatestyles>

Both statistical estimation and machine learning consider the problem of minimizing an objective function that has the form of a sum:

Q(w) = \sum_{i=1}^n Q_i(w),

where the parameter w^* which minimizes Q(w) is to be estimated. Each summand function Q_i is typically associated with the i-th observation in the data set (used for training).

In classical statistics, sum-minimization problems arise in least squares and in maximum-likelihood estimation (for independent observations). The general class of estimators that arise as minimizers of sums are called M-estimators. However, in statistics, it has been long recognized that requiring even local minimization is too restrictive for some problems of maximum-likelihood estimation, as shown for example by Thomas Ferguson's example.[1] Therefore, contemporary statistical theorists often consider stationary points of the likelihood function (or zeros of its derivative, the score function, and other estimating equations).

The sum-minimization problem also arises for empirical risk minimization: In this case, Q_i(w) is the value of the loss function at i-th example, and Q(w) is the empirical risk.

When used to minimize the above function, a standard (or "batch") gradient descent method would perform the following iterations :

w := w - \eta \nabla Q(w) = w - \eta \sum_{i=1}^n \nabla Q_i(w),

where \eta is a step size (sometimes called the learning rate in machine learning).

In many cases, the summand functions have a simple form that enables inexpensive evaluations of the sum-function and the sum gradient. For example, in statistics, one-parameter exponential families allow economical function-evaluations and gradient-evaluations.

However, in other cases, evaluating the sum-gradient may require expensive evaluations of the gradients from all summand functions. When the training set is enormous and no simple formulas exist, evaluating the sums of gradients becomes very expensive, because evaluating the gradient requires evaluating all the summand functions' gradients. To economize on the computational cost at every iteration, stochastic gradient descent samples a subset of summand functions at every step. This is very effective in the case of large-scale machine learning problems.[2]

Iterative method

Fluctuations in the total objective function as gradient steps with respect to mini-batches are taken.

In stochastic (or "on-line") gradient descent, the true gradient of Q(w) is approximated by a gradient at a single example:

w := w - \eta \nabla Q_i(w).

As the algorithm sweeps through the training set, it performs the above update for each training example. Several passes can be made over the training set until the algorithm converges. If this is done, the data can be shuffled for each pass to prevent cycles. Typical implementations may use an adaptive learning rate so that the algorithm converges.

In pseudocode, stochastic gradient descent can be presented as follows:

  • Choose an initial vector of parameters w and learning rate \eta.
  • Repeat until an approximate minimum is obtained:
    • Randomly shuffle examples in the training set.
    • For \! i=1, 2, ..., n, do:
      • \! w := w - \eta \nabla Q_i(w).

A compromise between computing the true gradient and the gradient at a single example, is to compute the gradient against more than one training example (called a "mini-batch") at each step. This can perform significantly better than true stochastic gradient descent because the code can make use of vectorization libraries rather than computing each step separately. It may also result in smoother convergence, as the gradient computed at each step uses more training examples.

The convergence of stochastic gradient descent has been analyzed using the theories of convex minimization and of stochastic approximation. Briefly, when the learning rates \eta decrease with an appropriate rate, and subject to relatively mild assumptions, stochastic gradient descent converges almost surely to a global minimum when the objective function is convex or pseudoconvex, and otherwise converges almost surely to a local minimum.[3] [4] This is in fact a consequence of the Robbins-Siegmund theorem.[5]

Example

Let's suppose we want to fit a straight line y = \! w_1 + w_2 x to a training set of two-dimensional points \! (x_1, y_1), \ldots, (x_n, y_n) using least squares. The objective function to be minimized is:

Q(w) = \sum_{i=1}^n Q_i(w) = \sum_{i=1}^n \left(w_1 + w_2 x_i - y_i\right)^2.

The last line in the above pseudocode for this specific problem will become:

\begin{bmatrix} w_1 \\ w_2 \end{bmatrix} :=
    \begin{bmatrix} w_1 \\ w_2 \end{bmatrix}
    -  \eta  \begin{bmatrix} 2 (w_1 + w_2 x_i - y_i) \\ 2 x_i(w_1 + w_2 x_i - y_i) \end{bmatrix}.

Applications

Stochastic gradient descent is a popular algorithm for training a wide range of models in machine learning, including (linear) support vector machines, logistic regression (see, e.g., Vowpal Wabbit) and graphical models.[6] When combined with the backpropagation algorithm, it is the de facto standard algorithm for training artificial neural networks.[7]

SGD competes with the L-BFGS algorithm,[citation needed] which is also widely used. SGD has been used since at least 1960 for training linear regression models, originally under the name ADALINE.[8]

Another popular stochastic gradient descent algorithm is the least mean squares (LMS) adaptive filter.

Extensions and variants

Many improvements on the basic SGD algorithm have been proposed and used. In particular, in machine learning, the need to set a learning rate (step size) has been recognized as problematic. Setting this parameter too high can cause the algorithm to diverge; setting it too low makes it slow to converge. A conceptually simple extension of SGD makes the learning rate a decreasing function ηt of the iteration number t, giving a learning rate schedule, so that the first iterations cause large changes in the parameters, while the later ones do only fine-tuning. Such schedules have been known since the work of MacQueen on k-means clustering.[9]

Momentum

Further proposals include the momentum method, which appeared in Rumelhart, Hinton and Williams' seminal paper on backpropagation learning.[10] SGD with momentum remembers the update Δ w at each iteration, and determines the next update as a convex combination of the gradient and the previous update:

\Delta w := \eta \nabla Q_i(w) + \alpha \Delta w
w := w - \eta \Delta w

The name momentum stems from an analogy to momentum in physics: the weight vector, thought of as a particle traveling through parameter space,[10]:{{{3}}} incurs acceleration from the gradient of the loss ("force"). Unlike in classical SGD, it tends to keep traveling in the same direction, preventing oscillations. Momentum has been used successfully for several decades.[11]

Averaging

Averaged SGD, invented independently by Ruppert and Polyak in the late 1980s, is ordinary SGD that records an average of its parameter vector over time. That is, the update is the same as for ordinary SGD, but the algorithm also keeps track of[12]

\bar{w} = \frac{1}{t} \sum_{i=0}^{t-1} w_i.

When optimization is done, this averaged parameter vector takes the place of w.

AdaGrad

AdaGrad (for adaptive gradient algorithm) is an enhanced SGD that automatically determines a per-parameter learning rate.[13][14] It still has a base learning rate η, but this is multiplied with the elements of a vector {Gj,j} that is thought of as the diagonal of a matrix

G = \sum_{\tau=1}^t g_\tau g_\tau^\mathsf{T}

where g_\tau = \nabla Q_i(w), the gradient, at iteration τ. The diagonal is given by

G_{j,j} = \sum_{\tau=1}^t g_{\tau,j}^2.

This vector is updated after every iteration. The formula for an update is now

w := w - \eta\, \mathrm{diag}(G)^{-\frac{1}{2}} \circ g[lower-alpha 1]

or, written as per-parameter updates,

w_j := w_j - \frac{\eta}{\sqrt{G_{j,j}}} g_j.

Each {G(i,i)} gives rise to a scaling factor for the learning rate that applies to a single parameter wi. Since the denominator in this factor, \sqrt{G_i} = \sqrt{\sum_{\tau=1}^t g_\tau^2} is the 2 norm of previous derivatives, extreme parameter updates get dampened, while parameters that get few or small updates receive higher learning rates.[11]

While designed for convex problems, AdaGrad has been successfully applied to non-convex optimization.[15]

Notes

See also

References

  1. Lua error in package.lua at line 80: module 'strict' not found.
  2. Lua error in package.lua at line 80: module 'strict' not found.
  3. Lua error in package.lua at line 80: module 'strict' not found.
  4. Lua error in package.lua at line 80: module 'strict' not found.
  5. Lua error in package.lua at line 80: module 'strict' not found.
  6. Jenny Rose Finkel, Alex Kleeman, Christopher D. Manning (2008). Efficient, Feature-based, Conditional Random Field Parsing. Proc. Annual Meeting of the ACL.
  7. Yann Lecun et. al., Efficient Backprop Triks
  8. Lua error in package.lua at line 80: module 'strict' not found.
  9. Cited by Lua error in package.lua at line 80: module 'strict' not found.
  10. 10.0 10.1 Lua error in package.lua at line 80: module 'strict' not found.
  11. 11.0 11.1 Lua error in package.lua at line 80: module 'strict' not found.
  12. Lua error in package.lua at line 80: module 'strict' not found.
  13. Lua error in package.lua at line 80: module 'strict' not found.
  14. Lua error in package.lua at line 80: module 'strict' not found.
  15. Lua error in package.lua at line 80: module 'strict' not found.

Further reading

  • Lua error in package.lua at line 80: module 'strict' not found..
  • Lua error in package.lua at line 80: module 'strict' not found..
  • Lua error in package.lua at line 80: module 'strict' not found..
  • Lua error in package.lua at line 80: module 'strict' not found..
  • Lua error in package.lua at line 80: module 'strict' not found..
  • Lua error in package.lua at line 80: module 'strict' not found.. (Extensive list of references)
  • Lua error in package.lua at line 80: module 'strict' not found..

Software

External links