pyspark编写kmeans代码,因为在迭代过程中使用了collect()算子导致速度被



# 这个函数的目的就是把读入的数据都转化为float类型的数据
def parseVector(line):
    return np.array([float(x) for x in line.split(' ')])


# 此函数的目的就是求取该点应该分到哪个点集中去,返回的是序号
def closestPoint(p, centers):
    bestIndex = 0
    closest = float("+inf")
    for i in range(len(centers)):
        tempDist = np.sum((p - centers[i]) ** 2)
        if tempDist < closest:
            closest = tempDist
            bestIndex = i
    return bestIndex


if __name__ == "__main__":

    if len(sys.argv) != 4:
        print("Usage: kmeans <file> <k> <convergeDist>", file=sys.stderr)
        exit(-1)

    print("""WARN: This is a naive implementation of KMeans Clustering and is given
       as an example! Please refer to examples/src/main/python/mllib/kmeans.py for an example on
       how to use MLlib's KMeans implementation.""", file=sys.stderr)

sc = SparkContext(appName="PythonKMeans")
lines = sc.textFile(sys.argv[1])
# 此处调用了rdd的map函数把所有的数据都转换为float类型
data = lines.map(parseVector).cache()
# 这里的K就是设置的中心点数
K = int(sys.argv[2])
# 设置的阈值,如果两次之间的距离小于该阈值的话则停止迭代
convergeDist = float(sys.argv[3])
# 从点集中用采样的方式来抽取K个值
kPoints = data.takeSample(False, K, 1)
# 中心点调整后的距离差
tempDist = 1.0

# 如果距离差大于阈值则执行
while tempDist > convergeDist:
    # 对所有数据执行map过程,最终生成的是(index, (point, 1))的rdd
    closest = data.map(
        lambda p: (closestPoint(p, kPoints), (p, 1)))
    # 执行reduce过程,该过程的目的是重新求取中心点,生成的也是rdd
    pointStats = closest.reduceByKey(
        lambda p1_c1, p2_c2: (p1_c1[0] + p2_c2[0], p1_c1[1] + p2_c2[1]))
    # 生成新的中心点!!这儿的collect导致运行的速度被大大大的拖慢!!
    newPoints = pointStats.map(
        lambda st: (st[0], st[1][0] / st[1][1])).collect()
    # 计算一下新旧中心点的距离差,这人collect()的使用导致速度被严重拖慢!!
    tempDist = sum(np.sum((kPoints[iK] - p) ** 2) for (iK, p) in newPoints)

    # 设置新的中心点
    for (iK, p) in newPoints:
        kPoints[iK] = p

print("Final centers: " + str(kPoints))

sc.stop()

#目前还没有想出什么好的办法来解决这个问题,不知道怎么办