User Tools

Site Tools


gradient_descent

This is an old revision of the document!


Gradient Descent

explanation

Why normalize (scale or make z-score) xi

x 변인의 측정단위로 인해서 b 값이 결정되게 되는데 이 때의 b값은 상당하고 다양한 범위를 가질 수 있다. 가령 월 수입이 (인컴) X 라고 한다면 우리가 추정해야 (추적해야) 할 b값은 수백만이 될 수도 있다.이 값을 gradient로 추적하게 된다면 너무도 많은 iteration을 거쳐야 할 수 있다. 변인이 바뀌면 이 b의 추적범위도 드라마틱하게 바뀌게 된다. 이를 표준화한 x 점수를 사용하게 된다면 일정한 learning rate와 iteration만으로도 정확한 a와 b를 추적할 수 있게 된다.

How to unnormalize (unscale) a and b

\begin{eqnarray*} y & = & a + b * x \\ & & \text{we use z instead of x} \\ & & \text{and } \\ & & z = \frac{(x - \mu)}{\sigma} \\ & & \text{suppose that the result after calculation be } \\ y & = & k + m * z \\ & = & k + m * \frac{(x - \mu)}{\sigma} \\ & = & k + \frac{m * x}{\sigma} - \frac{m * \mu}{\sigma} \\ & = & k - \frac{m * \mu}{\sigma} + \frac{m * x}{\sigma} \\ & = & k - \frac{\mu}{\sigma} * m + \frac{m}{\sigma} * x \\ & & \text{therefore, a and be that we try to get are } \\ a & = & k - \frac{\mu}{\sigma} * m \\ b & = & \frac{m}{\sigma} \\ \end{eqnarray*}

R code: Idea

library(ggplot2)
library(ggpmisc)

rm(list=ls())
# set.seed(191)
nx <- 200
mx <- 4.5
sdx <- mx * 0.56
x <- rnorm(nx, mx, sdx)
slp <- 12
y <-  slp * x + rnorm(nx, 0, slp*sdx*3)

data <- data.frame(x, y)

mo <- lm(y ~ x, data = data)
summary(mo)

ggplot(data = data, aes(x = x, y = y)) + 
  geom_point() +
  stat_poly_line() +
  stat_poly_eq(use_label(c("eq", "R2"))) +
  theme_classic() 
# set.seed(191)
# Initialize random betas
# 우선 b를 고정하고 a만 
# 변화시켜서 이해
b <- summary(mo)$coefficients[2]
a <- 0

b.init <- b
a.init <- a

# Predict function:
predict <- function(x, a, b){
  return (a + b * x)
}

# And loss function is:
residuals <- function(predictions, y) {
  return(y - predictions)
}

# we use sum of square of error which oftentimes become big
ssrloss <- function(predictions, y) {
  residuals <- (y - predictions)
  return(sum(residuals^2))
}

ssrs <- c() # for sum  of square residuals
srs <- c() # sum of residuals 
as <- c() # for as (intercepts)

for (i in seq(from = -50, to = 50, by = 0.01)) {
  pred <- predict(x, i, b)
  res <- residuals(pred, y)
  ssr <- ssrloss(pred, y)
  ssrs <- append(ssrs, ssr)
  srs <- append(srs, sum(res))
  as <- append(as, i)
}
length(ssrs)
length(srs)
length(as)

min(ssrs)
min.pos.ssrs <- which(ssrs == min(ssrs))
min.pos.ssrs
print(as[min.pos.ssrs])
summary(mo)
plot(seq(1, length(ssrs)), ssrs)
plot(seq(1, length(ssrs)), srs)
tail(ssrs)
max(ssrs)
min(ssrs)
tail(srs)
max(srs)
min(srs)

output

> rm(list=ls())
> # set.seed(191)
> n <- 5
> x <- rnorm(n, 5, 1.2)
> y <- 3.14 * x + rnorm(n,0,1)
> 
> # data <- data.frame(x, y)
> data <- tibble(x = x, y = y)
> 
> mo <- lm(y~x)
> summary(mo)

Call:
lm(formula = y ~ x)

Residuals:
      1       2       3       4       5 
-1.7472  1.7379  1.1598 -0.7010 -0.4496 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)  
(Intercept)  -2.2923     4.6038  -0.498   0.6528  
x             3.4510     0.8899   3.878   0.0304 *
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.644 on 3 degrees of freedom
Multiple R-squared:  0.8337,	Adjusted R-squared:  0.7783 
F-statistic: 15.04 on 1 and 3 DF,  p-value: 0.03036

> 
> # set.seed(191)
> # Initialize random betas
> # 우선 b를 고정하고 a만 
> # 변화시켜서 이해
> b <- summary(mo)$coefficients[2]
> a <- 0
> 
> b.init <- b
> a.init <- a
> 
> # Predict function:
> predict <- function(x, a, b){
+   return (a + b * x)
+ }
> 
> # And loss function is:
> residuals <- function(predictions, y) {
+   return(y - predictions)
+ }
> 
> # we use sum of square of error which oftentimes become big
> mseloss <- function(predictions, y) {
+   residuals <- (y - predictions)
+   return(sum(residuals^2))
+ }
> 
> mses <- c()
> j <- 0
> as <- c()
> 
> for (i in seq(from = -30, to = 30, by = 0.01)) {
+   pred <- predict(x, i, b)
+   res <- residuals(pred, y)
+   mse <- mseloss(pred, y)
+   mses <- append(mses, mse)
+   as <- append(as,i)
+ }
> mses
   [1] 3846.695 3843.925 3841.156 3838.387 3835.620 3832.854 3830.088 3827.324
   [9] 3824.561 3821.799 3819.037 3816.277 3813.518 3810.760 3808.002 3805.246
  [17] 3802.491 3799.737 3796.983 3794.231 3791.480 3788.729 3785.980 3783.232
  [25] 3780.485 3777.738 3774.993 3772.249 3769.506 3766.763 3764.022 3761.282
  [33] 3758.542 3755.804 3753.067 3750.331 3747.595 3744.861 3742.128 3739.396
  [41] 3736.664 3733.934 3731.205 3728.477 3725.749 3723.023 3720.298 3717.573
  [49] 3714.850 3712.128 3709.407 3706.686 3703.967 3701.249 3698.532 3695.815
  [57] 3693.100 3690.386 3687.672 3684.960 3682.249 3679.539 3676.829 3674.121
  [65] 3671.414 3668.708 3666.002 3663.298 3660.595 3657.892 3655.191 3652.491
  [73] 3649.792 3647.093 3644.396 3641.700 3639.005 3636.310 3633.617 3630.925
  [81] 3628.234 3625.543 3622.854 3620.166 3617.478 3614.792 3612.107 3609.423
  [89] 3606.739 3604.057 3601.376 3598.696 3596.016 3593.338 3590.661 3587.984
  [97] 3585.309 3582.635 3579.962 3577.289 3574.618 3571.948 3569.279 3566.610
 [105] 3563.943 3561.277 3558.611 3555.947 3553.284 3550.622 3547.960 3545.300
 [113] 3542.641 3539.983 3537.325 3534.669 3532.014 3529.360 3526.706 3524.054
 [121] 3521.403 3518.752 3516.103 3513.455 3510.808 3508.161 3505.516 3502.872
 [129] 3500.229 3497.586 3494.945 3492.305 3489.665 3487.027 3484.390 3481.754
 [137] 3479.118 3476.484 3473.851 3471.219 3468.587 3465.957 3463.328 3460.700
 [145] 3458.072 3455.446 3452.821 3450.196 3447.573 3444.951 3442.330 3439.709
 [153] 3437.090 3434.472 3431.855 3429.238 3426.623 3424.009 3421.395 3418.783
 [161] 3416.172 3413.562 3410.952 3408.344 3405.737 3403.131 3400.525 3397.921
 [169] 3395.318 3392.715 3390.114 3387.514 3384.915 3382.316 3379.719 3377.123
 [177] 3374.528 3371.933 3369.340 3366.748 3364.157 3361.566 3358.977 3356.389
 [185] 3353.801 3351.215 3348.630 3346.046 3343.462 3340.880 3338.299 3335.719
 [193] 3333.139 3330.561 3327.984 3325.407 3322.832 3320.258 3317.685 3315.112
 [201] 3312.541 3309.971 3307.402 3304.833 3302.266 3299.700 3297.134 3294.570
 [209] 3292.007 3289.445 3286.883 3284.323 3281.764 3279.206 3276.648 3274.092
 [217] 3271.537 3268.983 3266.429 3263.877 3261.326 3258.775 3256.226 3253.678
 [225] 3251.131 3248.584 3246.039 3243.495 3240.952 3238.409 3235.868 3233.328
 [233] 3230.788 3228.250 3225.713 3223.177 3220.641 3218.107 3215.574 3213.042
 [241] 3210.510 3207.980 3205.451 3202.923 3200.395 3197.869 3195.344 3192.819
 [249] 3190.296 3187.774 3185.253 3182.732 3180.213 3177.695 3175.178 3172.661
 [257] 3170.146 3167.632 3165.118 3162.606 3160.095 3157.585 3155.075 3152.567
 [265] 3150.060 3147.554 3145.048 3142.544 3140.041 3137.538 3135.037 3132.537
 [273] 3130.038 3127.539 3125.042 3122.546 3120.051 3117.556 3115.063 3112.571
 [281] 3110.080 3107.589 3105.100 3102.612 3100.124 3097.638 3095.153 3092.669
 [289] 3090.185 3087.703 3085.222 3082.742 3080.262 3077.784 3075.307 3072.830
 [297] 3070.355 3067.881 3065.408 3062.935 3060.464 3057.994 3055.525 3053.056
 [305] 3050.589 3048.123 3045.657 3043.193 3040.730 3038.268 3035.806 3033.346
 [313] 3030.887 3028.429 3025.971 3023.515 3021.060 3018.606 3016.152 3013.700
 [321] 3011.249 3008.798 3006.349 3003.901 3001.454 2999.007 2996.562 2994.118
 [329] 2991.675 2989.232 2986.791 2984.351 2981.911 2979.473 2977.036 2974.600
 [337] 2972.164 2969.730 2967.297 2964.865 2962.433 2960.003 2957.574 2955.145
 [345] 2952.718 2950.292 2947.867 2945.442 2943.019 2940.597 2938.176 2935.755
 [353] 2933.336 2930.918 2928.501 2926.084 2923.669 2921.255 2918.841 2916.429
 [361] 2914.018 2911.608 2909.198 2906.790 2904.383 2901.977 2899.571 2897.167
 [369] 2894.764 2892.361 2889.960 2887.560 2885.161 2882.762 2880.365 2877.969
 [377] 2875.574 2873.179 2870.786 2868.394 2866.003 2863.612 2861.223 2858.835
 [385] 2856.447 2854.061 2851.676 2849.292 2846.908 2844.526 2842.145 2839.765
 [393] 2837.385 2835.007 2832.630 2830.253 2827.878 2825.504 2823.131 2820.758
 [401] 2818.387 2816.017 2813.648 2811.279 2808.912 2806.546 2804.180 2801.816
 [409] 2799.453 2797.091 2794.729 2792.369 2790.010 2787.652 2785.294 2782.938
 [417] 2780.583 2778.229 2775.875 2773.523 2771.172 2768.821 2766.472 2764.124
 [425] 2761.777 2759.430 2757.085 2754.741 2752.398 2750.055 2747.714 2745.374
 [433] 2743.034 2740.696 2738.359 2736.023 2733.687 2731.353 2729.020 2726.688
 [441] 2724.356 2722.026 2719.697 2717.368 2715.041 2712.715 2710.390 2708.065
 [449] 2705.742 2703.420 2701.099 2698.778 2696.459 2694.141 2691.824 2689.507
 [457] 2687.192 2684.878 2682.564 2680.252 2677.941 2675.631 2673.321 2671.013
 [465] 2668.706 2666.400 2664.094 2661.790 2659.487 2657.184 2654.883 2652.583
 [473] 2650.284 2647.985 2645.688 2643.392 2641.097 2638.802 2636.509 2634.217
 [481] 2631.926 2629.635 2627.346 2625.058 2622.770 2620.484 2618.199 2615.915
 [489] 2613.631 2611.349 2609.068 2606.788 2604.508 2602.230 2599.953 2597.676
 [497] 2595.401 2593.127 2590.854 2588.581 2586.310 2584.040 2581.771 2579.502
 [505] 2577.235 2574.969 2572.703 2570.439 2568.176 2565.914 2563.652 2561.392
 [513] 2559.133 2556.875 2554.617 2552.361 2550.106 2547.852 2545.598 2543.346
 [521] 2541.095 2538.844 2536.595 2534.347 2532.100 2529.853 2527.608 2525.364
 [529] 2523.121 2520.878 2518.637 2516.397 2514.157 2511.919 2509.682 2507.446
 [537] 2505.210 2502.976 2500.743 2498.511 2496.279 2494.049 2491.820 2489.591
 [545] 2487.364 2485.138 2482.913 2480.688 2478.465 2476.243 2474.022 2471.801
 [553] 2469.582 2467.364 2465.147 2462.930 2460.715 2458.501 2456.287 2454.075
 [561] 2451.864 2449.654 2447.444 2445.236 2443.029 2440.823 2438.617 2436.413
 [569] 2434.210 2432.007 2429.806 2427.606 2425.407 2423.208 2421.011 2418.815
 [577] 2416.620 2414.425 2412.232 2410.040 2407.849 2405.658 2403.469 2401.281
 [585] 2399.093 2396.907 2394.722 2392.538 2390.354 2388.172 2385.991 2383.811
 [593] 2381.631 2379.453 2377.276 2375.099 2372.924 2370.750 2368.577 2366.404
 [601] 2364.233 2362.063 2359.894 2357.725 2355.558 2353.392 2351.226 2349.062
 [609] 2346.899 2344.737 2342.575 2340.415 2338.256 2336.098 2333.940 2331.784
 [617] 2329.629 2327.475 2325.321 2323.169 2321.018 2318.867 2316.718 2314.570
 [625] 2312.423 2310.276 2308.131 2305.987 2303.844 2301.701 2299.560 2297.420
 [633] 2295.280 2293.142 2291.005 2288.869 2286.733 2284.599 2282.466 2280.334
 [641] 2278.202 2276.072 2273.943 2271.814 2269.687 2267.561 2265.436 2263.311
 [649] 2261.188 2259.066 2256.945 2254.824 2252.705 2250.587 2248.470 2246.353
 [657] 2244.238 2242.124 2240.010 2237.898 2235.787 2233.677 2231.567 2229.459
 [665] 2227.352 2225.246 2223.140 2221.036 2218.933 2216.830 2214.729 2212.629
 [673] 2210.530 2208.431 2206.334 2204.238 2202.143 2200.048 2197.955 2195.863
 [681] 2193.771 2191.681 2189.592 2187.504 2185.416 2183.330 2181.245 2179.161
 [689] 2177.077 2174.995 2172.914 2170.834 2168.754 2166.676 2164.599 2162.522
 [697] 2160.447 2158.373 2156.300 2154.227 2152.156 2150.086 2148.017 2145.948
 [705] 2143.881 2141.815 2139.749 2137.685 2135.622 2133.560 2131.498 2129.438
 [713] 2127.379 2125.321 2123.263 2121.207 2119.152 2117.098 2115.044 2112.992
 [721] 2110.941 2108.890 2106.841 2104.793 2102.746 2100.699 2098.654 2096.610
 [729] 2094.567 2092.524 2090.483 2088.443 2086.403 2084.365 2082.328 2080.292
 [737] 2078.256 2076.222 2074.189 2072.157 2070.125 2068.095 2066.066 2064.037
 [745] 2062.010 2059.984 2057.959 2055.934 2053.911 2051.889 2049.868 2047.847
 [753] 2045.828 2043.810 2041.793 2039.776 2037.761 2035.747 2033.733 2031.721
 [761] 2029.710 2027.700 2025.690 2023.682 2021.675 2019.669 2017.663 2015.659
 [769] 2013.656 2011.653 2009.652 2007.652 2005.653 2003.654 2001.657 1999.661
 [777] 1997.666 1995.671 1993.678 1991.686 1989.694 1987.704 1985.715 1983.727
 [785] 1981.739 1979.753 1977.768 1975.784 1973.800 1971.818 1969.837 1967.857
 [793] 1965.877 1963.899 1961.922 1959.945 1957.970 1955.996 1954.023 1952.050
 [801] 1950.079 1948.109 1946.140 1944.171 1942.204 1940.238 1938.272 1936.308
 [809] 1934.345 1932.383 1930.421 1928.461 1926.502 1924.544 1922.586 1920.630
 [817] 1918.675 1916.721 1914.767 1912.815 1910.864 1908.913 1906.964 1905.016
 [825] 1903.069 1901.122 1899.177 1897.233 1895.290 1893.347 1891.406 1889.466
 [833] 1887.526 1885.588 1883.651 1881.715 1879.779 1877.845 1875.912 1873.980
 [841] 1872.048 1870.118 1868.189 1866.260 1864.333 1862.407 1860.482 1858.557
 [849] 1856.634 1854.712 1852.791 1850.870 1848.951 1847.033 1845.116 1843.199
 [857] 1841.284 1839.370 1837.456 1835.544 1833.633 1831.723 1829.813 1827.905
 [865] 1825.998 1824.092 1822.186 1820.282 1818.379 1816.476 1814.575 1812.675
 [873] 1810.776 1808.877 1806.980 1805.084 1803.189 1801.294 1799.401 1797.509
 [881] 1795.617 1793.727 1791.838 1789.950 1788.062 1786.176 1784.291 1782.407
 [889] 1780.523 1778.641 1776.760 1774.880 1773.000 1771.122 1769.245 1767.368
 [897] 1765.493 1763.619 1761.746 1759.873 1758.002 1756.132 1754.263 1752.394
 [905] 1750.527 1748.661 1746.795 1744.931 1743.068 1741.206 1739.344 1737.484
 [913] 1735.625 1733.767 1731.909 1730.053 1728.198 1726.344 1724.490 1722.638
 [921] 1720.787 1718.936 1717.087 1715.239 1713.392 1711.545 1709.700 1707.856
 [929] 1706.013 1704.170 1702.329 1700.489 1698.649 1696.811 1694.974 1693.138
 [937] 1691.302 1689.468 1687.635 1685.803 1683.971 1682.141 1680.312 1678.483
 [945] 1676.656 1674.830 1673.005 1671.180 1669.357 1667.535 1665.714 1663.893
 [953] 1662.074 1660.256 1658.439 1656.622 1654.807 1652.993 1651.179 1649.367
 [961] 1647.556 1645.746 1643.936 1642.128 1640.321 1638.515 1636.709 1634.905
 [969] 1633.102 1631.299 1629.498 1627.698 1625.899 1624.100 1622.303 1620.507
 [977] 1618.712 1616.917 1615.124 1613.332 1611.540 1609.750 1607.961 1606.173
 [985] 1604.385 1602.599 1600.814 1599.030 1597.246 1595.464 1593.683 1591.903
 [993] 1590.123 1588.345 1586.568 1584.791 1583.016 1581.242 1579.469 1577.696
 [ reached getOption("max.print") -- omitted 5001 entries ]
> as
   [1] -30.00 -29.99 -29.98 -29.97 -29.96 -29.95 -29.94 -29.93 -29.92 -29.91 -29.90
  [12] -29.89 -29.88 -29.87 -29.86 -29.85 -29.84 -29.83 -29.82 -29.81 -29.80 -29.79
  [23] -29.78 -29.77 -29.76 -29.75 -29.74 -29.73 -29.72 -29.71 -29.70 -29.69 -29.68
  [34] -29.67 -29.66 -29.65 -29.64 -29.63 -29.62 -29.61 -29.60 -29.59 -29.58 -29.57
  [45] -29.56 -29.55 -29.54 -29.53 -29.52 -29.51 -29.50 -29.49 -29.48 -29.47 -29.46
  [56] -29.45 -29.44 -29.43 -29.42 -29.41 -29.40 -29.39 -29.38 -29.37 -29.36 -29.35
  [67] -29.34 -29.33 -29.32 -29.31 -29.30 -29.29 -29.28 -29.27 -29.26 -29.25 -29.24
  [78] -29.23 -29.22 -29.21 -29.20 -29.19 -29.18 -29.17 -29.16 -29.15 -29.14 -29.13
  [89] -29.12 -29.11 -29.10 -29.09 -29.08 -29.07 -29.06 -29.05 -29.04 -29.03 -29.02
 [100] -29.01 -29.00 -28.99 -28.98 -28.97 -28.96 -28.95 -28.94 -28.93 -28.92 -28.91
 [111] -28.90 -28.89 -28.88 -28.87 -28.86 -28.85 -28.84 -28.83 -28.82 -28.81 -28.80
 [122] -28.79 -28.78 -28.77 -28.76 -28.75 -28.74 -28.73 -28.72 -28.71 -28.70 -28.69
 [133] -28.68 -28.67 -28.66 -28.65 -28.64 -28.63 -28.62 -28.61 -28.60 -28.59 -28.58
 [144] -28.57 -28.56 -28.55 -28.54 -28.53 -28.52 -28.51 -28.50 -28.49 -28.48 -28.47
 [155] -28.46 -28.45 -28.44 -28.43 -28.42 -28.41 -28.40 -28.39 -28.38 -28.37 -28.36
 [166] -28.35 -28.34 -28.33 -28.32 -28.31 -28.30 -28.29 -28.28 -28.27 -28.26 -28.25
 [177] -28.24 -28.23 -28.22 -28.21 -28.20 -28.19 -28.18 -28.17 -28.16 -28.15 -28.14
 [188] -28.13 -28.12 -28.11 -28.10 -28.09 -28.08 -28.07 -28.06 -28.05 -28.04 -28.03
 [199] -28.02 -28.01 -28.00 -27.99 -27.98 -27.97 -27.96 -27.95 -27.94 -27.93 -27.92
 [210] -27.91 -27.90 -27.89 -27.88 -27.87 -27.86 -27.85 -27.84 -27.83 -27.82 -27.81
 [221] -27.80 -27.79 -27.78 -27.77 -27.76 -27.75 -27.74 -27.73 -27.72 -27.71 -27.70
 [232] -27.69 -27.68 -27.67 -27.66 -27.65 -27.64 -27.63 -27.62 -27.61 -27.60 -27.59
 [243] -27.58 -27.57 -27.56 -27.55 -27.54 -27.53 -27.52 -27.51 -27.50 -27.49 -27.48
 [254] -27.47 -27.46 -27.45 -27.44 -27.43 -27.42 -27.41 -27.40 -27.39 -27.38 -27.37
 [265] -27.36 -27.35 -27.34 -27.33 -27.32 -27.31 -27.30 -27.29 -27.28 -27.27 -27.26
 [276] -27.25 -27.24 -27.23 -27.22 -27.21 -27.20 -27.19 -27.18 -27.17 -27.16 -27.15
 [287] -27.14 -27.13 -27.12 -27.11 -27.10 -27.09 -27.08 -27.07 -27.06 -27.05 -27.04
 [298] -27.03 -27.02 -27.01 -27.00 -26.99 -26.98 -26.97 -26.96 -26.95 -26.94 -26.93
 [309] -26.92 -26.91 -26.90 -26.89 -26.88 -26.87 -26.86 -26.85 -26.84 -26.83 -26.82
 [320] -26.81 -26.80 -26.79 -26.78 -26.77 -26.76 -26.75 -26.74 -26.73 -26.72 -26.71
 [331] -26.70 -26.69 -26.68 -26.67 -26.66 -26.65 -26.64 -26.63 -26.62 -26.61 -26.60
 [342] -26.59 -26.58 -26.57 -26.56 -26.55 -26.54 -26.53 -26.52 -26.51 -26.50 -26.49
 [353] -26.48 -26.47 -26.46 -26.45 -26.44 -26.43 -26.42 -26.41 -26.40 -26.39 -26.38
 [364] -26.37 -26.36 -26.35 -26.34 -26.33 -26.32 -26.31 -26.30 -26.29 -26.28 -26.27
 [375] -26.26 -26.25 -26.24 -26.23 -26.22 -26.21 -26.20 -26.19 -26.18 -26.17 -26.16
 [386] -26.15 -26.14 -26.13 -26.12 -26.11 -26.10 -26.09 -26.08 -26.07 -26.06 -26.05
 [397] -26.04 -26.03 -26.02 -26.01 -26.00 -25.99 -25.98 -25.97 -25.96 -25.95 -25.94
 [408] -25.93 -25.92 -25.91 -25.90 -25.89 -25.88 -25.87 -25.86 -25.85 -25.84 -25.83
 [419] -25.82 -25.81 -25.80 -25.79 -25.78 -25.77 -25.76 -25.75 -25.74 -25.73 -25.72
 [430] -25.71 -25.70 -25.69 -25.68 -25.67 -25.66 -25.65 -25.64 -25.63 -25.62 -25.61
 [441] -25.60 -25.59 -25.58 -25.57 -25.56 -25.55 -25.54 -25.53 -25.52 -25.51 -25.50
 [452] -25.49 -25.48 -25.47 -25.46 -25.45 -25.44 -25.43 -25.42 -25.41 -25.40 -25.39
 [463] -25.38 -25.37 -25.36 -25.35 -25.34 -25.33 -25.32 -25.31 -25.30 -25.29 -25.28
 [474] -25.27 -25.26 -25.25 -25.24 -25.23 -25.22 -25.21 -25.20 -25.19 -25.18 -25.17
 [485] -25.16 -25.15 -25.14 -25.13 -25.12 -25.11 -25.10 -25.09 -25.08 -25.07 -25.06
 [496] -25.05 -25.04 -25.03 -25.02 -25.01 -25.00 -24.99 -24.98 -24.97 -24.96 -24.95
 [507] -24.94 -24.93 -24.92 -24.91 -24.90 -24.89 -24.88 -24.87 -24.86 -24.85 -24.84
 [518] -24.83 -24.82 -24.81 -24.80 -24.79 -24.78 -24.77 -24.76 -24.75 -24.74 -24.73
 [529] -24.72 -24.71 -24.70 -24.69 -24.68 -24.67 -24.66 -24.65 -24.64 -24.63 -24.62
 [540] -24.61 -24.60 -24.59 -24.58 -24.57 -24.56 -24.55 -24.54 -24.53 -24.52 -24.51
 [551] -24.50 -24.49 -24.48 -24.47 -24.46 -24.45 -24.44 -24.43 -24.42 -24.41 -24.40
 [562] -24.39 -24.38 -24.37 -24.36 -24.35 -24.34 -24.33 -24.32 -24.31 -24.30 -24.29
 [573] -24.28 -24.27 -24.26 -24.25 -24.24 -24.23 -24.22 -24.21 -24.20 -24.19 -24.18
 [584] -24.17 -24.16 -24.15 -24.14 -24.13 -24.12 -24.11 -24.10 -24.09 -24.08 -24.07
 [595] -24.06 -24.05 -24.04 -24.03 -24.02 -24.01 -24.00 -23.99 -23.98 -23.97 -23.96
 [606] -23.95 -23.94 -23.93 -23.92 -23.91 -23.90 -23.89 -23.88 -23.87 -23.86 -23.85
 [617] -23.84 -23.83 -23.82 -23.81 -23.80 -23.79 -23.78 -23.77 -23.76 -23.75 -23.74
 [628] -23.73 -23.72 -23.71 -23.70 -23.69 -23.68 -23.67 -23.66 -23.65 -23.64 -23.63
 [639] -23.62 -23.61 -23.60 -23.59 -23.58 -23.57 -23.56 -23.55 -23.54 -23.53 -23.52
 [650] -23.51 -23.50 -23.49 -23.48 -23.47 -23.46 -23.45 -23.44 -23.43 -23.42 -23.41
 [661] -23.40 -23.39 -23.38 -23.37 -23.36 -23.35 -23.34 -23.33 -23.32 -23.31 -23.30
 [672] -23.29 -23.28 -23.27 -23.26 -23.25 -23.24 -23.23 -23.22 -23.21 -23.20 -23.19
 [683] -23.18 -23.17 -23.16 -23.15 -23.14 -23.13 -23.12 -23.11 -23.10 -23.09 -23.08
 [694] -23.07 -23.06 -23.05 -23.04 -23.03 -23.02 -23.01 -23.00 -22.99 -22.98 -22.97
 [705] -22.96 -22.95 -22.94 -22.93 -22.92 -22.91 -22.90 -22.89 -22.88 -22.87 -22.86
 [716] -22.85 -22.84 -22.83 -22.82 -22.81 -22.80 -22.79 -22.78 -22.77 -22.76 -22.75
 [727] -22.74 -22.73 -22.72 -22.71 -22.70 -22.69 -22.68 -22.67 -22.66 -22.65 -22.64
 [738] -22.63 -22.62 -22.61 -22.60 -22.59 -22.58 -22.57 -22.56 -22.55 -22.54 -22.53
 [749] -22.52 -22.51 -22.50 -22.49 -22.48 -22.47 -22.46 -22.45 -22.44 -22.43 -22.42
 [760] -22.41 -22.40 -22.39 -22.38 -22.37 -22.36 -22.35 -22.34 -22.33 -22.32 -22.31
 [771] -22.30 -22.29 -22.28 -22.27 -22.26 -22.25 -22.24 -22.23 -22.22 -22.21 -22.20
 [782] -22.19 -22.18 -22.17 -22.16 -22.15 -22.14 -22.13 -22.12 -22.11 -22.10 -22.09
 [793] -22.08 -22.07 -22.06 -22.05 -22.04 -22.03 -22.02 -22.01 -22.00 -21.99 -21.98
 [804] -21.97 -21.96 -21.95 -21.94 -21.93 -21.92 -21.91 -21.90 -21.89 -21.88 -21.87
 [815] -21.86 -21.85 -21.84 -21.83 -21.82 -21.81 -21.80 -21.79 -21.78 -21.77 -21.76
 [826] -21.75 -21.74 -21.73 -21.72 -21.71 -21.70 -21.69 -21.68 -21.67 -21.66 -21.65
 [837] -21.64 -21.63 -21.62 -21.61 -21.60 -21.59 -21.58 -21.57 -21.56 -21.55 -21.54
 [848] -21.53 -21.52 -21.51 -21.50 -21.49 -21.48 -21.47 -21.46 -21.45 -21.44 -21.43
 [859] -21.42 -21.41 -21.40 -21.39 -21.38 -21.37 -21.36 -21.35 -21.34 -21.33 -21.32
 [870] -21.31 -21.30 -21.29 -21.28 -21.27 -21.26 -21.25 -21.24 -21.23 -21.22 -21.21
 [881] -21.20 -21.19 -21.18 -21.17 -21.16 -21.15 -21.14 -21.13 -21.12 -21.11 -21.10
 [892] -21.09 -21.08 -21.07 -21.06 -21.05 -21.04 -21.03 -21.02 -21.01 -21.00 -20.99
 [903] -20.98 -20.97 -20.96 -20.95 -20.94 -20.93 -20.92 -20.91 -20.90 -20.89 -20.88
 [914] -20.87 -20.86 -20.85 -20.84 -20.83 -20.82 -20.81 -20.80 -20.79 -20.78 -20.77
 [925] -20.76 -20.75 -20.74 -20.73 -20.72 -20.71 -20.70 -20.69 -20.68 -20.67 -20.66
 [936] -20.65 -20.64 -20.63 -20.62 -20.61 -20.60 -20.59 -20.58 -20.57 -20.56 -20.55
 [947] -20.54 -20.53 -20.52 -20.51 -20.50 -20.49 -20.48 -20.47 -20.46 -20.45 -20.44
 [958] -20.43 -20.42 -20.41 -20.40 -20.39 -20.38 -20.37 -20.36 -20.35 -20.34 -20.33
 [969] -20.32 -20.31 -20.30 -20.29 -20.28 -20.27 -20.26 -20.25 -20.24 -20.23 -20.22
 [980] -20.21 -20.20 -20.19 -20.18 -20.17 -20.16 -20.15 -20.14 -20.13 -20.12 -20.11
 [991] -20.10 -20.09 -20.08 -20.07 -20.06 -20.05 -20.04 -20.03 -20.02 -20.01
 [ reached getOption("max.print") -- omitted 5001 entries ]
> 
> min(mses)
[1] 8.111863
> min.pos.mses <- which(mses == min(mses))
> print(as[min.pos.mses])
[1] -2.29
> summary(mo)

Call:
lm(formula = y ~ x)

Residuals:
      1       2       3       4       5 
-1.7472  1.7379  1.1598 -0.7010 -0.4496 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)  
(Intercept)  -2.2923     4.6038  -0.498   0.6528  
x             3.4510     0.8899   3.878   0.0304 *
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 1.644 on 3 degrees of freedom
Multiple R-squared:  0.8337,	Adjusted R-squared:  0.7783 
F-statistic: 15.04 on 1 and 3 DF,  p-value: 0.03036

> plot(seq(1, length(mses)), mses)
> 
> 

위 방법은 dumb . . . . .
우선 a의 범위가 어디가 될지 몰라서 -10 에서 10까지로 한것
이 두 지점 사이를 0.01 단위로 증가시킨 값을 a값이라고 할 때의 mse값을 구하여 저장한 것
이렇게 구한 mse 값들 중 최소값일 때의 a값을 regression의 a값으로 추정한 것.

이렇게 말고 구할 수 있는 방법은 없을까?
gradient descent

Gradient descend

위에서 a값이 무엇일 때 mse값이 최소가 될까를 봐서 이 때의 a값을 실제 $y = a + bx$ 의 a 값으로 삼았다. 이 때 a값의 추정을 위해서
seq(-10, 10, 0.01) 의 범위와 증가값을 가지고 일일이 대입하여 mse를 구아였다.

위에서의 0.01씩 증가시켜 대입하는 것이 아니라 처음 한 숫자에서 시작한 후 그 다음 숫자를 정한 후에 점진적으로 그 숫자 간격을 줄여가면서 보면 정율적으로 0.01씩 증가시키는 것 보다 효율적일 것 같다. 이 증가분을 구하기 위해서 미분을 사용한다.

점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기
prerequisite:
표준편차 추론에서 평균을 사용하는 이유: 실험적_수학적_이해
deriviation of a and b in a simple regression

위의 문서는 a, b에 대한 값을 미분법을 이용해서 직접 구하였다. 컴퓨터로는 이렇게 하기가 쉽지 않다. 그렇다면 이 값을 반복계산을 이용해서 추출하는 방법은 없을까? gradient descent

\begin{eqnarray*} \text{for a (constant)} \\ \\ \text{SSE} & = & \text{Sum of Square Residuals} \\ \text{Residual} & = & (Y_i - (a + bX_i)) \\ \\ \frac{\text{dSSE}}{\text{da}} & = & \frac{\text{dResidual^2}}{\text{dResidual}} * \frac{\text{dResidual}}{\text{da}} \\ & = & 2 * \text{Residual} * \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) \\ & \because & \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) = -1 \\ & = & 2 * \sum{(Y_i - (a + bX_i))} * -1 \\ & = & -2 *\sum{\text{Residual}} \\ \end{eqnarray*}
아래 R code에서 gradient function을 참조.

\begin{eqnarray*} \text{for b, (coefficient)} \\ \\ \dfrac{\text{d}}{\text{db}} \sum{(Y_i - (a + bX_i))^2} & = & \sum{\dfrac{\text{dResidual}^2}{\text{db}}} \\ & = & \sum{\dfrac{\text{dResidual}^2}{\text{dResidual}}*\dfrac{\text{dResidual}}{\text{db}} } \\ & = & \sum{2*\text{Residual} * \dfrac{\text{dResidual}}{\text{db}} } \\ & = & \sum{2*\text{Residual} * (-X_i) } \;\;\;\; \\ & \because & \dfrac{\text{dResidual}}{\text{db}} = (Y_i - (a+bX_i)) = -X_i \\ & = & -2 X_i \sum{(Y_i - (a + bX_i))} \\ & = & -2 * X_i * \sum{\text{residual}} \\ \\ \end{eqnarray*}

(미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/N)*X값을 곱한 값이다.

gradient <- function(x, y, predictions){
  error = y - predictions
  db = -2 * mean(x * error)
  da = -2 * mean(error)
  return(list("b" = db, "a" = da))
}

위 펑션으로 얻은 da와 db값을 초기에 설정한 a, b 값에 더해 준 값을 다시 a, b값으로 하여 gradient 펑션을 통해서 다시 db, da값을 구하고 이를 다시 이전 단계에서 구한 a, b값에 더하여 그 값을 다시 a, b값을 하여 . . . .

위를 반복한다. 단, db값과 da값을 그냥 대입하기 보다는 초기에 설정한 learning.rate값을 (0.01 예를 들면) 곱하여 구한 값을 더하게 된다. 이것이 아래의 code이다.

gradient <- function(x, y, predictions){
  error = y - predictions
  db = -2 * mean(x * error)
  da = -2 * mean(error)
  return(list("b" = db, "a" = da))
}

mseloss <- function(predictions, y) {
  residuals <- (y - predictions)
  return(mean(residuals^2))
}

# Train the model with scaled features
learning_rate = 1e-2
lr = 1e-1

# Record Loss for each epoch:
logs = list()
as = c()
bs = c()
mse = c()
sse = c()
x.ori <- x
zx <- (x-mean(x))/sd(x)

nlen <- 50
for (epoch in 1:nlen) {
  predictions <- predict(zx, a, b)
  loss <- mseloss(predictions, y)
  mse <- append(mse, loss)
  
  grad <- gradient(zx, y, predictions)
  
  step.b <- grad$b * lr 
  step.a <- grad$a * lr
  b <- b-step.b
  a <- a-step.a
  
  as <- append(as, a)
  bs <- append(bs, b)
}

R code

# d statquest explanation
# x <- c(0.5, 2.3, 2.9)
# y <- c(1.4, 1.9, 3.2)

rm(list=ls())
# set.seed(191)
n <- 300
x <- rnorm(n, 5, 1.2)
y <- 2.14 * x + rnorm(n, 0, 4)

# data <- data.frame(x, y)
data <- tibble(x = x, y = y)

mo <- lm(y~x)
summary(mo)

# set.seed(191)
# Initialize random betas
b1 = rnorm(1)
b0 = rnorm(1)

b1.init <- b1
b0.init <- b0

# Predict function:
predict <- function(x, b0, b1){
  return (b0 + b1 * x)
}

# And loss function is:
residuals <- function(predictions, y) {
  return(y - predictions)
}

loss_mse <- function(predictions, y){
  residuals = y - predictions
  return(mean(residuals ^ 2))
}

predictions <- predict(x, b0, b1)
residuals <- residuals(predictions, y)
loss = loss_mse(predictions, y)

data <- tibble(data.frame(x, y, predictions, residuals))

print(paste0("Loss is: ", round(loss)))

gradient <- function(x, y, predictions){
  dinputs = y - predictions
  db1 = -2 * mean(x * dinputs)
  db0 = -2 * mean(dinputs)
  
  return(list("db1" = db1, "db0" = db0))
}

gradients <- gradient(x, y, predictions)
print(gradients)

# Train the model with scaled features
x_scaled <- (x - mean(x)) / sd(x)

learning_rate = 1e-1

# Record Loss for each epoch:
# logs = list()
# bs=list()
b0s = c()
b1s = c()
mse = c()

nlen <- 80
for (epoch in 1:nlen){
  # Predict all y values:
  predictions = predict(x_scaled, b0, b1)
  loss = loss_mse(predictions, y)
  mse = append(mse, loss)
  # logs = append(logs, loss)
  
  if (epoch %% 10 == 0){
    print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
  }
  
  gradients <- gradient(x_scaled, y, predictions)
  db1 <- gradients$db1
  db0 <- gradients$db0
  
  b1 <- b1 - db1 * learning_rate
  b0 <- b0 - db0 * learning_rate
  b0s <- append(b0s, b0)
  b1s <- append(b1s, b1)
}

# unscale coefficients to make them comprehensible
b0 =  b0 - (mean(x) / sd(x)) * b1
b1 = b1 / sd(x)

# changes of estimators
b0s <- b0s - (mean(x) /sd(x)) * b1s
b1s <- b1s / sd(x)

parameters <- tibble(data.frame(b0s, b1s, mse))

cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n"))
summary(lm(y~x))$coefficients

ggplot(data, aes(x = x, y = y)) + 
  geom_point(size = 2) + 
  geom_abline(aes(intercept = b0s, slope = b1s),
              data = parameters, linewidth = 0.5, 
              color = 'green') + 
  theme_classic() +
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_head(), 
              linewidth = 1, color = 'blue') + 
  geom_abline(aes(intercept = b0s, slope = b1s), 
              data = parameters %>% slice_tail(), 
              linewidth = 1, color = 'red') +
  labs(title = 'Gradient descent. blue: start, red: end, green: gradients')

b0.init
b1.init

data
parameters

R output

> rm(list=ls())
> # set.seed(191)
> n <- 300
> x <- rnorm(n, 5, 1.2)
> y <- 2.14 * x + rnorm(n, 0, 4)
> 
> # data <- data.frame(x, y)
> data <- tibble(x = x, y = y)
> 
> mo <- lm(y~x)
> summary(mo)

Call:
lm(formula = y ~ x)

Residuals:
   Min     1Q Median     3Q    Max 
-9.754 -2.729 -0.135  2.415 10.750 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)  -0.7794     0.9258  -0.842    0.401    
x             2.2692     0.1793  12.658   <2e-16 ***
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Residual standard error: 3.951 on 298 degrees of freedom
Multiple R-squared:  0.3497,	Adjusted R-squared:  0.3475 
F-statistic: 160.2 on 1 and 298 DF,  p-value: < 2.2e-16

> 
> # set.seed(191)
> # Initialize random betas
> b1 = rnorm(1)
> b0 = rnorm(1)
> 
> b1.init <- b1
> b0.init <- b0
> 
> # Predict function:
> predict <- function(x, b0, b1){
+   return (b0 + b1 * x)
+ }
> 
> # And loss function is:
> residuals <- function(predictions, y) {
+   return(y - predictions)
+ }
> 
> loss_mse <- function(predictions, y){
+   residuals = y - predictions
+   return(mean(residuals ^ 2))
+ }
> 
> predictions <- predict(x, b0, b1)
> residuals <- residuals(predictions, y)
> loss = loss_mse(predictions, y)
> 
> data <- tibble(data.frame(x, y, predictions, residuals))
> 
> print(paste0("Loss is: ", round(loss)))
[1] "Loss is: 393"
> 
> gradient <- function(x, y, predictions){
+   dinputs = y - predictions
+   db1 = -2 * mean(x * dinputs)
+   db0 = -2 * mean(dinputs)
+   
+   return(list("db1" = db1, "db0" = db0))
+ }
> 
> gradients <- gradient(x, y, predictions)
> print(gradients)
$db1
[1] -200.6834

$db0
[1] -37.76994

> 
> # Train the model with scaled features
> x_scaled <- (x - mean(x)) / sd(x)
> 
> learning_rate = 1e-1
> 
> # Record Loss for each epoch:
> # logs = list()
> # bs=list()
> b0s = c()
> b1s = c()
> mse = c()
> 
> nlen <- 80
> for (epoch in 1:nlen){
+   # Predict all y values:
+   predictions = predict(x_scaled, b0, b1)
+   loss = loss_mse(predictions, y)
+   mse = append(mse, loss)
+   # logs = append(logs, loss)
+   
+   if (epoch %% 10 == 0){
+     print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5)))
+   }
+   
+   gradients <- gradient(x_scaled, y, predictions)
+   db1 <- gradients$db1
+   db0 <- gradients$db0
+   
+   b1 <- b1 - db1 * learning_rate
+   b0 <- b0 - db0 * learning_rate
+   b0s <- append(b0s, b0)
+   b1s <- append(b1s, b1)
+ }
[1] "Epoch: 10, Loss: 18.5393"
[1] "Epoch: 20, Loss: 15.54339"
[1] "Epoch: 30, Loss: 15.50879"
[1] "Epoch: 40, Loss: 15.50839"
[1] "Epoch: 50, Loss: 15.50839"
[1] "Epoch: 60, Loss: 15.50839"
[1] "Epoch: 70, Loss: 15.50839"
[1] "Epoch: 80, Loss: 15.50839"
> 
> # unscale coefficients to make them comprehensible
> b0 =  b0 - (mean(x) / sd(x)) * b1
> b1 = b1 / sd(x)
> 
> # changes of estimators
> b0s <- b0s - (mean(x) /sd(x)) * b1s
> b1s <- b1s / sd(x)
> 
> parameters <- tibble(data.frame(b0s, b1s, mse))
> 
> cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n"))
Slope: 2.26922511738252, 
Intercept: -0.779435058320381
> summary(lm(y~x))$coefficients
              Estimate Std. Error    t value     Pr(>|t|)
(Intercept) -0.7794352  0.9258064 -0.8418986 4.005198e-01
x            2.2692252  0.1792660 12.6584242 1.111614e-29
> 
> ggplot(data, aes(x = x, y = y)) + 
+   geom_point(size = 2) + 
+   geom_abline(aes(intercept = b0s, slope = b1s),
+               data = parameters, linewidth = 0.5, 
+               color = 'green') + 
+   theme_classic() +
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_head(), 
+               linewidth = 1, color = 'blue') + 
+   geom_abline(aes(intercept = b0s, slope = b1s), 
+               data = parameters %>% slice_tail(), 
+               linewidth = 1, color = 'red') +
+   labs(title = 'Gradient descent. blue: start, red: end, green: gradients')
> 
> b0.init
[1] -1.67967
> b1.init
[1] -1.323992
> 
> data
# A tibble: 300 × 4
       x     y predictions residuals
   <dbl> <dbl>       <dbl>     <dbl>
 1  4.13  6.74       -7.14     13.9 
 2  7.25 14.0       -11.3      25.3 
 3  6.09 13.5        -9.74     23.3 
 4  6.29 15.1       -10.0      25.1 
 5  4.40  3.81       -7.51     11.3 
 6  6.03 13.9        -9.67     23.5 
 7  6.97 12.1       -10.9      23.0 
 8  4.84 12.8        -8.09     20.9 
 9  6.85 17.2       -10.7      28.0 
10  3.33  3.80       -6.08      9.88
# ℹ 290 more rows
# ℹ Use `print(n = ...)` to see more rows
> parameters
# A tibble: 80 × 3
       b0s    b1s   mse
     <dbl>  <dbl> <dbl>
 1  2.67   -0.379 183. 
 2  1.99    0.149 123. 
 3  1.44    0.571  84.3
 4  1.00    0.910  59.6
 5  0.652   1.18   43.7
 6  0.369   1.40   33.6
 7  0.142   1.57   27.1
 8 -0.0397  1.71   22.9
 9 -0.186   1.82   20.2
10 -0.303   1.91   18.5
# ℹ 70 more rows
#

gradient_descent.1755745298.txt.gz · Last modified: 2025/08/21 12:01 by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki