1.Machine( ) 정의
import numpy as np
def Machine(x, w, b):
y_hat = (w * x) + b
return y_hat
x = np.array([1, 3, 5, 7, 9])
w = 2
b = 1
#x, w, b 객체 지정
Machine(x, w, b)
array([ 3, 7, 11, 15, 19])
2.Gradient( ) 정의
def Gradient(x, y, w, b):
y_hat = Machine(x, w, b)
dw = np.mean((y - y_hat) * (-2 * x))
db = np.mean((y - y_hat) * (-2))
return dw, db
y = np.array([2, 4, 6, 8, 10])
dw, db = Gradient(x, y, w, b)
print('dw is ', dw)
print('db is ', db)
dw is 66.0 db is 10.0
3.Learning( ) 정의
def Learning(x, y, w, b, step):
dw, db = Gradient(x, y, w, b)
uw = w - step * dw
ub = b - step * db
return uw, ub
step = 0.05
uw, ub = Learning(x, y, w, b, step)
print('Updated_w is ', '%.3f' % uw)
print('Updated_b is ', '%.3f' % ub)
Updated_w is -1.300
Updated_b is 0.500
4.testData.csv에 적용
import pandas as pd
import matplotlib.pyplot as plt
DATA = pd.read_csv
DATA.info()
DATA.head()
plt.scatter(DATA.inputs, DATA.outputs, s = 0.5)
plt.show()
w = 2
b = 3
step = 0.05
for i in range(0, 1500): #1500번 학습
uw, ub = Learning(DATA.inputs, DATA.outputs, w, b, step)
w = uw
b = ub
print('Learned_w is ', '%.3f' % w)
print('Learned_b is ', '%.3f' % b)
Learned_w is 0.505
Learned_b is -0.170
# 학습결과 회귀선 그리기
X = np.linspace(0, 1, 100)
Y = (w * X) + b
plt.scatter(DATA.inputs, DATA.outputs, s = 0.3)
plt.plot(X, Y, '-r', linewidth = 1.5)
plt.show()
4. Loss Visualization
#Gradient( )에 Loss 추가
def Gradient(x, y, w, b):
y_hat = Machine(x, w, b)
dw = np.mean((y - y_hat) * (-2 * x))
db = np.mean((y - y_hat) * (-2))
Loss = np.mean((y - y_hat)**2)
return dw, db, Loss
#Learning( )에 Loss 추가
def Learning(x, y, w, b, step):
dw, db, Loss = Gradient(x, y, w, b)
uw = w - step * dw
ub = b - step * db
Loss = Loss
return uw, ub, Loss
w = 2
b = 3
step = 0.001
Error = []
for i in range(0, 1500): #1500번 학습
uw, ub, Loss = Learning(DATA.inputs, DATA.outputs, w, b, step)
w = uw
b = ub
Error.append(Loss)
Error[0:10] #Loss 감소 확인
plt.plot(Error)
plt.show()
plt.plot(Error[0:50], '.')
plt.show()
plt.plot(Error[1450:1500], '.')
plt.show()
회귀분석(Regression Analysis) 4 (0) | 2022.06.07 |
---|---|
회귀분석(Regression Analysis) 3 (1) | 2022.06.07 |
회귀분석(Regression Analysis) 2 (0) | 2022.06.07 |
회귀분석(Regression Analysis) 1 (1) | 2022.06.07 |
Model Validation (0) | 2022.06.06 |