User Tools

Site Tools


gradient_descent

Differences

This shows you the differences between two versions of the page.

Link to this comparison view

Both sides previous revisionPrevious revision
Next revision
Previous revision
gradient_descent [2026/03/14 06:00] hkimscilgradient_descent [2026/03/14 11:21] (current) – [output] hkimscil
Line 17: Line 17:
 library(ggplot2) library(ggplot2)
 library(ggpmisc) library(ggpmisc)
 +
  
 rm(list=ls()) rm(list=ls())
-# set.seed(191)+ 
 +data 만들기 
 +set.seed(191)
 nx <- 200 nx <- 200
 mx <- 4.5 mx <- 4.5
Line 26: Line 29:
 slp <- 12 slp <- 12
 y <-  slp * x + rnorm(nx, 0, slp*sdx*3) y <-  slp * x + rnorm(nx, 0, slp*sdx*3)
- 
 data <- data.frame(x, y) data <- data.frame(x, y)
 +# data 변인 완성
  
 +# regression summary shows 
 +# a and b
 mo <- lm(y ~ x, data = data) mo <- lm(y ~ x, data = data)
 summary(mo) summary(mo)
Line 37: Line 42:
   stat_poly_eq(use_label(c("eq", "R2"))) +   stat_poly_eq(use_label(c("eq", "R2"))) +
   theme_classic()    theme_classic() 
-# set.seed(191) + 
-Initialize random betas +위에서 확인한 b값을 b로 고정하고 a만 변화시켜서 이해
-# 우선 b를 고정하고 a만  +
-변화시켜서 이해+
 b <- summary(mo)$coefficients[2] b <- summary(mo)$coefficients[2]
 a <- 0 a <- 0
Line 47: Line 50:
 a.init <- a a.init <- a
  
-# Predict function:+# Predict function: y hat 값
 predict <- function(x, a, b){ predict <- function(x, a, b){
   return (a + b * x)   return (a + b * x)
 } }
  
-# And loss function is:+# And loss function is: residual 혹은 error 값
 residuals <- function(predictions, y) { residuals <- function(predictions, y) {
   return(y - predictions)   return(y - predictions)
 } }
  
-# we use sum of square of error which oftentimes become big+# we use sum of square of error 
 ssrloss <- function(predictions, y) { ssrloss <- function(predictions, y) {
   residuals <- (y - predictions)   residuals <- (y - predictions)
Line 67: Line 70:
 as <- c() # for as (intercepts) as <- c() # for as (intercepts)
  
 +# x 값을 -50 에서 50을 범위로 0.01씩 증가시켜서 
 +# for 문에 대입, i로 사용
 for (i in seq(from = -50, to = 50, by = 0.01)) { for (i in seq(from = -50, to = 50, by = 0.01)) {
   pred <- predict(x, i, b)   pred <- predict(x, i, b)
Line 73: Line 78:
   ssrs <- append(ssrs, ssr)   ssrs <- append(ssrs, ssr)
   srs <- append(srs, sum(res))   srs <- append(srs, sum(res))
-  as <- append(as, i)+  as <- append(as, i) # i 값을 a로 사용했기에 as 변인에 기록
 } }
 +# 1에는 0.01이 100개 있고, -50 ~ 50 = 1이 101 개 있으니, 10100
 length(ssrs) length(ssrs)
 length(srs) length(srs)
 length(as) length(as)
  
-min(ssrs) +min(ssrs) # sum of square error 값 중 최소값 
-min.pos.ssrs <- which(ssrs == min(ssrs))+min.pos.ssrs <- which(ssrs == min(ssrs)) # 그 값이 몇 번째에 있는지 구함
 min.pos.ssrs min.pos.ssrs
-print(as[min.pos.ssrs])+print(as[min.pos.ssrs]) # 그 몇번째에 해당하는 a값을 구함 
 +# 이 a값이 최소 ssr값을 갖도록 하는 a (소수점 2자리에서 구함)
 summary(mo) summary(mo)
-plot(seq(1, length(ssrs)), ssrs) + 
-plot(seq(1length(ssrs))srs)+k <- min(ssrs) 
 +j <- as[min.pos.ssrs] 
 +plot(seq(1, length(as)), ssrs, type="l"
 +text(4500,2000000,paste("ssrs = "k, "\n" , "is minimum value when a = ", j)
 tail(ssrs) tail(ssrs)
 max(ssrs) max(ssrs)
Line 96: Line 107:
 <tabbox ro01> <tabbox ro01>
 <code> <code>
 +> # library(tidyverse)
 +> # library(data.table)
 > library(ggplot2) > library(ggplot2)
 > library(ggpmisc) > library(ggpmisc)
  
 > rm(list=ls()) > rm(list=ls())
-> # set.seed(191)+>  
 +> # data 만들기 
 +set.seed(191)
 > nx <- 200 > nx <- 200
 > mx <- 4.5 > mx <- 4.5
Line 107: Line 122:
 > slp <- 12 > slp <- 12
 > y <-  slp * x + rnorm(nx, 0, slp*sdx*3) > y <-  slp * x + rnorm(nx, 0, slp*sdx*3)
- 
 > data <- data.frame(x, y) > data <- data.frame(x, y)
 +> # data 변인 완성
  
 +> # regression summary shows 
 +> # a and b
 > mo <- lm(y ~ x, data = data) > mo <- lm(y ~ x, data = data)
 > summary(mo) > summary(mo)
Line 118: Line 135:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
  
 > ggplot(data = data, aes(x = x, y = y)) +  > ggplot(data = data, aes(x = x, y = y)) + 
-  geom_point() + +    geom_point() + 
-  stat_poly_line() + +    stat_poly_line() + 
-  stat_poly_eq(use_label(c("eq", "R2"))) + +    stat_poly_eq(use_label(c("eq", "R2"))) + 
-  theme_classic()  +    theme_classic()  
-# set.seed(191) +>  
-> # Initialize random betas +> # 위에서 확인한 b값을 b로 고정하고 a만 변화시켜서 이해
-> # 우선 b를 고정하고 a만  +
-> # 변화시켜서 이해+
 > b <- summary(mo)$coefficients[2] > b <- summary(mo)$coefficients[2]
 > a <- 0 > a <- 0
Line 147: Line 162:
 > a.init <- a > a.init <- a
  
-> # Predict function:+> # Predict function: y hat 값
 > predict <- function(x, a, b){ > predict <- function(x, a, b){
-  return (a + b * x)+    return (a + b * x)
 + } + }
  
-> # And loss function is:+> # And loss function is: residual 혹은 error 값
 > residuals <- function(predictions, y) { > residuals <- function(predictions, y) {
-  return(y - predictions)+    return(y - predictions)
 + } + }
  
-> # we use sum of square of error which oftentimes become big+> # we use sum of square of error 
 > ssrloss <- function(predictions, y) { > ssrloss <- function(predictions, y) {
-  residuals <- (y - predictions) +    residuals <- (y - predictions) 
-  return(sum(residuals^2))+    return(sum(residuals^2))
 + } + }
  
Line 167: Line 182:
 > as <- c() # for as (intercepts) > as <- c() # for as (intercepts)
  
 +> # x 값을 -50 에서 50을 범위로 0.01씩 증가시켜서 
 +> # for 문에 대입, i로 사용
 > for (i in seq(from = -50, to = 50, by = 0.01)) { > for (i in seq(from = -50, to = 50, by = 0.01)) {
-  pred <- predict(x, i, b) +    pred <- predict(x, i, b) 
-  res <- residuals(pred, y) +    res <- residuals(pred, y) 
-  ssr <- ssrloss(pred, y) +    ssr <- ssrloss(pred, y) 
-  ssrs <- append(ssrs, ssr) +    ssrs <- append(ssrs, ssr) 
-  srs <- append(srs, sum(res)) +    srs <- append(srs, sum(res)) 
-  as <- append(as, i)+    as <- append(as, i) # i 값을 a로 사용했기에 as 변인에 기록
 + } + }
 +> # 1에는 0.01이 100개 있고, -50 ~ 50 = 1이 101 개 있으니, 10100
 > length(ssrs) > length(ssrs)
 [1] 10001 [1] 10001
Line 182: Line 200:
 [1] 10001 [1] 10001
  
-> min(ssrs) +> min(ssrs) # sum of square error 값 중 최소값 
-[1] 1553336 +[1] 1747011 
-> min.pos.ssrs <- which(ssrs == min(ssrs))+> min.pos.ssrs <- which(ssrs == min(ssrs)) # 그 값이 몇 번째에 있는지 구함
 > min.pos.ssrs > min.pos.ssrs
-[1] 5828 +[1] 4896 
-> print(as[min.pos.ssrs]) +> print(as[min.pos.ssrs]) # 그 몇번째에 해당하는 a값을 구함 
-[1] 8.27+[1] -1.05 
 +> # 이 a값이 최소 ssr값을 갖도록 하는 a (소수점 2자리에서 구함)
 > summary(mo) > summary(mo)
  
Line 196: Line 215:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
-> plot(seq(1, length(ssrs)), ssrs) +>  
-plot(seq(1length(ssrs))srs)+> k <- min(ssrs) 
 +> j <- as[min.pos.ssrs] 
 +> plot(seq(1, length(as)), ssrs, type="l"
 +text(4500,2000000,paste("ssrs = "k, "\n" , "is minmum value when a = ", j)
 +>
 > tail(ssrs) > tail(ssrs)
-[1] 1900842 1901008 1901175 1901342 1901509 1901676+[1] 2267151 2267355 2267559 2267763 2267967 2268171
 > max(ssrs) > max(ssrs)
-[1] 2232329+[1] 2268171
 > min(ssrs) > min(ssrs)
-[1] 1553336+[1] 1747011
 > tail(srs) > tail(srs)
-[1] -8336.735 -8338.735 -8340.735 -8342.735 -8344.735 -8346.735+[1] -10199.41 -10201.41 -10203.41 -10205.41 -10207.41 -10209.41
 > max(srs) > max(srs)
-[1] 11653.26+[1] 9790.59
 > min(srs) > min(srs)
-[1] -8346.735 +[1] -10209.41
-+
  
 </code> </code>
 </tabbox> </tabbox>
-{{:pasted:20250821-120357.png}} +{{pasted:20260314-104647.png}}
-{{:pasted:20250821-120416.png}} +
-{{:pasted:20250821-120455.png}}+
  
 위 방법은 dumb . . . . .  위 방법은 dumb . . . . . 
Line 293: Line 313:
 > # we use sum of square of error which oftentimes become big > # we use sum of square of error which oftentimes become big
 > msrloss <- function(predictions, y) { > msrloss <- function(predictions, y) {
-  residuals <- (y - predictions) +    residuals <- (y - predictions) 
-  return(mean(residuals^2))+    return(mean(residuals^2))
 + } + }
  
Line 302: Line 322:
  
 > for (i in seq(from = -50, to = 50, by = 0.01)) { > for (i in seq(from = -50, to = 50, by = 0.01)) {
-  pred <- predict(x, i, b) +    pred <- predict(x, i, b) 
-  res <- residuals(pred, y) +    res <- residuals(pred, y) 
-  msr <- msrloss(pred, y) +    msr <- msrloss(pred, y) 
-  msrs <- append(msrs, msr) +    msrs <- append(msrs, msr) 
-  srs <- append(srs, mean(res)) +    srs <- append(srs, mean(res)) 
-  as <- append(as, i)+    as <- append(as, i)
 + } + }
 > length(msrs) > length(msrs)
Line 317: Line 337:
  
 > min(msrs) > min(msrs)
-[1] 7766.679+[1] 8735.055
 > min.pos.msrs <- which(msrs == min(msrs)) > min.pos.msrs <- which(msrs == min(msrs))
 > min.pos.msrs > min.pos.msrs
-[1] 5828+[1] 4896
 > print(as[min.pos.msrs]) > print(as[min.pos.msrs])
-[1] 8.27+[1] -1.05
 > summary(mo) > summary(mo)
  
Line 330: Line 350:
 Residuals: Residuals:
      Min       1Q   Median       3Q      Max       Min       1Q   Median       3Q      Max 
--259.314  -59.215    6.683   58.834  309.833 +-245.291  -67.967   -3.722   63.440  242.174 
  
 Coefficients: Coefficients:
             Estimate Std. Error t value Pr(>|t|)                 Estimate Std. Error t value Pr(>|t|)    
-(Intercept)    8.266     12.546   0.659    0.511     +(Intercept)   -1.047     14.289  -0.073    0.942     
-x             11.888      2.433   4.887 2.11e-06 ***+x             12.900      2.870   4.495 1.19e-05 ***
 --- ---
 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
  
-Residual standard error: 88.57 on 198 degrees of freedom +Residual standard error: 93.93 on 198 degrees of freedom 
-Multiple R-squared:  0.1076, Adjusted R-squared:  0.1031  +Multiple R-squared:  0.09258, Adjusted R-squared:  0.088  
-F-statistic: 23.88 on 1 and 198 DF,  p-value: 2.111e-06+F-statistic:  20.on 1 and 198 DF,  p-value: 1.185e-05
  
 > plot(seq(1, length(msrs)), msrs) > plot(seq(1, length(msrs)), msrs)
 > plot(seq(1, length(srs)), srs) > plot(seq(1, length(srs)), srs)
 > tail(msrs) > tail(msrs)
-[1] 9504.208 9505.041 9505.875 9506.710 9507.544 9508.379+[1] 11335.75 11336.77 11337.79 11338.81 11339.84 11340.86
 > max(msrs) > max(msrs)
-[1] 11161.64+[1] 11340.86
 > min(msrs) > min(msrs)
-[1] 7766.679+[1] 8735.055
 > tail(srs) > tail(srs)
-[1] -41.68368 -41.69368 -41.70368 -41.71368 -41.72368 -41.73368+[1] -50.99705 -51.00705 -51.01705 -51.02705 -51.03705 -51.04705
 > max(srs) > max(srs)
-[1] 58.26632+[1] 48.95295
 > min(srs) > min(srs)
-[1] -41.73368 +[1] -51.04705 
->+ 
 </code> </code>
-{{:pasted:20250821-121009.png}} +{{pasted:20260314-112129.png}}
-{{:pasted:20250821-121024.png}}+
  
 ===== b값 구하기 ===== ===== b값 구하기 =====
 이제는 a값을 고정하고 b값도 같은 방식으로 구해볼 수 있다 이제는 a값을 고정하고 b값도 같은 방식으로 구해볼 수 있다
 +<tabbox rs01>
 <code> <code>
 ############################################## ##############################################
Line 423: Line 444:
  
 </code> </code>
-===== output =====+ 
 +<tabbox ro01>
 <code> <code>
  
Line 510: Line 532:
 </code> </code>
 a와 b를 동시에 구할 수 있는 방법은 없을까? 위의 방법으로는 어렵다. 일반적으로 우리는 a와 b값이 무엇이되는가를 미분을 이용해서 구할 수 있었다. R에서 미분의 해를 구하기 보다는 해에 접근하도록 하는 프로그래밍을 써서 a와 b의 근사값을 구한다. 이것을 gradient descent라고 부른다. a와 b를 동시에 구할 수 있는 방법은 없을까? 위의 방법으로는 어렵다. 일반적으로 우리는 a와 b값이 무엇이되는가를 미분을 이용해서 구할 수 있었다. R에서 미분의 해를 구하기 보다는 해에 접근하도록 하는 프로그래밍을 써서 a와 b의 근사값을 구한다. 이것을 gradient descent라고 부른다.
 +
 +</tabbox>
 +
  
 ====== Gradient descend ====== ====== Gradient descend ======
Line 621: Line 646:
  
  
-====== R code ======+====== a 와 b 값을 gradient descent 방법을 이용하여 한꺼번에 구하기 ====== 
 +<tabbox rs01>
 <code> <code>
 # the above no gradient # the above no gradient
Line 741: Line 767:
 a a
 b b
- 
 </code> </code>
-====== R output =====+<tabbox ro01> 
 <code> <code>
  
Line 1007: Line 1033:
 {{:pasted:20250821-121924.png}} {{:pasted:20250821-121924.png}}
 {{:pasted:20250821-121943.png}} {{:pasted:20250821-121943.png}}
 +
 +</tabbox>
 +
  
 ====== Why normalize (scale or make z-score) xi ====== ====== Why normalize (scale or make z-score) xi ======
Line 1027: Line 1056:
 b & = & \frac{m}{\sigma} \\ b & = & \frac{m}{\sigma} \\
 \end{eqnarray*} \end{eqnarray*}
- 
- 
- 
- 
-<tabbed> 
-  * :gradient descent:code01 
-  * :gradient descent:code02 
-  * *:gradient descent:output01 
-  * :gradient descent:output02 
-</tabbed> 
  
  
gradient_descent.1773468028.txt.gz · Last modified: by hkimscil

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki