Quantcast
Channel: CodeSection,代码区,Python开发技术文章_教程 - CodeSec
Viewing all articles
Browse latest Browse all 9596

使用Python实现Mean Shift算法

$
0
0

前文介绍的K-Means算法需要指定K值(分组数),本文实现的MeanShift聚类算法不需要预先知道聚类的分组数,对聚类的形状也没有限制。

为了更好的理解这个算法,本帖使用python实现Mean Shift算法。

MeanShift算法详细介绍: https://en.wikipedia.org/wiki/Mean_shift

scikit-learn中的MeanShift import numpyas np from sklearn.clusterimport MeanShift from matplotlibimport pyplot from mpl_toolkits.mplot3dimport Axes3D from sklearn.datasets.samples_generatorimport make_blobs fig = pyplot.figure() ax = fig.add_subplot(111, projection='3d') # 生成3组数据样本 centers = [[2,1,3], [6,6,6], [10,8,9]] x,_ = make_blobs(n_samples=200, centers=centers, cluster_std=1) #for i in range(len(x)): #ax.scatter(x[i][0], x[i][1], x[i][2])
使用Python实现Mean Shift算法
# 对上面数据进行分组 clf = MeanShift() clf.fit(x) labels = clf.labels_# 每个点对应的组 cluster_centers = clf.cluster_centers_# 每个组的"中心点" #print(labels) print(cluster_centers) colors = ['r', 'g', 'b'] for i in range(len(x)): ax.scatter(x[i][0], x[i][1], x[i][2], c=colors[labels[i]]) ax.scatter(cluster_centers[:,0], cluster_centers[:,1], cluster_centers[:,2], marker='*', c='k', s=200, zorder=10) pyplot.show()
使用Python实现Mean Shift算法

MeanShift把上面数据自动分为3组,计算出的三个组的”中心点”为:

[[1.975666191.042125483.02410725] [6.016721576.183252715.96562957] [ 10.1445537812.023944359.03499578]] # 和[[2,1,3], [6,6,6], [10,12,9]]接近;生成的样本越多越接近 使用Python实现Mean Shift算法 importnumpyas np frommatplotlibimportpyplot frommpl_toolkits.mplot3d importAxes3D fromsklearn.datasets.samples_generator importmake_blobs class MeanShift(object): def__init__(self, bandwidth=4): self.bandwidth_ = bandwidth deffit(self, data): centers = {} #把每个点都当做中心点 for i in range(len(data)): centers[i] = data[i] #print(centers) while True: new_centers = [] for i in centers: in_bandwidth = [] #取一个点,把在范围内的其它点放到in_bandwidth center = centers[i] for featurein data: #self.bandwidth_越小分的组越多 if np.linalg.norm(feature - center) < self.bandwidth_: in_bandwidth.append(feature) new_center = np.average(in_bandwidth, axis=0) new_centers.append(tuple(new_center)) uniques = sorted(list(set(new_centers))) prev_centers = dict(centers) centers = {} for i in range(len(uniques)): centers[i] = np.array(uniques[i]) optimzed = True for i in centers: if not np.array_equal(centers[i], prev_centers[i]): optimzed = False if not optimzed: break if optimzed: break self.centers_ = centers if __name__ == '__main__': fig = pyplot.figure() ax = fig.add_subplot(111, projection='3d') centers = [[2,1,3], [6,6,6], [10,12,9]] x,_ = make_blobs(n_samples=18, centers=centers, cluster_std=1) clf = MeanShift() clf.fit(x) print(clf.centers_) for i in clf.centers_: ax.scatter(clf.centers_[i][0], clf.centers_[i][1], clf.centers_[i][2], marker='*', c='k', s=200, zorder=10) for i in range(len(x)): ax.scatter(x[i][0], x[i][1], x[i][2]) pyplot.show()

执行结果:


使用Python实现Mean Shift算法

bandwidth参数代表点的半径(radius)范围,bandwidth=20:


使用Python实现Mean Shift算法

bandwidth=2.5:


使用Python实现Mean Shift算法

这个bandwidth可以根据数据样本求出最适合的值。

class MeanShift(object): def __init__(self, bandwidth=None, bandwidth_step=100): self.bandwidth_ = bandwidth self.bandwidth_step_ = bandwidth_step def fit(self, data): if self.bandwidth_ == None: all_data_center = np.average(data, axis = 0) self.bandwidth_ = np.linalg.norm(all_data_center)/self.bandwidth_step_ print(self.bandwidth_) centers = {} #把每个点都当做中心点 for i in range(len(data)): centers[i] = data[i] #print(centers) while True: new_centers = [] for i in centers: in_bandwidth = [] #取一个点,把在范围内的其它点放到in_bandwidth center = centers[i] w = [i for i in range(self.bandwidth_step_)][::-1] for featurein data: distance = np.linalg.norm(feature - center) if distance == 0: distance = 0.000000001 w_index = int(distance/self.bandwidth_) if w_index > self.bandwidth_step_-1: w_index = self.bandwidth_step_-1 in_bandwidth += (w[w_index]**2) * [feature] new_center = np.average(in_bandwidth, axis=0) new_centers.append(tuple(new_center)) uniques = sorted(list(set(new_centers))) tmp = [] for i in uniques: for iiin uniques: if i == ii: pass elif np.linalg.norm(np.array(i) - np.array(ii)) <= self.bandwidth_: tmp.append(ii) break for i in tmp: try: uniques.remove(i) except: pass prev_centers = dict(centers) centers = {} for i in range(len(uniques)): centers[i] = np.array(uniques[i]) optimzed = True for i in centers: if not np.array_equal(centers[i], prev_centers[i]): optimzed = False if not optimzed: break if optimzed: break self.centers_ = centers self.labels_ = {} for i in range(len(centers)): self.labels_[i] = [] for featurein data: distances = [np.linalg.norm(feature - self.centers_[center]) for centerin self.centers_] clf = distances.index(min(distances)) self.labels_[clf].append(feature) def predict(self, data): distances = [np.linalg.norm(feature - self.centers_[center]) for centerin self.centers_] clf = distances.index(min(distances)) return clf 在实际数据上应用Mean Shift算法

数据集:titanic.xls(泰坦尼克号遇难者/幸存者名单)。目的:对乘客进行分类,看看这几组人有什么共同特点。

import numpyas np from sklearn.clusterimport MeanShift from sklearnimport preprocessing import pandasas pd ''' 数据集:titanic.xls(泰坦尼克号遇难者/幸存者名单) <http://blog.topspeedsnail.com/wp-content/uploads/2016/11/titanic.xls> ***字段*** pclass: 社会阶层(1,精英;2,中层;3,船员/劳苦大众) survived: 是否幸存 name: 名字 sex: 性别 age: 年龄 sibsp: 哥哥姐姐个数 parch: 父母儿女个数 ticket: 船票号 fare: 船票价钱 cabin: 船舱 embarked boat body: 尸体 home.dest ****** 目的:使用除survived字段外的数据进行means shift分组,看看能分为几组,这几组人有什么共同特点 ''' # 加载数据 df = pd.read_excel('titanic.xls') #print(df.shape)(1309, 14) #print(df.head()) #print(df.tail()) """ pclasssurvivedname sex\ 0 1 1Allen, Miss. Elisabeth Waltonfemale 1 1 1 Allison, Master. Hudson Trevormale 2 1 0 Allison, Miss. Helen Lorainefemale 3 1 0 Allison, Mr. Hudson Joshua Creightonmale 4 1G 0Allison, Mrs. Hudson J C (Bessie Waldo Daniels)female ages

Viewing all articles
Browse latest Browse all 9596

Trending Articles