gradient_descent:output02
>
> library(ggplot2)
> library(ggpmisc)
> library(tidyverse)
> library(data.table)
>
> # settle down
> rm(list=ls())
>
> ss <- function(x) {
+ return(sum((x-mean(x))^2))
+ }
>
> # data preparation
> set.seed(101)
> nx <- 50 # variable x, sample size
> mx <- 4.5 # mean of x
> sdx <- mx * 0.56 # sd of x
> x <- rnorm(nx, mx, sdx) # generating x
> slp <- 4 # slop of x = coefficient, b
> # y variable
> y <- slp * x + rnorm(nx, 0, slp*3*sdx)
>
> data <- data.frame(x, y)
> head(data)
x y
1 3.678388 -20.070168
2 5.892204 15.268808
3 2.799142 28.672292
4 5.040186 -22.081593
5 5.283138 43.784059
6 7.458395 -1.954306
>
> # check with regression
> mo <- lm(y ~ x, data = data)
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-58.703 -20.303 0.331 19.381 51.929
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2.708 8.313 -0.326 0.74601
x 5.005 1.736 2.884 0.00587 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299
F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867
>
> # graph
> ggplot(data = data, aes(x = x, y = y)) +
+ geom_point() +
+ stat_poly_line() +
+ stat_poly_eq(use_label(c("eq", "R2"))) +
+ theme_classic()
>
> # from what we know
> # get covariance value
> sp.yx <- sum((x-mean(x))*(y-mean(y)))
> df.yx <- length(y)-1
> sp.yx/df.yx
[1] 27.61592
> # check with cov function
> cov(x,y)
[1] 27.61592
> # correlation value
> cov(x,y)/(sd(x)*sd(y))
[1] 0.3842686
> cor(x,y)
[1] 0.3842686
>
> # regression by hand
> # b and a
> b <- sp.yx / ss(x) # b2 <- cov(x,y)/var(x)
> a <- mean(y) - b*(mean(x))
> a
[1] -2.708294
> b
[1] 5.004838
>
> # check a and b value from the lm
> summary(mo)$coefficient[1]
[1] -2.708294
> summary(mo)$coefficient[2]
[1] 5.004838
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-58.703 -20.303 0.331 19.381 51.929
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2.708 8.313 -0.326 0.74601
x 5.005 1.736 2.884 0.00587 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299
F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867
>
> fit.yx <- a + b*x # predicted value of y from x data
> res <- y - fit.yx # error residuals
> reg <- fit.yx - mean(y) # error regressions
> ss.res <- sum(res^2)
> ss.reg <- sum(reg^2)
> ss.res+ss.reg
[1] 45864.4
> ss.tot <- ss(y)
> ss.tot
[1] 45864.4
>
> plot(x,y)
> abline(a, b, col="red", lwd=2)
> plot(x, fit.yx)
> plot(x, res)
>
> df.y <- length(y)-1
> df.reg <- 2-1
> df.res <- df.y - df.reg
> df.res
[1] 48
>
> r.sq <- ss.reg / ss.tot
> r.sq
[1] 0.1476624
> summary(mo)$r.square
[1] 0.1476624
> ms.reg <- ss.reg / df.reg
> ms.res <- ss.res / df.res
>
>
> f.cal <- ms.reg / ms.res
> f.cal
[1] 8.315713
> pf(f.cal, df.reg, df.res,lower.tail = F)
[1] 0.005867079
> t.cal <- sqrt(f.cal)
> t.cal
[1] 2.883698
> se.b <- sqrt(ms.res/ss(x))
> se.b
[1] 1.735562
> t.cal <- (b-0)/se.b
> t.cal
[1] 2.883698
> pt(t.cal, df=df.res, lower.tail = F)*2
[1] 0.005867079
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-58.703 -20.303 0.331 19.381 51.929
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2.708 8.313 -0.326 0.74601
x 5.005 1.736 2.884 0.00587 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299
F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867
>
>
> # getting a and b from
> # gradient descent
> a <- rnorm(1)
> b <- rnorm(1)
> a.start <- a
> b.start <- b
> a.start
[1] 0.2680658
> b.start
[1] -0.5922083
>
> # Predict function:
> predict <- function(x, a, b){
+ return (a + b * x)
+ }
>
> # And loss function is:
> residuals <- function(fit, y) {
+ return(y - fit)
+ }
>
> gradient <- function(x, res){
+ db = -2 * mean(x * res)
+ da = -2 * mean(res)
+ return(list("b" = db, "a" = da))
+ }
>
> # to check ms.residual
> msrloss <- function(fit, y) {
+ res <- residuals(fit, y)
+ return(mean(res^2))
+ }
>
> # Train the model with scaled features
> learning.rate = 1e-1 # 0.1
>
> # Record Loss for each epoch:
> as = c()
> bs = c()
> msrs = c()
> ssrs = c()
> mres = c()
> zx <- (x-mean(x))/sd(x)
>
> nlen <- 75
> for (epoch in 1:nlen) {
+ fit.val <- predict(zx, a, b)
+ residual <- residuals(fit.val, y)
+ loss <- msrloss(fit.val, y)
+ mres <- append(mres, mean(residual))
+ msrs <- append(msrs, loss)
+
+ grad <- gradient(zx, residual)
+ step.b <- grad$b * learning.rate
+ step.a <- grad$a * learning.rate
+ b <- b-step.b
+ a <- a-step.a
+
+ as <- append(as, a)
+ bs <- append(bs, b)
+ }
> msrs
[1] 1254.6253 1085.3811 976.7258 906.9672 862.1801 833.4247 814.9621 803.1078 795.4963 790.6089
[11] 787.4707 785.4556 784.1615 783.3306 782.7970 782.4543 782.2342 782.0929 782.0021 781.9438
[21] 781.9064 781.8823 781.8669 781.8569 781.8506 781.8465 781.8439 781.8422 781.8411 781.8404
[31] 781.8399 781.8396 781.8395 781.8393 781.8393 781.8392 781.8392 781.8392 781.8392 781.8391
[41] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[51] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[61] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[71] 781.8391 781.8391 781.8391 781.8391 781.8391
> mres
[1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00
[8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01
[15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01
[22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02
[29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03
[36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03
[43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04
[50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05
[57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05
[64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06
[71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06
> as
[1] 3.864439 6.741538 9.043217 10.884560 12.357635 13.536094 14.478862 15.233076 15.836447 16.319144
[11] 16.705302 17.014228 17.261369 17.459082 17.617252 17.743788 17.845017 17.926000 17.990787 18.042616
[21] 18.084079 18.117250 18.143786 18.165016 18.181999 18.195586 18.206455 18.215151 18.222107 18.227672
[31] 18.232124 18.235686 18.238535 18.240815 18.242638 18.244097 18.245264 18.246198 18.246945 18.247542
[41] 18.248021 18.248403 18.248709 18.248954 18.249149 18.249306 18.249431 18.249532 18.249612 18.249676
[51] 18.249727 18.249768 18.249801 18.249828 18.249849 18.249865 18.249879 18.249890 18.249898 18.249905
[61] 18.249911 18.249915 18.249919 18.249921 18.249924 18.249925 18.249927 18.249928 18.249929 18.249930
[71] 18.249930 18.249931 18.249931 18.249931 18.249932
> bs
[1] 1.828121 3.774066 5.338606 6.596496 7.607839 8.420960 9.074708 9.600322 10.022916 10.362681
[11] 10.635852 10.855482 11.032064 11.174036 11.288182 11.379955 11.453741 11.513064 11.560760 11.599108
[21] 11.629940 11.654728 11.674658 11.690682 11.703565 11.713923 11.722251 11.728946 11.734330 11.738658
[31] 11.742138 11.744935 11.747185 11.748993 11.750447 11.751616 11.752556 11.753312 11.753920 11.754408
[41] 11.754801 11.755117 11.755370 11.755575 11.755739 11.755871 11.755977 11.756062 11.756131 11.756186
[51] 11.756230 11.756266 11.756294 11.756317 11.756336 11.756351 11.756363 11.756372 11.756380 11.756386
[61] 11.756391 11.756395 11.756399 11.756401 11.756403 11.756405 11.756406 11.756407 11.756408 11.756409
[71] 11.756410 11.756410 11.756410 11.756411 11.756411
>
> # scaled
> a
[1] 18.24993
> b
[1] 11.75641
>
> # unscale coefficients to make them comprehensible
> # see http://commres.net/wiki/gradient_descent#why_normalize_scale_or_make_z-score_xi
> # and
> # http://commres.net/wiki/gradient_descent#how_to_unnormalize_unscale_a_and_b
> #
> a = a - (mean(x) / sd(x)) * b
> b = b / sd(x)
> a
[1] -2.708293
> b
[1] 5.004837
>
> # changes of estimators
> as <- as - (mean(x) /sd(x)) * bs
> bs <- bs / sd(x)
>
> as
[1] 0.60543638 0.01348719 -0.47394836 -0.87505325 -1.20490696 -1.47600164 -1.69867560 -1.88147654
[9] -2.03146535 -2.15446983 -2.25529623 -2.33790528 -2.40555867 -2.46094055 -2.50625843 -2.54332669
[17] -2.57363572 -2.59840909 -2.61865082 -2.63518431 -2.64868455 -2.65970460 -2.66869741 -2.67603377
[25] -2.68201712 -2.68689566 -2.69087236 -2.69411311 -2.69675345 -2.69890410 -2.70065549 -2.70208142
[33] -2.70324211 -2.70418670 -2.70495527 -2.70558050 -2.70608902 -2.70650253 -2.70683873 -2.70711203
[41] -2.70733415 -2.70751464 -2.70766129 -2.70778042 -2.70787718 -2.70795575 -2.70801956 -2.70807135
[49] -2.70811340 -2.70814753 -2.70817522 -2.70819769 -2.70821592 -2.70823071 -2.70824271 -2.70825244
[57] -2.70826033 -2.70826672 -2.70827191 -2.70827611 -2.70827952 -2.70828228 -2.70828452 -2.70828634
[65] -2.70828781 -2.70828900 -2.70828996 -2.70829074 -2.70829137 -2.70829189 -2.70829230 -2.70829264
[73] -2.70829291 -2.70829313 -2.70829331
> bs
[1] 0.7782519 1.6066627 2.2727050 2.8082030 3.2387434 3.5848979 3.8632061 4.0869659 4.2668688 4.4115107
[11] 4.5278028 4.6213016 4.6964747 4.7569138 4.8055069 4.8445757 4.8759871 4.9012418 4.9215466 4.9378716
[21] 4.9509970 4.9615498 4.9700342 4.9768557 4.9823401 4.9867497 4.9902949 4.9931453 4.9954370 4.9972795
[31] 4.9987609 4.9999520 5.0009096 5.0016795 5.0022985 5.0027962 5.0031963 5.0035180 5.0037767 5.0039846
[41] 5.0041518 5.0042863 5.0043943 5.0044812 5.0045511 5.0046073 5.0046524 5.0046887 5.0047179 5.0047414
[51] 5.0047603 5.0047754 5.0047876 5.0047974 5.0048053 5.0048117 5.0048168 5.0048209 5.0048242 5.0048268
[61] 5.0048289 5.0048307 5.0048320 5.0048331 5.0048340 5.0048347 5.0048353 5.0048358 5.0048362 5.0048365
[71] 5.0048367 5.0048369 5.0048370 5.0048372 5.0048373
> mres
[1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00
[8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01
[15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01
[22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02
[29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03
[36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03
[43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04
[50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05
[57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05
[64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06
[71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06
> msrs
[1] 1254.6253 1085.3811 976.7258 906.9672 862.1801 833.4247 814.9621 803.1078 795.4963 790.6089
[11] 787.4707 785.4556 784.1615 783.3306 782.7970 782.4543 782.2342 782.0929 782.0021 781.9438
[21] 781.9064 781.8823 781.8669 781.8569 781.8506 781.8465 781.8439 781.8422 781.8411 781.8404
[31] 781.8399 781.8396 781.8395 781.8393 781.8393 781.8392 781.8392 781.8392 781.8392 781.8391
[41] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[51] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[61] 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391 781.8391
[71] 781.8391 781.8391 781.8391 781.8391 781.8391
>
> parameters <- data.frame(as, bs, mres, msrs)
>
> cat(paste0("Intercept: ", a, "\n", "Slope: ", b, "\n"))
Intercept: -2.7082933069293
Slope: 5.00483726695576
>
> summary(mo)$coefficients
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2.708294 8.313223 -0.3257815 0.746005708
x 5.004838 1.735562 2.8836978 0.005867079
>
> msrs <- data.frame(msrs)
> msrs.log <- data.table(epoch = 1:nlen, msrs)
> ggplot(msrs.log, aes(epoch, msrs)) +
+ geom_line(color="blue") +
+ theme_classic()
>
> mres <- data.frame(mres)
> mres.log <- data.table(epoch = 1:nlen, mres)
> ggplot(mres.log, aes(epoch, mres)) +
+ geom_line(color="red") +
+ theme_classic()
>
> ch <- data.frame(mres, msrs)
> ch
mres msrs
1 1.798187e+01 1254.6253
2 1.438549e+01 1085.3811
3 1.150839e+01 976.7258
4 9.206716e+00 906.9672
5 7.365373e+00 862.1801
6 5.892298e+00 833.4247
7 4.713838e+00 814.9621
8 3.771071e+00 803.1078
9 3.016857e+00 795.4963
10 2.413485e+00 790.6089
11 1.930788e+00 787.4707
12 1.544631e+00 785.4556
13 1.235704e+00 784.1615
14 9.885636e-01 783.3306
15 7.908509e-01 782.7970
16 6.326807e-01 782.4543
17 5.061446e-01 782.2342
18 4.049156e-01 782.0929
19 3.239325e-01 782.0021
20 2.591460e-01 781.9438
21 2.073168e-01 781.9064
22 1.658534e-01 781.8823
23 1.326828e-01 781.8669
24 1.061462e-01 781.8569
25 8.491697e-02 781.8506
26 6.793357e-02 781.8465
27 5.434686e-02 781.8439
28 4.347749e-02 781.8422
29 3.478199e-02 781.8411
30 2.782559e-02 781.8404
31 2.226047e-02 781.8399
32 1.780838e-02 781.8396
33 1.424670e-02 781.8395
34 1.139736e-02 781.8393
35 9.117890e-03 781.8393
36 7.294312e-03 781.8392
37 5.835449e-03 781.8392
38 4.668360e-03 781.8392
39 3.734688e-03 781.8392
40 2.987750e-03 781.8391
41 2.390200e-03 781.8391
42 1.912160e-03 781.8391
43 1.529728e-03 781.8391
44 1.223782e-03 781.8391
45 9.790260e-04 781.8391
46 7.832208e-04 781.8391
47 6.265766e-04 781.8391
48 5.012613e-04 781.8391
49 4.010090e-04 781.8391
50 3.208072e-04 781.8391
51 2.566458e-04 781.8391
52 2.053166e-04 781.8391
53 1.642533e-04 781.8391
54 1.314026e-04 781.8391
55 1.051221e-04 781.8391
56 8.409769e-05 781.8391
57 6.727815e-05 781.8391
58 5.382252e-05 781.8391
59 4.305802e-05 781.8391
60 3.444641e-05 781.8391
61 2.755713e-05 781.8391
62 2.204570e-05 781.8391
63 1.763656e-05 781.8391
64 1.410925e-05 781.8391
65 1.128740e-05 781.8391
66 9.029921e-06 781.8391
67 7.223936e-06 781.8391
68 5.779149e-06 781.8391
69 4.623319e-06 781.8391
70 3.698655e-06 781.8391
71 2.958924e-06 781.8391
72 2.367140e-06 781.8391
73 1.893712e-06 781.8391
74 1.514969e-06 781.8391
75 1.211975e-06 781.8391
> max(y)
[1] 83.02991
> ggplot(data, aes(x = x, y = y)) +
+ geom_point(size = 2) +
+ geom_abline(aes(intercept = as, slope = bs),
+ data = parameters, linewidth = 0.5,
+ color = 'green') +
+ stat_poly_line() +
+ stat_poly_eq(use_label(c("eq", "R2"))) +
+ theme_classic() +
+ geom_abline(aes(intercept = as, slope = bs),
+ data = parameters %>% slice_head(),
+ linewidth = 1, color = 'blue') +
+ geom_abline(aes(intercept = as, slope = bs),
+ data = parameters %>% slice_tail(),
+ linewidth = 1, color = 'red') +
+ labs(title = 'Gradient descent. blue: start, red: end, green: gradients')
> summary(mo)
Call:
lm(formula = y ~ x, data = data)
Residuals:
Min 1Q Median 3Q Max
-58.703 -20.303 0.331 19.381 51.929
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -2.708 8.313 -0.326 0.74601
x 5.005 1.736 2.884 0.00587 **
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared: 0.1477, Adjusted R-squared: 0.1299
F-statistic: 8.316 on 1 and 48 DF, p-value: 0.005867
> a.start
[1] 0.2680658
> b.start
[1] -0.5922083
> a
[1] -2.708293
> b
[1] 5.004837
> summary(mo)$coefficient[1]
[1] -2.708294
> summary(mo)$coefficient[2]
[1] 5.004838
>
gradient_descent/output02.txt · Last modified: by hkimscil
