์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
- ๋ ๋ฆฝํ๋ณธ
- ๋ฐ์ดํฐ๋ถ์์ค์ ๋ฌธ๊ฐ
- ๊ตฐ์งํ
- ๋ฐ์ดํฐ๋ถ๊ท ํ
- datascience
- t-test
- ๋ฐ์ดํฐ๋ถ์์ ๋ฌธ๊ฐ
- ๋น ๋ฐ์ดํฐ๋ถ์๊ธฐ์ฌ
- numpy
- ํ ์คํธ๋ถ์
- ํฌ๋กค๋ง
- PCA
- pandas
- ๋น ๋ฐ์ดํฐ
- opencv
- ์ธ๋์ํ๋ง
- DBSCAN
- Lambda
- ์ค๋ฒ์ํ๋ง
- ๋ฐ์ดํฐ๋ถ์
- dataframe
- ์ฃผ์ฑ๋ถ๋ถ์
- iloc
- ๋์ํ๋ณธ
- ์๋ํด๋ผ์ฐ๋
- ADsP
- LDA
- ํ์ด์ฌ
- ADP
- Python
Data Science LAB
[Python] ํ๊ท ์ด๋ ๋ณธ๋ฌธ
Mean Shift
ํ๊ท ์ด๋(Mean Shift)์ KMeans์ ์ ์ฌํ๊ฒ ์ค์ฌ์ ๊ตฐ์ง์ ์ค์ฌ์ผ๋ก ์ง์์ ์ผ๋ก ์์ง์ด๋ฉด์ ๊ตฐ์งํ๋ฅผ ์ํํ๋ค.
KMeans๋ ์ค์ฌ์ ์์๋ ๋ฐ์ดํฐ์ ํ๊ท ๊ฑฐ๋ฆฌ ์ค์ฌ์ผ๋ก ์ด๋ํ์ง๋ง, ํ๊ท ์ด๋์ ๋ฐ์ดํฐ๊ฐ ๋ชจ์ฌ ์๋ ๋ฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๊ณณ์ผ๋ก ์ด๋์ํจ๋ค.
ํ๊ท ์ด๋ ๊ตฐ์งํ๋ ๋ฐ์ดํฐ์ ๋ถํฌ๋๋ฅผ ์ด์ฉํ์ฌ ๊ตฐ์ง์ ์ค์ฌ์ ์ ์ฐพ๋๋ค. ๊ตฐ์ง ์ค์ฌ์ ์ ๋ฐ์ดํฐ ํฌ์ธํธ๊ฐ ๋ชจ์ฌ ์๋ ๊ณณ์ด๋ผ๋ ์๊ฐ์์ ์ฐฉ์ํ ๊ฒ์ด๋ฉฐ ์ด๋ฅผ ์ํด ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ด์ฉํ๋ค. ์ผ๋ฐ์ ์ผ๋ก ์ฃผ์ด์ง ๋ชจ๋ธ์ ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ฐพ๊ธฐ ์ํด KDE(Kernel Density Estimation)๋ฅผ ์ด์ฉํ๋ค. ํน์ ๋ฐ์ดํฐ ๋ฐ๊ฒฝ ๋ด์ ๋ฐ์ดํฐ ๋ถํฌ ํ๋ฅ ๋ฐ๋๊ฐ ๊ฐ์ฅ ๋์ ๊ณณ์ผ๋ก ์ด๋ํ๊ธฐ ์ํด์ ์ฃผ๋ณ ๋ฐ์ดํฐ์์ ๊ฑฐ๋ฆฌ ๊ฐ์ KDE ํจ์์ ์ ๋ ฅ ๊ฐ์ผ๋ก ์ ๋ ฅํ ๋ค ๋ฐํ ๊ฐ์ ํ์ฌ ์์น์์ ์ ๋ฐ์ดํธํ๋ฉด์ ์ด๋ํ๋ค. ์ด๋ฌํ ๋ฐฉ์์ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ฐ๋ณต์ ์ผ๋ก ์ ์ฉํ์ฌ ๋ฐ์ดํฐ ๊ตฐ์ง์ ์ค์ฌ์ ์ ์ฐพ์๋ธ๋ค.
1. ๊ฐ๋ณ ๋ฐ์ดํฐ์ ํน์ ๋ฐ๊ฒฝ ๋ด์ ์ฃผ๋ณ ๋ฐ์ดํฐ๋ฅผ ํฌํจํ ๋ฐ์ดํฐ ๋ถํฌ๋๋ฅผ KDE ๊ธฐ๋ฐ์ MeanShift ์๊ณ ๋ฆฌ์ฆ์ผ๋ก ๊ณ์ฐ
2. KDE๋ก ๊ณ์ฐ๋ ๋ฐ์ดํฐ ๋ถํฌ๋๊ฐ ๋์ ๋ฐฉํฅ์ผ๋ก ๋ฐ์ดํฐ ์ด๋
3. ๋ชจ๋ ๋ฐ์ดํฐ๋ค์ด 1~2๊น์ง ์ํํ๋ฉด์ ๋ฐ์ดํฐ ์ด๋. ๊ฐ๋ณ ๋ฐ์ดํฐ๋ค์ด ๊ตฐ์ง ์ค์ฌ์ ์ผ๋ก ๋ชจ์
4. ์ง์ ๋ ๋ฐ๋ณต ํ์๋งํผ ์ ์ฒด ๋ฐ์ดํฐ์ ๋ํด KDE ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ด๋์ํค๋ฉด์ ๊ตฐ์งํ ์ํ
5. ๊ฐ๋ณ ๋ฐ์ดํฐ๋ค์ด ๋ชจ์ธ ์ค์ฌ์ ์ ๊ตฐ์ง ์ค์ฌ์ ์ผ๋ก ์ค์
KDE๋ ์ปค๋ ํจ์๋ฅผ ํตํด ์ด๋ค ๋ณ์์ ํ๋ฅ ๋ฐ๋ ํจ์๋ฅผ ์ถ์ ํ๋ ๋ํ์ ๋ฐฉ๋ฒ์ผ๋ก, ๊ด์ธก๋ ๋ฐ์ดํฐ ๊ฐ๊ฐ์ ์ปค๋ ํจ์๋ฅผ ์ ์ฉํ ๊ฐ์ ๋ชจ๋ ๋ํ ๋ค ๋ฐ์ดํฐ ๊ฑด์๋ก ๋๋ ํ๋ฅ ๋ฐ๋ ํจ์ ์ถ์ ํ๋ค. ๋ํ์ ์ธ ์ปค๋ ํจ์๋ก ๊ฐ์ฐ์์ ๋ถํฌ ํจ์๊ฐ ์ฌ์ฉ๋๋ค. ๋ํ, ๋์ญํญ h์ ๋ฐ๋ผ ํ๋ฅ ๋ฐ๋ ์ถ์ ์ฑ๋ฅ์ ํฌ๊ฒ ์ข์ฐํ ์ ์๋ค. ์์ h๊ฐ์ ์ข๊ณ ๋พฐ์กฑํ KDE๋ฅผ ๊ฐ์ง๊ฒ ๋๋ฉฐ, ์ด๊ฒ์ ๋ณ๋์ฑ์ด ๋งค์ฐ ํฌ๋ฉฐ ๊ณผ์ ํฉ ํ๊ธฐ ์ฝ๋ค. ๋๋ฌด ํฐ h๊ฐ์ ๊ณผ๋ํ๊ฒ ํํํ๋ KDE๋ก ์ธํ์ฌ ๊ณผ์์ ํฉํ๊ธฐ ์ฝ๋ค. ๋ฐ๋ผ์ ์ ์ ํ KDE ๋์ญํญ h๋ฅผ ๊ฒฐ์ ํ์ฌ์ผ ํ๋ค.
์๊ณ ๋ฆฌ์ฆ ์ ์ฉ ์์
import numpy as np
from sklearn.datasets import make_blobs
from sklearn.cluster import MeanShift
X,y = make_blobs(n_samples=200,n_features=2,centers=3,cluster_std=0.7,random_state=0)
meanshift = MeanShift(bandwidth = 0.8)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์ ํ : ",np.unique(cluster_labels))
ํ์คํธ์ฐจ๋ฅผ 0.7๋ก ์ค์ ํ 3๊ฐ์ ๊ตฐ์ง ๋ฐ์ดํฐ์ ๋ํด bandwidth๋ฅผ 0.8๋ก ์ค์ ํ ํ๊ท ์ด๋ ๊ตฐ์งํ ์๊ณ ๋ฆฌ์ฆ์ ์ ์ฉํ ์์ ๋ฅผ ์์ฑํ์๋ค.
์์ฑ๋ ๊ตฐ์ง์ 0-5๋ก 6๊ฐ์ด๋ฉฐ ์ง๋์น๊ฒ ์ธ๋ถํ๋์ด ๊ตฐ์งํ๋์๋ค.
bandwidth๋ฅผ ์๊ฒํ ์๋ก ๊ตฐ์ง ๊ฐ์๊ฐ ๋ง์์ง
=> bandwidth๋ฅผ ๋์ฌ์ผ ํจ
bandwidth๋ฅผ 1๋ก ์กฐ์
meanshift = MeanShift(bandwidth = 1)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์ ํ : ",np.unique(cluster_labels))
3๊ฐ์ ๊ตฐ์ง์ผ๋ก ์ ๋ถ๋ฅ๋ ๊ฒ์ ํ์ธํ์๋ค.
์ต์ ํ๋ bandwidth์ฐพ๊ธฐ
from sklearn.cluster import estimate_bandwidth
bandwidth = estimate_bandwidth(X)
print('bandwidth ๊ฐ : ',round(bandwidth,3))
์ฌ์ดํท๋ฐ์์ ์ ๊ณตํ๋ estimate_bandwidth()๋ฅผ ์ฌ์ฉํ๋ฉด ์ต์ ์ bandwidth๊ฐ์ ์ฐพ์ ์ ์๋ค.
์ต์ ํ๋ bandwidth๊ฐ ์ ์ฉ
import pandas as pd
clusterDF = pd.DataFrame(data=X,columns = ['ftr1','ftr2'])
clusterDF['target'] = y
#estimate_bandwidth()๋ก ์ต์ ์ bandwidth ๊ณ์ฐ
best_bandwidth = estimate_bandwidth(X)
meanshift = MeanShift(bandwidth = best_bandwidth)
cluster_labels = meanshift.fit_predict(X)
print("cluster labels ์ ํ : ",np.unique(cluster_labels))
์ต์ ์ bandwidth๋ฅผ ์ ์ฉํ์ฌ ๊ตฐ์งํ ๋ชจ๋ธ ์์ฑ ๊ฒฐ๊ณผ 3๊ฐ์ ๊ตฐ์ง์ผ๋ก ๋ถ๋ฅ๋์๋ค.
๊ตฐ์ง ์๊ฐํ
import matplotlib.pyplot as plt
%matplotlib inline
clusterDF['meanshift_label'] = cluster_labels
centers = meanshift.cluster_centers_
unique_lables = np.unique(cluster_labels)
markers = ['o','s','^','x','*']
for label in unique_lables:
label_cluster = clusterDF[clusterDF['meanshift_label'] == label]
center_x_y = centers[label]
#๊ตฐ์ง๋ณ๋ก ๋ค๋ฅธ ๋ง์ปค๋ก ์ฐ์ ๋ ์ ์ฉ
plt.scatter(x=label_cluster['ftr1'],y=label_cluster['ftr2'],edgecolor='k',marker=markers[label])
#๊ตฐ์ง๋ณ ์ค์ฌํํ
plt.scatter(x=center_x_y[0],y = center_x_y[1], s = 200, color = 'gray',alpha=0.9,marker=markers[label])
plt.scatter(x=center_x_y[0],y=center_x_y[1],s=70,color='k',marker='$%d$' % label)
plt.show()
target๊ฐ๊ณผ ๊ตฐ์ง label ๋น๊ต
print(clusterDF.groupby('target')['meanshift_label'].value_counts())
target๊ฐ๊ณผ label๊ฐ์ด ์ ๋งค์นญ๋ ๊ฒ์ ํ์ธ
ํ๊ท ์ด๋์ ๋ฐ์ดํฐ์ ์ ํํ๋ฅผ ํน์ ํํ๋ก ๊ฐ์ ํ๊ฑฐ๋ ํน์ ๋ถํฌ๋ ๊ธฐ๋ฐ์ ๋ชจ๋ธ๋ก ๊ฐ์ ํ์ง ์์ ์ ์ฐํ ๊ตฐ์งํ๊ฐ ๊ฐ๋ฅํ๋ค. ๋ํ ์ด์์น์ ์ํฅ๋ ฅ๋ ํฌ์ง ์์ผ๋ฉฐ ๋ฏธ๋ฆฌ ๊ตฐ์ง์ ์๋ฅผ ์ง์ ํ์ง ์์๋ ๋๊ธฐ ๋๋ฌธ์ ์ด๋ฏธ์ง๋ ์์ ๋ฐ์ดํฐ์์ ํน์ ๊ฐ์ฒด๋ฅผ ๊ตฌ๋ถํ๊ฑฐ๋ ์์ง์์ ์ถ์ ํ๋ ๋ฐ์ ๋ฐ์ด๋ ์ญํ ์ ์ํํ๋ค.
###์ฐธ๊ณ
http://www.chioka.in/meanshift-algorithm-for-the-rest-of-us-python/
'๐ Machine Learning > Clustering' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[Python] DBSCAN (0) | 2022.03.04 |
---|---|
[python] GMM(Gaussian Mixture Model) (0) | 2022.03.03 |
[Python] ๊ตฐ์ง ํ๊ฐ(์ค๋ฃจ์ฃ ๊ณ์) (0) | 2022.03.01 |
[Python] KMeans Clustering(K-ํ๊ท ๊ตฐ์งํ) (0) | 2022.02.28 |