학교/인공지능

결정 트리 코드

서윤-정 2024. 4. 19. 11:19

 

* Iris 데이터를 이용하여 의사결정트리(Decision Tree)를 구현하고

새로운 데이터를 예측한 뒤, 정확도를 평가하는 코드이다.

 

 

from sklearn import datasets
from sklearn import metrics
from sklearn.tree import DecisionTreeClassifier


d = datasets.load_iris()  # 데이터를 읽은 후에,   # 1 데이터 준비
# print(d.DESCR)    # 내용 출력


#for i in range(0, len(d.data)):   # 샘플 순서대로 출력
#  print (i+1, d.data[i], d.target[i])


dtree_model = DecisionTreeClassifier(max_depth = 4, min_samples_split = 3)  # 2 모델 생성
dtree_model.fit(d.data, d.target)   # 학습 --> fit                          # 3 학습

new_d = [[6.4, 3.2, 6.0, 2.5], [7.1, 3.1, 4.7, 1.35]]
# 101, 51번째 데이터를 변형해서 새로운 데이터 생성
res1 = dtree_model.predict(new_d)   # L:14에서 만든 데이터를 이용해서 예측(predict)


print('새로운 2개 샘플 부류는: ', res1)

res2 = dtree_model.predict(d.data)    # 4 예측
# res2와 원래 정답 d.target과의 비교
print('정확도: ', metrics.accuracy_score(res2, d.target))

 

 

 

 

 

 

 

 

 

1. 데이터를 불러온다.

d = datasets.load_iris()  # Iris 데이터를 불러온다.

 

: datasets.load_iris() 를 사용하여 Iris 데이터셋을 불러온다.

이 데이터셋은 Iris 꽃의 꽃잎과 꽃받침의 길이와 너비를 포함한 특성 데이터와

세 종류의 Iris 꽃(붓꽃)을 나타내는 클래스 레이블로 구성되어 있다.

 

 

 

 

2. 의사결정트리 모델을 생성한다.

dtree_model = DecisionTreeClassifier(max_depth = 4, min_samples_split = 3)  # 2 의사결정트리 모델을 생성한다.

 

: DecisionTreeClassifier 를 사용하여 의사결정 모델을 생성한다. 

이때 max_depth 는 트리의 최대 깊이를, 

min_samples_split 은 노드를 분할하기 위한 최소 샘플 수를 지정한다.

 

 

 

 

3. 모델을 학습시킨다.

dtree_model.fit(d.data, d.target)   # Iris 데이터로 의사결정트리 모델을 학습시킨다.

 

: fit() 메서드를 사용하여 의사결정트리 모델을 Iris 데이터로 학습시킨다.

특성 데이터(d.data)와 클래스 레이블(d.target)을 사용하여 모델을 학습시킨다.

 

 

 

 

4. 새로운 데이터로 예측을 수행한다.

new_d = [[6.4, 3.2, 6.0, 2.5], [7.1, 3.1, 4.7, 1.35]]    # 새로운 데이터를 정의한다. 

res1 = dtree_model.predict(new_d)   # 학습된 의사결정트리 모델을 사용하여 새로운 데이터의 클래스를 예측한다.

print('새로운 2개 샘플 부류는: ', res1)    # 예측 결과를 출력한다.

 

: 새로운 데이터 new_d 를 정의하고,

predict() 메서드를 사용하여 의사결정트리 모델을 사용하여 새로운 데이터의 클래스를 예측한다.

 

 

 

 

5. 학습된 모델로 전체 데이터를 예측하고 정확도를 평가한다.

res2 = dtree_model.predict(d.data)    # 학습된 의사결정트리 모델로 전체 데이터의 클래스를 예측한다.

print('정확도: ', metrics.accuracy_score(res2, d.target))    # 예측 결과와 실제 타겟을 비교하여 정확도를 계산하고 출력한다. 

 

: predict() 메서드를 사용하여 의사결정트리 모델을 사용하여 전체 데이터의 클래스를 예측한다.

metrics.accuracy_score() 를 사용하여 예측 결과와 실제 타겟을 비교하여 정확도를 계산한다.

'학교 > 인공지능' 카테고리의 다른 글

4_인공지능개론  (0) 2024.06.06
3_인공지능개론  (0) 2024.04.21
2_인공지능개론  (1) 2024.04.20
SVM 코드  (0) 2024.04.19
레모네이드 실습 코드  (0) 2024.04.19