User Tools

Site Tools


gradient_descent

This is an old revision of the document!


Gradient Descent

점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기
prerequisite:
표준편차 추론에서 평균을 사용하는 이유: 실험적_수학적_이해
deriviation of a and b in a simple regression

위의 문서는 a, b에 대한 값을 미분법을 이용해서 직접 구하였다. 컴퓨터로는 이렇게 하기가 쉽지 않다. 그렇다면 이 값을 반복계산을 이용해서 추출하는 방법은 없을까? gradient descent

우선 위의 문서에서 (두번째) 최소값이 되는 SS값을 찾는다고 설명했는데, 이는 MS값으로 대체해서 생각해도 된다.

\begin{eqnarray*} \text{MS} & = & \frac {\text{SS}}{n} \end{eqnarray*}

\begin{eqnarray*} \text{for a (constant)} \\ \\ \dfrac{\text{d}}{\text{dv}} \text{MSE (Mean Square Error)} & = & \dfrac{\text{d}}{\text{dv}} \frac {\sum{(Y_i - (a + bX_i))^2}} {N} \\ & = & \sum \dfrac{\text{d}}{\text{dv}} \frac{{(Y_i - (a + bX_i))^2}} {N} \\ & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\ & \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\ & = & -2 \frac{1}{N} \sum{(Y_i - (a + bX_i))} \\ \end{eqnarray*}

library(tidyverse)
# a simple example
# statquest explanation
x <- c(0.5, 2.3, 2.9)
y <- c(1.4, 1.9, 3.2)

rm(list=ls())
# set.seed(191)
n <- 500
x <- rnorm(n, 5, 1.2)
y <- 2.14 * x + rnorm(n, 0, 4)

# data <- data.frame(x, y)
data <- tibble(x = x, y = y)
data

mo <- lm(y~x)
summary(mo)

# set.seed(191)
# Initialize random betas
b1 = rnorm(1)
b0 = rnorm(1)

# Predict function:
predict <- function(x, b0, b1){
  return (b0 + b1 * x)
}

# And loss function is:
residuals <- function(predictions, y) {
  return(y - predictions)
}

loss_mse <- function(predictions, y){
  residuals = y - predictions
  return(mean(residuals ^ 2))
}

predictions <- predict(x, b0, b1)
residuals <- residuals(predictions, y)
loss = loss_mse(predictions, y)

temp.sum <- data.frame(x, y, b0, b1,predictions, residuals)
temp.sum

print(paste0("Loss is: ", round(loss)))

gradient <- function(x, y, predictions){
  dinputs = y - predictions
  db1 = -2 * mean(x * dinputs)
  db0 = -2 * mean(dinputs)
  
  return(list("db1" = db1, "db0" = db0))
}

gradients <- gradient(x, y, predictions)
print(gradients)

# Train the model with scaled features
x_scaled <- (x - mean(x)) / sd(x)

learning_rate = 1e-1

# Record Loss for each epoch:
logs = list()
bs=list()
b0s = c()
b1s = c()
msr = c()

nlen <- 80
for (epoch in 1:nlen){
  # Predict all y values:
  predictions = predict(x_scaled, b0, b1)
  loss = loss_mse(predictions, y)
  msr = append(msr, loss)
  
  logs = append(logs, loss)
  
  if (epoch %% 10 == 0){
    print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
  }
  
  gradients <- gradient(x_scaled, y, predictions)
  db1 <- gradients$db1
  db0 <- gradients$db0
  
  b1 <- b1 - db1 * learning_rate
  b0 <- b0 - db0 * learning_rate
  b0s <- append(b0s, b0)
  b1s <- append(b1s, b1)
}
# I must unscale coefficients to make them comprehensible
b0 =  b0 - (mean(x) / sd(x)) * b1
b1 = b1 / sd(x)

b0s <- b0s - (mean(x) /sd(x)) * b1s
b1s <- b1s / sd(x)

parameters <- tibble(data.frame(b0s, b1s, msr))

cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n"))
summary(lm(y~x))$coefficients

ggplot(data, aes(x = x, y = y)) + 
  geom_point(size = 2) + 
  geom_abline(aes(intercept = b0s, slope = b1s),
              data = parameters, linewidth = 0.5, color = 'red') + 
  theme_classic() +
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_head(), 
              linewidth = 0.5, color = 'blue') + 
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_tail(), 
              linewidth = 1, color = 'green') +
  labs(title = 'Gradient descent: blue: start, green: end')
data
parameters
> # d statquest explanation
> x <- c(0.5, 2.3, 2.9)
> y <- c(1.4, 1.9, 3.2)
> 
> rm(list=ls())
> # set.seed(191)
> n <- 500
> x <- rnorm(n, 5, 1.2)
> y <- 2.14 * x + rnorm(n, 0, 4)
> 
> # data <- data.frame(x, y)
> data <- tibble(x = x, y = y)
> data
# A tibble: 500 × 2
       x     y
   <dbl> <dbl>
 1  6.78 10.6 
 2  7.17 17.2 
 3  4.63  5.80
 4  3.12 10.5 
 5  5.65  9.68
 6  5.12 10.8 
 7  4.05 16.8 
 8  7.27 16.5 
 9  4.13  3.96
10  5.27 13.9 
# ℹ 490 more rows
# ℹ Use `print(n = ...)` to see more rows
> 
> mo <- lm(y~x)
> summary(mo)

Call:
lm(formula = y ~ x)

Residuals:
    Min      1Q  Median      3Q     Max 
-10.474  -2.999   0.095   2.591  11.868 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)   0.4613     0.7339   0.629     0.53    
x             1.9828     0.1427  13.890   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 3.929 on 498 degrees of freedom
Multiple R-squared:  0.2792,	Adjusted R-squared:  0.2778 
F-statistic: 192.9 on 1 and 498 DF,  p-value: < 2.2e-16

> 
> # set.seed(191)
> # Initialize random betas
> b1 = rnorm(1)
> b0 = rnorm(1)
> 
> # Predict function:
> predict <- function(x, b0, b1){
+   return (b0 + b1 * x)
+ }
> 
> # And loss function is:
> residuals <- function(predictions, y) {
+   return(y - predictions)
+ }
> 
> loss_mse <- function(predictions, y){
+   residuals = y - predictions
+   return(mean(residuals ^ 2))
+ }
> 
> predictions <- predict(x, b0, b1)
> residuals <- residuals(predictions, y)
> loss = loss_mse(predictions, y)
> 
> temp.sum <- data.frame(x, y, b0, b1,predictions, residuals)
> temp.sum
            x          y        b0         b1 predictions  residuals
1   6.7787032 10.5653561 -1.264277 -0.5262407   -4.831506 15.3968625
2   7.1666745 17.2133136 -1.264277 -0.5262407   -5.035673 22.2489863
3   4.6255226  5.7995419 -1.264277 -0.5262407   -3.698415  9.4979572
4   3.1188584 10.4865173 -1.264277 -0.5262407   -2.905547 13.3920646
5   5.6541258  9.6787933 -1.264277 -0.5262407   -4.239708 13.9185013
6   5.1232990 10.8138626 -1.264277 -0.5262407   -3.960365 14.7742280
7   4.0510211 16.7630100 -1.264277 -0.5262407   -3.396089 20.1590992
8   7.2687149 16.4785689 -1.264277 -0.5262407   -5.089370 21.5679393
9   4.1333064  3.9624439 -1.264277 -0.5262407   -3.439391  7.4018350
10  5.2672548 13.9111428 -1.264277 -0.5262407   -4.036121 17.9472636
11  5.5156223 11.2320867 -1.264277 -0.5262407   -4.166822 15.3989085
12  3.1799040  6.1446257 -1.264277 -0.5262407   -2.937672  9.0822976
13  6.7547841 10.0729772 -1.264277 -0.5262407   -4.818919 14.8918964
14  7.2250923  7.4743490 -1.264277 -0.5262407   -5.066414 12.5407635
15  5.6642827  6.1278221 -1.264277 -0.5262407   -4.245053 10.3728751
16  6.4099293 16.6274813 -1.264277 -0.5262407   -4.637443 21.2649238
17  4.2914899 11.2173980 -1.264277 -0.5262407   -3.522634 14.7400316
18  4.6413358  8.0794876 -1.264277 -0.5262407   -3.706737 11.7862243
19  5.6684969  7.1059092 -1.264277 -0.5262407   -4.247271 11.3531799
20  5.3803989 10.4308559 -1.264277 -0.5262407   -4.095662 14.5265177
21  4.0240137  4.0767913 -1.264277 -0.5262407   -3.381877  7.4586681
22  5.0816903  9.5443887 -1.264277 -0.5262407   -3.938469 13.4828579
23  6.4856995  2.8470186 -1.264277 -0.5262407   -4.677316  7.5243345
24  4.1315971 15.0691180 -1.264277 -0.5262407   -3.438492 18.5076095
25  4.3975735 10.0977703 -1.264277 -0.5262407   -3.578459 13.6762295
26  3.6800792 14.8057519 -1.264277 -0.5262407   -3.200884 18.0066364
27  5.9063058 14.7436966 -1.264277 -0.5262407   -4.372415 19.1161120
28  4.0259638  5.7319405 -1.264277 -0.5262407   -3.382903  9.1148435
29  5.6966522  9.5224078 -1.264277 -0.5262407   -4.262087 13.7844950
30  3.8406098 10.1716630 -1.264277 -0.5262407   -3.285362 13.4570252
31  3.7738837 10.1961941 -1.264277 -0.5262407   -3.250248 13.4464424
32  4.8977716 14.6487758 -1.264277 -0.5262407   -3.841684 18.4904595
33  3.5182859 -0.6639891 -1.264277 -0.5262407   -3.115742  2.4517532
34  6.7555953 18.3109425 -1.264277 -0.5262407   -4.819346 23.1302886
35  5.3450800  7.5049835 -1.264277 -0.5262407   -4.077076 11.5820590
36  4.6063392  3.7802993 -1.264277 -0.5262407   -3.688320  7.4686194
37  3.3150851 10.6463845 -1.264277 -0.5262407   -3.008810 13.6551942
38  2.1438875  5.5805865 -1.264277 -0.5262407   -2.392478  7.9730644
39  6.7867082 14.4541469 -1.264277 -0.5262407   -4.835719 19.2898659
40  7.1651170 16.0295592 -1.264277 -0.5262407   -5.034853 21.0644123
41  6.0068801  4.2186189 -1.264277 -0.5262407   -4.425342  8.6439606
42  3.8429583  6.8926969 -1.264277 -0.5262407   -3.286598 10.1792950
43  6.6963199 14.9606533 -1.264277 -0.5262407   -4.788153 19.7488062
44  5.9766522 17.5076500 -1.264277 -0.5262407   -4.409435 21.9170845
45  3.4139235  3.5029295 -1.264277 -0.5262407   -3.060823  6.5637521
46  7.1468347 15.4958432 -1.264277 -0.5262407   -5.025232 20.5210754
47  2.9413548  1.8553437 -1.264277 -0.5262407   -2.812138  4.6674813
48  2.7033044  4.8611663 -1.264277 -0.5262407   -2.686866  7.5480321
49  5.6384145 15.7331403 -1.264277 -0.5262407   -4.231440 19.9645804
50  3.6549025 13.8667532 -1.264277 -0.5262407   -3.187635 17.0543887
51  2.4549093 -3.2303197 -1.264277 -0.5262407   -2.556150 -0.6741695
52  3.4177686  5.1092427 -1.264277 -0.5262407   -3.062846  8.1720886
53  8.5165648 25.6896235 -1.264277 -0.5262407   -5.746040 31.4356634
54  5.6089046  8.3783769 -1.264277 -0.5262407   -4.215911 12.5942877
55  3.3871150  8.5696011 -1.264277 -0.5262407   -3.046715 11.6163159
56  4.4518647  6.5901035 -1.264277 -0.5262407   -3.607029 10.1971329
57  5.5052891  7.0080255 -1.264277 -0.5262407   -4.161384 11.1694096
58  4.6203097  9.7294702 -1.264277 -0.5262407   -3.695672 13.4251422
59  5.4313621 11.1323523 -1.264277 -0.5262407   -4.122481 15.2548330
60  4.6889464 15.0880915 -1.264277 -0.5262407   -3.731791 18.8198829
61  6.0363099  9.4591404 -1.264277 -0.5262407   -4.440829 13.8999693
62  5.9834349 15.5984185 -1.264277 -0.5262407   -4.413004 20.0114224
63  6.1256065  8.8309730 -1.264277 -0.5262407   -4.487820 13.3187934
64  4.7770683  8.5916899 -1.264277 -0.5262407   -3.778165 12.3698547
65  3.0977662  6.5005729 -1.264277 -0.5262407   -2.894448  9.3950206
66  7.2582566 10.1658760 -1.264277 -0.5262407   -5.083867 15.2497428
67  4.4274063 10.5254164 -1.264277 -0.5262407   -3.594158 14.1195747
68  4.4224093 15.3288304 -1.264277 -0.5262407   -3.591529 18.9203591
69  7.9562345 15.8094793 -1.264277 -0.5262407   -5.451171 21.2606505
70  6.1605765 17.6226869 -1.264277 -0.5262407   -4.506223 22.1289099
71  4.8811933 11.4006073 -1.264277 -0.5262407   -3.832960 15.2335669
72  5.8458839  8.8660352 -1.264277 -0.5262407   -4.340619 13.2066542
73  5.4176401 17.9873603 -1.264277 -0.5262407   -4.115260 22.1026199
74  5.1877292 10.0773562 -1.264277 -0.5262407   -3.994271 14.0716274
75  4.4385788  8.1866806 -1.264277 -0.5262407   -3.600038 11.7867184
76  6.5079686 16.8845586 -1.264277 -0.5262407   -4.689035 21.5735935
77  5.8326041 12.8972544 -1.264277 -0.5262407   -4.333631 17.2308850
78  4.7833622 10.1946848 -1.264277 -0.5262407   -3.781477 13.9761616
79  5.9256779 16.6025871 -1.264277 -0.5262407   -4.382610 20.9851969
80  4.9860307 12.9422586 -1.264277 -0.5262407   -3.888129 16.8303879
81  5.2401527 17.2049456 -1.264277 -0.5262407   -4.021859 21.2268042
82  6.7570239 13.8736266 -1.264277 -0.5262407   -4.820098 18.6937244
83  5.4384563 14.1770698 -1.264277 -0.5262407   -4.126214 18.3032838
84  5.4498287 12.0591785 -1.264277 -0.5262407   -4.132199 16.1913771
85  5.9643832 13.8220440 -1.264277 -0.5262407   -4.402978 18.2250221
86  4.4057387  8.0541411 -1.264277 -0.5262407   -3.582756 11.6368971
87  3.6834860 15.4963951 -1.264277 -0.5262407   -3.202677 18.6990723
88  6.4421163  9.7619937 -1.264277 -0.5262407   -4.654381 14.4163744
89  6.5442359 13.3216053 -1.264277 -0.5262407   -4.708120 18.0297254
90  5.9329826  8.1974096 -1.264277 -0.5262407   -4.386454 12.5838635
91  6.3086002  9.3337688 -1.264277 -0.5262407   -4.584119 13.9178879
92  6.6894199  9.9187932 -1.264277 -0.5262407   -4.784522 14.7033151
93  4.3681693  6.9327113 -1.264277 -0.5262407   -3.562985 10.4956967
94  6.6461515 12.5441888 -1.264277 -0.5262407   -4.761752 17.3059411
95  5.1824709  7.4015164 -1.264277 -0.5262407   -3.991504 11.3930204
96  2.1875381  1.6910838 -1.264277 -0.5262407   -2.415449  4.1065324
97  5.6086270 10.8605436 -1.264277 -0.5262407   -4.215765 15.0763083
98  5.4158031  7.6891953 -1.264277 -0.5262407   -4.114293 11.8034883
99  0.8135196  7.8181045 -1.264277 -0.5262407   -1.692384  9.5104888
100 3.7676048  9.5822003 -1.264277 -0.5262407   -3.246944 12.8291443
101 3.7048040  5.3532350 -1.264277 -0.5262407   -3.213896  8.5671307
102 4.5848668  8.3965243 -1.264277 -0.5262407   -3.677020 12.0735447
103 4.6436020  5.5043385 -1.264277 -0.5262407   -3.707929  9.2122678
104 5.5137839 19.8317442 -1.264277 -0.5262407   -4.165854 23.9975986
105 5.5460450 11.7004649 -1.264277 -0.5262407   -4.182832 15.8832965
106 2.6322855  8.5278252 -1.264277 -0.5262407   -2.649493 11.1773180
107 4.8768933  9.2744021 -1.264277 -0.5262407   -3.830697 13.1050988
108 4.2582660 15.6448394 -1.264277 -0.5262407   -3.505150 19.1499892
109 5.3263384 11.2636384 -1.264277 -0.5262407   -4.067213 15.3308513
110 4.7875378  7.8894584 -1.264277 -0.5262407   -3.783674 11.6731326
111 3.6879691  6.7596765 -1.264277 -0.5262407   -3.205036  9.9647129
112 4.2645818  9.7032671 -1.264277 -0.5262407   -3.508474 13.2117406
113 5.7080723 11.4807855 -1.264277 -0.5262407   -4.268097 15.7488824
114 4.4170445 11.9510702 -1.264277 -0.5262407   -3.588706 15.5397757
115 4.7051114 15.0610402 -1.264277 -0.5262407   -3.740298 18.8013383
116 3.2868115  2.0588466 -1.264277 -0.5262407   -2.993931  5.0527776
117 4.9496883 11.1438762 -1.264277 -0.5262407   -3.869004 15.0128806
118 5.3579592 10.4867567 -1.264277 -0.5262407   -4.083853 14.5706099
119 3.4340821  4.9124528 -1.264277 -0.5262407   -3.071431  7.9838836
120 7.5112322 15.4265323 -1.264277 -0.5262407   -5.216993 20.6435252
121 7.4090157 17.2211544 -1.264277 -0.5262407   -5.163202 22.3843568
122 5.3807003 12.0371465 -1.264277 -0.5262407   -4.095820 16.1329670
123 5.9982219 12.2743501 -1.264277 -0.5262407   -4.420785 16.6951354
124 3.0991733  0.5928870 -1.264277 -0.5262407   -2.895188  3.4880752
125 4.9540596 12.7552638 -1.264277 -0.5262407   -3.871305 16.6265686
126 5.4521347 13.9911543 -1.264277 -0.5262407   -4.133412 18.1245665
127 4.2101665 11.0247402 -1.264277 -0.5262407   -3.479838 14.5045782
128 2.1382146  6.3007539 -1.264277 -0.5262407   -2.389493  8.6902465
129 4.7861186  8.4328932 -1.264277 -0.5262407   -3.782927 12.2158206
130 4.6243029 12.1062901 -1.264277 -0.5262407   -3.697773 15.8040635
131 4.7045999 21.6574011 -1.264277 -0.5262407   -3.740029 25.3974300
132 5.0353365  9.1128945 -1.264277 -0.5262407   -3.914076 13.0269705
133 5.7684327 14.9482380 -1.264277 -0.5262407   -4.299861 19.2480990
134 5.4650414 11.1578444 -1.264277 -0.5262407   -4.140204 15.2980486
135 6.3983266 12.7511754 -1.264277 -0.5262407   -4.631337 17.3825121
136 5.1054629 12.2676585 -1.264277 -0.5262407   -3.950979 16.2186379
137 5.4727620 15.7361092 -1.264277 -0.5262407   -4.144267 19.8803762
138 4.7336163 10.1946145 -1.264277 -0.5262407   -3.755299 13.9499130
139 5.2480093  6.9764202 -1.264277 -0.5262407   -4.025993 11.0024132
140 3.0004377  2.7520590 -1.264277 -0.5262407   -2.843229  5.5952884
141 3.6651683  7.2627301 -1.264277 -0.5262407   -3.193038 10.4557678
142 4.5906320  3.7751924 -1.264277 -0.5262407   -3.680054  7.4552467
143 5.8844053 19.4617765 -1.264277 -0.5262407   -4.360890 23.8226670
144 3.5734761 14.9494035 -1.264277 -0.5262407   -3.144786 18.0941891
145 4.1928815 11.5554015 -1.264277 -0.5262407   -3.470742 15.0261434
146 6.8910946 12.6820411 -1.264277 -0.5262407   -4.890651 17.5726924
147 5.7008138  7.6449870 -1.264277 -0.5262407   -4.264277 11.9092641
148 3.6069183 14.5345051 -1.264277 -0.5262407   -3.162384 17.6968893
149 5.4279511 12.5578155 -1.264277 -0.5262407   -4.120686 16.6785012
150 5.6139766 11.8074150 -1.264277 -0.5262407   -4.218580 16.0259949
151 5.6023394 11.5063542 -1.264277 -0.5262407   -4.212456 15.7188101
152 6.2157176  8.5306942 -1.264277 -0.5262407   -4.535240 13.0659347
153 5.6589077  8.3478592 -1.264277 -0.5262407   -4.242224 12.5900837
154 2.8113763  1.9872925 -1.264277 -0.5262407   -2.743738  4.7310302
155 4.5806760  4.8901946 -1.264277 -0.5262407   -3.674815  8.5650098
156 5.6392284 13.7001285 -1.264277 -0.5262407   -4.231868 17.9319969
157 5.2827326 14.4966648 -1.264277 -0.5262407   -4.044266 18.5409307
158 4.4994363  4.2327446 -1.264277 -0.5262407   -3.632063  7.8648081
159 5.2554305  7.3630872 -1.264277 -0.5262407   -4.029898 11.3929856
160 5.5622164  9.0163544 -1.264277 -0.5262407   -4.191342 13.2076960
161 7.0995953 12.5898387 -1.264277 -0.5262407   -5.000373 17.5902115
162 5.6046834 12.2369786 -1.264277 -0.5262407   -4.213689 16.4506680
163 3.9260205  5.3385293 -1.264277 -0.5262407   -3.330309  8.6688381
164 6.6192734 16.2775842 -1.264277 -0.5262407   -4.747608 21.0251922
165 5.2875993  8.5628853 -1.264277 -0.5262407   -4.046827 12.6097122
166 4.3956646  8.8889738 -1.264277 -0.5262407   -3.577455 12.4664284
 [ reached 'max' / getOption("max.print") -- omitted 334 rows ]
> 
> print(paste0("Loss is: ", round(loss)))
[1] "Loss is: 228"
> 
> gradient <- function(x, y, predictions){
+   dinputs = y - predictions
+   db1 = -2 * mean(x * dinputs)
+   db0 = -2 * mean(dinputs)
+   
+   return(list("db1" = db1, "db0" = db0))
+ }
> 
> gradients <- gradient(x, y, predictions)
> print(gradients)
$db1
[1] -149.8879

$db0
[1] -28.50182

> 
> # Train the model with scaled features
> x_scaled <- (x - mean(x)) / sd(x)
> 
> learning_rate = 1e-1
> 
> # Record Loss for each epoch:
> logs = list()
> bs=list()
> b0s = c()
> b1s = c()
> msr = c()
> 
> nlen <- 80
> for (epoch in 1:nlen){
+   # Predict all y values:
+   predictions = predict(x_scaled, b0, b1)
+   loss = loss_mse(predictions, y)
+   msr = append(msr, loss)
+   
+   logs = append(logs, loss)
+   
+   if (epoch %% 10 == 0){
+     print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
+   }
+   
+   gradients <- gradient(x_scaled, y, predictions)
+   db1 <- gradients$db1
+   db0 <- gradients$db0
+   
+   b1 <- b1 - db1 * learning_rate
+   b0 <- b0 - db0 * learning_rate
+   b0s <- append(b0s, b0)
+   b1s <- append(b1s, b1)
+ }
[1] "Epoch: 10, Loss: 17.96644"
[1] "Epoch: 20, Loss: 15.40245"
[1] "Epoch: 30, Loss: 15.37287"
[1] "Epoch: 40, Loss: 15.37253"
[1] "Epoch: 50, Loss: 15.37253"
[1] "Epoch: 60, Loss: 15.37253"
[1] "Epoch: 70, Loss: 15.37253"
[1] "Epoch: 80, Loss: 15.37253"
> # I must unscale coefficients to make them comprehensible
> b0 =  b0 - (mean(x) / sd(x)) * b1
> b1 = b1 / sd(x)
> 
> b0s <- b0s - (mean(x) /sd(x)) * b1s
> b1s <- b1s / sd(x)
> 
> parameters <- tibble(data.frame(b0s, b1s, msr))
> 
> cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n"))
Inclination: 1.98275951151325, 
Intercept: 0.461293071825862
> summary(lm(y~x))$coefficients
             Estimate Std. Error    t value     Pr(>|t|)
(Intercept) 0.4612931  0.7339392  0.6285167 5.299537e-01
x           1.9827596  0.1427436 13.8903548 2.612507e-37
> 
> ggplot(data, aes(x = x, y = y)) + 
+   geom_point(size = 2) + 
+   geom_abline(aes(intercept = b0s, slope = b1s),
+               data = parameters, linewidth = 0.5, color = 'red') + 
+   theme_classic() +
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_head(), 
+               linewidth = 0.5, color = 'blue') + 
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_tail(), 
+               linewidth = 1, color = 'green') +
+   labs(title = 'Gradient descent: blue: start, green: end')
> data
# A tibble: 500 × 2
       x     y
   <dbl> <dbl>
 1  6.78 10.6 
 2  7.17 17.2 
 3  4.63  5.80
 4  3.12 10.5 
 5  5.65  9.68
 6  5.12 10.8 
 7  4.05 16.8 
 8  7.27 16.5 
 9  4.13  3.96
10  5.27 13.9 
# ℹ 490 more rows
# ℹ Use `print(n = ...)` to see more rows
> parameters
# A tibble: 80 × 3
     b0s    b1s   msr
   <dbl>  <dbl> <dbl>
 1 0.791 0.0539 159. 
 2 0.729 0.439  107. 
 3 0.679 0.747   74.3
 4 0.638 0.994   53.1
 5 0.604 1.19    39.5
 6 0.577 1.35    30.8
 7 0.555 1.48    25.3
 8 0.538 1.58    21.7
 9 0.523 1.66    19.4
10 0.511 1.72    18.0
# ℹ 70 more rows
# ℹ Use `print(n = ...)` to see more rows
> 
> 
gradient_descent.1754023408.txt.gz · Last modified: 2025/08/01 13:43 by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki