์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- ๋์ํ๋ณธ
- pandas
- ํฌ๋กค๋ง
- ADsP
- PCA
- datascience
- dataframe
- ์ค๋ฒ์ํ๋ง
- ๋ ๋ฆฝํ๋ณธ
- ์๋ํด๋ผ์ฐ๋
- Python
- ๋ฐ์ดํฐ๋ถ์
- iloc
- ๋น ๋ฐ์ดํฐ
- opencv
- ๋ฐ์ดํฐ๋ถ์์ค์ ๋ฌธ๊ฐ
- ์ฃผ์ฑ๋ถ๋ถ์
- ๋น ๋ฐ์ดํฐ๋ถ์๊ธฐ์ฌ
- ๊ตฐ์งํ
- t-test
- LDA
- ๋ฐ์ดํฐ๋ถ๊ท ํ
- ๋ฐ์ดํฐ๋ถ์์ ๋ฌธ๊ฐ
- ํ์ด์ฌ
- ํ ์คํธ๋ถ์
- Lambda
- numpy
- ADP
- ์ธ๋์ํ๋ง
- DBSCAN
Data Science LAB
[Python] ํ๊ท ์ด๋ ๋ณธ๋ฌธ
Mean Shift
ํ๊ท ์ด๋(Mean Shift)์ KMeans์ ์ ์ฌํ๊ฒ ์ค์ฌ์ ๊ตฐ์ง์ ์ค์ฌ์ผ๋ก ์ง์์ ์ผ๋ก ์์ง์ด๋ฉด์ ๊ตฐ์งํ๋ฅผ ์ํํ๋ค.
KMeans๋ ์ค์ฌ์ ์์๋ ๋ฐ์ดํฐ์ ํ๊ท ๊ฑฐ๋ฆฌ ์ค์ฌ์ผ๋ก ์ด๋ํ์ง๋ง, ํ๊ท ์ด๋์ ๋ฐ์ดํฐ๊ฐ ๋ชจ์ฌ ์๋ ๋ฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๊ณณ์ผ๋ก ์ด๋์ํจ๋ค.
ํ๊ท ์ด๋ ๊ตฐ์งํ๋ ๋ฐ์ดํฐ์ ๋ถํฌ๋๋ฅผ ์ด์ฉํ์ฌ ๊ตฐ์ง์ ์ค์ฌ์ ์ ์ฐพ๋๋ค. ๊ตฐ์ง ์ค์ฌ์ ์ ๋ฐ์ดํฐ ํฌ์ธํธ๊ฐ ๋ชจ์ฌ ์๋ ๊ณณ์ด๋ผ๋ ์๊ฐ์์ ์ฐฉ์ํ ๊ฒ์ด๋ฉฐ ์ด๋ฅผ ์ํด ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ด์ฉํ๋ค. ์ผ๋ฐ์ ์ผ๋ก ์ฃผ์ด์ง ๋ชจ๋ธ์ ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ฐพ๊ธฐ ์ํด KDE(Kernel Density Estimation)๋ฅผ ์ด์ฉํ๋ค. ํน์ ๋ฐ์ดํฐ ๋ฐ๊ฒฝ ๋ด์ ๋ฐ์ดํฐ ๋ถํฌ ํ๋ฅ ๋ฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๊ณณ์ผ๋ก ์ด๋ํ๊ธฐ ์ํด์ ์ฃผ๋ณ ๋ฐ์ดํฐ์์ ๊ฑฐ๋ฆฌ ๊ฐ์ KDE ํจ์์ ์ ๋ ฅ ๊ฐ์ผ๋ก ์ ๋ ฅํ ๋ค ๋ฐํ ๊ฐ์ ํ์ฌ ์์น์์ ์ ๋ฐ์ดํธํ๋ฉด์ ์ด๋ํ๋ค. ์ด๋ฌํ ๋ฐฉ์์ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ฐ๋ณต์ ์ผ๋ก ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ๊ตฐ์ง์ ์ค์ฌ์ ์ ์ฐพ์๋ธ๋ค.
1. ๊ฐ๋ณ ๋ฐ์ดํฐ์ ํน์ ๋ฐ๊ฒฝ ๋ด์ ์ฃผ๋ณ ๋ฐ์ดํฐ๋ฅผ ํฌํจํ ๋ฐ์ดํฐ ๋ถํฌ๋๋ฅผ KDE ๊ธฐ๋ฐ์ MeanShift ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ๊ณ์ฐ
2. KDE๋ก ๊ณ์ฐ๋ ๋ฐ์ดํฐ ๋ถํฌ๋๊ฐ ๋์ ๋ฐฉํฅ์ผ๋ก ๋ฐ์ดํฐ ์ด๋
3. ๋ชจ๋ ๋ฐ์ดํฐ๋ค์ด 1~2๊น์ง ์ํํ๋ฉด์ ๋ฐ์ดํฐ ์ด๋. ๊ฐ๋ณ ๋ฐ์ดํฐ๋ค์ด ๊ตฐ์ง ์ค์ฌ์ ์ผ๋ก ๋ชจ์
4. ์ง์ ๋ ๋ฐ๋ณต ํ์๋งํผ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํด KDE ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ด๋์ํค๋ฉด์ ๊ตฐ์งํ ์ํ
5. ๊ฐ๋ณ ๋ฐ์ดํฐ๋ค์ด ๋ชจ์ธ ์ค์ฌ์ ์ ๊ตฐ์ง ์ค์ฌ์ ์ผ๋ก ์ค์
KDE๋ ์ปค๋ ํจ์๋ฅผ ํตํด ์ด๋ค ๋ณ์์ ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ถ์ ํ๋ ๋ํ์ ๋ฐฉ๋ฒ์ผ๋ก, ๊ด์ธก๋ ๋ฐ์ดํฐ ๊ฐ๊ฐ์ ์ปค๋ ํจ์๋ฅผ ์ ์ฉํ ๊ฐ์ ๋ชจ๋ ๋ํ ๋ค ๋ฐ์ดํฐ ๊ฑด์๋ก ๋๋ ํ๋ฅ ๋ฐ๋ ํจ์ ์ถ์ ํ๋ค. ๋ํ์ ์ธ ์ปค๋ ํจ์๋ก ๊ฐ์ฐ์์ ๋ถํฌ ํจ์๊ฐ ์ฌ์ฉ๋๋ค. ๋ํ, ๋์ญํญ h์ ๋ฐ๋ผ ํ๋ฅ ๋ฐ๋ ์ถ์ ์ฑ๋ฅ์ ํฌ๊ฒ ์ข์ฐํ ์ ์๋ค. ์์ h๊ฐ์ ์ข๊ณ ๋พฐ์กฑํ KDE๋ฅผ ๊ฐ์ง๊ฒ ๋๋ฉฐ, ์ด๊ฒ์ ๋ณ๋์ฑ์ด ๋งค์ฐ ํฌ๋ฉฐ ๊ณผ์ ํฉ ํ๊ธฐ ์ฝ๋ค. ๋๋ฌด ํฐ h๊ฐ์ ๊ณผ๋ํ๊ฒ ํํํ๋ KDE๋ก ์ธํ์ฌ ๊ณผ์์ ํฉํ๊ธฐ ์ฝ๋ค. ๋ฐ๋ผ์ ์ ์ ํ KDE ๋์ญํญ h๋ฅผ ๊ฒฐ์ ํ์ฌ์ผ ํ๋ค.
์๊ณ ๋ฆฌ์ฆ ์ ์ฉ ์์
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.cluster import MeanShift
X,y = make_blobs(n_samples=200,n_features=2,centers=3,cluster_std=0.7,random_state=0)
meanshift = MeanShift(bandwidth = 0.8)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์ ํ : ",np.unique(cluster_labels))
ํ์คํธ์ฐจ๋ฅผ 0.7๋ก ์ค์ ํ 3๊ฐ์ ๊ตฐ์ง ๋ฐ์ดํฐ์ ๋ํด bandwidth๋ฅผ 0.8๋ก ์ค์ ํ ํ๊ท ์ด๋ ๊ตฐ์งํ ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํ ์์ ๋ฅผ ์์ฑํ์๋ค.
์์ฑ๋ ๊ตฐ์ง์ 0-5๋ก 6๊ฐ์ด๋ฉฐ ์ง๋์น๊ฒ ์ธ๋ถํ๋์ด ๊ตฐ์งํ๋์๋ค.
bandwidth๋ฅผ ์๊ฒํ ์๋ก ๊ตฐ์ง ๊ฐ์๊ฐ ๋ง์์ง
=> bandwidth๋ฅผ ๋์ฌ์ผ ํจ
bandwidth๋ฅผ 1๋ก ์กฐ์
meanshift = MeanShift(bandwidth = 1)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์ ํ : ",np.unique(cluster_labels))
3๊ฐ์ ๊ตฐ์ง์ผ๋ก ์ ๋ถ๋ฅ๋ ๊ฒ์ ํ์ธํ์๋ค.
์ต์ ํ๋ bandwidth์ฐพ๊ธฐ
from sklearn.cluster import estimate_bandwidth
bandwidth = estimate_bandwidth(X)
print('bandwidth ๊ฐ : ',round(bandwidth,3))
์ฌ์ดํท๋ฐ์์ ์ ๊ณตํ๋ estimate_bandwidth()๋ฅผ ์ฌ์ฉํ๋ฉด ์ต์ ์ bandwidth๊ฐ์ ์ฐพ์ ์ ์๋ค.
์ต์ ํ๋ bandwidth๊ฐ ์ ์ฉ
import pandas as pd
clusterDF = pd.DataFrame(data=X,columns = ['ftr1','ftr2'])
clusterDF['target'] = y
#estimate_bandwidth()๋ก ์ต์ ์ bandwidth ๊ณ์ฐ
best_bandwidth = estimate_bandwidth(X)
meanshift = MeanShift(bandwidth = best_bandwidth)
cluster_labels = meanshift.fit_predict(X)
print("cluster labels ์ ํ : ",np.unique(cluster_labels))
์ต์ ์ bandwidth๋ฅผ ์ ์ฉํ์ฌ ๊ตฐ์งํ ๋ชจ๋ธ ์์ฑ ๊ฒฐ๊ณผ 3๊ฐ์ ๊ตฐ์ง์ผ๋ก ๋ถ๋ฅ๋์๋ค.
๊ตฐ์ง ์๊ฐํ
import matplotlib.pyplot as plt
%matplotlib inline
clusterDF['meanshift_label'] = cluster_labels
centers = meanshift.cluster_centers_
unique_lables = np.unique(cluster_labels)
markers = ['o','s','^','x','*']
for label in unique_lables:
label_cluster = clusterDF[clusterDF['meanshift_label'] == label]
center_x_y = centers[label]
#๊ตฐ์ง๋ณ๋ก ๋ค๋ฅธ ๋ง์ปค๋ก ์ฐ์ ๋ ์ ์ฉ
plt.scatter(x=label_cluster['ftr1'],y=label_cluster['ftr2'],edgecolor='k',marker=markers[label])
#๊ตฐ์ง๋ณ ์ค์ฌํํ
plt.scatter(x=center_x_y[0],y = center_x_y[1], s = 200, color = 'gray',alpha=0.9,marker=markers[label])
plt.scatter(x=center_x_y[0],y=center_x_y[1],s=70,color='k',marker='$%d$' % label)
plt.show()
target๊ฐ๊ณผ ๊ตฐ์ง label ๋น๊ต
print(clusterDF.groupby('target')['meanshift_label'].value_counts())
target๊ฐ๊ณผ label๊ฐ์ด ์ ๋งค์นญ๋ ๊ฒ์ ํ์ธ
ํ๊ท ์ด๋์ ๋ฐ์ดํฐ์ ์ ํํ๋ฅผ ํน์ ํํ๋ก ๊ฐ์ ํ๊ฑฐ๋ ํน์ ๋ถํฌ๋ ๊ธฐ๋ฐ์ ๋ชจ๋ธ๋ก ๊ฐ์ ํ์ง ์์ ์ ์ฐํ ๊ตฐ์งํ๊ฐ ๊ฐ๋ฅํ๋ค. ๋ํ ์ด์์น์ ์ํฅ๋ ฅ๋ ํฌ์ง ์์ผ๋ฉฐ ๋ฏธ๋ฆฌ ๊ตฐ์ง์ ์๋ฅผ ์ง์ ํ์ง ์์๋ ๋๊ธฐ ๋๋ฌธ์ ์ด๋ฏธ์ง๋ ์์ ๋ฐ์ดํฐ์์ ํน์ ๊ฐ์ฒด๋ฅผ ๊ตฌ๋ถํ๊ฑฐ๋ ์์ง์์ ์ถ์ ํ๋ ๋ฐ์ ๋ฐ์ด๋ ์ญํ ์ ์ํํ๋ค.
###์ฐธ๊ณ
http://www.chioka.in/meanshift-algorithm-for-the-rest-of-us-python/
Meanshift Algorithm for the Rest of Us (Python)
What is Meanshift? Meanshift is a clustering algorithm that assigns the datapoints to the clusters iteratively by shifting points towards the mode. The mode can be understood as the highest density of datapoints (in the region, in the context of the Meansh
www.chioka.in
'๐ Machine Learning > Clustering' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Python] DBSCAN (0) | 2022.03.04 |
---|---|
[python] GMM(Gaussian Mixture Model) (0) | 2022.03.03 |
[Python] ๊ตฐ์ง ํ๊ฐ(์ค๋ฃจ์ฃ ๊ณ์) (0) | 2022.03.01 |
[Python] KMeans Clustering(K-ํ๊ท ๊ตฐ์งํ) (0) | 2022.02.28 |