본문 바로가기

머신러닝

머신러닝 4일차

728x90
결정트리

데이터에 있는 규칙을 학습을 통해 자동으로 찾아내 트리기반 분규규칙 만듬

규칙 많으면 -> 분류 결정 방식 복잡 -> 과적합 으로 이어짐

깊이가 깊어질 수록 예측성능 저하됨

적은 결정 노드 가지려면 데이터 분류시 생성되는 결정노드 규칙 정해줘야함

균일한 데이터 세트 구성하도록 분할! 

균일도 측정 방법 : 엔트로피 이용한 정보이득 지수 : 1- 엔트로피지수 

                                정보이득이 높은 속성 기준으로 분할

 

                               지니 계수 : 지니계수 낮을 수록 데이터 균일도 높음 , 낮은 속성 기준으로 분할

 

결정트리 파라미터

min_samples_split  : 노드 분할 위한 최소한의 샘플 데이터 수로 과적합제어

min_samplts_leaf : 리프노드가 되기 위한 최소한 샘플데이터 수

max_features  : 최대 피처개수 디폴트 는 None

max_depth : 트리 최대 깊이 

max_leaf_nodes : 리프노드 최대 갯수 

 

 

결정트리 시각화를 위해  graphviz 설치

##파일생성
export_graphviz(dt_clf,
                out_file='tree.dot',
                class_names=iris.target_names,
                feature_names=iris.feature_names,
                filled=True)

출력파일 tree.dot 파일을 graphviz 가 읽어서 시각화 

with open('tree.dot') as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)
# Classifier의 Decision Boundary를 시각화 하는 함수
def visualize_boundary(model, X, y):
    fig,ax = plt.subplots()

    # 학습 데이타 scatter plot으로 나타내기
    ax.scatter(X[:, 0], X[:, 1], c=y, s=25, cmap='rainbow', edgecolor='k',
               clim=(y.min(), y.max()), zorder=3)
    ax.axis('tight')
    ax.axis('off')
    xlim_start , xlim_end = ax.get_xlim()
    ylim_start , ylim_end = ax.get_ylim()

    # 호출 파라미터로 들어온 training 데이타로 model 학습 . 
    model.fit(X, y)
    # meshgrid 형태인 모든 좌표값으로 예측 수행. 
    xx, yy = np.meshgrid(np.linspace(xlim_start,xlim_end, num=200),np.linspace(ylim_start,ylim_end, num=200))
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)

    # contourf() 를 이용하여 class boundary 를 visualization 수행. 
    n_classes = len(np.unique(y))
    contours = ax.contourf(xx, yy, Z, alpha=0.3,
                           levels=np.arange(n_classes + 1) - 0.5,
                           cmap='rainbow', clim=(y.min(), y.max()),
                           zorder=1)
#피처 중요도 가져오기 
dt_clf.feature_importances_
#피처별로 중요도값 매핑
for name, value in zip(iris.feature_names,dt_clf.feature_importances_):
    print(name,value)
##피처 중요도 컬럼별로 시각화 
sns.barplot(x=dt_clf.feature_importances_,y=iris.feature_names)
728x90

'머신러닝' 카테고리의 다른 글

머신러닝 6일차  (0) 2023.05.15
머신러닝 5일차  (1) 2023.05.15
머신러닝 3일차  (0) 2023.05.10
머신러닝 2일차  (0) 2023.05.09
ML 1일차  (2) 2023.05.08