User Tools

Site Tools


gradient_descent:output02
> 
> library(ggplot2)
> library(ggpmisc)
> library(tidyverse)
> library(data.table)
> 
> # settle down
> rm(list=ls())
> 
> ss <- function(x) { 
+   return(sum((x-mean(x))^2))
+ }
> 
> # data preparation
> set.seed(101)
> nx <- 50 # variable x, sample size 
> mx <- 4.5 # mean of x
> sdx <- mx * 0.56  # sd of x 
> x <- rnorm(nx, mx, sdx) # generating x 
> slp <- 4 # slop of x = coefficient, b
> # y variable 
> y <-  slp * x + rnorm(nx, 0, slp*3*sdx)
> 
> data <- data.frame(x, y)
> head(data)
         x          y
1 3.678388 -20.070168
2 5.892204  15.268808
3 2.799142  28.672292
4 5.040186 -22.081593
5 5.283138  43.784059
6 7.458395  -1.954306
> 
> # check with regression
> mo <- lm(y ~ x, data = data)
> summary(mo)

Call:
lm(formula = y ~ x, data = data)

Residuals:
    Min      1Q  Median      3Q     Max 
-58.703 -20.303   0.331  19.381  51.929 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)   
(Intercept)   -2.708      8.313  -0.326  0.74601   
x              5.005      1.736   2.884  0.00587 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared:  0.1477,	Adjusted R-squared:  0.1299 
F-statistic: 8.316 on 1 and 48 DF,  p-value: 0.005867

> 
> # graph
> ggplot(data = data, aes(x = x, y = y)) + 
+   geom_point() +
+   stat_poly_line() +
+   stat_poly_eq(use_label(c("eq", "R2"))) +
+   theme_classic() 
> 
> # from what we know
> # get covariance value
> sp.yx <- sum((x-mean(x))*(y-mean(y)))
> df.yx <- length(y)-1
> sp.yx/df.yx 
[1] 27.61592
> # check with cov function
> cov(x,y)
[1] 27.61592
> # correlation value
> cov(x,y)/(sd(x)*sd(y))
[1] 0.3842686
> cor(x,y)
[1] 0.3842686
> 
> # regression by hand 
> # b and a 
> b <- sp.yx / ss(x) # b2 <- cov(x,y)/var(x)
> a <- mean(y) - b*(mean(x))
> a
[1] -2.708294
> b
[1] 5.004838
> 
> # check a and b value from the lm 
> summary(mo)$coefficient[1]
[1] -2.708294
> summary(mo)$coefficient[2]
[1] 5.004838
> summary(mo)

Call:
lm(formula = y ~ x, data = data)

Residuals:
    Min      1Q  Median      3Q     Max 
-58.703 -20.303   0.331  19.381  51.929 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)   
(Intercept)   -2.708      8.313  -0.326  0.74601   
x              5.005      1.736   2.884  0.00587 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared:  0.1477,	Adjusted R-squared:  0.1299 
F-statistic: 8.316 on 1 and 48 DF,  p-value: 0.005867

> 
> fit.yx <- a + b*x # predicted value of y from x data
> res <- y - fit.yx # error residuals
> reg <- fit.yx - mean(y) # error regressions 
> ss.res <- sum(res^2)
> ss.reg <- sum(reg^2)
> ss.res+ss.reg
[1] 45864.4
> ss.tot <- ss(y)
> ss.tot
[1] 45864.4
> 
> plot(x,y)
> abline(a, b, col="red", lwd=2)
> plot(x, fit.yx)
> plot(x, res)
> 
> df.y <- length(y)-1
> df.reg <- 2-1
> df.res <- df.y - df.reg
> df.res
[1] 48
> 
> r.sq <- ss.reg / ss.tot
> r.sq
[1] 0.1476624
> summary(mo)$r.square
[1] 0.1476624
> ms.reg <- ss.reg / df.reg
> ms.res <- ss.res / df.res
> 
> 
> f.cal <- ms.reg / ms.res
> f.cal 
[1] 8.315713
> pf(f.cal, df.reg, df.res,lower.tail = F)
[1] 0.005867079
> t.cal <- sqrt(f.cal)
> t.cal
[1] 2.883698
> se.b <- sqrt(ms.res/ss(x))
> se.b
[1] 1.735562
> t.cal <- (b-0)/se.b
> t.cal
[1] 2.883698
> pt(t.cal, df=df.res, lower.tail = F)*2
[1] 0.005867079
> summary(mo)

Call:
lm(formula = y ~ x, data = data)

Residuals:
    Min      1Q  Median      3Q     Max 
-58.703 -20.303   0.331  19.381  51.929 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)   
(Intercept)   -2.708      8.313  -0.326  0.74601   
x              5.005      1.736   2.884  0.00587 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared:  0.1477,	Adjusted R-squared:  0.1299 
F-statistic: 8.316 on 1 and 48 DF,  p-value: 0.005867

> 
> 
> # getting a and b from
> # gradient descent
> a <- rnorm(1)
> b <- rnorm(1)
> a.start <- a
> b.start <- b
> a.start
[1] 0.2680658
> b.start
[1] -0.5922083
> 
> # Predict function:
> predict <- function(x, a, b){
+   return (a + b * x)
+ }
> 
> # And loss function is:
> residuals <- function(fit, y) {
+   return(y - fit)
+ }
> 
> gradient <- function(x, res){
+   db = -2 * mean(x * res)
+   da = -2 * mean(res)
+   return(list("b" = db, "a" = da))
+ }
> 
> # to check ms.residual
> msrloss <- function(fit, y) {
+   res <- residuals(fit, y)
+   return(mean(res^2))
+ }
> 
> # Train the model with scaled features
> learning.rate = 1e-1 # 0.1
> 
> # Record Loss for each epoch:
> as = c()
> bs = c()
> msrs = c()
> ssrs = c()
> mres = c()
> zx <- (x-mean(x))/sd(x)
> 
> nlen <- 75
> for (epoch in 1:nlen) {
+   fit.val <- predict(zx, a, b)
+   residual <- residuals(fit.val, y)
+   loss <- msrloss(fit.val, y)
+   mres <- append(mres, mean(residual))
+   msrs <- append(msrs, loss)
+   
+   grad <- gradient(zx, residual)
+   step.b <- grad$b * learning.rate 
+   step.a <- grad$a * learning.rate
+   b <- b-step.b
+   a <- a-step.a
+   
+   as <- append(as, a)
+   bs <- append(bs, b)
+ }
> msrs
 [1] 1254.6253 1085.3811  976.7258  906.9672  862.1801  833.4247  814.9621  803.1078  795.4963  790.6089
[11]  787.4707  785.4556  784.1615  783.3306  782.7970  782.4543  782.2342  782.0929  782.0021  781.9438
[21]  781.9064  781.8823  781.8669  781.8569  781.8506  781.8465  781.8439  781.8422  781.8411  781.8404
[31]  781.8399  781.8396  781.8395  781.8393  781.8393  781.8392  781.8392  781.8392  781.8392  781.8391
[41]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[51]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[61]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[71]  781.8391  781.8391  781.8391  781.8391  781.8391
> mres
 [1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00
 [8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01
[15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01
[22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02
[29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03
[36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03
[43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04
[50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05
[57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05
[64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06
[71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06
> as
 [1]  3.864439  6.741538  9.043217 10.884560 12.357635 13.536094 14.478862 15.233076 15.836447 16.319144
[11] 16.705302 17.014228 17.261369 17.459082 17.617252 17.743788 17.845017 17.926000 17.990787 18.042616
[21] 18.084079 18.117250 18.143786 18.165016 18.181999 18.195586 18.206455 18.215151 18.222107 18.227672
[31] 18.232124 18.235686 18.238535 18.240815 18.242638 18.244097 18.245264 18.246198 18.246945 18.247542
[41] 18.248021 18.248403 18.248709 18.248954 18.249149 18.249306 18.249431 18.249532 18.249612 18.249676
[51] 18.249727 18.249768 18.249801 18.249828 18.249849 18.249865 18.249879 18.249890 18.249898 18.249905
[61] 18.249911 18.249915 18.249919 18.249921 18.249924 18.249925 18.249927 18.249928 18.249929 18.249930
[71] 18.249930 18.249931 18.249931 18.249931 18.249932
> bs
 [1]  1.828121  3.774066  5.338606  6.596496  7.607839  8.420960  9.074708  9.600322 10.022916 10.362681
[11] 10.635852 10.855482 11.032064 11.174036 11.288182 11.379955 11.453741 11.513064 11.560760 11.599108
[21] 11.629940 11.654728 11.674658 11.690682 11.703565 11.713923 11.722251 11.728946 11.734330 11.738658
[31] 11.742138 11.744935 11.747185 11.748993 11.750447 11.751616 11.752556 11.753312 11.753920 11.754408
[41] 11.754801 11.755117 11.755370 11.755575 11.755739 11.755871 11.755977 11.756062 11.756131 11.756186
[51] 11.756230 11.756266 11.756294 11.756317 11.756336 11.756351 11.756363 11.756372 11.756380 11.756386
[61] 11.756391 11.756395 11.756399 11.756401 11.756403 11.756405 11.756406 11.756407 11.756408 11.756409
[71] 11.756410 11.756410 11.756410 11.756411 11.756411
> 
> # scaled
> a
[1] 18.24993
> b
[1] 11.75641
> 
> # unscale coefficients to make them comprehensible
> # see http://commres.net/wiki/gradient_descent#why_normalize_scale_or_make_z-score_xi
> # and 
> # http://commres.net/wiki/gradient_descent#how_to_unnormalize_unscale_a_and_b
> #  
> a =  a - (mean(x) / sd(x)) * b
> b =  b / sd(x)
> a
[1] -2.708293
> b
[1] 5.004837
> 
> # changes of estimators
> as <- as - (mean(x) /sd(x)) * bs
> bs <- bs / sd(x)
> 
> as
 [1]  0.60543638  0.01348719 -0.47394836 -0.87505325 -1.20490696 -1.47600164 -1.69867560 -1.88147654
 [9] -2.03146535 -2.15446983 -2.25529623 -2.33790528 -2.40555867 -2.46094055 -2.50625843 -2.54332669
[17] -2.57363572 -2.59840909 -2.61865082 -2.63518431 -2.64868455 -2.65970460 -2.66869741 -2.67603377
[25] -2.68201712 -2.68689566 -2.69087236 -2.69411311 -2.69675345 -2.69890410 -2.70065549 -2.70208142
[33] -2.70324211 -2.70418670 -2.70495527 -2.70558050 -2.70608902 -2.70650253 -2.70683873 -2.70711203
[41] -2.70733415 -2.70751464 -2.70766129 -2.70778042 -2.70787718 -2.70795575 -2.70801956 -2.70807135
[49] -2.70811340 -2.70814753 -2.70817522 -2.70819769 -2.70821592 -2.70823071 -2.70824271 -2.70825244
[57] -2.70826033 -2.70826672 -2.70827191 -2.70827611 -2.70827952 -2.70828228 -2.70828452 -2.70828634
[65] -2.70828781 -2.70828900 -2.70828996 -2.70829074 -2.70829137 -2.70829189 -2.70829230 -2.70829264
[73] -2.70829291 -2.70829313 -2.70829331
> bs
 [1] 0.7782519 1.6066627 2.2727050 2.8082030 3.2387434 3.5848979 3.8632061 4.0869659 4.2668688 4.4115107
[11] 4.5278028 4.6213016 4.6964747 4.7569138 4.8055069 4.8445757 4.8759871 4.9012418 4.9215466 4.9378716
[21] 4.9509970 4.9615498 4.9700342 4.9768557 4.9823401 4.9867497 4.9902949 4.9931453 4.9954370 4.9972795
[31] 4.9987609 4.9999520 5.0009096 5.0016795 5.0022985 5.0027962 5.0031963 5.0035180 5.0037767 5.0039846
[41] 5.0041518 5.0042863 5.0043943 5.0044812 5.0045511 5.0046073 5.0046524 5.0046887 5.0047179 5.0047414
[51] 5.0047603 5.0047754 5.0047876 5.0047974 5.0048053 5.0048117 5.0048168 5.0048209 5.0048242 5.0048268
[61] 5.0048289 5.0048307 5.0048320 5.0048331 5.0048340 5.0048347 5.0048353 5.0048358 5.0048362 5.0048365
[71] 5.0048367 5.0048369 5.0048370 5.0048372 5.0048373
> mres
 [1] 1.798187e+01 1.438549e+01 1.150839e+01 9.206716e+00 7.365373e+00 5.892298e+00 4.713838e+00
 [8] 3.771071e+00 3.016857e+00 2.413485e+00 1.930788e+00 1.544631e+00 1.235704e+00 9.885636e-01
[15] 7.908509e-01 6.326807e-01 5.061446e-01 4.049156e-01 3.239325e-01 2.591460e-01 2.073168e-01
[22] 1.658534e-01 1.326828e-01 1.061462e-01 8.491697e-02 6.793357e-02 5.434686e-02 4.347749e-02
[29] 3.478199e-02 2.782559e-02 2.226047e-02 1.780838e-02 1.424670e-02 1.139736e-02 9.117890e-03
[36] 7.294312e-03 5.835449e-03 4.668360e-03 3.734688e-03 2.987750e-03 2.390200e-03 1.912160e-03
[43] 1.529728e-03 1.223782e-03 9.790260e-04 7.832208e-04 6.265766e-04 5.012613e-04 4.010090e-04
[50] 3.208072e-04 2.566458e-04 2.053166e-04 1.642533e-04 1.314026e-04 1.051221e-04 8.409769e-05
[57] 6.727815e-05 5.382252e-05 4.305802e-05 3.444641e-05 2.755713e-05 2.204570e-05 1.763656e-05
[64] 1.410925e-05 1.128740e-05 9.029921e-06 7.223936e-06 5.779149e-06 4.623319e-06 3.698655e-06
[71] 2.958924e-06 2.367140e-06 1.893712e-06 1.514969e-06 1.211975e-06
> msrs
 [1] 1254.6253 1085.3811  976.7258  906.9672  862.1801  833.4247  814.9621  803.1078  795.4963  790.6089
[11]  787.4707  785.4556  784.1615  783.3306  782.7970  782.4543  782.2342  782.0929  782.0021  781.9438
[21]  781.9064  781.8823  781.8669  781.8569  781.8506  781.8465  781.8439  781.8422  781.8411  781.8404
[31]  781.8399  781.8396  781.8395  781.8393  781.8393  781.8392  781.8392  781.8392  781.8392  781.8391
[41]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[51]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[61]  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391  781.8391
[71]  781.8391  781.8391  781.8391  781.8391  781.8391
> 
> parameters <- data.frame(as, bs, mres, msrs)
> 
> cat(paste0("Intercept: ", a, "\n", "Slope: ", b, "\n"))
Intercept: -2.7082933069293
Slope: 5.00483726695576
> 
> summary(mo)$coefficients
             Estimate Std. Error    t value    Pr(>|t|)
(Intercept) -2.708294   8.313223 -0.3257815 0.746005708
x            5.004838   1.735562  2.8836978 0.005867079
> 
> msrs <- data.frame(msrs)
> msrs.log <- data.table(epoch = 1:nlen, msrs)
> ggplot(msrs.log, aes(epoch, msrs)) +
+   geom_line(color="blue") +
+   theme_classic()
> 
> mres <- data.frame(mres)
> mres.log <- data.table(epoch = 1:nlen, mres)
> ggplot(mres.log, aes(epoch, mres)) +
+   geom_line(color="red") +
+   theme_classic()
> 
> ch <- data.frame(mres, msrs)
> ch
           mres      msrs
1  1.798187e+01 1254.6253
2  1.438549e+01 1085.3811
3  1.150839e+01  976.7258
4  9.206716e+00  906.9672
5  7.365373e+00  862.1801
6  5.892298e+00  833.4247
7  4.713838e+00  814.9621
8  3.771071e+00  803.1078
9  3.016857e+00  795.4963
10 2.413485e+00  790.6089
11 1.930788e+00  787.4707
12 1.544631e+00  785.4556
13 1.235704e+00  784.1615
14 9.885636e-01  783.3306
15 7.908509e-01  782.7970
16 6.326807e-01  782.4543
17 5.061446e-01  782.2342
18 4.049156e-01  782.0929
19 3.239325e-01  782.0021
20 2.591460e-01  781.9438
21 2.073168e-01  781.9064
22 1.658534e-01  781.8823
23 1.326828e-01  781.8669
24 1.061462e-01  781.8569
25 8.491697e-02  781.8506
26 6.793357e-02  781.8465
27 5.434686e-02  781.8439
28 4.347749e-02  781.8422
29 3.478199e-02  781.8411
30 2.782559e-02  781.8404
31 2.226047e-02  781.8399
32 1.780838e-02  781.8396
33 1.424670e-02  781.8395
34 1.139736e-02  781.8393
35 9.117890e-03  781.8393
36 7.294312e-03  781.8392
37 5.835449e-03  781.8392
38 4.668360e-03  781.8392
39 3.734688e-03  781.8392
40 2.987750e-03  781.8391
41 2.390200e-03  781.8391
42 1.912160e-03  781.8391
43 1.529728e-03  781.8391
44 1.223782e-03  781.8391
45 9.790260e-04  781.8391
46 7.832208e-04  781.8391
47 6.265766e-04  781.8391
48 5.012613e-04  781.8391
49 4.010090e-04  781.8391
50 3.208072e-04  781.8391
51 2.566458e-04  781.8391
52 2.053166e-04  781.8391
53 1.642533e-04  781.8391
54 1.314026e-04  781.8391
55 1.051221e-04  781.8391
56 8.409769e-05  781.8391
57 6.727815e-05  781.8391
58 5.382252e-05  781.8391
59 4.305802e-05  781.8391
60 3.444641e-05  781.8391
61 2.755713e-05  781.8391
62 2.204570e-05  781.8391
63 1.763656e-05  781.8391
64 1.410925e-05  781.8391
65 1.128740e-05  781.8391
66 9.029921e-06  781.8391
67 7.223936e-06  781.8391
68 5.779149e-06  781.8391
69 4.623319e-06  781.8391
70 3.698655e-06  781.8391
71 2.958924e-06  781.8391
72 2.367140e-06  781.8391
73 1.893712e-06  781.8391
74 1.514969e-06  781.8391
75 1.211975e-06  781.8391
> max(y)
[1] 83.02991
> ggplot(data, aes(x = x, y = y)) + 
+   geom_point(size = 2) + 
+   geom_abline(aes(intercept = as, slope = bs),
+               data = parameters, linewidth = 0.5, 
+               color = 'green') + 
+   stat_poly_line() +
+   stat_poly_eq(use_label(c("eq", "R2"))) +
+   theme_classic() +
+   geom_abline(aes(intercept = as, slope = bs), 
+               data = parameters %>% slice_head(), 
+               linewidth = 1, color = 'blue') + 
+   geom_abline(aes(intercept = as, slope = bs), 
+               data = parameters %>% slice_tail(), 
+               linewidth = 1, color = 'red') +
+   labs(title = 'Gradient descent. blue: start, red: end, green: gradients')
> summary(mo)

Call:
lm(formula = y ~ x, data = data)

Residuals:
    Min      1Q  Median      3Q     Max 
-58.703 -20.303   0.331  19.381  51.929 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)   
(Intercept)   -2.708      8.313  -0.326  0.74601   
x              5.005      1.736   2.884  0.00587 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 28.54 on 48 degrees of freedom
Multiple R-squared:  0.1477,	Adjusted R-squared:  0.1299 
F-statistic: 8.316 on 1 and 48 DF,  p-value: 0.005867

> a.start
[1] 0.2680658
> b.start
[1] -0.5922083
> a
[1] -2.708293
> b
[1] 5.004837
> summary(mo)$coefficient[1]
[1] -2.708294
> summary(mo)$coefficient[2]
[1] 5.004838
> 
gradient_descent/output02.txt · Last modified: by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki