This is an old revision of the document!
Table of Contents
Gradient Descent
explanation
Why normalize (scale or make z-score) xi
x 변인의 측정단위로 인해서 b 값이 결정되게 되는데 이 때의 b값은 상당하고 다양한 범위를 가질 수 있다. 가령 월 수입이 (인컴) X 라고 한다면 우리가 추정해야 (추적해야) 할 b값은 수백만이 될 수도 있다.이 값을 gradient로 추적하게 된다면 너무도 많은 iteration을 거쳐야 할 수 있다. 변인이 바뀌면 이 b의 추적범위도 드라마틱하게 바뀌게 된다. 이를 표준화한 x 점수를 사용하게 된다면 일정한 learning rate와 iteration만으로도 정확한 a와 b를 추적할 수 있게 된다.
How to unnormalize (unscale) a and b
\begin{eqnarray*} y & = & a + b * x \\ & & \text{we use z instead of x} \\ & & \text{and } \\ & & z = \frac{(x - \mu)}{\sigma} \\ & & \text{suppose that the result after calculation be } \\ y & = & k + m * z \\ & = & k + m * \frac{(x - \mu)}{\sigma} \\ & = & k + \frac{m * x}{\sigma} - \frac{m * \mu}{\sigma} \\ & = & k - \frac{m * \mu}{\sigma} + \frac{m * x}{\sigma} \\ & = & k - \frac{\mu}{\sigma} * m + \frac{m}{\sigma} * x \\ & & \text{therefore, a and be that we try to get are } \\ a & = & k - \frac{\mu}{\sigma} * m \\ b & = & \frac{m}{\sigma} \\ \end{eqnarray*}
R code: Idea
library(tidyverse) library(data.table) rm(list=ls()) # set.seed(191) n <- 5 x <- rnorm(n, 5, 1.2) y <- 3.14 * x + rnorm(n,0,1) # data <- data.frame(x, y) data <- tibble(x = x, y = y) mo <- lm(y~x) summary(mo) # set.seed(191) # Initialize random betas # 우선 b를 고정하고 a만 # 변화시켜서 이해 b <- summary(mo)$coefficients[2] a <- 0 b.init <- b a.init <- a # Predict function: predict <- function(x, a, b){ return (a + b * x) } # And loss function is: residuals <- function(predictions, y) { return(y - predictions) } # we use sum of square of error which oftentimes become big mseloss <- function(predictions, y) { residuals <- (y - predictions) return(sum(residuals^2)) } mses <- c() j <- 0 as <- c() for (i in seq(from = -30, to = 30, by = 0.01)) { pred <- predict(x, i, b) res <- residuals(pred, y) mse <- mseloss(pred, y) mses <- append(mses, mse) as <- append(as,i) } mses as min(mses) min.pos.mses <- which(mses == min(mses)) print(as[min.pos.mses]) summary(mo) plot(seq(1, length(mses)), mses)
output
> rm(list=ls()) > # set.seed(191) > n <- 5 > x <- rnorm(n, 5, 1.2) > y <- 3.14 * x + rnorm(n,0,1) > > # data <- data.frame(x, y) > data <- tibble(x = x, y = y) > > mo <- lm(y~x) > summary(mo) Call: lm(formula = y ~ x) Residuals: 1 2 3 4 5 -1.7472 1.7379 1.1598 -0.7010 -0.4496 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.2923 4.6038 -0.498 0.6528 x 3.4510 0.8899 3.878 0.0304 * --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.644 on 3 degrees of freedom Multiple R-squared: 0.8337, Adjusted R-squared: 0.7783 F-statistic: 15.04 on 1 and 3 DF, p-value: 0.03036 > > # set.seed(191) > # Initialize random betas > # 우선 b를 고정하고 a만 > # 변화시켜서 이해 > b <- summary(mo)$coefficients[2] > a <- 0 > > b.init <- b > a.init <- a > > # Predict function: > predict <- function(x, a, b){ + return (a + b * x) + } > > # And loss function is: > residuals <- function(predictions, y) { + return(y - predictions) + } > > # we use sum of square of error which oftentimes become big > mseloss <- function(predictions, y) { + residuals <- (y - predictions) + return(sum(residuals^2)) + } > > mses <- c() > j <- 0 > as <- c() > > for (i in seq(from = -30, to = 30, by = 0.01)) { + pred <- predict(x, i, b) + res <- residuals(pred, y) + mse <- mseloss(pred, y) + mses <- append(mses, mse) + as <- append(as,i) + } > mses [1] 3846.695 3843.925 3841.156 3838.387 3835.620 3832.854 3830.088 3827.324 [9] 3824.561 3821.799 3819.037 3816.277 3813.518 3810.760 3808.002 3805.246 [17] 3802.491 3799.737 3796.983 3794.231 3791.480 3788.729 3785.980 3783.232 [25] 3780.485 3777.738 3774.993 3772.249 3769.506 3766.763 3764.022 3761.282 [33] 3758.542 3755.804 3753.067 3750.331 3747.595 3744.861 3742.128 3739.396 [41] 3736.664 3733.934 3731.205 3728.477 3725.749 3723.023 3720.298 3717.573 [49] 3714.850 3712.128 3709.407 3706.686 3703.967 3701.249 3698.532 3695.815 [57] 3693.100 3690.386 3687.672 3684.960 3682.249 3679.539 3676.829 3674.121 [65] 3671.414 3668.708 3666.002 3663.298 3660.595 3657.892 3655.191 3652.491 [73] 3649.792 3647.093 3644.396 3641.700 3639.005 3636.310 3633.617 3630.925 [81] 3628.234 3625.543 3622.854 3620.166 3617.478 3614.792 3612.107 3609.423 [89] 3606.739 3604.057 3601.376 3598.696 3596.016 3593.338 3590.661 3587.984 [97] 3585.309 3582.635 3579.962 3577.289 3574.618 3571.948 3569.279 3566.610 [105] 3563.943 3561.277 3558.611 3555.947 3553.284 3550.622 3547.960 3545.300 [113] 3542.641 3539.983 3537.325 3534.669 3532.014 3529.360 3526.706 3524.054 [121] 3521.403 3518.752 3516.103 3513.455 3510.808 3508.161 3505.516 3502.872 [129] 3500.229 3497.586 3494.945 3492.305 3489.665 3487.027 3484.390 3481.754 [137] 3479.118 3476.484 3473.851 3471.219 3468.587 3465.957 3463.328 3460.700 [145] 3458.072 3455.446 3452.821 3450.196 3447.573 3444.951 3442.330 3439.709 [153] 3437.090 3434.472 3431.855 3429.238 3426.623 3424.009 3421.395 3418.783 [161] 3416.172 3413.562 3410.952 3408.344 3405.737 3403.131 3400.525 3397.921 [169] 3395.318 3392.715 3390.114 3387.514 3384.915 3382.316 3379.719 3377.123 [177] 3374.528 3371.933 3369.340 3366.748 3364.157 3361.566 3358.977 3356.389 [185] 3353.801 3351.215 3348.630 3346.046 3343.462 3340.880 3338.299 3335.719 [193] 3333.139 3330.561 3327.984 3325.407 3322.832 3320.258 3317.685 3315.112 [201] 3312.541 3309.971 3307.402 3304.833 3302.266 3299.700 3297.134 3294.570 [209] 3292.007 3289.445 3286.883 3284.323 3281.764 3279.206 3276.648 3274.092 [217] 3271.537 3268.983 3266.429 3263.877 3261.326 3258.775 3256.226 3253.678 [225] 3251.131 3248.584 3246.039 3243.495 3240.952 3238.409 3235.868 3233.328 [233] 3230.788 3228.250 3225.713 3223.177 3220.641 3218.107 3215.574 3213.042 [241] 3210.510 3207.980 3205.451 3202.923 3200.395 3197.869 3195.344 3192.819 [249] 3190.296 3187.774 3185.253 3182.732 3180.213 3177.695 3175.178 3172.661 [257] 3170.146 3167.632 3165.118 3162.606 3160.095 3157.585 3155.075 3152.567 [265] 3150.060 3147.554 3145.048 3142.544 3140.041 3137.538 3135.037 3132.537 [273] 3130.038 3127.539 3125.042 3122.546 3120.051 3117.556 3115.063 3112.571 [281] 3110.080 3107.589 3105.100 3102.612 3100.124 3097.638 3095.153 3092.669 [289] 3090.185 3087.703 3085.222 3082.742 3080.262 3077.784 3075.307 3072.830 [297] 3070.355 3067.881 3065.408 3062.935 3060.464 3057.994 3055.525 3053.056 [305] 3050.589 3048.123 3045.657 3043.193 3040.730 3038.268 3035.806 3033.346 [313] 3030.887 3028.429 3025.971 3023.515 3021.060 3018.606 3016.152 3013.700 [321] 3011.249 3008.798 3006.349 3003.901 3001.454 2999.007 2996.562 2994.118 [329] 2991.675 2989.232 2986.791 2984.351 2981.911 2979.473 2977.036 2974.600 [337] 2972.164 2969.730 2967.297 2964.865 2962.433 2960.003 2957.574 2955.145 [345] 2952.718 2950.292 2947.867 2945.442 2943.019 2940.597 2938.176 2935.755 [353] 2933.336 2930.918 2928.501 2926.084 2923.669 2921.255 2918.841 2916.429 [361] 2914.018 2911.608 2909.198 2906.790 2904.383 2901.977 2899.571 2897.167 [369] 2894.764 2892.361 2889.960 2887.560 2885.161 2882.762 2880.365 2877.969 [377] 2875.574 2873.179 2870.786 2868.394 2866.003 2863.612 2861.223 2858.835 [385] 2856.447 2854.061 2851.676 2849.292 2846.908 2844.526 2842.145 2839.765 [393] 2837.385 2835.007 2832.630 2830.253 2827.878 2825.504 2823.131 2820.758 [401] 2818.387 2816.017 2813.648 2811.279 2808.912 2806.546 2804.180 2801.816 [409] 2799.453 2797.091 2794.729 2792.369 2790.010 2787.652 2785.294 2782.938 [417] 2780.583 2778.229 2775.875 2773.523 2771.172 2768.821 2766.472 2764.124 [425] 2761.777 2759.430 2757.085 2754.741 2752.398 2750.055 2747.714 2745.374 [433] 2743.034 2740.696 2738.359 2736.023 2733.687 2731.353 2729.020 2726.688 [441] 2724.356 2722.026 2719.697 2717.368 2715.041 2712.715 2710.390 2708.065 [449] 2705.742 2703.420 2701.099 2698.778 2696.459 2694.141 2691.824 2689.507 [457] 2687.192 2684.878 2682.564 2680.252 2677.941 2675.631 2673.321 2671.013 [465] 2668.706 2666.400 2664.094 2661.790 2659.487 2657.184 2654.883 2652.583 [473] 2650.284 2647.985 2645.688 2643.392 2641.097 2638.802 2636.509 2634.217 [481] 2631.926 2629.635 2627.346 2625.058 2622.770 2620.484 2618.199 2615.915 [489] 2613.631 2611.349 2609.068 2606.788 2604.508 2602.230 2599.953 2597.676 [497] 2595.401 2593.127 2590.854 2588.581 2586.310 2584.040 2581.771 2579.502 [505] 2577.235 2574.969 2572.703 2570.439 2568.176 2565.914 2563.652 2561.392 [513] 2559.133 2556.875 2554.617 2552.361 2550.106 2547.852 2545.598 2543.346 [521] 2541.095 2538.844 2536.595 2534.347 2532.100 2529.853 2527.608 2525.364 [529] 2523.121 2520.878 2518.637 2516.397 2514.157 2511.919 2509.682 2507.446 [537] 2505.210 2502.976 2500.743 2498.511 2496.279 2494.049 2491.820 2489.591 [545] 2487.364 2485.138 2482.913 2480.688 2478.465 2476.243 2474.022 2471.801 [553] 2469.582 2467.364 2465.147 2462.930 2460.715 2458.501 2456.287 2454.075 [561] 2451.864 2449.654 2447.444 2445.236 2443.029 2440.823 2438.617 2436.413 [569] 2434.210 2432.007 2429.806 2427.606 2425.407 2423.208 2421.011 2418.815 [577] 2416.620 2414.425 2412.232 2410.040 2407.849 2405.658 2403.469 2401.281 [585] 2399.093 2396.907 2394.722 2392.538 2390.354 2388.172 2385.991 2383.811 [593] 2381.631 2379.453 2377.276 2375.099 2372.924 2370.750 2368.577 2366.404 [601] 2364.233 2362.063 2359.894 2357.725 2355.558 2353.392 2351.226 2349.062 [609] 2346.899 2344.737 2342.575 2340.415 2338.256 2336.098 2333.940 2331.784 [617] 2329.629 2327.475 2325.321 2323.169 2321.018 2318.867 2316.718 2314.570 [625] 2312.423 2310.276 2308.131 2305.987 2303.844 2301.701 2299.560 2297.420 [633] 2295.280 2293.142 2291.005 2288.869 2286.733 2284.599 2282.466 2280.334 [641] 2278.202 2276.072 2273.943 2271.814 2269.687 2267.561 2265.436 2263.311 [649] 2261.188 2259.066 2256.945 2254.824 2252.705 2250.587 2248.470 2246.353 [657] 2244.238 2242.124 2240.010 2237.898 2235.787 2233.677 2231.567 2229.459 [665] 2227.352 2225.246 2223.140 2221.036 2218.933 2216.830 2214.729 2212.629 [673] 2210.530 2208.431 2206.334 2204.238 2202.143 2200.048 2197.955 2195.863 [681] 2193.771 2191.681 2189.592 2187.504 2185.416 2183.330 2181.245 2179.161 [689] 2177.077 2174.995 2172.914 2170.834 2168.754 2166.676 2164.599 2162.522 [697] 2160.447 2158.373 2156.300 2154.227 2152.156 2150.086 2148.017 2145.948 [705] 2143.881 2141.815 2139.749 2137.685 2135.622 2133.560 2131.498 2129.438 [713] 2127.379 2125.321 2123.263 2121.207 2119.152 2117.098 2115.044 2112.992 [721] 2110.941 2108.890 2106.841 2104.793 2102.746 2100.699 2098.654 2096.610 [729] 2094.567 2092.524 2090.483 2088.443 2086.403 2084.365 2082.328 2080.292 [737] 2078.256 2076.222 2074.189 2072.157 2070.125 2068.095 2066.066 2064.037 [745] 2062.010 2059.984 2057.959 2055.934 2053.911 2051.889 2049.868 2047.847 [753] 2045.828 2043.810 2041.793 2039.776 2037.761 2035.747 2033.733 2031.721 [761] 2029.710 2027.700 2025.690 2023.682 2021.675 2019.669 2017.663 2015.659 [769] 2013.656 2011.653 2009.652 2007.652 2005.653 2003.654 2001.657 1999.661 [777] 1997.666 1995.671 1993.678 1991.686 1989.694 1987.704 1985.715 1983.727 [785] 1981.739 1979.753 1977.768 1975.784 1973.800 1971.818 1969.837 1967.857 [793] 1965.877 1963.899 1961.922 1959.945 1957.970 1955.996 1954.023 1952.050 [801] 1950.079 1948.109 1946.140 1944.171 1942.204 1940.238 1938.272 1936.308 [809] 1934.345 1932.383 1930.421 1928.461 1926.502 1924.544 1922.586 1920.630 [817] 1918.675 1916.721 1914.767 1912.815 1910.864 1908.913 1906.964 1905.016 [825] 1903.069 1901.122 1899.177 1897.233 1895.290 1893.347 1891.406 1889.466 [833] 1887.526 1885.588 1883.651 1881.715 1879.779 1877.845 1875.912 1873.980 [841] 1872.048 1870.118 1868.189 1866.260 1864.333 1862.407 1860.482 1858.557 [849] 1856.634 1854.712 1852.791 1850.870 1848.951 1847.033 1845.116 1843.199 [857] 1841.284 1839.370 1837.456 1835.544 1833.633 1831.723 1829.813 1827.905 [865] 1825.998 1824.092 1822.186 1820.282 1818.379 1816.476 1814.575 1812.675 [873] 1810.776 1808.877 1806.980 1805.084 1803.189 1801.294 1799.401 1797.509 [881] 1795.617 1793.727 1791.838 1789.950 1788.062 1786.176 1784.291 1782.407 [889] 1780.523 1778.641 1776.760 1774.880 1773.000 1771.122 1769.245 1767.368 [897] 1765.493 1763.619 1761.746 1759.873 1758.002 1756.132 1754.263 1752.394 [905] 1750.527 1748.661 1746.795 1744.931 1743.068 1741.206 1739.344 1737.484 [913] 1735.625 1733.767 1731.909 1730.053 1728.198 1726.344 1724.490 1722.638 [921] 1720.787 1718.936 1717.087 1715.239 1713.392 1711.545 1709.700 1707.856 [929] 1706.013 1704.170 1702.329 1700.489 1698.649 1696.811 1694.974 1693.138 [937] 1691.302 1689.468 1687.635 1685.803 1683.971 1682.141 1680.312 1678.483 [945] 1676.656 1674.830 1673.005 1671.180 1669.357 1667.535 1665.714 1663.893 [953] 1662.074 1660.256 1658.439 1656.622 1654.807 1652.993 1651.179 1649.367 [961] 1647.556 1645.746 1643.936 1642.128 1640.321 1638.515 1636.709 1634.905 [969] 1633.102 1631.299 1629.498 1627.698 1625.899 1624.100 1622.303 1620.507 [977] 1618.712 1616.917 1615.124 1613.332 1611.540 1609.750 1607.961 1606.173 [985] 1604.385 1602.599 1600.814 1599.030 1597.246 1595.464 1593.683 1591.903 [993] 1590.123 1588.345 1586.568 1584.791 1583.016 1581.242 1579.469 1577.696 [ reached getOption("max.print") -- omitted 5001 entries ] > as [1] -30.00 -29.99 -29.98 -29.97 -29.96 -29.95 -29.94 -29.93 -29.92 -29.91 -29.90 [12] -29.89 -29.88 -29.87 -29.86 -29.85 -29.84 -29.83 -29.82 -29.81 -29.80 -29.79 [23] -29.78 -29.77 -29.76 -29.75 -29.74 -29.73 -29.72 -29.71 -29.70 -29.69 -29.68 [34] -29.67 -29.66 -29.65 -29.64 -29.63 -29.62 -29.61 -29.60 -29.59 -29.58 -29.57 [45] -29.56 -29.55 -29.54 -29.53 -29.52 -29.51 -29.50 -29.49 -29.48 -29.47 -29.46 [56] -29.45 -29.44 -29.43 -29.42 -29.41 -29.40 -29.39 -29.38 -29.37 -29.36 -29.35 [67] -29.34 -29.33 -29.32 -29.31 -29.30 -29.29 -29.28 -29.27 -29.26 -29.25 -29.24 [78] -29.23 -29.22 -29.21 -29.20 -29.19 -29.18 -29.17 -29.16 -29.15 -29.14 -29.13 [89] -29.12 -29.11 -29.10 -29.09 -29.08 -29.07 -29.06 -29.05 -29.04 -29.03 -29.02 [100] -29.01 -29.00 -28.99 -28.98 -28.97 -28.96 -28.95 -28.94 -28.93 -28.92 -28.91 [111] -28.90 -28.89 -28.88 -28.87 -28.86 -28.85 -28.84 -28.83 -28.82 -28.81 -28.80 [122] -28.79 -28.78 -28.77 -28.76 -28.75 -28.74 -28.73 -28.72 -28.71 -28.70 -28.69 [133] -28.68 -28.67 -28.66 -28.65 -28.64 -28.63 -28.62 -28.61 -28.60 -28.59 -28.58 [144] -28.57 -28.56 -28.55 -28.54 -28.53 -28.52 -28.51 -28.50 -28.49 -28.48 -28.47 [155] -28.46 -28.45 -28.44 -28.43 -28.42 -28.41 -28.40 -28.39 -28.38 -28.37 -28.36 [166] -28.35 -28.34 -28.33 -28.32 -28.31 -28.30 -28.29 -28.28 -28.27 -28.26 -28.25 [177] -28.24 -28.23 -28.22 -28.21 -28.20 -28.19 -28.18 -28.17 -28.16 -28.15 -28.14 [188] -28.13 -28.12 -28.11 -28.10 -28.09 -28.08 -28.07 -28.06 -28.05 -28.04 -28.03 [199] -28.02 -28.01 -28.00 -27.99 -27.98 -27.97 -27.96 -27.95 -27.94 -27.93 -27.92 [210] -27.91 -27.90 -27.89 -27.88 -27.87 -27.86 -27.85 -27.84 -27.83 -27.82 -27.81 [221] -27.80 -27.79 -27.78 -27.77 -27.76 -27.75 -27.74 -27.73 -27.72 -27.71 -27.70 [232] -27.69 -27.68 -27.67 -27.66 -27.65 -27.64 -27.63 -27.62 -27.61 -27.60 -27.59 [243] -27.58 -27.57 -27.56 -27.55 -27.54 -27.53 -27.52 -27.51 -27.50 -27.49 -27.48 [254] -27.47 -27.46 -27.45 -27.44 -27.43 -27.42 -27.41 -27.40 -27.39 -27.38 -27.37 [265] -27.36 -27.35 -27.34 -27.33 -27.32 -27.31 -27.30 -27.29 -27.28 -27.27 -27.26 [276] -27.25 -27.24 -27.23 -27.22 -27.21 -27.20 -27.19 -27.18 -27.17 -27.16 -27.15 [287] -27.14 -27.13 -27.12 -27.11 -27.10 -27.09 -27.08 -27.07 -27.06 -27.05 -27.04 [298] -27.03 -27.02 -27.01 -27.00 -26.99 -26.98 -26.97 -26.96 -26.95 -26.94 -26.93 [309] -26.92 -26.91 -26.90 -26.89 -26.88 -26.87 -26.86 -26.85 -26.84 -26.83 -26.82 [320] -26.81 -26.80 -26.79 -26.78 -26.77 -26.76 -26.75 -26.74 -26.73 -26.72 -26.71 [331] -26.70 -26.69 -26.68 -26.67 -26.66 -26.65 -26.64 -26.63 -26.62 -26.61 -26.60 [342] -26.59 -26.58 -26.57 -26.56 -26.55 -26.54 -26.53 -26.52 -26.51 -26.50 -26.49 [353] -26.48 -26.47 -26.46 -26.45 -26.44 -26.43 -26.42 -26.41 -26.40 -26.39 -26.38 [364] -26.37 -26.36 -26.35 -26.34 -26.33 -26.32 -26.31 -26.30 -26.29 -26.28 -26.27 [375] -26.26 -26.25 -26.24 -26.23 -26.22 -26.21 -26.20 -26.19 -26.18 -26.17 -26.16 [386] -26.15 -26.14 -26.13 -26.12 -26.11 -26.10 -26.09 -26.08 -26.07 -26.06 -26.05 [397] -26.04 -26.03 -26.02 -26.01 -26.00 -25.99 -25.98 -25.97 -25.96 -25.95 -25.94 [408] -25.93 -25.92 -25.91 -25.90 -25.89 -25.88 -25.87 -25.86 -25.85 -25.84 -25.83 [419] -25.82 -25.81 -25.80 -25.79 -25.78 -25.77 -25.76 -25.75 -25.74 -25.73 -25.72 [430] -25.71 -25.70 -25.69 -25.68 -25.67 -25.66 -25.65 -25.64 -25.63 -25.62 -25.61 [441] -25.60 -25.59 -25.58 -25.57 -25.56 -25.55 -25.54 -25.53 -25.52 -25.51 -25.50 [452] -25.49 -25.48 -25.47 -25.46 -25.45 -25.44 -25.43 -25.42 -25.41 -25.40 -25.39 [463] -25.38 -25.37 -25.36 -25.35 -25.34 -25.33 -25.32 -25.31 -25.30 -25.29 -25.28 [474] -25.27 -25.26 -25.25 -25.24 -25.23 -25.22 -25.21 -25.20 -25.19 -25.18 -25.17 [485] -25.16 -25.15 -25.14 -25.13 -25.12 -25.11 -25.10 -25.09 -25.08 -25.07 -25.06 [496] -25.05 -25.04 -25.03 -25.02 -25.01 -25.00 -24.99 -24.98 -24.97 -24.96 -24.95 [507] -24.94 -24.93 -24.92 -24.91 -24.90 -24.89 -24.88 -24.87 -24.86 -24.85 -24.84 [518] -24.83 -24.82 -24.81 -24.80 -24.79 -24.78 -24.77 -24.76 -24.75 -24.74 -24.73 [529] -24.72 -24.71 -24.70 -24.69 -24.68 -24.67 -24.66 -24.65 -24.64 -24.63 -24.62 [540] -24.61 -24.60 -24.59 -24.58 -24.57 -24.56 -24.55 -24.54 -24.53 -24.52 -24.51 [551] -24.50 -24.49 -24.48 -24.47 -24.46 -24.45 -24.44 -24.43 -24.42 -24.41 -24.40 [562] -24.39 -24.38 -24.37 -24.36 -24.35 -24.34 -24.33 -24.32 -24.31 -24.30 -24.29 [573] -24.28 -24.27 -24.26 -24.25 -24.24 -24.23 -24.22 -24.21 -24.20 -24.19 -24.18 [584] -24.17 -24.16 -24.15 -24.14 -24.13 -24.12 -24.11 -24.10 -24.09 -24.08 -24.07 [595] -24.06 -24.05 -24.04 -24.03 -24.02 -24.01 -24.00 -23.99 -23.98 -23.97 -23.96 [606] -23.95 -23.94 -23.93 -23.92 -23.91 -23.90 -23.89 -23.88 -23.87 -23.86 -23.85 [617] -23.84 -23.83 -23.82 -23.81 -23.80 -23.79 -23.78 -23.77 -23.76 -23.75 -23.74 [628] -23.73 -23.72 -23.71 -23.70 -23.69 -23.68 -23.67 -23.66 -23.65 -23.64 -23.63 [639] -23.62 -23.61 -23.60 -23.59 -23.58 -23.57 -23.56 -23.55 -23.54 -23.53 -23.52 [650] -23.51 -23.50 -23.49 -23.48 -23.47 -23.46 -23.45 -23.44 -23.43 -23.42 -23.41 [661] -23.40 -23.39 -23.38 -23.37 -23.36 -23.35 -23.34 -23.33 -23.32 -23.31 -23.30 [672] -23.29 -23.28 -23.27 -23.26 -23.25 -23.24 -23.23 -23.22 -23.21 -23.20 -23.19 [683] -23.18 -23.17 -23.16 -23.15 -23.14 -23.13 -23.12 -23.11 -23.10 -23.09 -23.08 [694] -23.07 -23.06 -23.05 -23.04 -23.03 -23.02 -23.01 -23.00 -22.99 -22.98 -22.97 [705] -22.96 -22.95 -22.94 -22.93 -22.92 -22.91 -22.90 -22.89 -22.88 -22.87 -22.86 [716] -22.85 -22.84 -22.83 -22.82 -22.81 -22.80 -22.79 -22.78 -22.77 -22.76 -22.75 [727] -22.74 -22.73 -22.72 -22.71 -22.70 -22.69 -22.68 -22.67 -22.66 -22.65 -22.64 [738] -22.63 -22.62 -22.61 -22.60 -22.59 -22.58 -22.57 -22.56 -22.55 -22.54 -22.53 [749] -22.52 -22.51 -22.50 -22.49 -22.48 -22.47 -22.46 -22.45 -22.44 -22.43 -22.42 [760] -22.41 -22.40 -22.39 -22.38 -22.37 -22.36 -22.35 -22.34 -22.33 -22.32 -22.31 [771] -22.30 -22.29 -22.28 -22.27 -22.26 -22.25 -22.24 -22.23 -22.22 -22.21 -22.20 [782] -22.19 -22.18 -22.17 -22.16 -22.15 -22.14 -22.13 -22.12 -22.11 -22.10 -22.09 [793] -22.08 -22.07 -22.06 -22.05 -22.04 -22.03 -22.02 -22.01 -22.00 -21.99 -21.98 [804] -21.97 -21.96 -21.95 -21.94 -21.93 -21.92 -21.91 -21.90 -21.89 -21.88 -21.87 [815] -21.86 -21.85 -21.84 -21.83 -21.82 -21.81 -21.80 -21.79 -21.78 -21.77 -21.76 [826] -21.75 -21.74 -21.73 -21.72 -21.71 -21.70 -21.69 -21.68 -21.67 -21.66 -21.65 [837] -21.64 -21.63 -21.62 -21.61 -21.60 -21.59 -21.58 -21.57 -21.56 -21.55 -21.54 [848] -21.53 -21.52 -21.51 -21.50 -21.49 -21.48 -21.47 -21.46 -21.45 -21.44 -21.43 [859] -21.42 -21.41 -21.40 -21.39 -21.38 -21.37 -21.36 -21.35 -21.34 -21.33 -21.32 [870] -21.31 -21.30 -21.29 -21.28 -21.27 -21.26 -21.25 -21.24 -21.23 -21.22 -21.21 [881] -21.20 -21.19 -21.18 -21.17 -21.16 -21.15 -21.14 -21.13 -21.12 -21.11 -21.10 [892] -21.09 -21.08 -21.07 -21.06 -21.05 -21.04 -21.03 -21.02 -21.01 -21.00 -20.99 [903] -20.98 -20.97 -20.96 -20.95 -20.94 -20.93 -20.92 -20.91 -20.90 -20.89 -20.88 [914] -20.87 -20.86 -20.85 -20.84 -20.83 -20.82 -20.81 -20.80 -20.79 -20.78 -20.77 [925] -20.76 -20.75 -20.74 -20.73 -20.72 -20.71 -20.70 -20.69 -20.68 -20.67 -20.66 [936] -20.65 -20.64 -20.63 -20.62 -20.61 -20.60 -20.59 -20.58 -20.57 -20.56 -20.55 [947] -20.54 -20.53 -20.52 -20.51 -20.50 -20.49 -20.48 -20.47 -20.46 -20.45 -20.44 [958] -20.43 -20.42 -20.41 -20.40 -20.39 -20.38 -20.37 -20.36 -20.35 -20.34 -20.33 [969] -20.32 -20.31 -20.30 -20.29 -20.28 -20.27 -20.26 -20.25 -20.24 -20.23 -20.22 [980] -20.21 -20.20 -20.19 -20.18 -20.17 -20.16 -20.15 -20.14 -20.13 -20.12 -20.11 [991] -20.10 -20.09 -20.08 -20.07 -20.06 -20.05 -20.04 -20.03 -20.02 -20.01 [ reached getOption("max.print") -- omitted 5001 entries ] > > min(mses) [1] 8.111863 > min.pos.mses <- which(mses == min(mses)) > print(as[min.pos.mses]) [1] -2.29 > summary(mo) Call: lm(formula = y ~ x) Residuals: 1 2 3 4 5 -1.7472 1.7379 1.1598 -0.7010 -0.4496 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -2.2923 4.6038 -0.498 0.6528 x 3.4510 0.8899 3.878 0.0304 * --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 1.644 on 3 degrees of freedom Multiple R-squared: 0.8337, Adjusted R-squared: 0.7783 F-statistic: 15.04 on 1 and 3 DF, p-value: 0.03036 > plot(seq(1, length(mses)), mses) > >
위 방법은 dumb . . . . .
우선 a의 범위가 어디가 될지 몰라서 -10 에서 10까지로 한것
이 두 지점 사이를 0.01 단위로 증가시킨 값을 a값이라고 할 때의 mse값을 구하여 저장한 것
이렇게 구한 mse 값들 중 최소값일 때의 a값을 regression의 a값으로 추정한 것.
이렇게 말고 구할 수 있는 방법은 없을까?
gradient descent
Gradient descend
위에서 a값이 무엇일 때 mse값이 최소가 될까를 봐서 이 때의 a값을 실제 $y = a + bx$ 의 a 값으로 삼았다. 이 때 a값의 추정을 위해서
seq(-10, 10, 0.01)
의 범위와 증가값을 가지고 일일이 대입하여 mse를 구아였다.
위에서의 0.01씩 증가시켜 대입하는 것이 아니라 처음 한 숫자에서 시작한 후 그 다음 숫자를 정한 후에 점진적으로 그 숫자 간격을 줄여가면서 보면 정율적으로 0.01씩 증가시키는 것 보다 효율적일 것 같다. 이 증가분을 구하기 위해서 미분을 사용한다.
점차하강 = 조금씩 깍아서 원하는 기울기 (미분값) 찾기
prerequisite:
표준편차 추론에서 평균을 사용하는 이유: 실험적_수학적_이해
deriviation of a and b in a simple regression
위의 문서는 a, b에 대한 값을 미분법을 이용해서 직접 구하였다. 컴퓨터로는 이렇게 하기가 쉽지 않다. 그렇다면 이 값을 반복계산을 이용해서 추출하는 방법은 없을까? gradient descent
\begin{eqnarray*}
\text{for a (constant)} \\
\\
\text{SSE} & = & \text{Mean Square Residuals} \\
\text{Residuals} & = & (Y_i - (a + bX_i)) \\
\\
\frac{\text{dSSE}}{\text{da}}
& = & \frac{\text{dResiduals^2}}{\text{dResiduals}} * \frac{\text{dResiduals}}{\text{da}} \\
& = & 2 * \text{Residuals} * \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) \\
& \because & \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) = -1 \\
& = & 2 * \sum{(Y_i - (a + bX_i))} * -1 \\
& = & -2 * \text{Residuals} \\
\end{eqnarray*}
아래 R code에서 gradient function을 참조.
gradient <- function(x, y, predictions){ error = y - predictions db = -2 * mean(x * error) da = -2 * mean(error) return(list("b" = db, "a" = da)) }
\begin{eqnarray*}
\text{for b, (coefficient)} \\
\\
\dfrac{\text{d}}{\text{dv}} \frac{\sum{(Y_i - (a + bX_i))^2}}{N} & = & \sum \dfrac{\text{d}}{\text{dv}} \frac{{(Y_i - (a + bX_i))^2}} {N} \\
& = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-X_i) \;\;\;\; \\
& \because & \dfrac{\text{d}}{\text{dv for b}} (Y_i - (a+bX_i)) = -X_i \\
& = & -2 X_i \frac{\sum{(Y_i - (a + bX_i))}}{N} \\
& = & -2 * X_i * \text{mean of residuals} \\
\\
\end{eqnarray*}
(미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/N)*X값을 곱한 값이다.
위 펑션으로 얻은 da와 db값을 초기에 설정한 a, b 값에 더해 준 값을 다시 a, b값으로 하여 gradient 펑션을 통해서 다시 db, da값을 구하고 이를 다시 이전 단계에서 구한 a, b값에 더하여 그 값을 다시 a, b값을 하여 . . . .
위를 반복한다. 단, db값과 da값을 그냥 대입하기 보다는 초기에 설정한 learning.rate값을 (0.01 예를 들면) 곱하여 구한 값을 더하게 된다. 이것이 아래의 code이다.
gradient <- function(x, y, predictions){ error = y - predictions db = -2 * mean(x * error) da = -2 * mean(error) return(list("b" = db, "a" = da)) } mseloss <- function(predictions, y) { residuals <- (y - predictions) return(mean(residuals^2)) } # Train the model with scaled features learning_rate = 1e-2 lr = 1e-1 # Record Loss for each epoch: logs = list() as = c() bs = c() mse = c() sse = c() x.ori <- x zx <- (x-mean(x))/sd(x) nlen <- 50 for (epoch in 1:nlen) { predictions <- predict(zx, a, b) loss <- mseloss(predictions, y) mse <- append(mse, loss) grad <- gradient(zx, y, predictions) step.b <- grad$b * lr step.a <- grad$a * lr b <- b-step.b a <- a-step.a as <- append(as, a) bs <- append(bs, b) }
R code
# d statquest explanation # x <- c(0.5, 2.3, 2.9) # y <- c(1.4, 1.9, 3.2) rm(list=ls()) # set.seed(191) n <- 300 x <- rnorm(n, 5, 1.2) y <- 2.14 * x + rnorm(n, 0, 4) # data <- data.frame(x, y) data <- tibble(x = x, y = y) mo <- lm(y~x) summary(mo) # set.seed(191) # Initialize random betas b1 = rnorm(1) b0 = rnorm(1) b1.init <- b1 b0.init <- b0 # Predict function: predict <- function(x, b0, b1){ return (b0 + b1 * x) } # And loss function is: residuals <- function(predictions, y) { return(y - predictions) } loss_mse <- function(predictions, y){ residuals = y - predictions return(mean(residuals ^ 2)) } predictions <- predict(x, b0, b1) residuals <- residuals(predictions, y) loss = loss_mse(predictions, y) data <- tibble(data.frame(x, y, predictions, residuals)) print(paste0("Loss is: ", round(loss))) gradient <- function(x, y, predictions){ dinputs = y - predictions db1 = -2 * mean(x * dinputs) db0 = -2 * mean(dinputs) return(list("db1" = db1, "db0" = db0)) } gradients <- gradient(x, y, predictions) print(gradients) # Train the model with scaled features x_scaled <- (x - mean(x)) / sd(x) learning_rate = 1e-1 # Record Loss for each epoch: # logs = list() # bs=list() b0s = c() b1s = c() mse = c() nlen <- 80 for (epoch in 1:nlen){ # Predict all y values: predictions = predict(x_scaled, b0, b1) loss = loss_mse(predictions, y) mse = append(mse, loss) # logs = append(logs, loss) if (epoch %% 10 == 0){ print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5))) } gradients <- gradient(x_scaled, y, predictions) db1 <- gradients$db1 db0 <- gradients$db0 b1 <- b1 - db1 * learning_rate b0 <- b0 - db0 * learning_rate b0s <- append(b0s, b0) b1s <- append(b1s, b1) } # unscale coefficients to make them comprehensible b0 = b0 - (mean(x) / sd(x)) * b1 b1 = b1 / sd(x) # changes of estimators b0s <- b0s - (mean(x) /sd(x)) * b1s b1s <- b1s / sd(x) parameters <- tibble(data.frame(b0s, b1s, mse)) cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n")) summary(lm(y~x))$coefficients ggplot(data, aes(x = x, y = y)) + geom_point(size = 2) + geom_abline(aes(intercept = b0s, slope = b1s), data = parameters, linewidth = 0.5, color = 'green') + theme_classic() + geom_abline(aes(intercept = b0s, slope = b1s), data = parameters %>% slice_head(), linewidth = 1, color = 'blue') + geom_abline(aes(intercept = b0s, slope = b1s), data = parameters %>% slice_tail(), linewidth = 1, color = 'red') + labs(title = 'Gradient descent. blue: start, red: end, green: gradients') b0.init b1.init data parameters
R output
> rm(list=ls()) > # set.seed(191) > n <- 300 > x <- rnorm(n, 5, 1.2) > y <- 2.14 * x + rnorm(n, 0, 4) > > # data <- data.frame(x, y) > data <- tibble(x = x, y = y) > > mo <- lm(y~x) > summary(mo) Call: lm(formula = y ~ x) Residuals: Min 1Q Median 3Q Max -9.754 -2.729 -0.135 2.415 10.750 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -0.7794 0.9258 -0.842 0.401 x 2.2692 0.1793 12.658 <2e-16 *** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 3.951 on 298 degrees of freedom Multiple R-squared: 0.3497, Adjusted R-squared: 0.3475 F-statistic: 160.2 on 1 and 298 DF, p-value: < 2.2e-16 > > # set.seed(191) > # Initialize random betas > b1 = rnorm(1) > b0 = rnorm(1) > > b1.init <- b1 > b0.init <- b0 > > # Predict function: > predict <- function(x, b0, b1){ + return (b0 + b1 * x) + } > > # And loss function is: > residuals <- function(predictions, y) { + return(y - predictions) + } > > loss_mse <- function(predictions, y){ + residuals = y - predictions + return(mean(residuals ^ 2)) + } > > predictions <- predict(x, b0, b1) > residuals <- residuals(predictions, y) > loss = loss_mse(predictions, y) > > data <- tibble(data.frame(x, y, predictions, residuals)) > > print(paste0("Loss is: ", round(loss))) [1] "Loss is: 393" > > gradient <- function(x, y, predictions){ + dinputs = y - predictions + db1 = -2 * mean(x * dinputs) + db0 = -2 * mean(dinputs) + + return(list("db1" = db1, "db0" = db0)) + } > > gradients <- gradient(x, y, predictions) > print(gradients) $db1 [1] -200.6834 $db0 [1] -37.76994 > > # Train the model with scaled features > x_scaled <- (x - mean(x)) / sd(x) > > learning_rate = 1e-1 > > # Record Loss for each epoch: > # logs = list() > # bs=list() > b0s = c() > b1s = c() > mse = c() > > nlen <- 80 > for (epoch in 1:nlen){ + # Predict all y values: + predictions = predict(x_scaled, b0, b1) + loss = loss_mse(predictions, y) + mse = append(mse, loss) + # logs = append(logs, loss) + + if (epoch %% 10 == 0){ + print(paste0("Epoch: ",epoch, ", Loss: ", round(loss, 5))) + } + + gradients <- gradient(x_scaled, y, predictions) + db1 <- gradients$db1 + db0 <- gradients$db0 + + b1 <- b1 - db1 * learning_rate + b0 <- b0 - db0 * learning_rate + b0s <- append(b0s, b0) + b1s <- append(b1s, b1) + } [1] "Epoch: 10, Loss: 18.5393" [1] "Epoch: 20, Loss: 15.54339" [1] "Epoch: 30, Loss: 15.50879" [1] "Epoch: 40, Loss: 15.50839" [1] "Epoch: 50, Loss: 15.50839" [1] "Epoch: 60, Loss: 15.50839" [1] "Epoch: 70, Loss: 15.50839" [1] "Epoch: 80, Loss: 15.50839" > > # unscale coefficients to make them comprehensible > b0 = b0 - (mean(x) / sd(x)) * b1 > b1 = b1 / sd(x) > > # changes of estimators > b0s <- b0s - (mean(x) /sd(x)) * b1s > b1s <- b1s / sd(x) > > parameters <- tibble(data.frame(b0s, b1s, mse)) > > cat(paste0("Slope: ", b1, ", \n", "Intercept: ", b0, "\n")) Slope: 2.26922511738252, Intercept: -0.779435058320381 > summary(lm(y~x))$coefficients Estimate Std. Error t value Pr(>|t|) (Intercept) -0.7794352 0.9258064 -0.8418986 4.005198e-01 x 2.2692252 0.1792660 12.6584242 1.111614e-29 > > ggplot(data, aes(x = x, y = y)) + + geom_point(size = 2) + + geom_abline(aes(intercept = b0s, slope = b1s), + data = parameters, linewidth = 0.5, + color = 'green') + + theme_classic() + + geom_abline(aes(intercept = b0s, slope = b1s), + data = parameters %>% slice_head(), + linewidth = 1, color = 'blue') + + geom_abline(aes(intercept = b0s, slope = b1s), + data = parameters %>% slice_tail(), + linewidth = 1, color = 'red') + + labs(title = 'Gradient descent. blue: start, red: end, green: gradients') > > b0.init [1] -1.67967 > b1.init [1] -1.323992 > > data # A tibble: 300 × 4 x y predictions residuals <dbl> <dbl> <dbl> <dbl> 1 4.13 6.74 -7.14 13.9 2 7.25 14.0 -11.3 25.3 3 6.09 13.5 -9.74 23.3 4 6.29 15.1 -10.0 25.1 5 4.40 3.81 -7.51 11.3 6 6.03 13.9 -9.67 23.5 7 6.97 12.1 -10.9 23.0 8 4.84 12.8 -8.09 20.9 9 6.85 17.2 -10.7 28.0 10 3.33 3.80 -6.08 9.88 # ℹ 290 more rows # ℹ Use `print(n = ...)` to see more rows > parameters # A tibble: 80 × 3 b0s b1s mse <dbl> <dbl> <dbl> 1 2.67 -0.379 183. 2 1.99 0.149 123. 3 1.44 0.571 84.3 4 1.00 0.910 59.6 5 0.652 1.18 43.7 6 0.369 1.40 33.6 7 0.142 1.57 27.1 8 -0.0397 1.71 22.9 9 -0.186 1.82 20.2 10 -0.303 1.91 18.5 # ℹ 70 more rows #