gradient_descent
Differences
This shows you the differences between two versions of the page.
Both sides previous revisionPrevious revisionNext revision | Previous revision | ||
gradient_descent [2025/08/01 17:02] – hkimscil | gradient_descent [2025/08/01 18:57] (current) – [R output] hkimscil | ||
---|---|---|---|
Line 1: | Line 1: | ||
====== Gradient Descent ====== | ====== Gradient Descent ====== | ||
+ | ====== explanation ====== | ||
점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기 | 점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기 | ||
prerequisite: | prerequisite: | ||
Line 22: | Line 23: | ||
& = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\ | & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-1) \;\;\;\; \\ | ||
& \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\ | & \because & \dfrac{\text{d}}{\text{dv for a}} (Y_i - (a+bX_i)) = -1 \\ | ||
- | & = & -2 \frac{1}{N} \sum{(Y_i - (a + bX_i))} \\ | + | & = & -2 \frac{\sum{(Y_i - (a + bX_i))}}{N} \\ |
+ | & = & -2 * \text{mean of residuals} \\ | ||
\end{eqnarray*} | \end{eqnarray*} | ||
+ | 아래 R code에서 gradient function을 참조. | ||
</ | </ | ||
Line 40: | Line 43: | ||
(미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/ | (미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/ | ||
</ | </ | ||
+ | ====== R code ====== | ||
< | < | ||
- | library(tidyverse) | + | # d statquest explanation |
- | # a simple example | + | # x <- c(0.5, 2.3, 2.9) |
- | # statquest explanation | + | # y <- c(1.4, 1.9, 3.2) |
- | x <- c(0.5, 2.3, 2.9) | + | |
- | y <- c(1.4, 1.9, 3.2) | + | |
rm(list=ls()) | rm(list=ls()) | ||
# set.seed(191) | # set.seed(191) | ||
- | n <- 500 | + | n <- 300 |
x <- rnorm(n, 5, 1.2) | x <- rnorm(n, 5, 1.2) | ||
y <- 2.14 * x + rnorm(n, 0, 4) | y <- 2.14 * x + rnorm(n, 0, 4) | ||
Line 56: | Line 57: | ||
# data <- data.frame(x, | # data <- data.frame(x, | ||
data <- tibble(x = x, y = y) | data <- tibble(x = x, y = y) | ||
- | data | ||
mo <- lm(y~x) | mo <- lm(y~x) | ||
Line 65: | Line 65: | ||
b1 = rnorm(1) | b1 = rnorm(1) | ||
b0 = rnorm(1) | b0 = rnorm(1) | ||
+ | |||
+ | b1.init <- b1 | ||
+ | b0.init <- b0 | ||
# Predict function: | # Predict function: | ||
Line 85: | Line 88: | ||
loss = loss_mse(predictions, | loss = loss_mse(predictions, | ||
- | temp.sum | + | data <- tibble(data.frame(x, |
- | temp.sum | + | |
print(paste0(" | print(paste0(" | ||
Line 107: | Line 109: | ||
# Record Loss for each epoch: | # Record Loss for each epoch: | ||
- | logs = list() | + | # logs = list() |
- | bs=list() | + | # bs=list() |
b0s = c() | b0s = c() | ||
b1s = c() | b1s = c() | ||
Line 119: | Line 121: | ||
loss = loss_mse(predictions, | loss = loss_mse(predictions, | ||
mse = append(mse, loss) | mse = append(mse, loss) | ||
- | | + | |
- | | + | |
| | ||
if (epoch %% 10 == 0){ | if (epoch %% 10 == 0){ | ||
Line 135: | Line 136: | ||
b1s <- append(b1s, b1) | b1s <- append(b1s, b1) | ||
} | } | ||
- | # I must unscale coefficients to make them comprehensible | + | |
+ | # unscale coefficients to make them comprehensible | ||
b0 = b0 - (mean(x) / sd(x)) * b1 | b0 = b0 - (mean(x) / sd(x)) * b1 | ||
b1 = b1 / sd(x) | b1 = b1 / sd(x) | ||
+ | # changes of estimators | ||
b0s <- b0s - (mean(x) /sd(x)) * b1s | b0s <- b0s - (mean(x) /sd(x)) * b1s | ||
b1s <- b1s / sd(x) | b1s <- b1s / sd(x) | ||
Line 144: | Line 147: | ||
parameters <- tibble(data.frame(b0s, | parameters <- tibble(data.frame(b0s, | ||
- | cat(paste0(" | + | cat(paste0(" |
summary(lm(y~x))$coefficients | summary(lm(y~x))$coefficients | ||
Line 150: | Line 153: | ||
geom_point(size = 2) + | geom_point(size = 2) + | ||
geom_abline(aes(intercept = b0s, slope = b1s), | geom_abline(aes(intercept = b0s, slope = b1s), | ||
- | data = parameters, linewidth = 0.5, color = 'red') + | + | data = parameters, linewidth = 0.5, |
+ | | ||
theme_classic() + | theme_classic() + | ||
geom_abline(aes(intercept = b0s, slope = b1s), | geom_abline(aes(intercept = b0s, slope = b1s), | ||
data = parameters %>% slice_head(), | data = parameters %>% slice_head(), | ||
- | linewidth = 0.5, color = ' | + | linewidth = 1, color = ' |
geom_abline(aes(intercept = b0s, slope = b1s), | geom_abline(aes(intercept = b0s, slope = b1s), | ||
data = parameters %>% slice_tail(), | data = parameters %>% slice_tail(), | ||
- | linewidth = 1, color = 'green') + | + | linewidth = 1, color = 'red') + |
- | labs(title = ' | + | labs(title = ' |
+ | |||
+ | b0.init | ||
+ | b1.init | ||
data | data | ||
parameters | parameters | ||
- | </ | ||
+ | </ | ||
+ | ====== R output ===== | ||
< | < | ||
> rm(list=ls()) | > rm(list=ls()) | ||
> # set.seed(191) | > # set.seed(191) | ||
- | > n <- 500 | + | > n <- 300 |
> x <- rnorm(n, 5, 1.2) | > x <- rnorm(n, 5, 1.2) | ||
> y <- 2.14 * x + rnorm(n, 0, 4) | > y <- 2.14 * x + rnorm(n, 0, 4) | ||
Line 173: | Line 182: | ||
> # data <- data.frame(x, | > # data <- data.frame(x, | ||
> data <- tibble(x = x, y = y) | > data <- tibble(x = x, y = y) | ||
- | > data | ||
- | # A tibble: 500 × 2 | ||
- | | ||
- | < | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | | ||
- | 10 3.81 5.22 | ||
- | # ℹ 490 more rows | ||
- | # ℹ Use `print(n = ...)` to see more rows | ||
> | > | ||
> mo <- lm(y~x) | > mo <- lm(y~x) | ||
Line 197: | Line 190: | ||
Residuals: | Residuals: | ||
- | Min | + | Min |
- | -10.2534 | + | -9.754 -2.729 -0.135 |
Coefficients: | Coefficients: | ||
Estimate Std. Error t value Pr(> | Estimate Std. Error t value Pr(> | ||
- | (Intercept) | + | (Intercept) |
- | x 2.1606 0.1388 15.57 <2e-16 *** | + | x 2.2692 0.1793 12.658 <2e-16 *** |
--- | --- | ||
Signif. codes: | Signif. codes: | ||
- | Residual standard error: 3.822 on 498 degrees of freedom | + | Residual standard error: 3.951 on 298 degrees of freedom |
- | Multiple R-squared: | + | Multiple R-squared: |
- | F-statistic: | + | F-statistic: |
> | > | ||
Line 216: | Line 209: | ||
> b1 = rnorm(1) | > b1 = rnorm(1) | ||
> b0 = rnorm(1) | > b0 = rnorm(1) | ||
+ | > | ||
+ | > b1.init <- b1 | ||
+ | > b0.init <- b0 | ||
> | > | ||
> # Predict function: | > # Predict function: | ||
Line 236: | Line 232: | ||
> loss = loss_mse(predictions, | > loss = loss_mse(predictions, | ||
> | > | ||
- | > temp.sum | + | > data <- tibble(data.frame(x, |
- | > temp.sum | + | |
- | | + | |
- | 1 | + | |
- | 2 | + | |
- | 3 | + | |
- | 4 | + | |
- | 5 | + | |
- | 6 | + | |
- | 7 | + | |
- | 8 | + | |
- | 9 | + | |
- | 10 3.812044 | + | |
- | 11 3.436925 | + | |
- | 12 5.883357 14.4497406 0.5808843 0.9861742 | + | |
- | 13 4.328653 14.5001264 0.5808843 0.9861742 | + | |
- | 14 4.130057 10.0931558 0.5808843 0.9861742 | + | |
- | 15 5.322393 | + | |
- | 16 4.526528 | + | |
- | 17 3.817400 | + | |
- | 18 3.387983 -0.2180968 0.5808843 0.9861742 | + | |
- | 19 4.270354 | + | |
- | 20 5.822266 10.5076073 0.5808843 0.9861742 | + | |
- | 21 6.009412 | + | |
- | 22 5.785644 12.2267578 0.5808843 0.9861742 | + | |
- | 23 5.103190 10.6399113 0.5808843 0.9861742 | + | |
- | 24 5.381166 18.1917469 0.5808843 0.9861742 | + | |
- | 25 2.812116 11.6562811 0.5808843 0.9861742 | + | |
- | 26 3.146225 | + | |
- | 27 4.883188 10.6312680 0.5808843 0.9861742 | + | |
- | 28 4.955458 11.8617949 0.5808843 0.9861742 | + | |
- | 29 3.952036 | + | |
- | 30 6.739458 18.9309648 0.5808843 0.9861742 | + | |
- | 31 4.882959 | + | |
- | 32 5.070600 17.4532502 0.5808843 0.9861742 | + | |
- | 33 5.257397 10.4619521 0.5808843 0.9861742 | + | |
- | 34 3.921518 | + | |
- | 35 5.112554 | + | |
- | 36 5.783692 17.7148068 0.5808843 0.9861742 | + | |
- | 37 5.756150 10.1595169 0.5808843 0.9861742 | + | |
- | 38 6.010004 15.6295228 0.5808843 0.9861742 | + | |
- | 39 7.527341 10.8972192 0.5808843 0.9861742 | + | |
- | 40 3.718376 11.4857450 0.5808843 0.9861742 | + | |
- | 41 3.816369 | + | |
- | 42 3.904699 11.9299211 0.5808843 0.9861742 | + | |
- | 43 4.889957 17.8620975 0.5808843 0.9861742 | + | |
- | 44 3.456463 | + | |
- | 45 5.274541 15.1190395 0.5808843 0.9861742 | + | |
- | 46 5.064607 13.5273619 0.5808843 0.9861742 | + | |
- | 47 3.175056 | + | |
- | 48 7.179882 10.5044662 0.5808843 0.9861742 | + | |
- | 49 4.098562 16.6057557 0.5808843 0.9861742 | + | |
- | 50 4.532198 | + | |
- | 51 7.248676 12.4619092 0.5808843 0.9861742 | + | |
- | 52 4.440335 14.0170577 0.5808843 0.9861742 | + | |
- | 53 6.565150 15.0841504 0.5808843 0.9861742 | + | |
- | 54 5.886684 | + | |
- | 55 3.417331 | + | |
- | 56 5.917124 | + | |
- | 57 3.453444 11.2981952 0.5808843 0.9861742 | + | |
- | 58 4.825523 | + | |
- | 59 4.649551 | + | |
- | 60 5.065955 15.8326852 0.5808843 0.9861742 | + | |
- | 61 5.189293 11.3101911 0.5808843 0.9861742 | + | |
- | 62 5.769762 | + | |
- | 63 6.136073 14.2736277 0.5808843 0.9861742 | + | |
- | 64 5.079882 | + | |
- | 65 2.407474 | + | |
- | 66 5.754148 13.7113431 0.5808843 0.9861742 | + | |
- | 67 4.674475 15.8016174 0.5808843 0.9861742 | + | |
- | 68 5.690545 19.2021795 0.5808843 0.9861742 | + | |
- | 69 5.204651 | + | |
- | 70 4.331535 11.2805649 0.5808843 0.9861742 | + | |
- | 71 3.605775 | + | |
- | 72 6.863329 10.9046518 0.5808843 0.9861742 | + | |
- | 73 6.296937 | + | |
- | 74 5.377210 13.8655597 0.5808843 0.9861742 | + | |
- | 75 5.403542 | + | |
- | 76 4.097157 | + | |
- | 77 3.994292 | + | |
- | 78 3.898466 | + | |
- | 79 6.201434 19.8080454 0.5808843 0.9861742 | + | |
- | 80 6.972587 15.4954552 0.5808843 0.9861742 | + | |
- | 81 5.512087 10.1381343 0.5808843 0.9861742 | + | |
- | 82 5.463011 12.2079314 0.5808843 0.9861742 | + | |
- | 83 5.840064 15.6354616 0.5808843 0.9861742 | + | |
- | 84 4.628974 | + | |
- | 85 3.775477 | + | |
- | 86 4.789949 | + | |
- | 87 5.878382 22.5915677 0.5808843 0.9861742 | + | |
- | 88 4.992651 14.0194907 0.5808843 0.9861742 | + | |
- | 89 3.181054 10.1407775 0.5808843 0.9861742 | + | |
- | 90 5.530133 11.2858888 0.5808843 0.9861742 | + | |
- | 91 5.141758 | + | |
- | 92 4.911979 13.4673585 0.5808843 0.9861742 | + | |
- | 93 6.363601 | + | |
- | 94 4.590408 13.5678458 0.5808843 0.9861742 | + | |
- | 95 3.394860 11.9961020 0.5808843 0.9861742 | + | |
- | 96 5.054608 10.9710834 0.5808843 0.9861742 | + | |
- | 97 5.631312 15.0984710 0.5808843 0.9861742 | + | |
- | 98 4.528634 13.5841385 0.5808843 0.9861742 | + | |
- | 99 6.368627 14.4842546 0.5808843 0.9861742 | + | |
- | 100 4.502220 | + | |
- | 101 3.363460 | + | |
- | 102 5.489312 | + | |
- | 103 2.208872 -0.2618084 0.5808843 0.9861742 | + | |
- | 104 4.703816 15.7759129 0.5808843 0.9861742 | + | |
- | 105 2.594336 10.8674426 0.5808843 0.9861742 | + | |
- | 106 4.360380 10.0070566 0.5808843 0.9861742 | + | |
- | 107 4.255169 | + | |
- | 108 6.229481 11.1195615 0.5808843 0.9861742 | + | |
- | 109 3.429806 | + | |
- | 110 8.152707 22.5016746 0.5808843 0.9861742 | + | |
- | 111 2.640736 -3.1625558 0.5808843 0.9861742 | + | |
- | 112 4.741218 18.5919552 0.5808843 0.9861742 | + | |
- | 113 5.488745 12.9386181 0.5808843 0.9861742 | + | |
- | 114 3.227389 | + | |
- | 115 4.443698 13.9122126 0.5808843 0.9861742 | + | |
- | 116 5.338956 | + | |
- | 117 6.797698 15.9969094 0.5808843 0.9861742 | + | |
- | 118 7.022180 13.5295313 0.5808843 0.9861742 | + | |
- | 119 5.473466 11.9334015 0.5808843 0.9861742 | + | |
- | 120 6.024003 13.1870525 0.5808843 0.9861742 | + | |
- | 121 5.091827 | + | |
- | 122 4.493815 12.2146866 0.5808843 0.9861742 | + | |
- | 123 7.112794 19.9524515 0.5808843 0.9861742 | + | |
- | 124 5.225292 16.7979964 0.5808843 0.9861742 | + | |
- | 125 5.064472 15.5614482 0.5808843 0.9861742 | + | |
- | 126 5.552849 | + | |
- | 127 3.902294 | + | |
- | 128 6.951468 17.2616673 0.5808843 0.9861742 | + | |
- | 129 5.217489 14.3563415 0.5808843 0.9861742 | + | |
- | 130 1.832789 | + | |
- | 131 5.170683 | + | |
- | 132 6.104459 11.7304563 0.5808843 0.9861742 | + | |
- | 133 4.584068 14.0684806 0.5808843 0.9861742 | + | |
- | 134 6.594802 11.5921896 0.5808843 0.9861742 | + | |
- | 135 4.492839 11.3445727 0.5808843 0.9861742 | + | |
- | 136 4.835051 10.0457136 0.5808843 0.9861742 | + | |
- | 137 4.495956 12.9092820 0.5808843 0.9861742 | + | |
- | 138 4.704205 11.3171541 0.5808843 0.9861742 | + | |
- | 139 6.495443 | + | |
- | 140 4.475086 | + | |
- | 141 3.089286 | + | |
- | 142 5.959849 12.4859299 0.5808843 0.9861742 | + | |
- | 143 6.474426 13.7818185 0.5808843 0.9861742 | + | |
- | 144 5.567374 14.4706101 0.5808843 0.9861742 | + | |
- | 145 4.052363 | + | |
- | 146 5.074791 10.0325838 0.5808843 0.9861742 | + | |
- | 147 6.831115 | + | |
- | 148 3.607346 10.5973814 0.5808843 0.9861742 | + | |
- | 149 6.896116 21.4919066 0.5808843 0.9861742 | + | |
- | 150 6.317008 15.9406225 0.5808843 0.9861742 | + | |
- | 151 5.168403 15.8355851 0.5808843 0.9861742 | + | |
- | 152 4.434618 | + | |
- | 153 5.891072 | + | |
- | 154 3.512759 | + | |
- | 155 3.946038 11.5013653 0.5808843 0.9861742 | + | |
- | 156 6.756402 13.1194434 0.5808843 0.9861742 | + | |
- | 157 4.691839 13.2412345 0.5808843 0.9861742 | + | |
- | 158 4.691552 11.1985804 0.5808843 0.9861742 | + | |
- | 159 4.025405 13.5279553 0.5808843 0.9861742 | + | |
- | 160 6.330436 12.8980948 0.5808843 0.9861742 | + | |
- | 161 5.213079 15.7716904 0.5808843 0.9861742 | + | |
- | 162 6.378086 11.6205214 0.5808843 0.9861742 | + | |
- | 163 7.228954 11.5549520 0.5808843 0.9861742 | + | |
- | 164 3.621615 14.6260406 0.5808843 0.9861742 | + | |
- | 165 5.235016 | + | |
- | 166 3.767853 11.1148677 0.5808843 0.9861742 | + | |
- | [ reached ' | + | |
> | > | ||
> print(paste0(" | > print(paste0(" | ||
- | [1] "Loss is: 46" | + | [1] "Loss is: 393" |
> | > | ||
> gradient <- function(x, y, predictions){ | > gradient <- function(x, y, predictions){ | ||
Line 421: | Line 248: | ||
> print(gradients) | > print(gradients) | ||
$db1 | $db1 | ||
- | [1] -57.11316 | + | [1] -200.6834 |
$db0 | $db0 | ||
- | [1] -10.77174 | + | [1] -37.76994 |
> | > | ||
Line 433: | Line 260: | ||
> | > | ||
> # Record Loss for each epoch: | > # Record Loss for each epoch: | ||
- | > logs = list() | + | > # logs = list() |
- | > bs=list() | + | > # bs=list() |
> b0s = c() | > b0s = c() | ||
> b1s = c() | > b1s = c() | ||
Line 445: | Line 272: | ||
+ loss = loss_mse(predictions, | + loss = loss_mse(predictions, | ||
+ mse = append(mse, loss) | + mse = append(mse, loss) | ||
- | + | + | + # logs = append(logs, |
- | + logs = append(logs, | + | |
+ | + | ||
+ if (epoch %% 10 == 0){ | + if (epoch %% 10 == 0){ | ||
Line 461: | Line 287: | ||
+ b1s <- append(b1s, b1) | + b1s <- append(b1s, b1) | ||
+ } | + } | ||
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | [1] " | + | [1] " |
- | > # I must unscale coefficients to make them comprehensible | + | > |
+ | > # unscale coefficients to make them comprehensible | ||
> b0 = b0 - (mean(x) / sd(x)) * b1 | > b0 = b0 - (mean(x) / sd(x)) * b1 | ||
> b1 = b1 / sd(x) | > b1 = b1 / sd(x) | ||
> | > | ||
+ | > # changes of estimators | ||
> b0s <- b0s - (mean(x) /sd(x)) * b1s | > b0s <- b0s - (mean(x) /sd(x)) * b1s | ||
> b1s <- b1s / sd(x) | > b1s <- b1s / sd(x) | ||
Line 478: | Line 306: | ||
> parameters <- tibble(data.frame(b0s, | > parameters <- tibble(data.frame(b0s, | ||
> | > | ||
- | > cat(paste0(" | + | > cat(paste0(" |
- | Inclination: 2.16059976407543, | + | Slope: 2.26922511738252, |
- | Intercept: 0.128130381671001 | + | Intercept: |
> summary(lm(y~x))$coefficients | > summary(lm(y~x))$coefficients | ||
- | Estimate Std. Error t value | + | |
- | (Intercept) 0.1281304 | + | (Intercept) |
- | x | + | x 2.2692252 |
> | > | ||
> ggplot(data, | > ggplot(data, | ||
+ | + | ||
+ | + | ||
- | + data = parameters, linewidth = 0.5, color = 'red') + | + | + data = parameters, linewidth = 0.5, |
+ | + color = 'green') + | ||
+ | + | ||
+ | + | ||
+ data = parameters %>% slice_head(), | + data = parameters %>% slice_head(), | ||
- | + | + | + |
+ | + | ||
+ data = parameters %>% slice_tail(), | + data = parameters %>% slice_tail(), | ||
- | + | + | + |
- | + | + | + |
+ | > | ||
+ | > b0.init | ||
+ | [1] -1.67967 | ||
+ | > b1.init | ||
+ | [1] -1.323992 | ||
+ | > | ||
> data | > data | ||
- | # A tibble: | + | # A tibble: |
- | | + | |
- | < | + | < |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | 10 3.81 5.22 | + | 10 3.33 3.80 |
- | # ℹ 490 more rows | + | # ℹ 290 more rows |
# ℹ Use `print(n = ...)` to see more rows | # ℹ Use `print(n = ...)` to see more rows | ||
> parameters | > parameters | ||
# A tibble: 80 × 3 | # A tibble: 80 × 3 | ||
- | | + | b0s b1s mse |
- | <dbl> <dbl> < | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | 8 -0.459 1.93 19.3 | + | 8 -0.0397 1.71 22.9 |
- | 9 -0.341 | + | 9 -0.186 1.82 20.2 |
- | 10 -0.247 2.01 16.5 | + | 10 -0.303 1.91 18.5 |
# ℹ 70 more rows | # ℹ 70 more rows | ||
- | # ℹ Use `print(n = ...)` to see more rows | + | # |
- | > | + | |
</ | </ | ||
- | {{: | + | |
+ | {{: | ||
gradient_descent.1754035324.txt.gz · Last modified: 2025/08/01 17:02 by hkimscil