본문 바로가기
Python

머신러닝 - 선형 회귀(Linear regression) 알고리즘 특징과 코드

by 올엠 2024. 6. 8.
반응형

선형 회귀와 다항 회귀는 머신러닝에서 K-최근접 이웃점 회귀과 함께 기본적으로 배우게 되는 머신러닝 모델이자, 성능이 좋은 모델이라고 할 수 있다. 특히  K-최근접 이웃점 회귀과 다르게 미래를 예측할 수 기능을 가지고 있다.

 K-최근접 이웃점 회귀 는 미래의 데이터를 예측하는데에는 사용하기 어렵다. 이유는 학습 데이터의 평균으로 예측을 하기 때문에, 만약 측정되지 않은 값이 들어온다면, 예측이 되지 힘들어진다.

이 부분에 선형 회귀를 이용해서 해결할 수 있는데, 이유는 학습 데이터를 통해 유의미한 연결 선을 생성하여 해당 선을 통해 측정되지 않은 값도 예측이 가능하게 된다.

아래 그림과 같이 파란 점의 학습 데이터가 있다고 하자. 여기에 가장 중심이 되는 선을 하나 긋고, 이를 통해서 학습하지 않은 영역에도 예측을 하게 된다.

코드를 통해서 살펴보자.

아래 코드는 가상의 아파트 가격을 통해 미래 아파트 가격을 예측하는 선형 회귀 모델이라고 할 수 있다.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# 데이터 생성
years = [2000, 2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009]
apt_prices = [320000, 330000, 340000, 350000, 360000, 370000, 380000, 390000, 400000, 410000]  # 아파트 크기를 150으로 가정

선형 회귀에서 사용할 수 있도록 데이터를 변환해주자. 그리고 fit() 함수를 이용해서 학습을 진행한다.

# 데이터 변환
years_reshaped = [[year] for year in years]
apt_prices = np.array(apt_prices)

# 선형 회귀 모델 학습
model = LinearRegression()
model.fit(years_reshaped, apt_prices)

학습된 내용을 통해서 2023년 가격을 예측해보도록 하자.

# 모델 예측
years_predict = np.arange(2022,2023).reshape(-1, 1)
predicted_prices = model.predict(years_predict)
print(predicted_prices)

# 평가
y_pred = model.predict(years_reshaped)
mse = mean_squared_error(apt_prices, y_pred)
print('Mean Squared Error:', mse)

예측된 가격은 540,000 으로 확인된다.

추가로 mean_squared_error은 귀 문제에서 모델의 예측값과 실제 타겟 값 간의 평균 제곱 오차(Mean Squared Error, MSE)를 계산하는 지표 상용할 수 있어서 평가에 유용하기 때문에 활용하면 좋다.

  • MSE=n1​∑i=1n​(yi​−y^​i​)2

n은 데이터 샘플의 개수

yi​는 실제 타겟 값

y^​i​는 모델의 예측 값

 

MSE 값이 0.0이므로 모델의 성능이 좋다고(학습한 결과 대로 예측이 잘 되었다) 할 수 있다.

마지막으로 그래프로 보면, 보다 직관적으로 알 수 있다.

# 그래프 그리기
plt.figure(figsize=(10, 6))
plt.scatter(years, apt_prices, color='blue', label='Actual Prices')
plt.plot(years_predict, predicted_prices, color='red', label='Predicted Prices')
plt.xlabel('Year')
plt.ylabel('Price')
plt.title('Apartment Price vs. Year')
plt.legend()
plt.grid(True)
plt.show()

 



반응형