Machine Learning/Preprocessing

train_test_split(), fit(), predict()

진성01 2022. 5. 8. 19:03

 

사이킷런(scikit-learn): 파이썬 머신러닝 라이브러리 중 가장 많이 사용되는 라이브러리. 특히 데이터 분석에서 많이 사용된다.

 

머신러닝을 통한 예측 프로세스는 다음과 같다.

  1. 데이터 세트 분리
  2. 모델 학습
  3. 예측 수행
  4. 평가

이번 글에서는 다음 프로세스를 진행하기 위해 사이킷런(sklearn)에서 사용되는 메서드들에 대해 리뷰한다.

 

train_test_split() 메서드

학습을 진행하기 위해서는 학습에 이용할 데이터와 검증에 이용할 데이터를 분리해 주어야 한다. 사이킷런에서는 train_test_split()을 이용하여 작업을 간편하게 할 수 있다.

변수이름을 보면 X,y,train,test들로 구성되어 있다. 

 

-X는 사용할 변수들을 일컫는다. 붓꽃 품종을 예측하기 위한 모델을 만든다고 하면, 예측에 사용될 변수들인 잎의 길이, 꽃의 색깔 등이 여기에 해당된다.

-y는 예측할 값을 일컫는다. 붓꽃 품종 예측 모델을 예로 든다면, 붓꽃의 품종이 여기에 해당된다.

즉 예측에 이용할 변수들을 X, 예측할 변수를 y라고 한다(label, target이라고도 부른다).

 

-train은 학습에 이용할 데이터셋이다. 머신러닝 모델을 학습시키는데 이용한다. 수능 시험을 예시로 든다면 모의고사에 해당한다. 

-test는 검증에 이용할 데이터셋이다. train 데이터셋을 통해 모델을 학습시켰다면, test데이터셋을 이용하여 모델은 자신의 예측 정확도를 검증하게 된다. train이 모의고사를 통해서 학습하는 것이라면, test는 본 수능에 해당한다.

 

※train 데이터셋을 이용하여 파라미터를 업데이트(즉, 학습)하지만 test 데이터셋을 이용해서는 정확도를 검증만 할 뿐, 파라미터를 업데이트 하지는 않는다.

  • X_train: 학습에 사용될 변수들(DataFrame)
  • X_test: 검증에 사용될 변수들(DataFrame)
  • y_train: 학습에 사용될 label(Series)
  • y_test: 검증에 사용될 label(Series)

train_test_split()는 두 개의 파라미터를 필수적으로 입력받는다.

  1. 변수들로 구성된 DataFrame
  2. label로 구성된 Series

그리고 선택적으로 다음과 같은 파라미터를 입력받는다.

  • test_size: 전체 데이터 중 테스트 데이터로 샘플링할 크기(default: 0.25)
  • random_state: 분리하기 위하여 지정되는 난수값. 만약 같은 데이터 셋이고 random_state가 같다면 데이터가 완전히 똑같이 분리된다.

 

fit() 메서드

사이킷런에서는 ML모델 학습을 위해 fit()메서드를 제공하고, 학습된 모델의 예측을 위해 predict()메서드를 제공한다. fit()메서드는 사용할 분류 혹은 회귀 모델과, 학습에 이용할 데이터셋을 입력하면 학습을 수행한다.

X_train, X_test, y_train, y_test를 train_test_split()을 이용하여 할당해주고, Classifier로 DecisionTreeClassfier를 할당하였다(후술).

 

(Classifier).fit(X_train, y_train)

다음과 같은 형식으로 작성하면 해당 Classifier의 학습이 진행된다.

 

※후에 나올 정규화 에서도fit()이 사용된다. 모델 학습에서는 fit()이후 predict()를 사용하는 것이 일반적이나 전처리에서는 transform()을 사용하므로 주의가 필요하다.

 

predict()메서드

X_train을 이용하여 학습을 진행했으므로 예측하는 함수인 predict()를 사용할 경우 X_test를 이용해야 한다. 다음 라인을 진행하면 pred에는 X_test에 대응하는 y_test값을 예측하여 Series형태로 저장하게 된다. 이는 이후에 accuracy_score()와 같은 함수에서 정확도를 계산하는데 사용된다.

 

 

#이 글은 권철민 작가님의 [파이썬 머신러닝 완벽가이드]를 통해 공부한 내용을 바탕으로 정리하였습니다.