Stochastic gradient descent (SGD) in R

Recently I’ve been learning about optimization with gradient descent, and more specifically, stochastic gradient descent (SGD). I found it to be a little confusing at first so I created a Shiny app to help myself figure it out.

Anyway, the basic gist of gradient descent is that you can take advantage of the property that the sum of derivatives is equal to the derivative of a sum and thereby avoid doing computationally intensive matrix multiplication with large data sets. For example, if we have a linear model of:

Y=Xw Y = Xw

where YY is a vector of observations, i.e. YRNY \in \mathbb{R}^{N}, XX is matrix of input data s.t. XRNxDX \in \mathbb{R}^{N\text{x}D}, and ww is a vector of parameters, wRDw \in \mathbb{R}^{D}. NN being the number of observations and DD the number of parameters.

To solve the linear model, we could use the least-squares loss function to find ww to minimize the following:

L(e)=e2e(w)=YXw L(e) = ||e||^2 \\ e(w) = Y - Xw

where e||e|| represents the Euclidean norm of the error function ee.

Or to think of it in non-matrix terms, find ww to minimize

n=1N(ynf(xn;w))2 \sum_{n=1}^N (y_n - f(x_n;w))^2

where yny_n is the nthn^{th} observation yy and f()f() is a function applied to the nthn^{th} observation of xx using the parameters ww s.t.

f(xn;w)=w0+w1x1+...+wDxD f(x_n;w) = w_0 + w_1x_1 + ... + w_Dx_D

for each observation.

Using the chain rule and desired derivatives (not going through the derivation here, but one great resource for the underlying math is Deisenroth, Faisal, and Ong (2020)), we can solve for the derivative dLdw\frac{dL}{dw},

dLdw=dLdededw=2XT(YXw) \frac{dL}{dw} = \frac{dL}{de}\frac{de}{dw} \\ = -2X^T(Y-Xw)

Setting the derivative equal to zero and solving for ww produces

w=(XTX)1XTY w = (X^TX)^{-1}X^TY

Which makes apparent that if these are large matrices the matrix computation can become expensive. Now we get to stochastic gradient descent.

Instead of using all the observations, we can update the vector of parameters ww incrementally using batches of samples, or even just one sample at a time. Here I’ll show updating by one sample at a time over each iteration ii:

wi+1=wi+γi(L(wi))wi+1=wi+γi(ynf(xn,w))xn w_{i+1} = w_i + \gamma_i(\nabla L(w_i)) \\ w_{i+1} = w_i + \gamma_i(y_n - f(x_n,w))x_n

using a step-size parameter γi\gamma_i. The step-size is chosen to not be too small that convergence takes too long, but not too large to jump over the optimal parameter values. As the error rate decreases, wiw_i converges on the true value of ww. That’s the beauty of gradient descent. And also the crux point of confusion for me and why I made this Shiny app!

In the app, I code a simple implementation of SGD for a linear regression model. The parameters specifying the observed data can be manipulated, and so can the learning rate and number of epochs for SGD. After running the model, plots output the root mean square error, and the weights for the coefficients of the model (intercept and slope). It also plots the predicted versus observed y values for a given epoch, so the crappy performance at earlier epochs can be observed.

See the links at the top of the page for accessing code and the app. A great video to get some intuition of the math behind gradient descent is from 3Blue1Brown.

References

Deisenroth, Marc Peter, A. Aldo Faisal, and Chen Soon Ong. 2020. Mathematics for Machine Learning. 1st ed. Cambridge University Press. https://mml-book.com/.
Posted on:
February 8, 2022
Length:
3 minute read, 570 words
Categories:
Shiny Machine learning
See Also: