데이터 분류 및 클러스터링 시각화

Author

Gabriel Yang


데이터 과학에서 데이터 분류와 클러스터링은 중요한 분석 기법입니다. 이러한 기법들은 데이터를 이해하고 패턴을 발견하는 데 도움을 줍니다. Plotly는 대화형 시각화를 통해 데이터 분류와 클러스터링의 결과를 시각적으로 명확히 전달할 수 있는 강력한 도구입니다. 이 글에서는 Plotly를 사용하여 데이터 분류와 클러스터링을 시각화하는 방법을 단계별로 설명하겠습니다.

데이터 분류는 주어진 데이터가 어떤 클래스에 속하는지를 예측하는 작업입니다. 분류 결과를 시각화하면 각 클래스의 데이터 포인트를 쉽게 식별할 수 있습니다. Plotly를 사용하여 데이터 분류를 시각화하는 방법을 살펴보겠습니다.

1. 데이터 준비

먼저, 예제 데이터를 준비합니다. Iris 데이터셋을 사용하여 품종에 따른 분류를 시각화하겠습니다.

import plotly.express as px
import pandas as pd
from sklearn.datasets import load_iris
import plotly.graph_objs as go

fig = go.Figure()

# Iris 데이터셋 로드
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['species'] = iris.target_names[iris.target]

# 데이터 확인
df.head()
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa

2. 2D 산점도로 분류 시각화

두 개의 주요 특징을 선택하여 데이터 포인트를 2D 산점도로 시각화합니다.

fig = go.Figure()
# 산점도 생성
fig = px.scatter(df, x='sepal length (cm)', y='sepal width (cm)',
                 color='species', symbol='species',
                 title='Iris 데이터셋의 품종에 따른 분류')

# 그래프 표시
fig.show()

이 코드에서는 sepal lengthsepal width를 x축과 y축으로 설정하여 각 품종을 다른 색상과 기호로 표시합니다.

3. 3D 분류 시각화

세 개의 특징을 사용하여 3D 산점도로 시각화할 수 있습니다.

fig = go.Figure()
fig = px.scatter_3d(df, x='sepal length (cm)', y='sepal width (cm)',
                    z='petal length (cm)', color='species',
                    symbol='species',
                    title='Iris 데이터셋의 3D 분류')

# 그래프 표시
fig.show()

여기서는 sepal length, sepal width, petal length를 사용하여 3D 산점도를 생성합니다.

클러스터링은 데이터 포인트를 유사한 그룹으로 나누는 작업입니다. 데이터 클러스터링 결과를 시각화하면 각 클러스터의 분포와 경계를 쉽게 이해할 수 있습니다. 예를 들어, K-평균 클러스터링을 사용하여 클러스터링 결과를 시각화하는 방법을 설명합니다.

1. 데이터 준비 및 클러스터링

K-평균 클러스터링을 사용하여 데이터셋을 클러스터링합니다.

from sklearn.cluster import KMeans
import pandas as pd

# 클러스터링 수행
kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
df['cluster'] = kmeans.fit_predict(df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)']])

# 클러스터 중심
centers = kmeans.cluster_centers_
display(centers)
array([[6.83571429, 3.06428571, 5.6547619 ],
       [5.006     , 3.428     , 1.462     ],
       [5.84655172, 2.73275862, 4.3637931 ]])

코드 설명

from sklearn.cluster import KMeans

  • KMeansscikit-learn 라이브러리에서 제공하는 클러스터링 알고리즘입니다. KMeans는 주어진 데이터 포인트를 K개의 클러스터로 나누는 알고리즘입니다.

kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)

  • KMeans 객체를 생성합니다.
  • n_clusters=3: 데이터 포인트를 3개의 클러스터로 나누겠다는 의미입니다.
  • random_state=42: 결과의 재현성을 위해 랜덤 시드를 설정합니다. 같은 데이터를 넣었을 때 같은 결과를 얻을 수 있도록 합니다.
  • n_init=10: 알고리즘을 여러 번 실행하여 가장 좋은 결과를 선택합니다. 기본값은 10입니다.

df['cluster'] = kmeans.fit_predict(df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)']])

  • fit_predict 메서드를 사용하여 클러스터링을 수행하고 각 데이터 포인트에 클러스터 레이블을 할당합니다.
  • df[['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)']]: 클러스터링에 사용할 데이터셋의 특정 열을 선택합니다. 여기서는 sepal length (cm), sepal width (cm), petal length (cm) 열을 사용합니다.
  • 이 메서드는 클러스터 레이블을 반환하며, 반환된 레이블을 df['cluster']에 저장합니다. 즉, 원래 데이터프레임 df에 새로운 열 cluster가 추가되어 각 데이터 포인트가 어떤 클러스터에 속하는지를 나타냅니다.

centers = kmeans.cluster_centers_

  • kmeans.cluster_centers_는 각 클러스터의 중심점을 반환합니다. 각 중심점은 클러스터에 속하는 데이터 포인트들의 평균 좌표입니다. 이 정보를 centers 변수에 저장합니다.

2. 클러스터링 결과 시각화

클러스터링 결과를 2D 산점도로 시각화합니다.

fig = go.Figure()

# 클러스터링 결과 시각화
fig = px.scatter(df, x='sepal length (cm)', y='sepal width (cm)',
                 color='cluster', hover_data='species',
                 title='Iris 데이터셋의 클러스터링 결과')

# 클러스터 중심 표시
fig.add_trace(go.Scatter(
                x=centers[:, 0], y=centers[:, 1],
                mode='markers',
                marker=dict(color='red', size=10))
)
fig.update_layout(legend=dict(visible=False))

# 그래프 표시
fig.show()

이 코드에서는 클러스터링된 데이터 포인트를 색상으로 구분하고, 클러스터 중심을 빨간색으로 표시합니다.

3. 3D 클러스터링 시각화

세 개의 특징을 사용하여 3D 클러스터링 결과를 시각화할 수 있습니다.

fig = go.Figure()
fig = px.scatter_3d(df, x='sepal length (cm)', y='sepal width (cm)',
                    z='petal length (cm)', color='cluster',
                    symbol='cluster',hover_data = 'species',
                    title='Iris 데이터셋의 3D 클러스터링 결과')

# 클러스터 중심 추가
fig.add_trace(
    go.Scatter3d(
        x=centers[:, 0],
        y=centers[:, 1],
        z=centers[:, 2],
        mode='markers',
        marker=dict(color='red', size=10),
        name='Cluster Centers')
)
fig.update_layout(legend=dict(visible=False))

# 그래프 표시
fig.show()

여기서는 sepal length, sepal width, petal length를 사용하여 3D 클러스터링 결과를 시각화합니다.

결론

Plotly는 데이터 분류 및 클러스터링 결과를 직관적으로 시각화할 수 있는 강력한 도구입니다. 2D 및 3D 산점도를 사용하여 데이터의 분포와 클러스터를 쉽게 시각화할 수 있으며, 이를 통해 데이터 분석과 인사이트를 보다 효과적으로 얻을 수 있습니다.