250x250
Link
๋‚˜์˜ GitHub Contribution ๊ทธ๋ž˜ํ”„
Loading data ...
Notice
Recent Posts
Recent Comments
๊ด€๋ฆฌ ๋ฉ”๋‰ด

Data Science LAB

[Python] ํ‰๊ท  ์ด๋™ ๋ณธ๋ฌธ

๐Ÿ›  Machine Learning/Clustering

[Python] ํ‰๊ท  ์ด๋™

ใ…… ใ…œ ใ…” ใ…‡ 2022. 3. 2. 16:11
728x90

Mean Shift

ํ‰๊ท  ์ด๋™(Mean Shift)์€ KMeans์™€ ์œ ์‚ฌํ•˜๊ฒŒ ์ค‘์‹ฌ์„ ๊ตฐ์ง‘์˜ ์ค‘์‹ฌ์œผ๋กœ ์ง€์†์ ์œผ๋กœ ์›€์ง์ด๋ฉด์„œ ๊ตฐ์ง‘ํ™”๋ฅผ ์ˆ˜ํ–‰ํ•œ๋‹ค.

KMeans๋Š” ์ค‘์‹ฌ์— ์†Œ์†๋œ ๋ฐ์ดํ„ฐ์˜ ํ‰๊ท  ๊ฑฐ๋ฆฌ ์ค‘์‹ฌ์œผ๋กœ ์ด๋™ํ•˜์ง€๋งŒ, ํ‰๊ท  ์ด๋™์€ ๋ฐ์ดํ„ฐ๊ฐ€ ๋ชจ์—ฌ ์žˆ๋Š” ๋ฐ€๋„๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๊ณณ์œผ๋กœ ์ด๋™์‹œํ‚จ๋‹ค.

 

ํ‰๊ท  ์ด๋™ ๊ตฐ์ง‘ํ™”๋Š” ๋ฐ์ดํ„ฐ์˜ ๋ถ„ํฌ๋„๋ฅผ ์ด์šฉํ•˜์—ฌ ๊ตฐ์ง‘์˜ ์ค‘์‹ฌ์ ์„ ์ฐพ๋Š”๋‹ค. ๊ตฐ์ง‘ ์ค‘์‹ฌ์ ์€ ๋ฐ์ดํ„ฐ ํฌ์ธํŠธ๊ฐ€ ๋ชจ์—ฌ ์žˆ๋Š” ๊ณณ์ด๋ผ๋Š” ์ƒ๊ฐ์—์„œ ์ฐฉ์•ˆํ•œ ๊ฒƒ์ด๋ฉฐ ์ด๋ฅผ ์œ„ํ•ด ํ™•๋ฅ  ๋ฐ€๋„ ํ•จ์ˆ˜๋ฅผ ์ด์šฉํ•œ๋‹ค. ์ผ๋ฐ˜์ ์œผ๋กœ ์ฃผ์–ด์ง„ ๋ชจ๋ธ์˜ ํ™•๋ฅ  ๋ฐ€๋„ ํ•จ์ˆ˜๋ฅผ ์ฐพ๊ธฐ ์œ„ํ•ด KDE(Kernel Density Estimation)๋ฅผ ์ด์šฉํ•œ๋‹ค. ํŠน์ • ๋ฐ์ดํ„ฐ ๋ฐ˜๊ฒฝ ๋‚ด์˜ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ ํ™•๋ฅ  ๋ฐ€๋„๊ฐ€ ๊ฐ€์žฅ ๋†’์€ ๊ณณ์œผ๋กœ ์ด๋™ํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ฃผ๋ณ€ ๋ฐ์ดํ„ฐ์™€์˜ ๊ฑฐ๋ฆฌ ๊ฐ’์„ KDE ํ•จ์ˆ˜์˜ ์ž…๋ ฅ ๊ฐ’์œผ๋กœ ์ž…๋ ฅํ•œ ๋’ค ๋ฐ˜ํ™˜ ๊ฐ’์„ ํ˜„์žฌ ์œ„์น˜์—์„œ ์—…๋ฐ์ดํŠธํ•˜๋ฉด์„œ ์ด๋™ํ•œ๋‹ค. ์ด๋Ÿฌํ•œ ๋ฐฉ์‹์„ ์ „์ฒด ๋ฐ์ดํ„ฐ์— ๋ฐ˜๋ณต์ ์œผ๋กœ ์ ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ ๊ตฐ์ง‘์˜ ์ค‘์‹ฌ์ ์„ ์ฐพ์•„๋‚ธ๋‹ค. 

 

 

 

1. ๊ฐœ๋ณ„ ๋ฐ์ดํ„ฐ์˜ ํŠน์ • ๋ฐ˜๊ฒฝ ๋‚ด์˜ ์ฃผ๋ณ€ ๋ฐ์ดํ„ฐ๋ฅผ ํฌํ•จํ•œ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ๋„๋ฅผ KDE ๊ธฐ๋ฐ˜์˜ MeanShift ์•Œ๊ณ ๋ฆฌ์ฆ˜์œผ๋กœ ๊ณ„์‚ฐ

 

2. KDE๋กœ ๊ณ„์‚ฐ๋œ ๋ฐ์ดํ„ฐ ๋ถ„ํฌ๋„๊ฐ€ ๋†’์€ ๋ฐฉํ–ฅ์œผ๋กœ ๋ฐ์ดํ„ฐ ์ด๋™

 

3. ๋ชจ๋“  ๋ฐ์ดํ„ฐ๋“ค์ด 1~2๊นŒ์ง€ ์ˆ˜ํ–‰ํ•˜๋ฉด์„œ ๋ฐ์ดํ„ฐ ์ด๋™. ๊ฐœ๋ณ„ ๋ฐ์ดํ„ฐ๋“ค์ด ๊ตฐ์ง‘ ์ค‘์‹ฌ์ ์œผ๋กœ ๋ชจ์ž„

 

4. ์ง€์ •๋œ ๋ฐ˜๋ณต ํšŸ์ˆ˜๋งŒํผ ์ „์ฒด ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด KDE ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ์ด๋™์‹œํ‚ค๋ฉด์„œ ๊ตฐ์ง‘ํ™” ์ˆ˜ํ–‰

 

5. ๊ฐœ๋ณ„ ๋ฐ์ดํ„ฐ๋“ค์ด ๋ชจ์ธ ์ค‘์‹ฌ์ ์„ ๊ตฐ์ง‘ ์ค‘์‹ฌ์ ์œผ๋กœ ์„ค์ •

 

 

 

 

KDE๋Š” ์ปค๋„ ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์–ด๋–ค ๋ณ€์ˆ˜์˜ ํ™•๋ฅ  ๋ฐ€๋„ ํ•จ์ˆ˜๋ฅผ ์ถ”์ •ํ•˜๋Š” ๋Œ€ํ‘œ์  ๋ฐฉ๋ฒ•์œผ๋กœ, ๊ด€์ธก๋œ ๋ฐ์ดํ„ฐ ๊ฐ๊ฐ์— ์ปค๋„ ํ•จ์ˆ˜๋ฅผ ์ ์šฉํ•œ ๊ฐ’์„ ๋ชจ๋‘ ๋”ํ•œ ๋’ค ๋ฐ์ดํ„ฐ ๊ฑด์ˆ˜๋กœ ๋‚˜๋ˆ  ํ™•๋ฅ  ๋ฐ€๋„ ํ•จ์ˆ˜ ์ถ”์ •ํ•œ๋‹ค. ๋Œ€ํ‘œ์ ์ธ ์ปค๋„ ํ•จ์ˆ˜๋กœ ๊ฐ€์šฐ์‹œ์•ˆ ๋ถ„ํฌ ํ•จ์ˆ˜๊ฐ€ ์‚ฌ์šฉ๋œ๋‹ค. ๋˜ํ•œ, ๋Œ€์—ญํญ h์— ๋”ฐ๋ผ ํ™•๋ฅ  ๋ฐ€๋„ ์ถ”์ • ์„ฑ๋Šฅ์„ ํฌ๊ฒŒ ์ขŒ์šฐํ•  ์ˆ˜ ์žˆ๋‹ค. ์ž‘์€ h๊ฐ’์€ ์ข๊ณ  ๋พฐ์กฑํ•œ KDE๋ฅผ ๊ฐ€์ง€๊ฒŒ ๋˜๋ฉฐ, ์ด๊ฒƒ์€ ๋ณ€๋™์„ฑ์ด ๋งค์šฐ ํฌ๋ฉฐ ๊ณผ์ ํ•ฉ ํ•˜๊ธฐ ์‰ฝ๋‹ค. ๋„ˆ๋ฌด ํฐ h๊ฐ’์€ ๊ณผ๋„ํ•˜๊ฒŒ ํ‰ํ™œํ™”๋œ KDE๋กœ ์ธํ•˜์—ฌ ๊ณผ์†Œ์ ํ•ฉํ•˜๊ธฐ ์‰ฝ๋‹ค. ๋”ฐ๋ผ์„œ ์ ์ ˆํ•œ KDE ๋Œ€์—ญํญ h๋ฅผ ๊ฒฐ์ •ํ•˜์—ฌ์•ผ ํ•œ๋‹ค. 

 

 

 

 

์•Œ๊ณ ๋ฆฌ์ฆ˜ ์ ์šฉ ์˜ˆ์ œ

import numpy as np
from sklearn.datasets import make_blobs
from sklearn.cluster import MeanShift

X,y = make_blobs(n_samples=200,n_features=2,centers=3,cluster_std=0.7,random_state=0)
meanshift = MeanShift(bandwidth = 0.8)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์œ ํ˜• : ",np.unique(cluster_labels))

 

ํ‘œ์ค€ํŽธ์ฐจ๋ฅผ 0.7๋กœ ์„ค์ •ํ•œ 3๊ฐœ์˜ ๊ตฐ์ง‘ ๋ฐ์ดํ„ฐ์— ๋Œ€ํ•ด bandwidth๋ฅผ 0.8๋กœ ์„ค์ •ํ•œ ํ‰๊ท  ์ด๋™ ๊ตฐ์ง‘ํ™” ์•Œ๊ณ ๋ฆฌ์ฆ˜์„ ์ ์šฉํ•œ ์˜ˆ์ œ๋ฅผ ์ƒ์„ฑํ•˜์˜€๋‹ค.

์ƒ์„ฑ๋œ ๊ตฐ์ง‘์€ 0-5๋กœ 6๊ฐœ์ด๋ฉฐ ์ง€๋‚˜์น˜๊ฒŒ ์„ธ๋ถ„ํ™”๋˜์–ด ๊ตฐ์ง‘ํ™”๋˜์—ˆ๋‹ค. 

 

bandwidth๋ฅผ ์ž‘๊ฒŒํ• ์ˆ˜๋ก ๊ตฐ์ง‘ ๊ฐœ์ˆ˜๊ฐ€ ๋งŽ์•„์ง

=> bandwidth๋ฅผ ๋†’์—ฌ์•ผ ํ•จ

 

 

 

 

 

 

bandwidth๋ฅผ 1๋กœ ์กฐ์ •

meanshift = MeanShift(bandwidth = 1)
cluster_labels = meanshift.fit_predict(X)
print("cluster ์œ ํ˜• : ",np.unique(cluster_labels))

3๊ฐœ์˜ ๊ตฐ์ง‘์œผ๋กœ ์ž˜ ๋ถ„๋ฅ˜๋œ ๊ฒƒ์„ ํ™•์ธํ•˜์˜€๋‹ค.

 

 

 

 

 

 

์ตœ์ ํ™”๋œ bandwidth์ฐพ๊ธฐ

from sklearn.cluster import estimate_bandwidth

bandwidth = estimate_bandwidth(X)
print('bandwidth ๊ฐ’ : ',round(bandwidth,3))

 

์‚ฌ์ดํ‚ท๋Ÿฐ์—์„œ ์ œ๊ณตํ•˜๋Š” estimate_bandwidth()๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ์ตœ์ ์˜ bandwidth๊ฐ’์„ ์ฐพ์„ ์ˆ˜ ์žˆ๋‹ค. 

 

 

 

 

 

 

์ตœ์ ํ™”๋œ bandwidth๊ฐ’ ์ ์šฉ

import pandas as pd

clusterDF = pd.DataFrame(data=X,columns = ['ftr1','ftr2'])
clusterDF['target'] = y

#estimate_bandwidth()๋กœ ์ตœ์ ์˜ bandwidth ๊ณ„์‚ฐ
best_bandwidth = estimate_bandwidth(X)

meanshift = MeanShift(bandwidth = best_bandwidth)
cluster_labels = meanshift.fit_predict(X)
print("cluster labels ์œ ํ˜• : ",np.unique(cluster_labels))

์ตœ์ ์˜ bandwidth๋ฅผ ์ ์šฉํ•˜์—ฌ ๊ตฐ์ง‘ํ™” ๋ชจ๋ธ ์ƒ์„ฑ ๊ฒฐ๊ณผ 3๊ฐœ์˜ ๊ตฐ์ง‘์œผ๋กœ ๋ถ„๋ฅ˜๋˜์—ˆ๋‹ค. 

 

 

 

 

๊ตฐ์ง‘ ์‹œ๊ฐํ™”

import matplotlib.pyplot as plt
%matplotlib inline

clusterDF['meanshift_label'] = cluster_labels
centers = meanshift.cluster_centers_
unique_lables = np.unique(cluster_labels)
markers = ['o','s','^','x','*']

for label in unique_lables:
    label_cluster = clusterDF[clusterDF['meanshift_label'] == label]
    center_x_y = centers[label]
    
    #๊ตฐ์ง‘๋ณ„๋กœ ๋‹ค๋ฅธ ๋งˆ์ปค๋กœ ์‚ฐ์ ๋„ ์ ์šฉ
    plt.scatter(x=label_cluster['ftr1'],y=label_cluster['ftr2'],edgecolor='k',marker=markers[label])
    
    #๊ตฐ์ง‘๋ณ„ ์ค‘์‹ฌํ‘œํ˜„
    plt.scatter(x=center_x_y[0],y = center_x_y[1], s = 200, color = 'gray',alpha=0.9,marker=markers[label])
    plt.scatter(x=center_x_y[0],y=center_x_y[1],s=70,color='k',marker='$%d$' % label)
    
plt.show()

 

 

 

 

 

target๊ฐ’๊ณผ ๊ตฐ์ง‘ label ๋น„๊ต

print(clusterDF.groupby('target')['meanshift_label'].value_counts())

target๊ฐ’๊ณผ label๊ฐ’์ด ์ž˜ ๋งค์นญ๋œ ๊ฒƒ์„ ํ™•์ธ

 

 

 

 

 

ํ‰๊ท ์ด๋™์€ ๋ฐ์ดํ„ฐ์…‹์˜ ํ˜•ํƒœ๋ฅผ ํŠน์ • ํ˜•ํƒœ๋กœ ๊ฐ€์ •ํ•˜๊ฑฐ๋‚˜ ํŠน์ • ๋ถ„ํฌ๋„ ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ๋กœ ๊ฐ€์ •ํ•˜์ง€ ์•Š์•„ ์œ ์—ฐํ•œ ๊ตฐ์ง‘ํ™”๊ฐ€ ๊ฐ€๋Šฅํ•˜๋‹ค. ๋˜ํ•œ ์ด์ƒ์น˜์˜ ์˜ํ–ฅ๋ ฅ๋„ ํฌ์ง€ ์•Š์œผ๋ฉฐ ๋ฏธ๋ฆฌ ๊ตฐ์ง‘์˜ ์ˆ˜๋ฅผ ์ง€์ •ํ•˜์ง€ ์•Š์•„๋„ ๋˜๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฏธ์ง€๋‚˜ ์˜์ƒ ๋ฐ์ดํ„ฐ์—์„œ ํŠน์ • ๊ฐœ์ฒด๋ฅผ ๊ตฌ๋ถ„ํ•˜๊ฑฐ๋‚˜ ์›€์ง์ž„์„ ์ถ”์ ํ•˜๋Š” ๋ฐ์— ๋›ฐ์–ด๋‚œ ์—ญํ• ์„ ์ˆ˜ํ–‰ํ•œ๋‹ค. 

 

 

 

 

 

 

 

 

 

###์ฐธ๊ณ 

http://www.chioka.in/meanshift-algorithm-for-the-rest-of-us-python/

 

Meanshift Algorithm for the Rest of Us (Python)

What is Meanshift? Meanshift is a clustering algorithm that assigns the datapoints to the clusters iteratively by shifting points towards the mode. The mode can be understood as the highest density of datapoints (in the region, in the context of the Meansh

www.chioka.in

 

728x90
Comments