User Tools

Site Tools


gradient_descent

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revisionPrevious revision
Next revision
Previous revision
gradient_descent [2026/03/14 06:32] – [R output] hkimscilgradient_descent [2026/03/14 11:21] (current) – [output] hkimscil
Line 17: Line 17:
 library(ggplot2) library(ggplot2)
 library(ggpmisc) library(ggpmisc)
 +
  
 rm(list=ls()) rm(list=ls())
-# set.seed(191)+ 
 +data 만들기 
 +set.seed(191)
 nx <- 200 nx <- 200
 mx <- 4.5 mx <- 4.5
Line 26: Line 29:
 slp <- 12 slp <- 12
 y <-  slp * x + rnorm(nx, 0, slp*sdx*3) y <-  slp * x + rnorm(nx, 0, slp*sdx*3)
- 
 data <- data.frame(x, y) data <- data.frame(x, y)
 +# data 변인 완성
  
 +# regression summary shows 
 +# a and b
 mo <- lm(y ~ x, data = data) mo <- lm(y ~ x, data = data)
 summary(mo) summary(mo)
Line 37: Line 42:
   stat_poly_eq(use_label(c("eq", "R2"))) +   stat_poly_eq(use_label(c("eq", "R2"))) +
   theme_classic()    theme_classic() 
-# set.seed(191) + 
-Initialize random betas +위에서 확인한 b값을 b로 고정하고 a만 변화시켜서 이해
-# 우선 b를 고정하고 a만  +
-변화시켜서 이해+
 b <- summary(mo)$coefficients[2] b <- summary(mo)$coefficients[2]
 a <- 0 a <- 0
Line 47: Line 50:
 a.init <- a a.init <- a
  
-# Predict function:+# Predict function: y hat 값
 predict <- function(x, a, b){ predict <- function(x, a, b){
   return (a + b * x)   return (a + b * x)
 } }
  
-# And loss function is:+# And loss function is: residual 혹은 error 값
 residuals <- function(predictions, y) { residuals <- function(predictions, y) {
   return(y - predictions)   return(y - predictions)
 } }
  
-# we use sum of square of error which oftentimes become big+# we use sum of square of error 
 ssrloss <- function(predictions, y) { ssrloss <- function(predictions, y) {
   residuals <- (y - predictions)   residuals <- (y - predictions)
Line 67: Line 70:
 as <- c() # for as (intercepts) as <- c() # for as (intercepts)
  
 +# x 값을 -50 에서 50을 범위로 0.01씩 증가시켜서 
 +# for 문에 대입, i로 사용
 for (i in seq(from = -50, to = 50, by = 0.01)) { for (i in seq(from = -50, to = 50, by = 0.01)) {
   pred <- predict(x, i, b)   pred <- predict(x, i, b)
Line 73: Line 78:
   ssrs <- append(ssrs, ssr)   ssrs <- append(ssrs, ssr)
   srs <- append(srs, sum(res))   srs <- append(srs, sum(res))
-  as <- append(as, i)+  as <- append(as, i) # i 값을 a로 사용했기에 as 변인에 기록
 } }
 +# 1에는 0.01이 100개 있고, -50 ~ 50 = 1이 101 개 있으니, 10100
 length(ssrs) length(ssrs)
 length(srs) length(srs)
 length(as) length(as)
  
-min(ssrs) +min(ssrs) # sum of square error 값 중 최소값 
-min.pos.ssrs <- which(ssrs == min(ssrs))+min.pos.ssrs <- which(ssrs == min(ssrs)) # 그 값이 몇 번째에 있는지 구함
 min.pos.ssrs min.pos.ssrs
-print(as[min.pos.ssrs])+print(as[min.pos.ssrs]) # 그 몇번째에 해당하는 a값을 구함 
 +# 이 a값이 최소 ssr값을 갖도록 하는 a (소수점 2자리에서 구함)
 summary(mo) summary(mo)
-plot(seq(1, length(ssrs)), ssrs) + 
-plot(seq(1length(ssrs))srs)+k <- min(ssrs) 
 +j <- as[min.pos.ssrs] 
 +plot(seq(1, length(as)), ssrs, type="l"
 +text(4500,2000000,paste("ssrs = "k, "\n" , "is minimum value when a = ", j)
 tail(ssrs) tail(ssrs)
 max(ssrs) max(ssrs)
Line 96: Line 107:
 <tabbox ro01> <tabbox ro01>
 <code> <code>
 +> # library(tidyverse)
 +> # library(data.table)
 > library(ggplot2) > library(ggplot2)
 > library(ggpmisc) > library(ggpmisc)
  
 > rm(list=ls()) > rm(list=ls())
-> # set.seed(191)+>  
 +> # data 만들기 
 +set.seed(191)
 > nx <- 200 > nx <- 200
 > mx <- 4.5 > mx <- 4.5
Line 107: Line 122:
 > slp <- 12 > slp <- 12
 > y <-  slp * x + rnorm(nx, 0, slp*sdx*3) > y <-  slp * x + rnorm(nx, 0, slp*sdx*3)
- 
 > data <- data.frame(x, y) > data <- data.frame(x, y)
 +> # data 변인 완성
  
 +> # regression summary shows 
 +> # a and b
 > mo <- lm(y ~ x, data = data) > mo <- lm(y ~ x, data = data)
 > summary(mo) > summary(mo)
Line 118: Line 135:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
  
 > ggplot(data = data, aes(x = x, y = y)) +  > ggplot(data = data, aes(x = x, y = y)) + 
-  geom_point() + +    geom_point() + 
-  stat_poly_line() + +    stat_poly_line() + 
-  stat_poly_eq(use_label(c("eq", "R2"))) + +    stat_poly_eq(use_label(c("eq", "R2"))) + 
-  theme_classic()  +    theme_classic()  
-# set.seed(191) +>  
-> # Initialize random betas +> # 위에서 확인한 b값을 b로 고정하고 a만 변화시켜서 이해
-> # 우선 b를 고정하고 a만  +
-> # 변화시켜서 이해+
 > b <- summary(mo)$coefficients[2] > b <- summary(mo)$coefficients[2]
 > a <- 0 > a <- 0
Line 147: Line 162:
 > a.init <- a > a.init <- a
  
-> # Predict function:+> # Predict function: y hat 값
 > predict <- function(x, a, b){ > predict <- function(x, a, b){
-  return (a + b * x)+    return (a + b * x)
 + } + }
  
-> # And loss function is:+> # And loss function is: residual 혹은 error 값
 > residuals <- function(predictions, y) { > residuals <- function(predictions, y) {
-  return(y - predictions)+    return(y - predictions)
 + } + }
  
-> # we use sum of square of error which oftentimes become big+> # we use sum of square of error 
 > ssrloss <- function(predictions, y) { > ssrloss <- function(predictions, y) {
-  residuals <- (y - predictions) +    residuals <- (y - predictions) 
-  return(sum(residuals^2))+    return(sum(residuals^2))
 + } + }
  
Line 167: Line 182:
 > as <- c() # for as (intercepts) > as <- c() # for as (intercepts)
  
 +> # x 값을 -50 에서 50을 범위로 0.01씩 증가시켜서 
 +> # for 문에 대입, i로 사용
 > for (i in seq(from = -50, to = 50, by = 0.01)) { > for (i in seq(from = -50, to = 50, by = 0.01)) {
-  pred <- predict(x, i, b) +    pred <- predict(x, i, b) 
-  res <- residuals(pred, y) +    res <- residuals(pred, y) 
-  ssr <- ssrloss(pred, y) +    ssr <- ssrloss(pred, y) 
-  ssrs <- append(ssrs, ssr) +    ssrs <- append(ssrs, ssr) 
-  srs <- append(srs, sum(res)) +    srs <- append(srs, sum(res)) 
-  as <- append(as, i)+    as <- append(as, i) # i 값을 a로 사용했기에 as 변인에 기록
 + } + }
 +> # 1에는 0.01이 100개 있고, -50 ~ 50 = 1이 101 개 있으니, 10100
 > length(ssrs) > length(ssrs)
 [1] 10001 [1] 10001
Line 182: Line 200:
 [1] 10001 [1] 10001
  
-> min(ssrs) +> min(ssrs) # sum of square error 값 중 최소값 
-[1] 1553336 +[1] 1747011 
-> min.pos.ssrs <- which(ssrs == min(ssrs))+> min.pos.ssrs <- which(ssrs == min(ssrs)) # 그 값이 몇 번째에 있는지 구함
 > min.pos.ssrs > min.pos.ssrs
-[1] 5828 +[1] 4896 
-> print(as[min.pos.ssrs]) +> print(as[min.pos.ssrs]) # 그 몇번째에 해당하는 a값을 구함 
-[1] 8.27+[1] -1.05 
 +> # 이 a값이 최소 ssr값을 갖도록 하는 a (소수점 2자리에서 구함)
 > summary(mo) > summary(mo)
  
Line 196: Line 215:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
-> plot(seq(1, length(ssrs)), ssrs) +>  
-plot(seq(1length(ssrs))srs)+> k <- min(ssrs) 
 +> j <- as[min.pos.ssrs] 
 +> plot(seq(1, length(as)), ssrs, type="l"
 +text(4500,2000000,paste("ssrs = "k, "\n" , "is minmum value when a = ", j)
 +>
 > tail(ssrs) > tail(ssrs)
-[1] 1900842 1901008 1901175 1901342 1901509 1901676+[1] 2267151 2267355 2267559 2267763 2267967 2268171
 > max(ssrs) > max(ssrs)
-[1] 2232329+[1] 2268171
 > min(ssrs) > min(ssrs)
-[1] 1553336+[1] 1747011
 > tail(srs) > tail(srs)
-[1] -8336.735 -8338.735 -8340.735 -8342.735 -8344.735 -8346.735+[1] -10199.41 -10201.41 -10203.41 -10205.41 -10207.41 -10209.41
 > max(srs) > max(srs)
-[1] 11653.26+[1] 9790.59
 > min(srs) > min(srs)
-[1] -8346.735 +[1] -10209.41
-+
  
 </code> </code>
 </tabbox> </tabbox>
-{{:pasted:20250821-120357.png}} +{{pasted:20260314-104647.png}}
-{{:pasted:20250821-120416.png}} +
-{{:pasted:20250821-120455.png}}+
  
 위 방법은 dumb . . . . .  위 방법은 dumb . . . . . 
Line 293: Line 313:
 > # we use sum of square of error which oftentimes become big > # we use sum of square of error which oftentimes become big
 > msrloss <- function(predictions, y) { > msrloss <- function(predictions, y) {
-  residuals <- (y - predictions) +    residuals <- (y - predictions) 
-  return(mean(residuals^2))+    return(mean(residuals^2))
 + } + }
  
Line 302: Line 322:
  
 > for (i in seq(from = -50, to = 50, by = 0.01)) { > for (i in seq(from = -50, to = 50, by = 0.01)) {
-  pred <- predict(x, i, b) +    pred <- predict(x, i, b) 
-  res <- residuals(pred, y) +    res <- residuals(pred, y) 
-  msr <- msrloss(pred, y) +    msr <- msrloss(pred, y) 
-  msrs <- append(msrs, msr) +    msrs <- append(msrs, msr) 
-  srs <- append(srs, mean(res)) +    srs <- append(srs, mean(res)) 
-  as <- append(as, i)+    as <- append(as, i)
 + } + }
 > length(msrs) > length(msrs)
Line 317: Line 337:
  
 > min(msrs) > min(msrs)
-[1] 7766.679+[1] 8735.055
 > min.pos.msrs <- which(msrs == min(msrs)) > min.pos.msrs <- which(msrs == min(msrs))
 > min.pos.msrs > min.pos.msrs
-[1] 5828+[1] 4896
 > print(as[min.pos.msrs]) > print(as[min.pos.msrs])
-[1] 8.27+[1] -1.05
 > summary(mo) > summary(mo)
  
Line 330: Line 350:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
 > plot(seq(1, length(msrs)), msrs) > plot(seq(1, length(msrs)), msrs)
 > plot(seq(1, length(srs)), srs) > plot(seq(1, length(srs)), srs)
 > tail(msrs) > tail(msrs)
-[1] 9504.208 9505.041 9505.875 9506.710 9507.544 9508.379+[1] 11335.75 11336.77 11337.79 11338.81 11339.84 11340.86
 > max(msrs) > max(msrs)
-[1] 11161.64+[1] 11340.86
 > min(msrs) > min(msrs)
-[1] 7766.679+[1] 8735.055
 > tail(srs) > tail(srs)
-[1] -41.68368 -41.69368 -41.70368 -41.71368 -41.72368 -41.73368+[1] -50.99705 -51.00705 -51.01705 -51.02705 -51.03705 -51.04705
 > max(srs) > max(srs)
-[1] 58.26632+[1] 48.95295
 > min(srs) > min(srs)
-[1] -41.73368 +[1] -51.04705 
->+ 
 </code> </code>
-{{:pasted:20250821-121009.png}} +{{pasted:20260314-112129.png}}
-{{:pasted:20250821-121024.png}}+
  
 ===== b값 구하기 ===== ===== b값 구하기 =====
Line 747: Line 767:
 a a
 b b
- 
 </code> </code>
 <tabbox ro01> <tabbox ro01>
-</tablox>+ 
 +<code> 
 +>  
 +>  
 +> # the above no gradient 
 +> # mse 값으로 계산 rather than sse  
 +> # 후자는 값이 너무 커짐 
 +>  
 +> a <- rnorm(1) 
 +> b <- rnorm(1) 
 +> a.start <- a 
 +> b.start <- b 
 +>  
 +> gradient <- function(x, y, predictions){ 
 ++   error = y - predictions 
 ++   db = -2 * mean(x * error) 
 ++   da = -2 * mean(error) 
 ++   return(list("b" = db, "a" = da)) 
 ++ } 
 +>  
 +> mseloss <- function(predictions, y) { 
 ++   residuals <- (y - predictions) 
 ++   return(mean(residuals^2)) 
 ++ } 
 +>  
 +> # Train the model with scaled features 
 +> learning.rate = 1e-1 
 +>  
 +> # Record Loss for each epoch: 
 +> as = c() 
 +> bs = c() 
 +> mses = c() 
 +> sses = c() 
 +> mres = c() 
 +> zx <- (x-mean(x))/sd(x) 
 +>  
 +> nlen <- 50 
 +> for (epoch in 1:nlen) { 
 ++   predictions <- predict(zx, a, b) 
 ++   residual <- residuals(predictions, y) 
 ++   loss <- mseloss(predictions, y) 
 ++   mres <- append(mres, mean(residual)) 
 ++   mses <- append(mses, loss) 
 ++    
 ++   grad <- gradient(zx, y, predictions) 
 ++    
 ++   step.b <- grad$b * learning.rate  
 ++   step.a <- grad$a * learning.rate 
 ++   b <- b-step.b 
 ++   a <- a-step.a 
 ++    
 ++   as <- append(as, a) 
 ++   bs <- append(bs, b) 
 ++ } 
 +> mses 
 + [1] 12376.887 10718.824  9657.086  8977.203  8541.840  8263.055  8084.535  7970.219 
 + [9]  7897.017  7850.141  7820.125  7800.903  7788.595  7780.713  7775.666  7772.434 
 +[17]  7770.364  7769.039  7768.190  7767.646  7767.298  7767.076  7766.933  7766.841 
 +[25]  7766.783  7766.745  7766.721  7766.706  7766.696  7766.690  7766.686  7766.683 
 +[33]  7766.682  7766.681  7766.680  7766.680  7766.679  7766.679  7766.679  7766.679 
 +[41]  7766.679  7766.679  7766.679  7766.679  7766.679  7766.679  7766.679  7766.679 
 +[49]  7766.679  7766.679 
 +> mres 
 + [1] 60.026423686 48.021138949 38.416911159 30.733528927 24.586823142 19.669458513 
 + [7] 15.735566811 12.588453449 10.070762759  8.056610207  6.445288166  5.156230533 
 +[13]  4.124984426  3.299987541  2.639990033  2.111992026  1.689593621  1.351674897 
 +[19]  1.081339917  0.865071934  0.692057547  0.553646038  0.442916830  0.354333464 
 +[25]  0.283466771  0.226773417  0.181418734  0.145134987  0.116107990  0.092886392 
 +[31]  0.074309113  0.059447291  0.047557833  0.038046266  0.030437013  0.024349610 
 +[37]  0.019479688  0.015583751  0.012467000  0.009973600  0.007978880  0.006383104 
 +[43]  0.005106483  0.004085187  0.003268149  0.002614519  0.002091616  0.001673292 
 +[49]  0.001338634  0.001070907 
 +> as 
 + [1] 13.36987 22.97409 30.65748 36.80418 41.72155 45.65544 48.80255 51.32024 
 + [9] 53.33440 54.94572 56.23478 57.26602 58.09102 58.75102 59.27901 59.70141 
 +[17] 60.03933 60.30967 60.52593 60.69895 60.83736 60.94809 61.03667 61.10754 
 +[25] 61.16423 61.20959 61.24587 61.27490 61.29812 61.31670 61.33156 61.34345 
 +[33] 61.35296 61.36057 61.36666 61.37153 61.37542 61.37854 61.38103 61.38303 
 +[41] 61.38462 61.38590 61.38692 61.38774 61.38839 61.38891 61.38933 61.38967 
 +[49] 61.38993 61.39015 
 +> bs 
 + [1]  5.201201 10.272237 14.334137 17.587719 20.193838 22.281340 23.953428 25.292771 
 + [9] 26.365585 27.224909 27.913227 28.464570 28.906196 29.259938 29.543285 29.770247 
 +[17] 29.952043 30.097661 30.214302 30.307731 30.382568 30.442512 30.490527 30.528987 
 +[25] 30.559794 30.584470 30.604236 30.620068 30.632750 30.642908 30.651044 30.657562 
 +[33] 30.662782 30.666964 30.670313 30.672996 30.675145 30.676866 30.678245 30.679349 
 +[41] 30.680234 30.680943 30.681510 30.681965 30.682329 30.682621 30.682854 30.683041 
 +[49] 30.683191 30.683311 
 +>  
 +> # scaled 
 +> a 
 +[1] 61.39015 
 +> b 
 +[1] 30.68331 
 +>  
 +> # unscale coefficients to make them comprehensible 
 +> # see http://commres.net/wiki/gradient_descent#why_normalize_scale_or_make_z-score_xi 
 +> # and  
 +> # http://commres.net/wiki/gradient_descent#how_to_unnormalize_unscale_a_and_b 
 +> #   
 +> a =  a - (mean(x) / sd(x)) * b 
 +> b =  b / sd(x) 
 +> a 
 +[1] 8.266303 
 +> b 
 +[1] 11.88797 
 +>  
 +> # changes of estimators 
 +> as <- as - (mean(x) /sd(x)) * bs 
 +> bs <- bs / sd(x) 
 +>  
 +> as 
 + [1] 4.364717 5.189158 5.839931 6.353516 6.758752 7.078428 7.330555 7.529361 
 + [9] 7.686087 7.809611 7.906942 7.983615 8.043999 8.091541 8.128963 8.158410 
 +[17] 8.181574 8.199791 8.214112 8.225367 8.234209 8.241154 8.246605 8.250884 
 +[25] 8.254239 8.256871 8.258933 8.260549 8.261814 8.262804 8.263579 8.264184 
 +[33] 8.264658 8.265027 8.265315 8.265540 8.265716 8.265852 8.265958 8.266041 
 +[41] 8.266105 8.266155 8.266193 8.266223 8.266246 8.266264 8.266278 8.266289 
 +[49] 8.266297 8.266303 
 +> bs 
 + [1]  2.015158  3.979885  5.553632  6.814203  7.823920  8.632704  9.280539  9.799455 
 + [9] 10.215107 10.548045 10.814727 11.028340 11.199444 11.336498 11.446279 11.534213 
 +[17] 11.604648 11.661067 11.706258 11.742456 11.771451 11.794676 11.813279 11.828180 
 +[25] 11.840116 11.849676 11.857334 11.863469 11.868382 11.872317 11.875470 11.877995 
 +[33] 11.880018 11.881638 11.882935 11.883975 11.884807 11.885474 11.886009 11.886437 
 +[41] 11.886779 11.887054 11.887274 11.887450 11.887591 11.887704 11.887794 11.887867 
 +[49] 11.887925 11.887972 
 +> mres 
 + [1] 60.026423686 48.021138949 38.416911159 30.733528927 24.586823142 19.669458513 
 + [7] 15.735566811 12.588453449 10.070762759  8.056610207  6.445288166  5.156230533 
 +[13]  4.124984426  3.299987541  2.639990033  2.111992026  1.689593621  1.351674897 
 +[19]  1.081339917  0.865071934  0.692057547  0.553646038  0.442916830  0.354333464 
 +[25]  0.283466771  0.226773417  0.181418734  0.145134987  0.116107990  0.092886392 
 +[31]  0.074309113  0.059447291  0.047557833  0.038046266  0.030437013  0.024349610 
 +[37]  0.019479688  0.015583751  0.012467000  0.009973600  0.007978880  0.006383104 
 +[43]  0.005106483  0.004085187  0.003268149  0.002614519  0.002091616  0.001673292 
 +[49]  0.001338634  0.001070907 
 +> mse.x <- mses 
 +>  
 +> parameters <- data.frame(as, bs, mres, mses) 
 +>  
 +> cat(paste0("Intercept: ", a, "\n", "Slope: ", b, "\n")) 
 +Intercept: 8.26630323816515 
 +Slope: 11.8879715830899 
 +> summary(lm(y~x))$coefficients 
 +             Estimate Std. Error   t value     Pr(>|t|) 
 +(Intercept)  8.266323  12.545898 0.6588865 5.107342e-01 
 +x           11.888159   2.432647 4.8869234 2.110569e-06 
 +>  
 +> mses <- data.frame(mses) 
 +> mses.log <- data.table(epoch = 1:nlen, mses) 
 +> ggplot(mses.log, aes(epoch, mses)) + 
 ++   geom_line(color="blue") + 
 ++   theme_classic() 
 +>  
 +> # mres <- data.frame(mres) 
 +> mres.log <- data.table(epoch = 1:nlen, mres) 
 +> ggplot(mres.log, aes(epoch, mres)) + 
 ++   geom_line(color="red") + 
 ++   theme_classic() 
 +>  
 +> ch <- data.frame(mres, mses) 
 +> ch 
 +           mres      mses 
 +1  60.026423686 12376.887 
 +2  48.021138949 10718.824 
 +3  38.416911159  9657.086 
 +4  30.733528927  8977.203 
 +5  24.586823142  8541.840 
 +6  19.669458513  8263.055 
 +7  15.735566811  8084.535 
 +8  12.588453449  7970.219 
 +9  10.070762759  7897.017 
 +10  8.056610207  7850.141 
 +11  6.445288166  7820.125 
 +12  5.156230533  7800.903 
 +13  4.124984426  7788.595 
 +14  3.299987541  7780.713 
 +15  2.639990033  7775.666 
 +16  2.111992026  7772.434 
 +17  1.689593621  7770.364 
 +18  1.351674897  7769.039 
 +19  1.081339917  7768.190 
 +20  0.865071934  7767.646 
 +21  0.692057547  7767.298 
 +22  0.553646038  7767.076 
 +23  0.442916830  7766.933 
 +24  0.354333464  7766.841 
 +25  0.283466771  7766.783 
 +26  0.226773417  7766.745 
 +27  0.181418734  7766.721 
 +28  0.145134987  7766.706 
 +29  0.116107990  7766.696 
 +30  0.092886392  7766.690 
 +31  0.074309113  7766.686 
 +32  0.059447291  7766.683 
 +33  0.047557833  7766.682 
 +34  0.038046266  7766.681 
 +35  0.030437013  7766.680 
 +36  0.024349610  7766.680 
 +37  0.019479688  7766.679 
 +38  0.015583751  7766.679 
 +39  0.012467000  7766.679 
 +40  0.009973600  7766.679 
 +41  0.007978880  7766.679 
 +42  0.006383104  7766.679 
 +43  0.005106483  7766.679 
 +44  0.004085187  7766.679 
 +45  0.003268149  7766.679 
 +46  0.002614519  7766.679 
 +47  0.002091616  7766.679 
 +48  0.001673292  7766.679 
 +49  0.001338634  7766.679 
 +50  0.001070907  7766.679 
 +> max(y) 
 +[1] 383.1671 
 +> ggplot(data, aes(x = x, y = y)) +  
 ++   geom_point(size = 2) +  
 ++   geom_abline(aes(intercept = as, slope = bs), 
 ++               data = parameters, linewidth = 0.5,  
 ++               color = 'green') +  
 ++   stat_poly_line() + 
 ++   stat_poly_eq(use_label(c("eq", "R2"))) + 
 ++   theme_classic() + 
 ++   geom_abline(aes(intercept = as, slope = bs),  
 ++               data = parameters %>% slice_head(),  
 ++               linewidth = 1, color = 'blue') +  
 ++   geom_abline(aes(intercept = as, slope = bs),  
 ++               data = parameters %>% slice_tail(),  
 ++               linewidth = 1, color = 'red') + 
 ++   labs(title = 'Gradient descent. blue: start, red: end, green: gradients'
 +> summary(lm(y~x)) 
 + 
 +Call: 
 +lm(formula = y ~ x) 
 + 
 +Residuals: 
 +     Min       1Q   Median       3Q      Max  
 +-259.314  -59.215    6.683   58.834  309.833  
 + 
 +Coefficients: 
 +            Estimate Std. Error t value Pr(>|t|)     
 +(Intercept)    8.266     12.546   0.659    0.511     
 +x             11.888      2.433   4.887 2.11e-06 *** 
 +--- 
 +Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 
 + 
 +Residual standard error: 88.57 on 198 degrees of freedom 
 +Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  
 +F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06 
 + 
 +> a.start 
 +[1] 1.364582 
 +> b.start 
 +[1] -1.12968 
 +> a 
 +[1] 8.266303 
 +> b 
 +[1] 11.88797 
 +>  
 +</code> 
 +{{:pasted:20250821-121910.png}} 
 +{{:pasted:20250821-121924.png}} 
 +{{:pasted:20250821-121943.png}} 
 + 
 +</tabbox>
  
  
Line 772: Line 1056:
 b & = & \frac{m}{\sigma} \\ b & = & \frac{m}{\sigma} \\
 \end{eqnarray*} \end{eqnarray*}
- 
- 
- 
- 
-<tabbed> 
-  * :gradient descent:code01 
-  * :gradient descent:code02 
-  * *:gradient descent:output01 
-  * :gradient descent:output02 
-</tabbed> 
  
  
gradient_descent.1773469954.txt.gz · Last modified: by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki