User Tools

Site Tools


gradient_descent

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revisionPrevious revision
Next revision
Previous revision
gradient_descent [2025/08/01 13:43] hkimscilgradient_descent [2025/08/01 18:57] (current) – [R output] hkimscil
Line 1: Line 1:
 ====== Gradient Descent ====== ====== Gradient Descent ======
 +====== explanation ======
 점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기 점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기
 prerequisite:  prerequisite: 
Line 22: Line 23:
 & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\ & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\
 & \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\ & \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\
-& = & -2 \frac{1}{N} \sum{(Y_i - (a + bX_i))} \\ +& = & -2 \frac{\sum{(Y_i - (a + bX_i))}}{N} \\ 
 +& = & -2 * \text{mean of residuals} \\ 
 \end{eqnarray*}  \end{eqnarray*} 
 +아래 R code에서 gradient function을 참조.
  
 </WRAP> </WRAP>
 +<WRAP box>
 +\begin{eqnarray*}
 +\text{for b, (coefficient)} \\ 
 +\\
 +\dfrac{\text{d}}{\text{dv}} \frac{\sum{(Y_i - (a + bX_i))^2}}{N}  & = & \sum \dfrac{\text{d}}{\text{dv}} \frac{{(Y_i - (a + bX_i))^2}} {N} \\ 
 +& = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-X_i) \;\;\;\; \\
 +& \because & \dfrac{\text{d}}{\text{dv for b}} (Y_i - (a+bX_i)) = -X_i \\
 +& = & -2 X_i \frac{\sum{(Y_i - (a + bX_i))}}{N} \\
 +& = & -2 * X_i * \text{mean of residuals} \\
  
 +\\ 
 +\end{eqnarray*} 
 +(미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/N)*X값을 곱한 값이다.  
 +</WRAP> 
 +====== R code ======
 <code> <code>
-library(tidyverse) +statquest explanation 
-# a simple example +x <- c(0.5, 2.3, 2.9) 
-# statquest explanation +y <- c(1.4, 1.9, 3.2)
-x <- c(0.5, 2.3, 2.9) +
-y <- c(1.4, 1.9, 3.2)+
  
 rm(list=ls()) rm(list=ls())
 # set.seed(191) # set.seed(191)
-n <- 500+n <- 300
 x <- rnorm(n, 5, 1.2) x <- rnorm(n, 5, 1.2)
 y <- 2.14 * x + rnorm(n, 0, 4) y <- 2.14 * x + rnorm(n, 0, 4)
Line 43: Line 57:
 # data <- data.frame(x, y) # data <- data.frame(x, y)
 data <- tibble(x = x, y = y) data <- tibble(x = x, y = y)
-data 
  
 mo <- lm(y~x) mo <- lm(y~x)
Line 52: Line 65:
 b1 = rnorm(1) b1 = rnorm(1)
 b0 = rnorm(1) b0 = rnorm(1)
 +
 +b1.init <- b1
 +b0.init <- b0
  
 # Predict function: # Predict function:
Line 72: Line 88:
 loss = loss_mse(predictions, y) loss = loss_mse(predictions, y)
  
-temp.sum <- data.frame(x, y, b0, b1,predictions, residuals) +data <- tibble(data.frame(x, y, predictions, residuals))
-temp.sum+
  
 print(paste0("Loss is: ", round(loss))) print(paste0("Loss is: ", round(loss)))
Line 94: Line 109:
  
 # Record Loss for each epoch: # Record Loss for each epoch:
-logs = list() +logs = list() 
-bs=list()+bs=list()
 b0s = c() b0s = c()
 b1s = c() b1s = c()
-msr = c()+mse = c()
  
 nlen <- 80 nlen <- 80
Line 105: Line 120:
   predictions = predict(x_scaled, b0, b1)   predictions = predict(x_scaled, b0, b1)
   loss = loss_mse(predictions, y)   loss = loss_mse(predictions, y)
-  msr = append(msr, loss) +  mse = append(mse, loss) 
-   +  logs = append(logs, loss)
-  logs = append(logs, loss)+
      
   if (epoch %% 10 == 0){   if (epoch %% 10 == 0){
Line 122: Line 136:
   b1s <- append(b1s, b1)   b1s <- append(b1s, b1)
 } }
-I must unscale coefficients to make them comprehensible+ 
 +# unscale coefficients to make them comprehensible
 b0 =  b0 - (mean(x) / sd(x)) * b1 b0 =  b0 - (mean(x) / sd(x)) * b1
 b1 = b1 / sd(x) b1 = b1 / sd(x)
  
 +# changes of estimators
 b0s <- b0s - (mean(x) /sd(x)) * b1s b0s <- b0s - (mean(x) /sd(x)) * b1s
 b1s <- b1s / sd(x) b1s <- b1s / sd(x)
  
-parameters <- tibble(data.frame(b0s, b1s, msr))+parameters <- tibble(data.frame(b0s, b1s, mse))
  
-cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n"))+cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n"))
 summary(lm(y~x))$coefficients summary(lm(y~x))$coefficients
  
Line 137: Line 153:
   geom_point(size = 2) +    geom_point(size = 2) + 
   geom_abline(aes(intercept = b0s, slope = b1s),   geom_abline(aes(intercept = b0s, slope = b1s),
-              data = parameters, linewidth = 0.5, color = 'red') + +              data = parameters, linewidth = 0.5,  
 +              color = 'green') + 
   theme_classic() +   theme_classic() +
   geom_abline(aes(intercept = b0s, slope = b1s),    geom_abline(aes(intercept = b0s, slope = b1s), 
               data = parameters %>% slice_head(),                data = parameters %>% slice_head(), 
-              linewidth = 0.5, color = 'blue') + +              linewidth = 1, color = 'blue') + 
   geom_abline(aes(intercept = b0s, slope = b1s),    geom_abline(aes(intercept = b0s, slope = b1s), 
               data = parameters %>% slice_tail(),                data = parameters %>% slice_tail(), 
-              linewidth = 1, color = 'green') + +              linewidth = 1, color = 'red') + 
-  labs(title = 'Gradient descentblue: start, green: end')+  labs(title = 'Gradient descentblue: start, red: end, green: gradients') 
 + 
 +b0.init 
 +b1.init 
 data data
 parameters parameters
 +
  
 </code> </code>
 +====== R output =====
 <code> <code>
-> # d statquest explanation 
-> x <- c(0.5, 2.3, 2.9) 
-> y <- c(1.4, 1.9, 3.2) 
- 
 > rm(list=ls()) > rm(list=ls())
 > # set.seed(191) > # set.seed(191)
-> n <- 500+> n <- 300
 > x <- rnorm(n, 5, 1.2) > x <- rnorm(n, 5, 1.2)
 > y <- 2.14 * x + rnorm(n, 0, 4) > y <- 2.14 * x + rnorm(n, 0, 4)
Line 163: Line 182:
 > # data <- data.frame(x, y) > # data <- data.frame(x, y)
 > data <- tibble(x = x, y = y) > data <- tibble(x = x, y = y)
-> data 
-# A tibble: 500 × 2 
-           y 
-   <dbl> <dbl> 
-  6.78 10.6  
-  7.17 17.2  
-  4.63  5.80 
-  3.12 10.5  
-  5.65  9.68 
-  5.12 10.8  
-  4.05 16.8  
-  7.27 16.5  
-  4.13  3.96 
-10  5.27 13.9  
-# ℹ 490 more rows 
-# ℹ Use `print(n = ...)` to see more rows 
  
 > mo <- lm(y~x) > mo <- lm(y~x)
Line 187: Line 190:
  
 Residuals: Residuals:
-    Min      1Q  Median      3Q     Max  +   Min     1Q Median     3Q    Max  
--10.474  -2.999   0.095   2.591  11.868 +-9.754 -2.729 -0.135  2.415 10.750 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)   0.4613     0.7339   0.629     0.53     +(Intercept)  -0.7794     0.9258  -0.842    0.401     
-x             1.9828     0.1427  13.890   <2e-16 ***+x             2.2692     0.1793  12.658   <2e-16 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 3.929 on 498 degrees of freedom +Residual standard error: 3.951 on 298 degrees of freedom 
-Multiple R-squared:  0.2792, Adjusted R-squared:  0.2778  +Multiple R-squared:  0.3497, Adjusted R-squared:  0.3475  
-F-statistic: 192.on 1 and 498 DF,  p-value: < 2.2e-16+F-statistic: 160.on 1 and 298 DF,  p-value: < 2.2e-16
  
  
Line 206: Line 209:
 > b1 = rnorm(1) > b1 = rnorm(1)
 > b0 = rnorm(1) > b0 = rnorm(1)
 +
 +> b1.init <- b1
 +> b0.init <- b0
  
 > # Predict function: > # Predict function:
Line 226: Line 232:
 > loss = loss_mse(predictions, y) > loss = loss_mse(predictions, y)
  
-temp.sum <- data.frame(x, y, b0, b1,predictions, residuals) +data <- tibble(data.frame(x, y, predictions, residuals))
-> temp.sum +
-            x          y        b0         b1 predictions  residuals +
-1   6.7787032 10.5653561 -1.264277 -0.5262407   -4.831506 15.3968625 +
-2   7.1666745 17.2133136 -1.264277 -0.5262407   -5.035673 22.2489863 +
-3   4.6255226  5.7995419 -1.264277 -0.5262407   -3.698415  9.4979572 +
-4   3.1188584 10.4865173 -1.264277 -0.5262407   -2.905547 13.3920646 +
-5   5.6541258  9.6787933 -1.264277 -0.5262407   -4.239708 13.9185013 +
-6   5.1232990 10.8138626 -1.264277 -0.5262407   -3.960365 14.7742280 +
-7   4.0510211 16.7630100 -1.264277 -0.5262407   -3.396089 20.1590992 +
-8   7.2687149 16.4785689 -1.264277 -0.5262407   -5.089370 21.5679393 +
-9   4.1333064  3.9624439 -1.264277 -0.5262407   -3.439391  7.4018350 +
-10  5.2672548 13.9111428 -1.264277 -0.5262407   -4.036121 17.9472636 +
-11  5.5156223 11.2320867 -1.264277 -0.5262407   -4.166822 15.3989085 +
-12  3.1799040  6.1446257 -1.264277 -0.5262407   -2.937672  9.0822976 +
-13  6.7547841 10.0729772 -1.264277 -0.5262407   -4.818919 14.8918964 +
-14  7.2250923  7.4743490 -1.264277 -0.5262407   -5.066414 12.5407635 +
-15  5.6642827  6.1278221 -1.264277 -0.5262407   -4.245053 10.3728751 +
-16  6.4099293 16.6274813 -1.264277 -0.5262407   -4.637443 21.2649238 +
-17  4.2914899 11.2173980 -1.264277 -0.5262407   -3.522634 14.7400316 +
-18  4.6413358  8.0794876 -1.264277 -0.5262407   -3.706737 11.7862243 +
-19  5.6684969  7.1059092 -1.264277 -0.5262407   -4.247271 11.3531799 +
-20  5.3803989 10.4308559 -1.264277 -0.5262407   -4.095662 14.5265177 +
-21  4.0240137  4.0767913 -1.264277 -0.5262407   -3.381877  7.4586681 +
-22  5.0816903  9.5443887 -1.264277 -0.5262407   -3.938469 13.4828579 +
-23  6.4856995  2.8470186 -1.264277 -0.5262407   -4.677316  7.5243345 +
-24  4.1315971 15.0691180 -1.264277 -0.5262407   -3.438492 18.5076095 +
-25  4.3975735 10.0977703 -1.264277 -0.5262407   -3.578459 13.6762295 +
-26  3.6800792 14.8057519 -1.264277 -0.5262407   -3.200884 18.0066364 +
-27  5.9063058 14.7436966 -1.264277 -0.5262407   -4.372415 19.1161120 +
-28  4.0259638  5.7319405 -1.264277 -0.5262407   -3.382903  9.1148435 +
-29  5.6966522  9.5224078 -1.264277 -0.5262407   -4.262087 13.7844950 +
-30  3.8406098 10.1716630 -1.264277 -0.5262407   -3.285362 13.4570252 +
-31  3.7738837 10.1961941 -1.264277 -0.5262407   -3.250248 13.4464424 +
-32  4.8977716 14.6487758 -1.264277 -0.5262407   -3.841684 18.4904595 +
-33  3.5182859 -0.6639891 -1.264277 -0.5262407   -3.115742  2.4517532 +
-34  6.7555953 18.3109425 -1.264277 -0.5262407   -4.819346 23.1302886 +
-35  5.3450800  7.5049835 -1.264277 -0.5262407   -4.077076 11.5820590 +
-36  4.6063392  3.7802993 -1.264277 -0.5262407   -3.688320  7.4686194 +
-37  3.3150851 10.6463845 -1.264277 -0.5262407   -3.008810 13.6551942 +
-38  2.1438875  5.5805865 -1.264277 -0.5262407   -2.392478  7.9730644 +
-39  6.7867082 14.4541469 -1.264277 -0.5262407   -4.835719 19.2898659 +
-40  7.1651170 16.0295592 -1.264277 -0.5262407   -5.034853 21.0644123 +
-41  6.0068801  4.2186189 -1.264277 -0.5262407   -4.425342  8.6439606 +
-42  3.8429583  6.8926969 -1.264277 -0.5262407   -3.286598 10.1792950 +
-43  6.6963199 14.9606533 -1.264277 -0.5262407   -4.788153 19.7488062 +
-44  5.9766522 17.5076500 -1.264277 -0.5262407   -4.409435 21.9170845 +
-45  3.4139235  3.5029295 -1.264277 -0.5262407   -3.060823  6.5637521 +
-46  7.1468347 15.4958432 -1.264277 -0.5262407   -5.025232 20.5210754 +
-47  2.9413548  1.8553437 -1.264277 -0.5262407   -2.812138  4.6674813 +
-48  2.7033044  4.8611663 -1.264277 -0.5262407   -2.686866  7.5480321 +
-49  5.6384145 15.7331403 -1.264277 -0.5262407   -4.231440 19.9645804 +
-50  3.6549025 13.8667532 -1.264277 -0.5262407   -3.187635 17.0543887 +
-51  2.4549093 -3.2303197 -1.264277 -0.5262407   -2.556150 -0.6741695 +
-52  3.4177686  5.1092427 -1.264277 -0.5262407   -3.062846  8.1720886 +
-53  8.5165648 25.6896235 -1.264277 -0.5262407   -5.746040 31.4356634 +
-54  5.6089046  8.3783769 -1.264277 -0.5262407   -4.215911 12.5942877 +
-55  3.3871150  8.5696011 -1.264277 -0.5262407   -3.046715 11.6163159 +
-56  4.4518647  6.5901035 -1.264277 -0.5262407   -3.607029 10.1971329 +
-57  5.5052891  7.0080255 -1.264277 -0.5262407   -4.161384 11.1694096 +
-58  4.6203097  9.7294702 -1.264277 -0.5262407   -3.695672 13.4251422 +
-59  5.4313621 11.1323523 -1.264277 -0.5262407   -4.122481 15.2548330 +
-60  4.6889464 15.0880915 -1.264277 -0.5262407   -3.731791 18.8198829 +
-61  6.0363099  9.4591404 -1.264277 -0.5262407   -4.440829 13.8999693 +
-62  5.9834349 15.5984185 -1.264277 -0.5262407   -4.413004 20.0114224 +
-63  6.1256065  8.8309730 -1.264277 -0.5262407   -4.487820 13.3187934 +
-64  4.7770683  8.5916899 -1.264277 -0.5262407   -3.778165 12.3698547 +
-65  3.0977662  6.5005729 -1.264277 -0.5262407   -2.894448  9.3950206 +
-66  7.2582566 10.1658760 -1.264277 -0.5262407   -5.083867 15.2497428 +
-67  4.4274063 10.5254164 -1.264277 -0.5262407   -3.594158 14.1195747 +
-68  4.4224093 15.3288304 -1.264277 -0.5262407   -3.591529 18.9203591 +
-69  7.9562345 15.8094793 -1.264277 -0.5262407   -5.451171 21.2606505 +
-70  6.1605765 17.6226869 -1.264277 -0.5262407   -4.506223 22.1289099 +
-71  4.8811933 11.4006073 -1.264277 -0.5262407   -3.832960 15.2335669 +
-72  5.8458839  8.8660352 -1.264277 -0.5262407   -4.340619 13.2066542 +
-73  5.4176401 17.9873603 -1.264277 -0.5262407   -4.115260 22.1026199 +
-74  5.1877292 10.0773562 -1.264277 -0.5262407   -3.994271 14.0716274 +
-75  4.4385788  8.1866806 -1.264277 -0.5262407   -3.600038 11.7867184 +
-76  6.5079686 16.8845586 -1.264277 -0.5262407   -4.689035 21.5735935 +
-77  5.8326041 12.8972544 -1.264277 -0.5262407   -4.333631 17.2308850 +
-78  4.7833622 10.1946848 -1.264277 -0.5262407   -3.781477 13.9761616 +
-79  5.9256779 16.6025871 -1.264277 -0.5262407   -4.382610 20.9851969 +
-80  4.9860307 12.9422586 -1.264277 -0.5262407   -3.888129 16.8303879 +
-81  5.2401527 17.2049456 -1.264277 -0.5262407   -4.021859 21.2268042 +
-82  6.7570239 13.8736266 -1.264277 -0.5262407   -4.820098 18.6937244 +
-83  5.4384563 14.1770698 -1.264277 -0.5262407   -4.126214 18.3032838 +
-84  5.4498287 12.0591785 -1.264277 -0.5262407   -4.132199 16.1913771 +
-85  5.9643832 13.8220440 -1.264277 -0.5262407   -4.402978 18.2250221 +
-86  4.4057387  8.0541411 -1.264277 -0.5262407   -3.582756 11.6368971 +
-87  3.6834860 15.4963951 -1.264277 -0.5262407   -3.202677 18.6990723 +
-88  6.4421163  9.7619937 -1.264277 -0.5262407   -4.654381 14.4163744 +
-89  6.5442359 13.3216053 -1.264277 -0.5262407   -4.708120 18.0297254 +
-90  5.9329826  8.1974096 -1.264277 -0.5262407   -4.386454 12.5838635 +
-91  6.3086002  9.3337688 -1.264277 -0.5262407   -4.584119 13.9178879 +
-92  6.6894199  9.9187932 -1.264277 -0.5262407   -4.784522 14.7033151 +
-93  4.3681693  6.9327113 -1.264277 -0.5262407   -3.562985 10.4956967 +
-94  6.6461515 12.5441888 -1.264277 -0.5262407   -4.761752 17.3059411 +
-95  5.1824709  7.4015164 -1.264277 -0.5262407   -3.991504 11.3930204 +
-96  2.1875381  1.6910838 -1.264277 -0.5262407   -2.415449  4.1065324 +
-97  5.6086270 10.8605436 -1.264277 -0.5262407   -4.215765 15.0763083 +
-98  5.4158031  7.6891953 -1.264277 -0.5262407   -4.114293 11.8034883 +
-99  0.8135196  7.8181045 -1.264277 -0.5262407   -1.692384  9.5104888 +
-100 3.7676048  9.5822003 -1.264277 -0.5262407   -3.246944 12.8291443 +
-101 3.7048040  5.3532350 -1.264277 -0.5262407   -3.213896  8.5671307 +
-102 4.5848668  8.3965243 -1.264277 -0.5262407   -3.677020 12.0735447 +
-103 4.6436020  5.5043385 -1.264277 -0.5262407   -3.707929  9.2122678 +
-104 5.5137839 19.8317442 -1.264277 -0.5262407   -4.165854 23.9975986 +
-105 5.5460450 11.7004649 -1.264277 -0.5262407   -4.182832 15.8832965 +
-106 2.6322855  8.5278252 -1.264277 -0.5262407   -2.649493 11.1773180 +
-107 4.8768933  9.2744021 -1.264277 -0.5262407   -3.830697 13.1050988 +
-108 4.2582660 15.6448394 -1.264277 -0.5262407   -3.505150 19.1499892 +
-109 5.3263384 11.2636384 -1.264277 -0.5262407   -4.067213 15.3308513 +
-110 4.7875378  7.8894584 -1.264277 -0.5262407   -3.783674 11.6731326 +
-111 3.6879691  6.7596765 -1.264277 -0.5262407   -3.205036  9.9647129 +
-112 4.2645818  9.7032671 -1.264277 -0.5262407   -3.508474 13.2117406 +
-113 5.7080723 11.4807855 -1.264277 -0.5262407   -4.268097 15.7488824 +
-114 4.4170445 11.9510702 -1.264277 -0.5262407   -3.588706 15.5397757 +
-115 4.7051114 15.0610402 -1.264277 -0.5262407   -3.740298 18.8013383 +
-116 3.2868115  2.0588466 -1.264277 -0.5262407   -2.993931  5.0527776 +
-117 4.9496883 11.1438762 -1.264277 -0.5262407   -3.869004 15.0128806 +
-118 5.3579592 10.4867567 -1.264277 -0.5262407   -4.083853 14.5706099 +
-119 3.4340821  4.9124528 -1.264277 -0.5262407   -3.071431  7.9838836 +
-120 7.5112322 15.4265323 -1.264277 -0.5262407   -5.216993 20.6435252 +
-121 7.4090157 17.2211544 -1.264277 -0.5262407   -5.163202 22.3843568 +
-122 5.3807003 12.0371465 -1.264277 -0.5262407   -4.095820 16.1329670 +
-123 5.9982219 12.2743501 -1.264277 -0.5262407   -4.420785 16.6951354 +
-124 3.0991733  0.5928870 -1.264277 -0.5262407   -2.895188  3.4880752 +
-125 4.9540596 12.7552638 -1.264277 -0.5262407   -3.871305 16.6265686 +
-126 5.4521347 13.9911543 -1.264277 -0.5262407   -4.133412 18.1245665 +
-127 4.2101665 11.0247402 -1.264277 -0.5262407   -3.479838 14.5045782 +
-128 2.1382146  6.3007539 -1.264277 -0.5262407   -2.389493  8.6902465 +
-129 4.7861186  8.4328932 -1.264277 -0.5262407   -3.782927 12.2158206 +
-130 4.6243029 12.1062901 -1.264277 -0.5262407   -3.697773 15.8040635 +
-131 4.7045999 21.6574011 -1.264277 -0.5262407   -3.740029 25.3974300 +
-132 5.0353365  9.1128945 -1.264277 -0.5262407   -3.914076 13.0269705 +
-133 5.7684327 14.9482380 -1.264277 -0.5262407   -4.299861 19.2480990 +
-134 5.4650414 11.1578444 -1.264277 -0.5262407   -4.140204 15.2980486 +
-135 6.3983266 12.7511754 -1.264277 -0.5262407   -4.631337 17.3825121 +
-136 5.1054629 12.2676585 -1.264277 -0.5262407   -3.950979 16.2186379 +
-137 5.4727620 15.7361092 -1.264277 -0.5262407   -4.144267 19.8803762 +
-138 4.7336163 10.1946145 -1.264277 -0.5262407   -3.755299 13.9499130 +
-139 5.2480093  6.9764202 -1.264277 -0.5262407   -4.025993 11.0024132 +
-140 3.0004377  2.7520590 -1.264277 -0.5262407   -2.843229  5.5952884 +
-141 3.6651683  7.2627301 -1.264277 -0.5262407   -3.193038 10.4557678 +
-142 4.5906320  3.7751924 -1.264277 -0.5262407   -3.680054  7.4552467 +
-143 5.8844053 19.4617765 -1.264277 -0.5262407   -4.360890 23.8226670 +
-144 3.5734761 14.9494035 -1.264277 -0.5262407   -3.144786 18.0941891 +
-145 4.1928815 11.5554015 -1.264277 -0.5262407   -3.470742 15.0261434 +
-146 6.8910946 12.6820411 -1.264277 -0.5262407   -4.890651 17.5726924 +
-147 5.7008138  7.6449870 -1.264277 -0.5262407   -4.264277 11.9092641 +
-148 3.6069183 14.5345051 -1.264277 -0.5262407   -3.162384 17.6968893 +
-149 5.4279511 12.5578155 -1.264277 -0.5262407   -4.120686 16.6785012 +
-150 5.6139766 11.8074150 -1.264277 -0.5262407   -4.218580 16.0259949 +
-151 5.6023394 11.5063542 -1.264277 -0.5262407   -4.212456 15.7188101 +
-152 6.2157176  8.5306942 -1.264277 -0.5262407   -4.535240 13.0659347 +
-153 5.6589077  8.3478592 -1.264277 -0.5262407   -4.242224 12.5900837 +
-154 2.8113763  1.9872925 -1.264277 -0.5262407   -2.743738  4.7310302 +
-155 4.5806760  4.8901946 -1.264277 -0.5262407   -3.674815  8.5650098 +
-156 5.6392284 13.7001285 -1.264277 -0.5262407   -4.231868 17.9319969 +
-157 5.2827326 14.4966648 -1.264277 -0.5262407   -4.044266 18.5409307 +
-158 4.4994363  4.2327446 -1.264277 -0.5262407   -3.632063  7.8648081 +
-159 5.2554305  7.3630872 -1.264277 -0.5262407   -4.029898 11.3929856 +
-160 5.5622164  9.0163544 -1.264277 -0.5262407   -4.191342 13.2076960 +
-161 7.0995953 12.5898387 -1.264277 -0.5262407   -5.000373 17.5902115 +
-162 5.6046834 12.2369786 -1.264277 -0.5262407   -4.213689 16.4506680 +
-163 3.9260205  5.3385293 -1.264277 -0.5262407   -3.330309  8.6688381 +
-164 6.6192734 16.2775842 -1.264277 -0.5262407   -4.747608 21.0251922 +
-165 5.2875993  8.5628853 -1.264277 -0.5262407   -4.046827 12.6097122 +
-166 4.3956646  8.8889738 -1.264277 -0.5262407   -3.577455 12.4664284 +
- [ reached 'max' / getOption("max.print"-- omitted 334 rows ]+
  
 > print(paste0("Loss is: ", round(loss))) > print(paste0("Loss is: ", round(loss)))
-[1] "Loss is: 228"+[1] "Loss is: 393"
  
 > gradient <- function(x, y, predictions){ > gradient <- function(x, y, predictions){
Line 411: Line 248:
 > print(gradients) > print(gradients)
 $db1 $db1
-[1] -149.8879+[1] -200.6834
  
 $db0 $db0
-[1] -28.50182+[1] -37.76994
  
  
Line 423: Line 260:
  
 > # Record Loss for each epoch: > # Record Loss for each epoch:
-> logs = list() +logs = list() 
-> bs=list()+bs=list()
 > b0s = c() > b0s = c()
 > b1s = c() > b1s = c()
-msr = c()+mse = c()
  
 > nlen <- 80 > nlen <- 80
Line 434: Line 271:
 +   predictions = predict(x_scaled, b0, b1) +   predictions = predict(x_scaled, b0, b1)
 +   loss = loss_mse(predictions, y) +   loss = loss_mse(predictions, y)
-+   msr = append(msr, loss) ++   mse = append(mse, loss) 
-+    ++   logs = append(logs, loss)
-+   logs = append(logs, loss)+
 +    +   
 +   if (epoch %% 10 == 0){ +   if (epoch %% 10 == 0){
Line 451: Line 287:
 +   b1s <- append(b1s, b1) +   b1s <- append(b1s, b1)
 + } + }
-[1] "Epoch: 10, Loss: 17.96644+[1] "Epoch: 10, Loss: 18.5393
-[1] "Epoch: 20, Loss: 15.40245+[1] "Epoch: 20, Loss: 15.54339
-[1] "Epoch: 30, Loss: 15.37287+[1] "Epoch: 30, Loss: 15.50879
-[1] "Epoch: 40, Loss: 15.37253+[1] "Epoch: 40, Loss: 15.50839
-[1] "Epoch: 50, Loss: 15.37253+[1] "Epoch: 50, Loss: 15.50839
-[1] "Epoch: 60, Loss: 15.37253+[1] "Epoch: 60, Loss: 15.50839
-[1] "Epoch: 70, Loss: 15.37253+[1] "Epoch: 70, Loss: 15.50839
-[1] "Epoch: 80, Loss: 15.37253+[1] "Epoch: 80, Loss: 15.50839" 
-> # I must unscale coefficients to make them comprehensible+ 
 +> # unscale coefficients to make them comprehensible
 > b0 =  b0 - (mean(x) / sd(x)) * b1 > b0 =  b0 - (mean(x) / sd(x)) * b1
 > b1 = b1 / sd(x) > b1 = b1 / sd(x)
  
 +> # changes of estimators
 > b0s <- b0s - (mean(x) /sd(x)) * b1s > b0s <- b0s - (mean(x) /sd(x)) * b1s
 > b1s <- b1s / sd(x) > b1s <- b1s / sd(x)
  
-> parameters <- tibble(data.frame(b0s, b1s, msr))+> parameters <- tibble(data.frame(b0s, b1s, mse))
  
-> cat(paste0("Inclination: ", b1, ", \n", "Intercept: ", b0, "\n")) +> cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n")) 
-Inclination1.98275951151325,  +Slope2.26922511738252,  
-Intercept: 0.461293071825862+Intercept: -0.779435058320381
 > summary(lm(y~x))$coefficients > summary(lm(y~x))$coefficients
-             Estimate Std. Error    t value     Pr(>|t|) +              Estimate Std. Error    t value     Pr(>|t|) 
-(Intercept) 0.4612931  0.7339392  0.6285167 5.299537e-01 +(Intercept) -0.7794352  0.9258064 -0.8418986 4.005198e-01 
-          1.9827596  0.1427436 13.8903548 2.612507e-37+           2.2692252  0.1792660 12.6584242 1.111614e-29
  
 > ggplot(data, aes(x = x, y = y)) +  > ggplot(data, aes(x = x, y = y)) + 
 +   geom_point(size = 2) +  +   geom_point(size = 2) + 
 +   geom_abline(aes(intercept = b0s, slope = b1s), +   geom_abline(aes(intercept = b0s, slope = b1s),
-+               data = parameters, linewidth = 0.5, color = 'red') + ++               data = parameters, linewidth = 0.5,  
 ++               color = 'green') + 
 +   theme_classic() + +   theme_classic() +
 +   geom_abline(aes(intercept = b0s, slope = b1s),  +   geom_abline(aes(intercept = b0s, slope = b1s), 
 +               data = parameters %>% slice_head(),  +               data = parameters %>% slice_head(), 
-+               linewidth = 0.5, color = 'blue') + ++               linewidth = 1, color = 'blue') + 
 +   geom_abline(aes(intercept = b0s, slope = b1s),  +   geom_abline(aes(intercept = b0s, slope = b1s), 
 +               data = parameters %>% slice_tail(),  +               data = parameters %>% slice_tail(), 
-+               linewidth = 1, color = 'green') + ++               linewidth = 1, color = 'red') + 
-+   labs(title = 'Gradient descentblue: start, green: end')++   labs(title = 'Gradient descentblue: start, red: end, green: gradients') 
 +>  
 +> b0.init 
 +[1] -1.67967 
 +> b1.init 
 +[1] -1.323992 
 +
 > data > data
-# A tibble: 500 × 2 +# A tibble: 300 × 4 
-           y +           predictions residuals 
-   <dbl> <dbl> +   <dbl> <dbl>       <dbl>     <dbl> 
-  6.78 10. +  4.13  6.74       -7.14     13.9  
-  7.17 17. +  7.25 14.0       -11.3      25.3  
- 3  4.63  5.80 + 3  6.09 13.5        -9.74     23.3  
- 4  3.12 10. + 4  6.29 15.1       -10.0      25.1  
- 5  5.65  9.68 + 5  4.40  3.81       -7.51     11.3  
-  5.12 10. + 6  6.03 13.9        -9.67     23. 
-  4.05 16.8  +  6.97 12.1       -10.9      23.0  
-  7.27 16. +  4.84 12.8        -8.09     20. 
- 9  4.13  3.96 + 9  6.85 17.2       -10.7      28. 
-10  5.27 13.9  +10  3.33  3.80       -6.08      9.88 
-# ℹ 490 more rows+# ℹ 290 more rows
 # ℹ Use `print(n = ...)` to see more rows # ℹ Use `print(n = ...)` to see more rows
 > parameters > parameters
 # A tibble: 80 × 3 # A tibble: 80 × 3
-     b0s    b1s   msr +       b0s    b1s   mse 
-   <dbl>  <dbl> <dbl> +     <dbl>  <dbl> <dbl> 
- 0.791 0.0539 159.  +  2.67   -0.379 183.  
- 0.729 0.439  107.  +  1.99    0.149 123.  
- 0.679 0.747   74.3 +  1.44    0.571  84.3 
- 0.638 0.994   53.1 +  1.00    0.910  59.6 
- 5 0.604 1.19    39.5 +  0.652   1.18   43.7 
- 6 0.577 1.35    30.8 +  0.369   1.40   33.6 
- 7 0.555 1.48    25.3 +  0.142   1.57   27.1 
- 8 0.538 1.58    21.7 + -0.0397  1.71   22.9 
- 9 0.523 1.66    19.4 + -0.186   1.82   20.2 
-10 0.511 1.72    18.0+10 -0.303   1.91   18.5
 # ℹ 70 more rows # ℹ 70 more rows
-ℹ Use `print(n = ...)` to see more rows +
->  +
-+
 </code> </code>
-{{:pasted:20250801-134352.png}}+ 
 +{{:pasted:20250801-185727.png}}
  
gradient_descent.1754023435.txt.gz · Last modified: 2025/08/01 13:43 by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki