gradient_descent
Differences
This shows you the differences between two versions of the page.
Both sides previous revisionPrevious revisionNext revision | Previous revision | ||
gradient_descent [2025/08/05 00:45] – [Gradient descend] hkimscil | gradient_descent [2025/08/21 16:24] (current) – [Gradient descend] hkimscil | ||
---|---|---|---|
Line 1: | Line 1: | ||
====== Gradient Descent ====== | ====== Gradient Descent ====== | ||
- | ====== explanation ====== | ||
- | |||
- | ====== Why normalize (scale or make z-score) xi ====== | ||
- | x 변인의 측정단위로 인해서 b 값이 결정되게 되는데 이 때의 b값은 상당하고 다양한 범위를 가질 수 있다. 가령 월 수입이 (인컴) X 라고 한다면 우리가 추정해야 (추적해야) 할 b값은 수백만이 될 수도 있다.이 값을 gradient로 추적하게 된다면 너무도 많은 iteration을 거쳐야 할 수 있다. 변인이 바뀌면 이 b의 추적범위도 드라마틱하게 바뀌게 된다. 이를 표준화한 x 점수를 사용하게 된다면 일정한 learning rate와 iteration만으로도 정확한 a와 b를 추적할 수 있게 된다. | ||
- | |||
- | ====== How to unnormalize (unscale) a and b ====== | ||
- | \begin{eqnarray*} | ||
- | y & = & a + b * x \\ | ||
- | & & \text{we use z instead of x} \\ | ||
- | & & \text{and } \\ | ||
- | & & z = \frac{(x - \mu)}{\sigma} \\ | ||
- | & & \text{suppose that the result after calculation be } \\ | ||
- | y & = & k + m * z \\ | ||
- | & = & k + m * \frac{(x - \mu)}{\sigma} \\ | ||
- | & = & k + \frac{m * x}{\sigma} - \frac{m * \mu}{\sigma} | ||
- | & = & k - \frac{m * \mu}{\sigma} + \frac{m * x}{\sigma} | ||
- | & = & k - \frac{\mu}{\sigma} * m + \frac{m}{\sigma} * x \\ | ||
- | & & \text{therefore, | ||
- | a & = & k - \frac{\mu}{\sigma} * m \\ | ||
- | b & = & \frac{m}{\sigma} \\ | ||
- | \end{eqnarray*} | ||
- | |||
- | |||
====== R code: Idea ====== | ====== R code: Idea ====== | ||
< | < | ||
- | library(tidyverse) | + | |
- | library(data.table) | + | library(ggplot2) |
+ | library(ggpmisc) | ||
rm(list=ls()) | rm(list=ls()) | ||
# set.seed(191) | # set.seed(191) | ||
- | n <- 5 | + | nx <- 200 |
- | x <- rnorm(n, 5, 1.2) | + | mx <- 4.5 |
- | y <- 3.14 * x + rnorm(n,0,1) | + | sdx <- mx * 0.56 |
+ | x <- rnorm(nx, mx, sdx) | ||
+ | slp <- 12 | ||
+ | y <- | ||
- | # data <- data.frame(x, | + | data <- data.frame(x, |
- | data <- tibble(x = x, y = y) | + | |
- | mo <- lm(y~x) | + | mo <- lm(y ~ x, data = data) |
summary(mo) | summary(mo) | ||
+ | ggplot(data = data, aes(x = x, y = y)) + | ||
+ | geom_point() + | ||
+ | stat_poly_line() + | ||
+ | stat_poly_eq(use_label(c(" | ||
+ | theme_classic() | ||
# set.seed(191) | # set.seed(191) | ||
# Initialize random betas | # Initialize random betas | ||
Line 61: | Line 46: | ||
# we use sum of square of error which oftentimes become big | # we use sum of square of error which oftentimes become big | ||
- | mseloss | + | ssrloss |
residuals <- (y - predictions) | residuals <- (y - predictions) | ||
return(sum(residuals^2)) | return(sum(residuals^2)) | ||
} | } | ||
- | mses <- c() | + | ssrs <- c() # for sum of square residuals |
- | j <- 0 | + | srs <- c() # sum of residuals |
- | as <- c() | + | as <- c() # for as (intercepts) |
- | for (i in seq(from = -30, to = 30, by = 0.01)) { | + | for (i in seq(from = -50, to = 50, by = 0.01)) { |
pred <- predict(x, i, b) | pred <- predict(x, i, b) | ||
res <- residuals(pred, | res <- residuals(pred, | ||
- | | + | |
- | | + | |
- | as <- append(as, | + | srs <- append(srs, sum(res)) |
+ | as <- append(as, i) | ||
} | } | ||
- | mses | + | length(ssrs) |
- | as | + | length(srs) |
+ | length(as) | ||
- | min(mses) | + | min(ssrs) |
- | min.pos.mses <- which(mses == min(mses)) | + | min.pos.ssrs <- which(ssrs == min(ssrs)) |
- | print(as[min.pos.mses]) | + | min.pos.ssrs |
+ | print(as[min.pos.ssrs]) | ||
summary(mo) | summary(mo) | ||
- | plot(seq(1, length(mses)), mses) | + | plot(seq(1, length(ssrs)), ssrs) |
+ | plot(seq(1, length(ssrs)), | ||
+ | tail(ssrs) | ||
+ | max(ssrs) | ||
+ | min(ssrs) | ||
+ | tail(srs) | ||
+ | max(srs) | ||
+ | min(srs) | ||
</ | </ | ||
===== output ===== | ===== output ===== | ||
- | < | + | <code> |
+ | > library(ggplot2) | ||
+ | > library(ggpmisc) | ||
+ | > | ||
> rm(list=ls()) | > rm(list=ls()) | ||
> # set.seed(191) | > # set.seed(191) | ||
- | > n <- 5 | + | > nx <- 200 |
- | > x <- rnorm(n, 5, 1.2) | + | > mx <- 4.5 |
- | > y <- 3.14 * x + rnorm(n,0,1) | + | > sdx <- mx * 0.56 |
+ | > x <- rnorm(nx, mx, sdx) | ||
+ | > slp <- 12 | ||
+ | > y <- | ||
> | > | ||
- | > # data <- data.frame(x, | + | > data <- data.frame(x, |
- | > data <- tibble(x = x, y = y) | + | |
> | > | ||
- | > mo <- lm(y~x) | + | > mo <- lm(y ~ x, data = data) |
> summary(mo) | > summary(mo) | ||
Call: | Call: | ||
- | lm(formula = y ~ x) | + | lm(formula = y ~ x, data = data) |
Residuals: | Residuals: | ||
- | | + | |
- | -1.7472 1.7379 1.1598 -0.7010 -0.4496 | + | -259.314 -59.215 6.683 58.834 309.833 |
Coefficients: | Coefficients: | ||
- | Estimate Std. Error t value Pr(> | + | Estimate Std. Error t value Pr(> |
- | (Intercept) | + | (Intercept) |
- | x 3.4510 0.8899 | + | x 11.888 2.433 |
--- | --- | ||
Signif. codes: | Signif. codes: | ||
- | Residual standard error: | + | Residual standard error: |
- | Multiple R-squared: | + | Multiple R-squared: |
- | F-statistic: | + | F-statistic: |
> | > | ||
+ | > ggplot(data = data, aes(x = x, y = y)) + | ||
+ | + | ||
+ | + | ||
+ | + | ||
+ | + | ||
> # set.seed(191) | > # set.seed(191) | ||
> # Initialize random betas | > # Initialize random betas | ||
Line 141: | Line 146: | ||
> | > | ||
> # we use sum of square of error which oftentimes become big | > # we use sum of square of error which oftentimes become big | ||
- | > mseloss | + | > ssrloss |
+ | + | ||
+ | + | ||
+ } | + } | ||
> | > | ||
- | > mses <- c() | + | > ssrs <- c() # for sum of square residuals |
- | > j <- 0 | + | > srs <- c() # sum of residuals |
- | > as <- c() | + | > as <- c() # for as (intercepts) |
> | > | ||
- | > for (i in seq(from = -30, to = 30, by = 0.01)) { | + | > for (i in seq(from = -50, to = 50, by = 0.01)) { |
+ pred <- predict(x, i, b) | + pred <- predict(x, i, b) | ||
+ res <- residuals(pred, | + res <- residuals(pred, | ||
- | + mse <- mseloss(pred, y) | + | + ssr <- ssrloss(pred, y) |
- | + mses <- append(mses, mse) | + | + ssrs <- append(ssrs, ssr) |
- | + as <- append(as, | + | + srs <- append(srs, sum(res)) |
+ | + as <- append(as, i) | ||
+ } | + } | ||
- | > mses | + | > length(ssrs) |
- | | + | [1] 10001 |
- | [9] 3824.561 3821.799 3819.037 3816.277 3813.518 3810.760 3808.002 3805.246 | + | > length(srs) |
- | [17] 3802.491 3799.737 3796.983 3794.231 3791.480 3788.729 3785.980 3783.232 | + | [1] 10001 |
- | [25] 3780.485 3777.738 3774.993 3772.249 3769.506 3766.763 3764.022 3761.282 | + | > length(as) |
- | [33] 3758.542 3755.804 3753.067 3750.331 3747.595 3744.861 3742.128 3739.396 | + | [1] 10001 |
- | [41] 3736.664 3733.934 3731.205 3728.477 3725.749 3723.023 3720.298 3717.573 | + | |
- | [49] 3714.850 3712.128 3709.407 3706.686 3703.967 3701.249 3698.532 3695.815 | + | |
- | [57] 3693.100 3690.386 3687.672 3684.960 3682.249 3679.539 3676.829 3674.121 | + | |
- | [65] 3671.414 3668.708 3666.002 3663.298 3660.595 3657.892 3655.191 3652.491 | + | |
- | [73] 3649.792 3647.093 3644.396 3641.700 3639.005 3636.310 3633.617 3630.925 | + | |
- | [81] 3628.234 3625.543 3622.854 3620.166 3617.478 3614.792 3612.107 3609.423 | + | |
- | [89] 3606.739 3604.057 3601.376 3598.696 3596.016 3593.338 3590.661 3587.984 | + | |
- | [97] 3585.309 3582.635 3579.962 3577.289 3574.618 3571.948 3569.279 3566.610 | + | |
- | [105] 3563.943 3561.277 3558.611 3555.947 3553.284 3550.622 3547.960 3545.300 | + | |
- | [113] 3542.641 3539.983 3537.325 3534.669 3532.014 3529.360 3526.706 3524.054 | + | |
- | [121] 3521.403 3518.752 3516.103 3513.455 3510.808 3508.161 3505.516 3502.872 | + | |
- | [129] 3500.229 3497.586 3494.945 3492.305 3489.665 3487.027 3484.390 3481.754 | + | |
- | [137] 3479.118 3476.484 3473.851 3471.219 3468.587 3465.957 3463.328 3460.700 | + | |
- | [145] 3458.072 3455.446 3452.821 3450.196 3447.573 3444.951 3442.330 3439.709 | + | |
- | [153] 3437.090 3434.472 3431.855 3429.238 3426.623 3424.009 3421.395 3418.783 | + | |
- | [161] 3416.172 3413.562 3410.952 3408.344 3405.737 3403.131 3400.525 3397.921 | + | |
- | [169] 3395.318 3392.715 3390.114 3387.514 3384.915 3382.316 3379.719 3377.123 | + | |
- | [177] 3374.528 3371.933 3369.340 3366.748 3364.157 3361.566 3358.977 3356.389 | + | |
- | [185] 3353.801 3351.215 3348.630 3346.046 3343.462 3340.880 3338.299 3335.719 | + | |
- | [193] 3333.139 3330.561 3327.984 3325.407 3322.832 3320.258 3317.685 3315.112 | + | |
- | [201] 3312.541 3309.971 3307.402 3304.833 3302.266 3299.700 3297.134 3294.570 | + | |
- | [209] 3292.007 3289.445 3286.883 3284.323 3281.764 3279.206 3276.648 3274.092 | + | |
- | [217] 3271.537 3268.983 3266.429 3263.877 3261.326 3258.775 3256.226 3253.678 | + | |
- | [225] 3251.131 3248.584 3246.039 3243.495 3240.952 3238.409 3235.868 3233.328 | + | |
- | [233] 3230.788 3228.250 3225.713 3223.177 3220.641 3218.107 3215.574 3213.042 | + | |
- | [241] 3210.510 3207.980 3205.451 3202.923 3200.395 3197.869 3195.344 3192.819 | + | |
- | [249] 3190.296 3187.774 3185.253 3182.732 3180.213 3177.695 3175.178 3172.661 | + | |
- | [257] 3170.146 3167.632 3165.118 3162.606 3160.095 3157.585 3155.075 3152.567 | + | |
- | [265] 3150.060 3147.554 3145.048 3142.544 3140.041 3137.538 3135.037 3132.537 | + | |
- | [273] 3130.038 3127.539 3125.042 3122.546 3120.051 3117.556 3115.063 3112.571 | + | |
- | [281] 3110.080 3107.589 3105.100 3102.612 3100.124 3097.638 3095.153 3092.669 | + | |
- | [289] 3090.185 3087.703 3085.222 3082.742 3080.262 3077.784 3075.307 3072.830 | + | |
- | [297] 3070.355 3067.881 3065.408 3062.935 3060.464 3057.994 3055.525 3053.056 | + | |
- | [305] 3050.589 3048.123 3045.657 3043.193 3040.730 3038.268 3035.806 3033.346 | + | |
- | [313] 3030.887 3028.429 3025.971 3023.515 3021.060 3018.606 3016.152 3013.700 | + | |
- | [321] 3011.249 3008.798 3006.349 3003.901 3001.454 2999.007 2996.562 2994.118 | + | |
- | [329] 2991.675 2989.232 2986.791 2984.351 2981.911 2979.473 2977.036 2974.600 | + | |
- | [337] 2972.164 2969.730 2967.297 2964.865 2962.433 2960.003 2957.574 2955.145 | + | |
- | [345] 2952.718 2950.292 2947.867 2945.442 2943.019 2940.597 2938.176 2935.755 | + | |
- | [353] 2933.336 2930.918 2928.501 2926.084 2923.669 2921.255 2918.841 2916.429 | + | |
- | [361] 2914.018 2911.608 2909.198 2906.790 2904.383 2901.977 2899.571 2897.167 | + | |
- | [369] 2894.764 2892.361 2889.960 2887.560 2885.161 2882.762 2880.365 2877.969 | + | |
- | [377] 2875.574 2873.179 2870.786 2868.394 2866.003 2863.612 2861.223 2858.835 | + | |
- | [385] 2856.447 2854.061 2851.676 2849.292 2846.908 2844.526 2842.145 2839.765 | + | |
- | [393] 2837.385 2835.007 2832.630 2830.253 2827.878 2825.504 2823.131 2820.758 | + | |
- | [401] 2818.387 2816.017 2813.648 2811.279 2808.912 2806.546 2804.180 2801.816 | + | |
- | [409] 2799.453 2797.091 2794.729 2792.369 2790.010 2787.652 2785.294 2782.938 | + | |
- | [417] 2780.583 2778.229 2775.875 2773.523 2771.172 2768.821 2766.472 2764.124 | + | |
- | [425] 2761.777 2759.430 2757.085 2754.741 2752.398 2750.055 2747.714 2745.374 | + | |
- | [433] 2743.034 2740.696 2738.359 2736.023 2733.687 2731.353 2729.020 2726.688 | + | |
- | [441] 2724.356 2722.026 2719.697 2717.368 2715.041 2712.715 2710.390 2708.065 | + | |
- | [449] 2705.742 2703.420 2701.099 2698.778 2696.459 2694.141 2691.824 2689.507 | + | |
- | [457] 2687.192 2684.878 2682.564 2680.252 2677.941 2675.631 2673.321 2671.013 | + | |
- | [465] 2668.706 2666.400 2664.094 2661.790 2659.487 2657.184 2654.883 2652.583 | + | |
- | [473] 2650.284 2647.985 2645.688 2643.392 2641.097 2638.802 2636.509 2634.217 | + | |
- | [481] 2631.926 2629.635 2627.346 2625.058 2622.770 2620.484 2618.199 2615.915 | + | |
- | [489] 2613.631 2611.349 2609.068 2606.788 2604.508 2602.230 2599.953 2597.676 | + | |
- | [497] 2595.401 2593.127 2590.854 2588.581 2586.310 2584.040 2581.771 2579.502 | + | |
- | [505] 2577.235 2574.969 2572.703 2570.439 2568.176 2565.914 2563.652 2561.392 | + | |
- | [513] 2559.133 2556.875 2554.617 2552.361 2550.106 2547.852 2545.598 2543.346 | + | |
- | [521] 2541.095 2538.844 2536.595 2534.347 2532.100 2529.853 2527.608 2525.364 | + | |
- | [529] 2523.121 2520.878 2518.637 2516.397 2514.157 2511.919 2509.682 2507.446 | + | |
- | [537] 2505.210 2502.976 2500.743 2498.511 2496.279 2494.049 2491.820 2489.591 | + | |
- | [545] 2487.364 2485.138 2482.913 2480.688 2478.465 2476.243 2474.022 2471.801 | + | |
- | [553] 2469.582 2467.364 2465.147 2462.930 2460.715 2458.501 2456.287 2454.075 | + | |
- | [561] 2451.864 2449.654 2447.444 2445.236 2443.029 2440.823 2438.617 2436.413 | + | |
- | [569] 2434.210 2432.007 2429.806 2427.606 2425.407 2423.208 2421.011 2418.815 | + | |
- | [577] 2416.620 2414.425 2412.232 2410.040 2407.849 2405.658 2403.469 2401.281 | + | |
- | [585] 2399.093 2396.907 2394.722 2392.538 2390.354 2388.172 2385.991 2383.811 | + | |
- | [593] 2381.631 2379.453 2377.276 2375.099 2372.924 2370.750 2368.577 2366.404 | + | |
- | [601] 2364.233 2362.063 2359.894 2357.725 2355.558 2353.392 2351.226 2349.062 | + | |
- | [609] 2346.899 2344.737 2342.575 2340.415 2338.256 2336.098 2333.940 2331.784 | + | |
- | [617] 2329.629 2327.475 2325.321 2323.169 2321.018 2318.867 2316.718 2314.570 | + | |
- | [625] 2312.423 2310.276 2308.131 2305.987 2303.844 2301.701 2299.560 2297.420 | + | |
- | [633] 2295.280 2293.142 2291.005 2288.869 2286.733 2284.599 2282.466 2280.334 | + | |
- | [641] 2278.202 2276.072 2273.943 2271.814 2269.687 2267.561 2265.436 2263.311 | + | |
- | [649] 2261.188 2259.066 2256.945 2254.824 2252.705 2250.587 2248.470 2246.353 | + | |
- | [657] 2244.238 2242.124 2240.010 2237.898 2235.787 2233.677 2231.567 2229.459 | + | |
- | [665] 2227.352 2225.246 2223.140 2221.036 2218.933 2216.830 2214.729 2212.629 | + | |
- | [673] 2210.530 2208.431 2206.334 2204.238 2202.143 2200.048 2197.955 2195.863 | + | |
- | [681] 2193.771 2191.681 2189.592 2187.504 2185.416 2183.330 2181.245 2179.161 | + | |
- | [689] 2177.077 2174.995 2172.914 2170.834 2168.754 2166.676 2164.599 2162.522 | + | |
- | [697] 2160.447 2158.373 2156.300 2154.227 2152.156 2150.086 2148.017 2145.948 | + | |
- | [705] 2143.881 2141.815 2139.749 2137.685 2135.622 2133.560 2131.498 2129.438 | + | |
- | [713] 2127.379 2125.321 2123.263 2121.207 2119.152 2117.098 2115.044 2112.992 | + | |
- | [721] 2110.941 2108.890 2106.841 2104.793 2102.746 2100.699 2098.654 2096.610 | + | |
- | [729] 2094.567 2092.524 2090.483 2088.443 2086.403 2084.365 2082.328 2080.292 | + | |
- | [737] 2078.256 2076.222 2074.189 2072.157 2070.125 2068.095 2066.066 2064.037 | + | |
- | [745] 2062.010 2059.984 2057.959 2055.934 2053.911 2051.889 2049.868 2047.847 | + | |
- | [753] 2045.828 2043.810 2041.793 2039.776 2037.761 2035.747 2033.733 2031.721 | + | |
- | [761] 2029.710 2027.700 2025.690 2023.682 2021.675 2019.669 2017.663 2015.659 | + | |
- | [769] 2013.656 2011.653 2009.652 2007.652 2005.653 2003.654 2001.657 1999.661 | + | |
- | [777] 1997.666 1995.671 1993.678 1991.686 1989.694 1987.704 1985.715 1983.727 | + | |
- | [785] 1981.739 1979.753 1977.768 1975.784 1973.800 1971.818 1969.837 1967.857 | + | |
- | [793] 1965.877 1963.899 1961.922 1959.945 1957.970 1955.996 1954.023 1952.050 | + | |
- | [801] 1950.079 1948.109 1946.140 1944.171 1942.204 1940.238 1938.272 1936.308 | + | |
- | [809] 1934.345 1932.383 1930.421 1928.461 1926.502 1924.544 1922.586 1920.630 | + | |
- | [817] 1918.675 1916.721 1914.767 1912.815 1910.864 1908.913 1906.964 1905.016 | + | |
- | [825] 1903.069 1901.122 1899.177 1897.233 1895.290 1893.347 1891.406 1889.466 | + | |
- | [833] 1887.526 1885.588 1883.651 1881.715 1879.779 1877.845 1875.912 1873.980 | + | |
- | [841] 1872.048 1870.118 1868.189 1866.260 1864.333 1862.407 1860.482 1858.557 | + | |
- | [849] 1856.634 1854.712 1852.791 1850.870 1848.951 1847.033 1845.116 1843.199 | + | |
- | [857] 1841.284 1839.370 1837.456 1835.544 1833.633 1831.723 1829.813 1827.905 | + | |
- | [865] 1825.998 1824.092 1822.186 1820.282 1818.379 1816.476 1814.575 1812.675 | + | |
- | [873] 1810.776 1808.877 1806.980 1805.084 1803.189 1801.294 1799.401 1797.509 | + | |
- | [881] 1795.617 1793.727 1791.838 1789.950 1788.062 1786.176 1784.291 1782.407 | + | |
- | [889] 1780.523 1778.641 1776.760 1774.880 1773.000 1771.122 1769.245 1767.368 | + | |
- | [897] 1765.493 1763.619 1761.746 1759.873 1758.002 1756.132 1754.263 1752.394 | + | |
- | [905] 1750.527 1748.661 1746.795 1744.931 1743.068 1741.206 1739.344 1737.484 | + | |
- | [913] 1735.625 1733.767 1731.909 1730.053 1728.198 1726.344 1724.490 1722.638 | + | |
- | [921] 1720.787 1718.936 1717.087 1715.239 1713.392 1711.545 1709.700 1707.856 | + | |
- | [929] 1706.013 1704.170 1702.329 1700.489 1698.649 1696.811 1694.974 1693.138 | + | |
- | [937] 1691.302 1689.468 1687.635 1685.803 1683.971 1682.141 1680.312 1678.483 | + | |
- | [945] 1676.656 1674.830 1673.005 1671.180 1669.357 1667.535 1665.714 1663.893 | + | |
- | [953] 1662.074 1660.256 1658.439 1656.622 1654.807 1652.993 1651.179 1649.367 | + | |
- | [961] 1647.556 1645.746 1643.936 1642.128 1640.321 1638.515 1636.709 1634.905 | + | |
- | [969] 1633.102 1631.299 1629.498 1627.698 1625.899 1624.100 1622.303 1620.507 | + | |
- | [977] 1618.712 1616.917 1615.124 1613.332 1611.540 1609.750 1607.961 1606.173 | + | |
- | [985] 1604.385 1602.599 1600.814 1599.030 1597.246 1595.464 1593.683 1591.903 | + | |
- | [993] 1590.123 1588.345 1586.568 1584.791 1583.016 1581.242 1579.469 1577.696 | + | |
- | [ reached getOption(" | + | |
- | > as | + | |
- | [1] -30.00 -29.99 -29.98 -29.97 -29.96 -29.95 -29.94 -29.93 -29.92 -29.91 -29.90 | + | |
- | [12] -29.89 -29.88 -29.87 -29.86 -29.85 -29.84 -29.83 -29.82 -29.81 -29.80 -29.79 | + | |
- | [23] -29.78 -29.77 -29.76 -29.75 -29.74 -29.73 -29.72 -29.71 -29.70 -29.69 -29.68 | + | |
- | [34] -29.67 -29.66 -29.65 -29.64 -29.63 -29.62 -29.61 -29.60 -29.59 -29.58 -29.57 | + | |
- | [45] -29.56 -29.55 -29.54 -29.53 -29.52 -29.51 -29.50 -29.49 -29.48 -29.47 -29.46 | + | |
- | [56] -29.45 -29.44 -29.43 -29.42 -29.41 -29.40 -29.39 -29.38 -29.37 -29.36 -29.35 | + | |
- | [67] -29.34 -29.33 -29.32 -29.31 -29.30 -29.29 -29.28 -29.27 -29.26 -29.25 -29.24 | + | |
- | [78] -29.23 -29.22 -29.21 -29.20 -29.19 -29.18 -29.17 -29.16 -29.15 -29.14 -29.13 | + | |
- | [89] -29.12 -29.11 -29.10 -29.09 -29.08 -29.07 -29.06 -29.05 -29.04 -29.03 -29.02 | + | |
- | [100] -29.01 -29.00 -28.99 -28.98 -28.97 -28.96 -28.95 -28.94 -28.93 -28.92 -28.91 | + | |
- | [111] -28.90 -28.89 -28.88 -28.87 -28.86 -28.85 -28.84 -28.83 -28.82 -28.81 -28.80 | + | |
- | [122] -28.79 -28.78 -28.77 -28.76 -28.75 -28.74 -28.73 -28.72 -28.71 -28.70 -28.69 | + | |
- | [133] -28.68 -28.67 -28.66 -28.65 -28.64 -28.63 -28.62 -28.61 -28.60 -28.59 -28.58 | + | |
- | [144] -28.57 -28.56 -28.55 -28.54 -28.53 -28.52 -28.51 -28.50 -28.49 -28.48 -28.47 | + | |
- | [155] -28.46 -28.45 -28.44 -28.43 -28.42 -28.41 -28.40 -28.39 -28.38 -28.37 -28.36 | + | |
- | [166] -28.35 -28.34 -28.33 -28.32 -28.31 -28.30 -28.29 -28.28 -28.27 -28.26 -28.25 | + | |
- | [177] -28.24 -28.23 -28.22 -28.21 -28.20 -28.19 -28.18 -28.17 -28.16 -28.15 -28.14 | + | |
- | [188] -28.13 -28.12 -28.11 -28.10 -28.09 -28.08 -28.07 -28.06 -28.05 -28.04 -28.03 | + | |
- | [199] -28.02 -28.01 -28.00 -27.99 -27.98 -27.97 -27.96 -27.95 -27.94 -27.93 -27.92 | + | |
- | [210] -27.91 -27.90 -27.89 -27.88 -27.87 -27.86 -27.85 -27.84 -27.83 -27.82 -27.81 | + | |
- | [221] -27.80 -27.79 -27.78 -27.77 -27.76 -27.75 -27.74 -27.73 -27.72 -27.71 -27.70 | + | |
- | [232] -27.69 -27.68 -27.67 -27.66 -27.65 -27.64 -27.63 -27.62 -27.61 -27.60 -27.59 | + | |
- | [243] -27.58 -27.57 -27.56 -27.55 -27.54 -27.53 -27.52 -27.51 -27.50 -27.49 -27.48 | + | |
- | [254] -27.47 -27.46 -27.45 -27.44 -27.43 -27.42 -27.41 -27.40 -27.39 -27.38 -27.37 | + | |
- | [265] -27.36 -27.35 -27.34 -27.33 -27.32 -27.31 -27.30 -27.29 -27.28 -27.27 -27.26 | + | |
- | [276] -27.25 -27.24 -27.23 -27.22 -27.21 -27.20 -27.19 -27.18 -27.17 -27.16 -27.15 | + | |
- | [287] -27.14 -27.13 -27.12 -27.11 -27.10 -27.09 -27.08 -27.07 -27.06 -27.05 -27.04 | + | |
- | [298] -27.03 -27.02 -27.01 -27.00 -26.99 -26.98 -26.97 -26.96 -26.95 -26.94 -26.93 | + | |
- | [309] -26.92 -26.91 -26.90 -26.89 -26.88 -26.87 -26.86 -26.85 -26.84 -26.83 -26.82 | + | |
- | [320] -26.81 -26.80 -26.79 -26.78 -26.77 -26.76 -26.75 -26.74 -26.73 -26.72 -26.71 | + | |
- | [331] -26.70 -26.69 -26.68 -26.67 -26.66 -26.65 -26.64 -26.63 -26.62 -26.61 -26.60 | + | |
- | [342] -26.59 -26.58 -26.57 -26.56 -26.55 -26.54 -26.53 -26.52 -26.51 -26.50 -26.49 | + | |
- | [353] -26.48 -26.47 -26.46 -26.45 -26.44 -26.43 -26.42 -26.41 -26.40 -26.39 -26.38 | + | |
- | [364] -26.37 -26.36 -26.35 -26.34 -26.33 -26.32 -26.31 -26.30 -26.29 -26.28 -26.27 | + | |
- | [375] -26.26 -26.25 -26.24 -26.23 -26.22 -26.21 -26.20 -26.19 -26.18 -26.17 -26.16 | + | |
- | [386] -26.15 -26.14 -26.13 -26.12 -26.11 -26.10 -26.09 -26.08 -26.07 -26.06 -26.05 | + | |
- | [397] -26.04 -26.03 -26.02 -26.01 -26.00 -25.99 -25.98 -25.97 -25.96 -25.95 -25.94 | + | |
- | [408] -25.93 -25.92 -25.91 -25.90 -25.89 -25.88 -25.87 -25.86 -25.85 -25.84 -25.83 | + | |
- | [419] -25.82 -25.81 -25.80 -25.79 -25.78 -25.77 -25.76 -25.75 -25.74 -25.73 -25.72 | + | |
- | [430] -25.71 -25.70 -25.69 -25.68 -25.67 -25.66 -25.65 -25.64 -25.63 -25.62 -25.61 | + | |
- | [441] -25.60 -25.59 -25.58 -25.57 -25.56 -25.55 -25.54 -25.53 -25.52 -25.51 -25.50 | + | |
- | [452] -25.49 -25.48 -25.47 -25.46 -25.45 -25.44 -25.43 -25.42 -25.41 -25.40 -25.39 | + | |
- | [463] -25.38 -25.37 -25.36 -25.35 -25.34 -25.33 -25.32 -25.31 -25.30 -25.29 -25.28 | + | |
- | [474] -25.27 -25.26 -25.25 -25.24 -25.23 -25.22 -25.21 -25.20 -25.19 -25.18 -25.17 | + | |
- | [485] -25.16 -25.15 -25.14 -25.13 -25.12 -25.11 -25.10 -25.09 -25.08 -25.07 -25.06 | + | |
- | [496] -25.05 -25.04 -25.03 -25.02 -25.01 -25.00 -24.99 -24.98 -24.97 -24.96 -24.95 | + | |
- | [507] -24.94 -24.93 -24.92 -24.91 -24.90 -24.89 -24.88 -24.87 -24.86 -24.85 -24.84 | + | |
- | [518] -24.83 -24.82 -24.81 -24.80 -24.79 -24.78 -24.77 -24.76 -24.75 -24.74 -24.73 | + | |
- | [529] -24.72 -24.71 -24.70 -24.69 -24.68 -24.67 -24.66 -24.65 -24.64 -24.63 -24.62 | + | |
- | [540] -24.61 -24.60 -24.59 -24.58 -24.57 -24.56 -24.55 -24.54 -24.53 -24.52 -24.51 | + | |
- | [551] -24.50 -24.49 -24.48 -24.47 -24.46 -24.45 -24.44 -24.43 -24.42 -24.41 -24.40 | + | |
- | [562] -24.39 -24.38 -24.37 -24.36 -24.35 -24.34 -24.33 -24.32 -24.31 -24.30 -24.29 | + | |
- | [573] -24.28 -24.27 -24.26 -24.25 -24.24 -24.23 -24.22 -24.21 -24.20 -24.19 -24.18 | + | |
- | [584] -24.17 -24.16 -24.15 -24.14 -24.13 -24.12 -24.11 -24.10 -24.09 -24.08 -24.07 | + | |
- | [595] -24.06 -24.05 -24.04 -24.03 -24.02 -24.01 -24.00 -23.99 -23.98 -23.97 -23.96 | + | |
- | [606] -23.95 -23.94 -23.93 -23.92 -23.91 -23.90 -23.89 -23.88 -23.87 -23.86 -23.85 | + | |
- | [617] -23.84 -23.83 -23.82 -23.81 -23.80 -23.79 -23.78 -23.77 -23.76 -23.75 -23.74 | + | |
- | [628] -23.73 -23.72 -23.71 -23.70 -23.69 -23.68 -23.67 -23.66 -23.65 -23.64 -23.63 | + | |
- | [639] -23.62 -23.61 -23.60 -23.59 -23.58 -23.57 -23.56 -23.55 -23.54 -23.53 -23.52 | + | |
- | [650] -23.51 -23.50 -23.49 -23.48 -23.47 -23.46 -23.45 -23.44 -23.43 -23.42 -23.41 | + | |
- | [661] -23.40 -23.39 -23.38 -23.37 -23.36 -23.35 -23.34 -23.33 -23.32 -23.31 -23.30 | + | |
- | [672] -23.29 -23.28 -23.27 -23.26 -23.25 -23.24 -23.23 -23.22 -23.21 -23.20 -23.19 | + | |
- | [683] -23.18 -23.17 -23.16 -23.15 -23.14 -23.13 -23.12 -23.11 -23.10 -23.09 -23.08 | + | |
- | [694] -23.07 -23.06 -23.05 -23.04 -23.03 -23.02 -23.01 -23.00 -22.99 -22.98 -22.97 | + | |
- | [705] -22.96 -22.95 -22.94 -22.93 -22.92 -22.91 -22.90 -22.89 -22.88 -22.87 -22.86 | + | |
- | [716] -22.85 -22.84 -22.83 -22.82 -22.81 -22.80 -22.79 -22.78 -22.77 -22.76 -22.75 | + | |
- | [727] -22.74 -22.73 -22.72 -22.71 -22.70 -22.69 -22.68 -22.67 -22.66 -22.65 -22.64 | + | |
- | [738] -22.63 -22.62 -22.61 -22.60 -22.59 -22.58 -22.57 -22.56 -22.55 -22.54 -22.53 | + | |
- | [749] -22.52 -22.51 -22.50 -22.49 -22.48 -22.47 -22.46 -22.45 -22.44 -22.43 -22.42 | + | |
- | [760] -22.41 -22.40 -22.39 -22.38 -22.37 -22.36 -22.35 -22.34 -22.33 -22.32 -22.31 | + | |
- | [771] -22.30 -22.29 -22.28 -22.27 -22.26 -22.25 -22.24 -22.23 -22.22 -22.21 -22.20 | + | |
- | [782] -22.19 -22.18 -22.17 -22.16 -22.15 -22.14 -22.13 -22.12 -22.11 -22.10 -22.09 | + | |
- | [793] -22.08 -22.07 -22.06 -22.05 -22.04 -22.03 -22.02 -22.01 -22.00 -21.99 -21.98 | + | |
- | [804] -21.97 -21.96 -21.95 -21.94 -21.93 -21.92 -21.91 -21.90 -21.89 -21.88 -21.87 | + | |
- | [815] -21.86 -21.85 -21.84 -21.83 -21.82 -21.81 -21.80 -21.79 -21.78 -21.77 -21.76 | + | |
- | [826] -21.75 -21.74 -21.73 -21.72 -21.71 -21.70 -21.69 -21.68 -21.67 -21.66 -21.65 | + | |
- | [837] -21.64 -21.63 -21.62 -21.61 -21.60 -21.59 -21.58 -21.57 -21.56 -21.55 -21.54 | + | |
- | [848] -21.53 -21.52 -21.51 -21.50 -21.49 -21.48 -21.47 -21.46 -21.45 -21.44 -21.43 | + | |
- | [859] -21.42 -21.41 -21.40 -21.39 -21.38 -21.37 -21.36 -21.35 -21.34 -21.33 -21.32 | + | |
- | [870] -21.31 -21.30 -21.29 -21.28 -21.27 -21.26 -21.25 -21.24 -21.23 -21.22 -21.21 | + | |
- | [881] -21.20 -21.19 -21.18 -21.17 -21.16 -21.15 -21.14 -21.13 -21.12 -21.11 -21.10 | + | |
- | [892] -21.09 -21.08 -21.07 -21.06 -21.05 -21.04 -21.03 -21.02 -21.01 -21.00 -20.99 | + | |
- | [903] -20.98 -20.97 -20.96 -20.95 -20.94 -20.93 -20.92 -20.91 -20.90 -20.89 -20.88 | + | |
- | [914] -20.87 -20.86 -20.85 -20.84 -20.83 -20.82 -20.81 -20.80 -20.79 -20.78 -20.77 | + | |
- | [925] -20.76 -20.75 -20.74 -20.73 -20.72 -20.71 -20.70 -20.69 -20.68 -20.67 -20.66 | + | |
- | [936] -20.65 -20.64 -20.63 -20.62 -20.61 -20.60 -20.59 -20.58 -20.57 -20.56 -20.55 | + | |
- | [947] -20.54 -20.53 -20.52 -20.51 -20.50 -20.49 -20.48 -20.47 -20.46 -20.45 -20.44 | + | |
- | [958] -20.43 -20.42 -20.41 -20.40 -20.39 -20.38 -20.37 -20.36 -20.35 -20.34 -20.33 | + | |
- | [969] -20.32 -20.31 -20.30 -20.29 -20.28 -20.27 -20.26 -20.25 -20.24 -20.23 -20.22 | + | |
- | [980] -20.21 -20.20 -20.19 -20.18 -20.17 -20.16 -20.15 -20.14 -20.13 -20.12 -20.11 | + | |
- | [991] -20.10 -20.09 -20.08 -20.07 -20.06 -20.05 -20.04 -20.03 -20.02 -20.01 | + | |
- | [ reached getOption(" | + | |
> | > | ||
- | > min(mses) | + | > min(ssrs) |
- | [1] 8.111863 | + | [1] 1553336 |
- | > min.pos.mses <- which(mses == min(mses)) | + | > min.pos.ssrs <- which(ssrs == min(ssrs)) |
- | > print(as[min.pos.mses]) | + | > min.pos.ssrs |
- | [1] -2.29 | + | [1] 5828 |
+ | > print(as[min.pos.ssrs]) | ||
+ | [1] 8.27 | ||
> summary(mo) | > summary(mo) | ||
Call: | Call: | ||
- | lm(formula = y ~ x) | + | lm(formula = y ~ x, data = data) |
Residuals: | Residuals: | ||
- | | + | |
- | -1.7472 1.7379 1.1598 -0.7010 -0.4496 | + | -259.314 -59.215 6.683 58.834 309.833 |
Coefficients: | Coefficients: | ||
- | Estimate Std. Error t value Pr(> | + | Estimate Std. Error t value Pr(> |
- | (Intercept) | + | (Intercept) |
- | x 3.4510 0.8899 | + | x 11.888 2.433 |
--- | --- | ||
Signif. codes: | Signif. codes: | ||
- | Residual standard error: | + | Residual standard error: |
- | Multiple R-squared: | + | Multiple R-squared: |
- | F-statistic: | + | F-statistic: |
- | > plot(seq(1, length(mses)), mses) | + | > plot(seq(1, length(ssrs)), ssrs) |
+ | > plot(seq(1, length(ssrs)), | ||
+ | > tail(ssrs) | ||
+ | [1] 1900842 1901008 1901175 1901342 1901509 1901676 | ||
+ | > max(ssrs) | ||
+ | [1] 2232329 | ||
+ | > min(ssrs) | ||
+ | [1] 1553336 | ||
+ | > tail(srs) | ||
+ | [1] -8336.735 -8338.735 -8340.735 -8342.735 -8344.735 -8346.735 | ||
+ | > max(srs) | ||
+ | [1] 11653.26 | ||
+ | > min(srs) | ||
+ | [1] -8346.735 | ||
> | > | ||
> | > | ||
</ | </ | ||
- | {{:pasted:20250804-152826.png}} | + | {{:pasted:20250821-120357.png}} |
+ | {{: | ||
+ | {{: | ||
위 방법은 dumb . . . . . | 위 방법은 dumb . . . . . | ||
- | 우선 a의 범위가 어디가 될지 몰라서 -10 에서 | + | 우선 a의 범위가 어디가 될지 몰라서 -50 에서 |
- | 이 두 지점 사이를 0.01 단위로 증가시킨 값을 a값이라고 할 때의 | + | 이 두 지점 사이를 0.01 단위로 증가시킨 값을 a값이라고 할 때의 |
- | 이렇게 구한 | + | 이렇게 구한 |
- | 이렇게 | + | 이렇게 |
- | gradient descent | + | |
+ | ===== SSE 대신에 MSE를 쓰기 ===== | ||
+ | < | ||
+ | |||
+ | ##### | ||
+ | # with mean square error (mse) instead of sse | ||
+ | |||
+ | b <- summary(mo)$coefficients[2] | ||
+ | a <- 0 | ||
+ | |||
+ | # we use sum of square of error which oftentimes become big | ||
+ | msrloss <- function(predictions, | ||
+ | residuals <- (y - predictions) | ||
+ | return(mean(residuals^2)) | ||
+ | } | ||
+ | |||
+ | msrs <- c() # for sum of square residuals | ||
+ | srs <- c() # sum of residuals | ||
+ | as <- c() # for as (intercepts) | ||
+ | |||
+ | for (i in seq(from = -50, to = 50, by = 0.01)) { | ||
+ | pred <- predict(x, i, b) | ||
+ | res <- residuals(pred, | ||
+ | msr <- msrloss(pred, | ||
+ | msrs <- append(msrs, | ||
+ | srs <- append(srs, mean(res)) | ||
+ | as <- append(as, i) | ||
+ | } | ||
+ | length(msrs) | ||
+ | length(srs) | ||
+ | length(as) | ||
+ | |||
+ | min(msrs) | ||
+ | min.pos.msrs <- which(msrs == min(msrs)) | ||
+ | min.pos.msrs | ||
+ | print(as[min.pos.msrs]) | ||
+ | summary(mo) | ||
+ | plot(seq(1, length(msrs)), | ||
+ | plot(seq(1, length(srs)), | ||
+ | tail(msrs) | ||
+ | max(msrs) | ||
+ | min(msrs) | ||
+ | tail(srs) | ||
+ | max(srs) | ||
+ | min(srs) | ||
+ | </ | ||
+ | ===== output ===== | ||
+ | < | ||
+ | > ##### | ||
+ | > # with mean square error (mse) instead of sse | ||
+ | > | ||
+ | > b <- summary(mo)$coefficients[2] | ||
+ | > a <- 0 | ||
+ | > | ||
+ | > # we use sum of square of error which oftentimes become big | ||
+ | > msrloss <- function(predictions, | ||
+ | + | ||
+ | + | ||
+ | + } | ||
+ | > | ||
+ | > msrs <- c() # for sum of square residuals | ||
+ | > srs <- c() # sum of residuals | ||
+ | > as <- c() # for as (intercepts) | ||
+ | > | ||
+ | > for (i in seq(from = -50, to = 50, by = 0.01)) { | ||
+ | + pred <- predict(x, i, b) | ||
+ | + res <- residuals(pred, | ||
+ | + msr <- msrloss(pred, | ||
+ | + msrs <- append(msrs, | ||
+ | + srs <- append(srs, mean(res)) | ||
+ | + as <- append(as, i) | ||
+ | + } | ||
+ | > length(msrs) | ||
+ | [1] 10001 | ||
+ | > length(srs) | ||
+ | [1] 10001 | ||
+ | > length(as) | ||
+ | [1] 10001 | ||
+ | > | ||
+ | > min(msrs) | ||
+ | [1] 7766.679 | ||
+ | > min.pos.msrs <- which(msrs == min(msrs)) | ||
+ | > min.pos.msrs | ||
+ | [1] 5828 | ||
+ | > print(as[min.pos.msrs]) | ||
+ | [1] 8.27 | ||
+ | > summary(mo) | ||
+ | |||
+ | Call: | ||
+ | lm(formula = y ~ x, data = data) | ||
+ | |||
+ | Residuals: | ||
+ | | ||
+ | -259.314 | ||
+ | |||
+ | Coefficients: | ||
+ | Estimate Std. Error t value Pr(> | ||
+ | (Intercept) | ||
+ | x | ||
+ | --- | ||
+ | Signif. codes: | ||
+ | |||
+ | Residual standard error: 88.57 on 198 degrees of freedom | ||
+ | Multiple R-squared: | ||
+ | F-statistic: | ||
+ | |||
+ | > plot(seq(1, length(msrs)), | ||
+ | > plot(seq(1, length(srs)), | ||
+ | > tail(msrs) | ||
+ | [1] 9504.208 9505.041 9505.875 9506.710 9507.544 9508.379 | ||
+ | > max(msrs) | ||
+ | [1] 11161.64 | ||
+ | > min(msrs) | ||
+ | [1] 7766.679 | ||
+ | > tail(srs) | ||
+ | [1] -41.68368 -41.69368 -41.70368 -41.71368 -41.72368 -41.73368 | ||
+ | > max(srs) | ||
+ | [1] 58.26632 | ||
+ | > min(srs) | ||
+ | [1] -41.73368 | ||
+ | > | ||
+ | </ | ||
+ | {{: | ||
+ | {{: | ||
+ | |||
+ | ===== b값 구하기 ===== | ||
+ | 이제는 a값을 | ||
+ | < | ||
+ | ############################################## | ||
+ | # b값도 범위를 추측한 후에 0.01씩 증가시키면서 | ||
+ | # 각 b값에서 mse값을 구해본후 가장 작은 값을 | ||
+ | # 가질 때의 b값을 구하면 된다. | ||
+ | # 그러나 b값의 적절한 구간을 예측하는 것이 | ||
+ | # 불가능하다 (그냥 추측뿐) | ||
+ | # 위의 y 데이터에서 y = 314*x + rnorm(. . .) | ||
+ | # 이라면 -30-30 구간은 적절하지 않은 구간이 된다. | ||
+ | # 더우기 a값을 정확히 알아야 b값을 추출할 수 있다. | ||
+ | # 이는 적절한 방법이 아니다. | ||
+ | |||
+ | b <- 1 | ||
+ | a <- summary(mo)$coefficients[1] | ||
+ | |||
+ | b.init <- b | ||
+ | a.init <- a | ||
+ | |||
+ | # Predict function: | ||
+ | predict <- function(x, a, b){ | ||
+ | return (a + b * x) | ||
+ | } | ||
+ | |||
+ | # And loss function is: | ||
+ | residuals <- function(predictions, | ||
+ | return(y - predictions) | ||
+ | } | ||
+ | |||
+ | # we use sum of square of error which oftentimes become big | ||
+ | msrloss <- function(predictions, | ||
+ | residuals <- (y - predictions) | ||
+ | return(mean(residuals^2)) | ||
+ | } | ||
+ | |||
+ | msrs <- c() | ||
+ | mrs <- c() | ||
+ | as <- c() | ||
+ | |||
+ | for (i in seq(from = -50, to = 50, by = 0.01)) { | ||
+ | pred <- predict(x, a, i) | ||
+ | res <- residuals(pred, | ||
+ | msr <- msrloss(pred, | ||
+ | msrs <- append(msrs, | ||
+ | mrs <- append(mrs, mean(res)) | ||
+ | as <- append(as, | ||
+ | } | ||
+ | |||
+ | min(msrs) | ||
+ | min.pos.msrs <- which(msrs == min(msrs)) | ||
+ | print(as[min.pos.msrs]) | ||
+ | summary(mo) | ||
+ | plot(seq(1, length(msrs)), | ||
+ | plot(seq(1, length(mrs)), | ||
+ | min(msrs) | ||
+ | max(msrs) | ||
+ | min(mrs) | ||
+ | max(mrs) | ||
+ | |||
+ | </ | ||
+ | ===== output ===== | ||
+ | < | ||
+ | > | ||
+ | > ############################################## | ||
+ | > # b값도 범위를 추측한 후에 0.01씩 증가시키면서 | ||
+ | > # 각 b값에서 mse값을 구해본후 가장 작은 값을 | ||
+ | > # 가질 때의 b값을 구하면 된다. | ||
+ | > # 그러나 b값의 적절한 구간을 예측하는 것이 | ||
+ | > # 불가능하다 (그냥 추측뿐) | ||
+ | > # 위의 y 데이터에서 y = 314*x + rnorm(. . .) | ||
+ | > # 이라면 -30-30 구간은 적절하지 않은 구간이 된다. | ||
+ | > # 더우기 a값을 정확히 알아야 b값을 추출할 수 있다. | ||
+ | > # 이는 적절한 방법이 아니다. | ||
+ | > | ||
+ | > b <- 1 | ||
+ | > a <- summary(mo)$coefficients[1] | ||
+ | > | ||
+ | > b.init <- b | ||
+ | > a.init <- a | ||
+ | > | ||
+ | > # Predict function: | ||
+ | > predict <- function(x, a, b){ | ||
+ | + | ||
+ | + } | ||
+ | > | ||
+ | > # And loss function is: | ||
+ | > residuals <- function(predictions, | ||
+ | + | ||
+ | + } | ||
+ | > | ||
+ | > # we use sum of square of error which oftentimes become big | ||
+ | > msrloss <- function(predictions, | ||
+ | + | ||
+ | + | ||
+ | + } | ||
+ | > | ||
+ | > msrs <- c() | ||
+ | > mrs <- c() | ||
+ | > as <- c() | ||
+ | > | ||
+ | > for (i in seq(from = -50, to = 50, by = 0.01)) { | ||
+ | + pred <- predict(x, a, i) | ||
+ | + res <- residuals(pred, | ||
+ | + msr <- msrloss(pred, | ||
+ | + msrs <- append(msrs, | ||
+ | + mrs <- append(mrs, mean(res)) | ||
+ | + as <- append(as, | ||
+ | + } | ||
+ | > | ||
+ | > min(msrs) | ||
+ | [1] 7766.679 | ||
+ | > min.pos.msrs <- which(msrs == min(msrs)) | ||
+ | > print(as[min.pos.msrs]) | ||
+ | [1] 11.89 | ||
+ | > summary(mo) | ||
+ | |||
+ | Call: | ||
+ | lm(formula = y ~ x, data = data) | ||
+ | |||
+ | Residuals: | ||
+ | | ||
+ | -259.314 | ||
+ | |||
+ | Coefficients: | ||
+ | Estimate Std. Error t value Pr(> | ||
+ | (Intercept) | ||
+ | x | ||
+ | --- | ||
+ | Signif. codes: | ||
+ | |||
+ | Residual standard error: 88.57 on 198 degrees of freedom | ||
+ | Multiple R-squared: | ||
+ | F-statistic: | ||
+ | |||
+ | > plot(seq(1, length(msrs)), | ||
+ | > plot(seq(1, length(mrs)), | ||
+ | > min(msrs) | ||
+ | [1] 7766.679 | ||
+ | > max(msrs) | ||
+ | [1] 109640 | ||
+ | > min(mrs) | ||
+ | [1] -170.3106 | ||
+ | > max(mrs) | ||
+ | [1] 276.56 | ||
+ | > | ||
+ | </ | ||
+ | a와 b를 동시에 | ||
====== Gradient descend ====== | ====== Gradient descend ====== | ||
Line 434: | Line 515: | ||
\text{for a (constant)} \\ | \text{for a (constant)} \\ | ||
\\ | \\ | ||
- | \text{SSE} & = & \text{Mean Square Residuals} \\ | + | \text{SSE} & = & \text{Sum of Square Residuals} \\ |
- | \text{Residuals} & = & (Y_i - (a + bX_i)) \\ | + | \text{Residual} & = & (Y_i - (a + bX_i)) \\ |
\\ | \\ | ||
\frac{\text{dSSE}}{\text{da}} | \frac{\text{dSSE}}{\text{da}} | ||
- | & = & \frac{\text{dResiduals^2}}{\text{dResiduals}} * \frac{\text{dResiduals}}{\text{da}} \\ | + | & = & \frac{\text{dResidual^2}}{\text{dResidual}} * \frac{\text{dResidual}}{\text{da}} \\ |
- | & = & 2 * \text{Residuals} * \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) \\ | + | & = & 2 * \text{Residual} * \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) \\ |
& \because & \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) = -1 \\ | & \because & \dfrac{\text{d}}{\text{da}} (Y_i - (a+bX_i)) = -1 \\ | ||
& = & 2 * \sum{(Y_i - (a + bX_i))} * -1 \\ | & = & 2 * \sum{(Y_i - (a + bX_i))} * -1 \\ | ||
- | & = & -2 * \text{Residuals} \\ | + | & = & -2 *\sum{\text{Residual}} \\ |
\end{eqnarray*} | \end{eqnarray*} | ||
아래 R code에서 gradient function을 참조. | 아래 R code에서 gradient function을 참조. | ||
+ | </ | ||
+ | <WRAP box> | ||
+ | \begin{eqnarray*} | ||
+ | \text{for b, (coefficient)} \\ | ||
+ | \\ | ||
+ | \dfrac{\text{d}}{\text{db}} \sum{(Y_i - (a + bX_i))^2} | ||
+ | & = & \sum{\dfrac{\text{dResidual}^2}{\text{db}}} \\ | ||
+ | & = & \sum{\dfrac{\text{dResidual}^2}{\text{dResidual}}*\dfrac{\text{dResidual}}{\text{db}} } \\ | ||
+ | & = & \sum{2*\text{Residual} * \dfrac{\text{dResidual}}{\text{db}} | ||
+ | & = & \sum{2*\text{Residual} * (-X_i) } \;\;\;\; \\ | ||
+ | & \because & \dfrac{\text{dResidual}}{\text{db}} = (Y_i - (a+bX_i)) = -X_i \\ | ||
+ | & = & -2 X_i \sum{(Y_i - (a + bX_i))} \\ | ||
+ | & = & -2 * X_i * \sum{\text{residual}} \\ | ||
+ | & .. & -2 * X_i * \frac{\sum{\text{residual}}}{n} \\ | ||
+ | & = & -2 * \overline{X_i * \text{residual}} \\ | ||
+ | |||
+ | \end{eqnarray*} | ||
+ | |||
+ | 위의 설명은 Sum of Square값을 미분하는 것을 전제로 하였지만, | ||
+ | |||
< | < | ||
gradient <- function(x, y, predictions){ | gradient <- function(x, y, predictions){ | ||
Line 454: | Line 555: | ||
</ | </ | ||
- | |||
- | </ | ||
- | <WRAP box> | ||
- | \begin{eqnarray*} | ||
- | \text{for b, (coefficient)} \\ | ||
- | \\ | ||
- | \dfrac{\text{d}}{\text{dv}} \frac{\sum{(Y_i - (a + bX_i))^2}}{N} | ||
- | & = & \sum{2 \frac{1}{N} (Y_i - (a + bX_i))} * (-X_i) \;\;\;\; \\ | ||
- | & \because & \dfrac{\text{d}}{\text{dv for b}} (Y_i - (a+bX_i)) = -X_i \\ | ||
- | & = & -2 X_i \frac{\sum{(Y_i - (a + bX_i))}}{N} \\ | ||
- | & = & -2 * X_i * \text{mean of residuals} \\ | ||
- | \\ | ||
- | \end{eqnarray*} | ||
- | (미분을 이해한다는 것을 전제로) 위의 식은 b값이 변할 때 msr (mean square residual) 값이 어떻게 변하는가를 알려주는 것이다. 그리고 그것은 b값에 대한 residual의 총합에 (-2/ | ||
</ | </ | ||
Line 521: | Line 608: | ||
====== R code ====== | ====== R code ====== | ||
< | < | ||
- | # d statquest explanation | + | # the above no gradient |
- | # x <- c(0.5, 2.3, 2.9) | + | # mse 값으로 계산 rather than sse |
- | # y <- c(1.4, 1.9, 3.2) | + | # 후자는 값이 너무 커짐 |
- | rm(list=ls()) | + | a <- rnorm(1) |
- | # set.seed(191) | + | b <- rnorm(1) |
- | n <- 300 | + | a.start |
- | x <- rnorm(n, 5, 1.2) | + | b.start <- b |
- | y <- 2.14 * x + rnorm(n, 0, 4) | + | |
- | # data <- data.frame(x, y) | + | gradient |
- | data <- tibble(x | + | |
- | + | db = -2 * mean(x * error) | |
- | mo <- lm(y~x) | + | da = -2 * mean(error) |
- | summary(mo) | + | |
- | + | ||
- | # set.seed(191) | + | |
- | # Initialize random betas | + | |
- | b1 = rnorm(1) | + | |
- | b0 = rnorm(1) | + | |
- | + | ||
- | b1.init <- b1 | + | |
- | b0.init <- b0 | + | |
- | + | ||
- | # Predict function: | + | |
- | predict <- function(x, b0, b1){ | + | |
- | return (b0 + b1 * x) | + | |
} | } | ||
- | # And loss function is: | + | mseloss |
- | residuals | + | |
- | | + | return(mean(residuals^2)) |
} | } | ||
- | |||
- | loss_mse <- function(predictions, | ||
- | residuals = y - predictions | ||
- | return(mean(residuals ^ 2)) | ||
- | } | ||
- | |||
- | predictions <- predict(x, b0, b1) | ||
- | residuals <- residuals(predictions, | ||
- | loss = loss_mse(predictions, | ||
- | |||
- | data <- tibble(data.frame(x, | ||
- | |||
- | print(paste0(" | ||
- | |||
- | gradient <- function(x, y, predictions){ | ||
- | dinputs = y - predictions | ||
- | db1 = -2 * mean(x * dinputs) | ||
- | db0 = -2 * mean(dinputs) | ||
- | | ||
- | return(list(" | ||
- | } | ||
- | |||
- | gradients <- gradient(x, y, predictions) | ||
- | print(gradients) | ||
# Train the model with scaled features | # Train the model with scaled features | ||
- | x_scaled <- (x - mean(x)) / sd(x) | + | learning.rate |
- | + | ||
- | learning_rate | + | |
# Record Loss for each epoch: | # Record Loss for each epoch: | ||
- | # logs = list() | + | as = c() |
- | # bs=list() | + | bs = c() |
- | b0s = c() | + | mses = c() |
- | b1s = c() | + | sses = c() |
- | mse = c() | + | mres = c() |
+ | zx <- (x-mean(x))/ | ||
- | nlen <- 80 | + | nlen <- 50 |
- | for (epoch in 1:nlen){ | + | for (epoch in 1:nlen) { |
- | # Predict all y values: | + | predictions |
- | predictions | + | residual <- residuals(predictions, |
- | loss = loss_mse(predictions, | + | loss <- mseloss(predictions, |
- | | + | |
- | | + | |
| | ||
- | | + | |
- | print(paste0(" | + | |
- | } | + | |
| | ||
- | | + | |
- | | + | |
- | | + | |
+ | a <- a-step.a | ||
| | ||
- | | + | |
- | b0 <- b0 - db0 * learning_rate | + | |
- | b0s <- append(b0s, b0) | + | |
- | | + | |
} | } | ||
+ | mses | ||
+ | mres | ||
+ | as | ||
+ | bs | ||
+ | |||
+ | # scaled | ||
+ | a | ||
+ | b | ||
# unscale coefficients to make them comprehensible | # unscale coefficients to make them comprehensible | ||
- | b0 = | + | # see http:// |
- | b1 = b1 / sd(x) | + | # and |
+ | # http:// | ||
+ | # | ||
+ | a = | ||
+ | b = | ||
+ | a | ||
+ | b | ||
# changes of estimators | # changes of estimators | ||
- | b0s <- b0s - (mean(x) /sd(x)) * b1s | + | as <- as - (mean(x) /sd(x)) * bs |
- | b1s <- b1s / sd(x) | + | bs <- bs / sd(x) |
- | parameters | + | as |
+ | bs | ||
+ | mres | ||
+ | mse.x <- mses | ||
- | cat(paste0(" | + | parameters <- data.frame(as, |
+ | |||
+ | cat(paste0(" | ||
summary(lm(y~x))$coefficients | summary(lm(y~x))$coefficients | ||
+ | mses <- data.frame(mses) | ||
+ | mses.log <- data.table(epoch = 1:nlen, mses) | ||
+ | ggplot(mses.log, | ||
+ | geom_line(color=" | ||
+ | theme_classic() | ||
+ | |||
+ | # mres <- data.frame(mres) | ||
+ | mres.log <- data.table(epoch = 1:nlen, mres) | ||
+ | ggplot(mres.log, | ||
+ | geom_line(color=" | ||
+ | theme_classic() | ||
+ | |||
+ | ch <- data.frame(mres, | ||
+ | ch | ||
+ | max(y) | ||
ggplot(data, | ggplot(data, | ||
geom_point(size = 2) + | geom_point(size = 2) + | ||
- | geom_abline(aes(intercept = b0s, slope = b1s), | + | geom_abline(aes(intercept = as, slope = bs), |
data = parameters, linewidth = 0.5, | data = parameters, linewidth = 0.5, | ||
color = ' | color = ' | ||
+ | stat_poly_line() + | ||
+ | stat_poly_eq(use_label(c(" | ||
theme_classic() + | theme_classic() + | ||
- | geom_abline(aes(intercept = b0s, slope = b1s), | + | geom_abline(aes(intercept = as, slope = bs), |
data = parameters %>% slice_head(), | data = parameters %>% slice_head(), | ||
linewidth = 1, color = ' | linewidth = 1, color = ' | ||
- | geom_abline(aes(intercept = b0s, slope = b1s), | + | geom_abline(aes(intercept = as, slope = bs), |
data = parameters %>% slice_tail(), | data = parameters %>% slice_tail(), | ||
linewidth = 1, color = ' | linewidth = 1, color = ' | ||
labs(title = ' | labs(title = ' | ||
- | + | summary(lm(y~x)) | |
- | b0.init | + | a.start |
- | b1.init | + | b.start |
- | + | a | |
- | data | + | b |
- | parameters | + | |
</ | </ | ||
====== R output ===== | ====== R output ===== | ||
< | < | ||
- | > rm(list=ls()) | ||
- | > # set.seed(191) | ||
- | > n <- 300 | ||
- | > x <- rnorm(n, 5, 1.2) | ||
- | > y <- 2.14 * x + rnorm(n, 0, 4) | ||
> | > | ||
- | > # data <- data.frame(x, | ||
- | > data <- tibble(x = x, y = y) | ||
> | > | ||
- | > mo <- lm(y~x) | + | > # the above no gradient |
- | > summary(mo) | + | > # mse 값으로 계산 rather than sse |
- | + | > # 후자는 값이 너무 커짐 | |
- | Call: | + | |
- | lm(formula = y ~ x) | + | |
- | + | ||
- | Residuals: | + | |
- | | + | |
- | -9.754 -2.729 -0.135 | + | |
- | + | ||
- | Coefficients: | + | |
- | Estimate Std. Error t value Pr(>|t|) | + | |
- | (Intercept) | + | |
- | x | + | |
- | --- | + | |
- | Signif. codes: | + | |
- | + | ||
- | Residual standard error: 3.951 on 298 degrees of freedom | + | |
- | Multiple R-squared: | + | |
- | F-statistic: | + | |
> | > | ||
- | > # set.seed(191) | + | > a <- rnorm(1) |
- | > # Initialize random betas | + | > b <- rnorm(1) |
- | > b1 = rnorm(1) | + | > a.start <- a |
- | > b0 = rnorm(1) | + | > b.start <- b |
> | > | ||
- | > b1.init <- b1 | + | > gradient |
- | > b0.init <- b0 | + | + error = y - predictions |
- | > | + | + db = -2 * mean(x * error) |
- | > # Predict function: | + | + da = -2 * mean(error) |
- | > predict | + | + |
- | + return (b0 + b1 * x) | + | |
+ } | + } | ||
> | > | ||
- | > # And loss function is: | + | > mseloss |
- | > residuals | + | + residuals <- (y - predictions) |
- | + return(y - predictions) | + | + |
+ } | + } | ||
- | > | ||
- | > loss_mse <- function(predictions, | ||
- | + | ||
- | + | ||
- | + } | ||
- | > | ||
- | > predictions <- predict(x, b0, b1) | ||
- | > residuals <- residuals(predictions, | ||
- | > loss = loss_mse(predictions, | ||
- | > | ||
- | > data <- tibble(data.frame(x, | ||
- | > | ||
- | > print(paste0(" | ||
- | [1] "Loss is: 393" | ||
- | > | ||
- | > gradient <- function(x, y, predictions){ | ||
- | + | ||
- | + db1 = -2 * mean(x * dinputs) | ||
- | + db0 = -2 * mean(dinputs) | ||
- | + | ||
- | + | ||
- | + } | ||
- | > | ||
- | > gradients <- gradient(x, y, predictions) | ||
- | > print(gradients) | ||
- | $db1 | ||
- | [1] -200.6834 | ||
- | |||
- | $db0 | ||
- | [1] -37.76994 | ||
- | |||
> | > | ||
> # Train the model with scaled features | > # Train the model with scaled features | ||
- | > x_scaled <- (x - mean(x)) / sd(x) | + | > learning.rate |
- | > | + | |
- | > learning_rate | + | |
> | > | ||
> # Record Loss for each epoch: | > # Record Loss for each epoch: | ||
- | > # logs = list() | + | > as = c() |
- | > # bs=list() | + | > bs = c() |
- | > b0s = c() | + | > mses = c() |
- | > b1s = c() | + | > sses = c() |
- | > mse = c() | + | > mres = c() |
+ | > zx <- (x-mean(x))/ | ||
> | > | ||
- | > nlen <- 80 | + | > nlen <- 50 |
- | > for (epoch in 1:nlen){ | + | > for (epoch in 1:nlen) { |
- | + # Predict all y values: | + | + |
- | + | + | + |
- | + | + | + |
- | + mse = append(mse, loss) | + | + mres <- append(mres, mean(residual)) |
- | + # logs = append(logs, loss) | + | + mses <- append(mses, loss) |
+ | + | ||
- | + if (epoch %% 10 == 0){ | + | + grad <- gradient(zx, y, predictions) |
- | + | + | |
- | + } | + | |
+ | + | ||
- | + gradients | + | + step.b |
- | + db1 <- gradients$db1 | + | + step.a |
- | + db0 <- gradients$db0 | + | + b <- b-step.b |
+ | + a <- a-step.a | ||
+ | + | ||
- | + b1 <- b1 - db1 * learning_rate | + | + as <- append(as, a) |
- | + b0 <- b0 - db0 * learning_rate | + | + bs <- append(bs, b) |
- | + | + | |
- | + b1s <- append(b1s, b1) | + | |
+ } | + } | ||
- | [1] " | + | > mses |
- | [1] " | + | [1] 12376.887 10718.824 |
- | [1] " | + | [9] |
- | [1] " | + | [17] 7770.364 |
- | [1] " | + | [25] 7766.783 |
- | [1] " | + | [33] 7766.682 |
- | [1] " | + | [41] 7766.679 |
- | [1] " | + | [49] 7766.679 |
+ | > mres | ||
+ | | ||
+ | [7] 15.735566811 12.588453449 10.070762759 | ||
+ | [13] 4.124984426 | ||
+ | [19] 1.081339917 | ||
+ | [25] 0.283466771 | ||
+ | [31] 0.074309113 | ||
+ | [37] 0.019479688 | ||
+ | [43] 0.005106483 | ||
+ | [49] 0.001338634 | ||
+ | > as | ||
+ | | ||
+ | [9] 53.33440 54.94572 56.23478 57.26602 58.09102 58.75102 59.27901 59.70141 | ||
+ | [17] 60.03933 60.30967 60.52593 60.69895 60.83736 60.94809 61.03667 61.10754 | ||
+ | [25] 61.16423 61.20959 61.24587 61.27490 61.29812 61.31670 61.33156 61.34345 | ||
+ | [33] 61.35296 61.36057 61.36666 61.37153 61.37542 61.37854 61.38103 61.38303 | ||
+ | [41] 61.38462 61.38590 61.38692 61.38774 61.38839 61.38891 61.38933 61.38967 | ||
+ | [49] 61.38993 61.39015 | ||
+ | > bs | ||
+ | | ||
+ | [9] 26.365585 27.224909 27.913227 28.464570 28.906196 29.259938 29.543285 29.770247 | ||
+ | [17] 29.952043 30.097661 30.214302 30.307731 30.382568 30.442512 30.490527 30.528987 | ||
+ | [25] 30.559794 30.584470 30.604236 30.620068 30.632750 30.642908 30.651044 30.657562 | ||
+ | [33] 30.662782 30.666964 30.670313 30.672996 30.675145 30.676866 30.678245 30.679349 | ||
+ | [41] 30.680234 30.680943 30.681510 30.681965 30.682329 30.682621 30.682854 30.683041 | ||
+ | [49] 30.683191 30.683311 | ||
+ | > | ||
+ | > # scaled | ||
+ | > a | ||
+ | [1] 61.39015 | ||
+ | > b | ||
+ | [1] 30.68331 | ||
> | > | ||
> # unscale coefficients to make them comprehensible | > # unscale coefficients to make them comprehensible | ||
- | > b0 = | + | > # see http:// |
- | > b1 = b1 / sd(x) | + | > # and |
+ | > # http:// | ||
+ | > # | ||
+ | > a = | ||
+ | > b = | ||
+ | > a | ||
+ | [1] 8.266303 | ||
+ | > b | ||
+ | [1] 11.88797 | ||
> | > | ||
> # changes of estimators | > # changes of estimators | ||
- | > b0s <- b0s - (mean(x) /sd(x)) * b1s | + | > as <- as - (mean(x) /sd(x)) * bs |
- | > b1s <- b1s / sd(x) | + | > bs <- bs / sd(x) |
> | > | ||
- | > parameters <- tibble(data.frame(b0s, b1s, mse)) | + | > as |
+ | [1] 4.364717 5.189158 5.839931 6.353516 6.758752 7.078428 7.330555 7.529361 | ||
+ | [9] 7.686087 7.809611 7.906942 7.983615 8.043999 8.091541 8.128963 8.158410 | ||
+ | [17] 8.181574 8.199791 8.214112 8.225367 8.234209 8.241154 8.246605 8.250884 | ||
+ | [25] 8.254239 8.256871 8.258933 8.260549 8.261814 8.262804 8.263579 8.264184 | ||
+ | [33] 8.264658 8.265027 8.265315 8.265540 8.265716 8.265852 8.265958 8.266041 | ||
+ | [41] 8.266105 8.266155 8.266193 8.266223 8.266246 8.266264 8.266278 8.266289 | ||
+ | [49] 8.266297 8.266303 | ||
+ | > bs | ||
+ | | ||
+ | [9] 10.215107 10.548045 10.814727 11.028340 11.199444 11.336498 11.446279 11.534213 | ||
+ | [17] 11.604648 11.661067 11.706258 11.742456 11.771451 11.794676 11.813279 11.828180 | ||
+ | [25] 11.840116 11.849676 11.857334 11.863469 11.868382 11.872317 11.875470 11.877995 | ||
+ | [33] 11.880018 11.881638 11.882935 11.883975 11.884807 11.885474 11.886009 11.886437 | ||
+ | [41] 11.886779 11.887054 11.887274 11.887450 11.887591 11.887704 11.887794 11.887867 | ||
+ | [49] 11.887925 11.887972 | ||
+ | > mres | ||
+ | [1] 60.026423686 48.021138949 38.416911159 30.733528927 24.586823142 19.669458513 | ||
+ | [7] 15.735566811 12.588453449 10.070762759 | ||
+ | [13] 4.124984426 | ||
+ | [19] 1.081339917 | ||
+ | [25] 0.283466771 | ||
+ | [31] 0.074309113 | ||
+ | [37] 0.019479688 | ||
+ | [43] 0.005106483 | ||
+ | [49] 0.001338634 | ||
+ | > mse.x <- mses | ||
> | > | ||
- | > cat(paste0(" | + | > parameters <- data.frame(as, |
- | Slope: 2.26922511738252, | + | > |
- | Intercept: -0.779435058320381 | + | > cat(paste0(" |
+ | Intercept: 8.26630323816515 | ||
+ | Slope: 11.8879715830899 | ||
> summary(lm(y~x))$coefficients | > summary(lm(y~x))$coefficients | ||
- | | + | Estimate Std. Error |
- | (Intercept) | + | (Intercept) |
- | x 2.2692252 | + | x 11.888159 |
+ | > | ||
+ | > mses <- data.frame(mses) | ||
+ | > mses.log <- data.table(epoch = 1:nlen, mses) | ||
+ | > ggplot(mses.log, aes(epoch, mses)) + | ||
+ | + | ||
+ | + | ||
+ | > | ||
+ | > # mres <- data.frame(mres) | ||
+ | > mres.log <- data.table(epoch = 1:nlen, mres) | ||
+ | > ggplot(mres.log, | ||
+ | + | ||
+ | + | ||
> | > | ||
+ | > ch <- data.frame(mres, | ||
+ | > ch | ||
+ | | ||
+ | 1 60.026423686 12376.887 | ||
+ | 2 48.021138949 10718.824 | ||
+ | 3 38.416911159 | ||
+ | 4 30.733528927 | ||
+ | 5 24.586823142 | ||
+ | 6 19.669458513 | ||
+ | 7 15.735566811 | ||
+ | 8 12.588453449 | ||
+ | 9 10.070762759 | ||
+ | 10 8.056610207 | ||
+ | 11 6.445288166 | ||
+ | 12 5.156230533 | ||
+ | 13 4.124984426 | ||
+ | 14 3.299987541 | ||
+ | 15 2.639990033 | ||
+ | 16 2.111992026 | ||
+ | 17 1.689593621 | ||
+ | 18 1.351674897 | ||
+ | 19 1.081339917 | ||
+ | 20 0.865071934 | ||
+ | 21 0.692057547 | ||
+ | 22 0.553646038 | ||
+ | 23 0.442916830 | ||
+ | 24 0.354333464 | ||
+ | 25 0.283466771 | ||
+ | 26 0.226773417 | ||
+ | 27 0.181418734 | ||
+ | 28 0.145134987 | ||
+ | 29 0.116107990 | ||
+ | 30 0.092886392 | ||
+ | 31 0.074309113 | ||
+ | 32 0.059447291 | ||
+ | 33 0.047557833 | ||
+ | 34 0.038046266 | ||
+ | 35 0.030437013 | ||
+ | 36 0.024349610 | ||
+ | 37 0.019479688 | ||
+ | 38 0.015583751 | ||
+ | 39 0.012467000 | ||
+ | 40 0.009973600 | ||
+ | 41 0.007978880 | ||
+ | 42 0.006383104 | ||
+ | 43 0.005106483 | ||
+ | 44 0.004085187 | ||
+ | 45 0.003268149 | ||
+ | 46 0.002614519 | ||
+ | 47 0.002091616 | ||
+ | 48 0.001673292 | ||
+ | 49 0.001338634 | ||
+ | 50 0.001070907 | ||
+ | > max(y) | ||
+ | [1] 383.1671 | ||
> ggplot(data, | > ggplot(data, | ||
+ | + | ||
- | + | + | + |
+ data = parameters, linewidth = 0.5, | + data = parameters, linewidth = 0.5, | ||
+ color = ' | + color = ' | ||
+ | + | ||
+ | + | ||
+ | + | ||
- | + | + | + |
+ data = parameters %>% slice_head(), | + data = parameters %>% slice_head(), | ||
+ | + | ||
- | + | + | + |
+ data = parameters %>% slice_tail(), | + data = parameters %>% slice_tail(), | ||
+ | + | ||
+ | + | ||
- | > | + | > summary(lm(y~x)) |
- | > b0.init | + | |
- | [1] -1.67967 | + | |
- | > b1.init | + | |
- | [1] -1.323992 | + | |
- | > | + | |
- | > data | + | |
- | # A tibble: 300 × 4 | + | |
- | | + | |
- | < | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | 10 3.33 3.80 | + | |
- | # ℹ 290 more rows | + | |
- | # ℹ Use `print(n = ...)` to see more rows | + | |
- | > parameters | + | |
- | # A tibble: 80 × 3 | + | |
- | | + | |
- | < | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | | + | |
- | 8 -0.0397 | + | |
- | 9 -0.186 | + | |
- | 10 -0.303 | + | |
- | # ℹ 70 more rows | + | |
- | # | + | |
+ | Call: | ||
+ | lm(formula = y ~ x) | ||
+ | |||
+ | Residuals: | ||
+ | | ||
+ | -259.314 | ||
+ | |||
+ | Coefficients: | ||
+ | Estimate Std. Error t value Pr(> | ||
+ | (Intercept) | ||
+ | x | ||
+ | --- | ||
+ | Signif. codes: | ||
+ | |||
+ | Residual standard error: 88.57 on 198 degrees of freedom | ||
+ | Multiple R-squared: | ||
+ | F-statistic: | ||
+ | |||
+ | > a.start | ||
+ | [1] 1.364582 | ||
+ | > b.start | ||
+ | [1] -1.12968 | ||
+ | > a | ||
+ | [1] 8.266303 | ||
+ | > b | ||
+ | [1] 11.88797 | ||
+ | > | ||
</ | </ | ||
+ | {{: | ||
+ | {{: | ||
+ | {{: | ||
+ | |||
+ | ====== Why normalize (scale or make z-score) xi ====== | ||
+ | x 변인의 측정단위로 인해서 b 값이 결정되게 되는데 이 때의 b값은 상당하고 다양한 범위를 가질 수 있다. 가령 월 수입이 (인컴) X 라고 한다면 우리가 추정해야 (추적해야) 할 b값은 수백만이 될 수도 있다.이 값을 gradient로 추적하게 된다면 너무도 많은 iteration을 거쳐야 할 수 있다. 변인이 바뀌면 이 b의 추적범위도 드라마틱하게 바뀌게 된다. 이를 표준화한 x 점수를 사용하게 된다면 일정한 learning rate와 iteration만으로도 정확한 a와 b를 추적할 수 있게 된다. | ||
+ | |||
+ | ====== How to unnormalize (unscale) a and b ====== | ||
+ | \begin{eqnarray*} | ||
+ | y & = & a + b * x \\ | ||
+ | & & \text{we use z instead of x} \\ | ||
+ | & & \text{and } \\ | ||
+ | & & z = \frac{(x - \mu)}{\sigma} \\ | ||
+ | & & \text{suppose that the result after calculation be } \\ | ||
+ | y & = & k + m * z \\ | ||
+ | & = & k + m * \frac{(x - \mu)}{\sigma} \\ | ||
+ | & = & k + \frac{m * x}{\sigma} - \frac{m * \mu}{\sigma} | ||
+ | & = & k - \frac{m * \mu}{\sigma} + \frac{m * x}{\sigma} | ||
+ | & = & \underbrace{k - \frac{\mu}{\sigma} * m}_\text{ 1 } + \underbrace{\frac{m}{\sigma}}_\text{ 2 } * x \\ | ||
+ | & & \text{therefore, | ||
+ | a & = & k - \frac{\mu}{\sigma} * m \\ | ||
+ | b & = & \frac{m}{\sigma} \\ | ||
+ | \end{eqnarray*} | ||
+ | |||
- | {{: | ||
gradient_descent.1754322327.txt.gz · Last modified: 2025/08/05 00:45 by hkimscil