博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
使用Java实现K-Means聚类算法
阅读量:5952 次
发布时间:2019-06-19

本文共 10820 字,大约阅读时间需要 36 分钟。

hot3.png

 

关于K-Means介绍很多,还不清楚可以查一些相关资料。

个人对其实现步骤简单总结为4步:

1.选出k值,随机出k个起始质心点。 

 
2.分别计算每个点和k个起始质点之间的距离,就近归类。 
 
3.最终中心点集可以划分为k类,分别计算每类中新的中心点。 
 

4.重复2,3步骤对所有点进行归类,如果当所有分类的质心点不再改变,则最终收敛。

 

下面贴代码。

1.入口类,基本读取数据源进行训练然后输出。 数据源文件和源码后面会补上。

package com.hyr.kmeans;import au.com.bytecode.opencsv.CSVReader;import java.io.FileReader;import java.io.FileWriter;import java.io.IOException;import java.util.ArrayList;import java.util.List;public class KmeansMain {    public static void main(String[] args) throws IOException {        // 读取数据源文件        CSVReader reader = new CSVReader(new FileReader("src/main/resources/data.csv")); // 数据源        FileWriter writer = new FileWriter("src/main/resources/out.csv");        List
myEntries = reader.readAll(); // 6.8, 12.6 // 转换数据点集 List
points = new ArrayList
(); // 数据点集 for (String[] entry : myEntries) { points.add(new Point(Float.parseFloat(entry[0]), Float.parseFloat(entry[1]))); } int k = 6; // K值 int type = 1; KmeansModel model = Kmeans.run(points, k, type); writer.write("==================== K is " + model.getK() + " , Object Funcion Value is " + model.getOfv() + " , calc_distance_type is " + model.getCalc_distance_type() + " ====================\n"); int i = 0; for (Cluster cluster : model.getClusters()) { i++; writer.write("==================== classification " + i + " ====================\n"); for (Point point : cluster.getPoints()) { writer.write(point.toString() + "\n"); } writer.write("\n"); writer.write("centroid is " + cluster.getCentroid().toString()); writer.write("\n\n"); } writer.close(); }}

 

2.最终生成的模型类,也就是最终训练好的结果。K值,计算的点距离类型以及object function value值。

package com.hyr.kmeans;import java.util.ArrayList;import java.util.List;public class KmeansModel {    private List
clusters = new ArrayList
(); private Double ofv; private int k; // k值 private int calc_distance_type; public KmeansModel(List
clusters, Double ofv, int k, int calc_distance_type) { this.clusters = clusters; this.ofv = ofv; this.k = k; this.calc_distance_type = calc_distance_type; } public List
getClusters() { return clusters; } public Double getOfv() { return ofv; } public int getK() { return k; } public int getCalc_distance_type() { return calc_distance_type; }}

 

3.数据集点对象,包含点的维度,代码里只给出了x轴,y轴二维。以及点的距离计算。通过类型选择距离公式。给出了几种常用的距离公式。

package com.hyr.kmeans;public class Point {    private Float x;     // x 轴    private Float y;    // y 轴    public Point(Float x, Float y) {        this.x = x;        this.y = y;    }    public Float getX() {        return x;    }    public void setX(Float x) {        this.x = x;    }    public Float getY() {        return y;    }    public void setY(Float y) {        this.y = y;    }    @Override    public String toString() {        return "Point{" +                "x=" + x +                ", y=" + y +                '}';    }    /**     * 计算距离     *     * @param centroid 质心点     * @param type     * @return     */    public Double calculateDistance(Point centroid, int type) {        // TODO        Double result = null;        switch (type) {            case 1:                result = calcL1Distance(centroid);                break;            case 2:                result = calcCanberraDistance(centroid);                break;            case 3:                result = calcEuclidianDistance(centroid);                break;        }        return result;    }    /*            计算距离公式     */    private Double calcL1Distance(Point centroid) {        double res = 0;        res = Math.abs(getX() - centroid.getX()) + Math.abs(getY() - centroid.getY());        return res / (double) 2;    }    private double calcEuclidianDistance(Point centroid) {        return Math.sqrt(Math.pow((centroid.getX() - getX()), 2) + Math.pow((centroid.getY() - getY()), 2));    }    private double calcCanberraDistance(Point centroid) {        double res = 0;        res = Math.abs(getX() - centroid.getX()) / (Math.abs(getX()) + Math.abs(centroid.getX()))                + Math.abs(getY() - centroid.getY()) / (Math.abs(getY()) + Math.abs(centroid.getY()));        return res / (double) 2;    }    @Override    public boolean equals(Object obj) {        Point other = (Point) obj;        if (getX().equals(other.getX()) && getY().equals(other.getY())) {            return true;        }        return false;    }}

 

4.训练后最终得到的分类。包含该分类的质点,属于该分类的点集合该分类是否收敛。

package com.hyr.kmeans;import java.util.ArrayList;import java.util.List;public class Cluster {    private List
points = new ArrayList
(); // 属于该分类的点集 private Point centroid; // 该分类的中心质点 private boolean isConvergence = false; public Point getCentroid() { return centroid; } public void setCentroid(Point centroid) { this.centroid = centroid; } @Override public String toString() { return centroid.toString(); } public List
getPoints() { return points; } public void setPoints(List
points) { this.points = points; } public void initPoint() { points.clear(); } public boolean isConvergence() { return isConvergence; } public void setConvergence(boolean convergence) { isConvergence = convergence; }}

 

5.K-Meams训练类。按照上面所说四个步骤不断进行训练。

package com.hyr.kmeans;import java.util.ArrayList;import java.util.List;import java.util.Random;public class Kmeans {    /**     * kmeans     *     * @param points 数据集     * @param k      K值     * @param k      计算距离方式     */    public static KmeansModel run(List
points, int k, int type) { // 初始化质心点 List
clusters = initCentroides(points, k); while (!checkConvergence(clusters)) { // 所有分类是否全部收敛 // 1.计算距离对每个点进行分类 // 2.判断质心点是否改变,未改变则该分类已经收敛 // 3.重新生成质心点 initClusters(clusters); // 重置分类中的点 classifyPoint(points, clusters, type);// 计算距离进行分类 recalcularCentroides(clusters); // 重新计算质心点 } // 计算目标函数值 Double ofv = calcularObjetiFuncionValue(clusters); KmeansModel kmeansModel = new KmeansModel(clusters, ofv, k, type); return kmeansModel; } /** * 初始化k个质心点 * * @param points 点集 * @param k K值 * @return 分类集合对象 */ private static List
initCentroides(List
points, Integer k) { List
centroides = new ArrayList
(); // 求出数据集的范围(找出所有点的x最小、最大和y最小、最大坐标。) Float max_X = Float.NEGATIVE_INFINITY; Float max_Y = Float.NEGATIVE_INFINITY; Float min_X = Float.POSITIVE_INFINITY; Float min_Y = Float.POSITIVE_INFINITY; for (Point point : points) { max_X = max_X < point.getX() ? point.getX() : max_X; max_Y = max_Y < point.getY() ? point.getY() : max_Y; min_X = min_X > point.getX() ? point.getX() : min_X; min_Y = min_Y > point.getY() ? point.getY() : min_Y; } System.out.println("min_X" + min_X + ",max_X:" + max_X + ",min_Y" + min_Y + ",max_Y" + max_Y); // 在范围内随机初始化k个质心点 Random random = new Random(); // 随机初始化k个中心点 for (int i = 0; i < k; i++) { float x = random.nextFloat() * (max_X - min_X) + min_X; float y = random.nextFloat() * (max_Y - min_Y) + min_X; Cluster c = new Cluster(); Point centroide = new Point(x, y); // 初始化的随机中心点 c.setCentroid(centroide); centroides.add(c); } return centroides; } /** * 重新计算质心点 * * @param clusters */ private static void recalcularCentroides(List
clusters) { for (Cluster c : clusters) { if (c.getPoints().isEmpty()) { c.setConvergence(true); continue; } // 求均值,作为新的质心点 Float x; Float y; Float sum_x = 0f; Float sum_y = 0f; for (Point point : c.getPoints()) { sum_x += point.getX(); sum_y += point.getY(); } x = sum_x / c.getPoints().size(); y = sum_y / c.getPoints().size(); Point nuevoCentroide = new Point(x, y); // 新的质心点 if (nuevoCentroide.equals(c.getCentroid())) { // 如果质心点不再改变 则该分类已经收敛 c.setConvergence(true); } else { c.setCentroid(nuevoCentroide); } } } /** * 计算距离,对点集进行分类 * * @param points 点集 * @param clusters 分类 * @param type 计算距离方式 */ private static void classifyPoint(List
points, List
clusters, int type) { for (Point point : points) { Cluster masCercano = clusters.get(0); // 该点计算距离后所属的分类 Double minDistancia = Double.MAX_VALUE; // 最小距离 for (Cluster cluster : clusters) { Double distancia = point.calculateDistance(cluster.getCentroid(), type); // 点和每个分类质心点的距离 if (minDistancia > distancia) { // 得到该点和k个质心点最小的距离 minDistancia = distancia; masCercano = cluster; // 得到该点的分类 } } masCercano.getPoints().add(point); // 将该点添加到距离最近的分类中 } } private static void initClusters(List
clusters) { for (Cluster cluster : clusters) { cluster.initPoint(); } } /** * 检查收敛 * * @param clusters * @return */ private static boolean checkConvergence(List
clusters) { for (Cluster cluster : clusters) { if (!cluster.isConvergence()) { return false; } } return true; } /** * 计算目标函数值 * * @param clusters * @return */ private static Double calcularObjetiFuncionValue(List
clusters) { Double ofv = 0d; for (Cluster cluster : clusters) { for (Point point : cluster.getPoints()) { int type = 1; ofv += point.calculateDistance(cluster.getCentroid(), type); } } return ofv; }}

 

最终训练结果:

====================   K is 6 ,  Object Funcion Value is 21.82857036590576 ,  calc_distance_type is 3   ========================================   classification 1   ====================Point{x=3.5, y=12.5}centroid is Point{x=3.5, y=12.5}====================   classification 2   ====================Point{x=6.8, y=12.6}Point{x=7.8, y=12.2}Point{x=8.2, y=11.1}Point{x=9.6, y=11.1}centroid is Point{x=8.1, y=11.75}====================   classification 3   ====================Point{x=4.4, y=6.5}Point{x=4.8, y=1.1}Point{x=5.3, y=6.4}Point{x=6.6, y=7.7}Point{x=8.2, y=4.5}Point{x=8.4, y=6.9}Point{x=9.0, y=3.4}centroid is Point{x=6.671428, y=5.2142863}====================   classification 4   ====================Point{x=6.0, y=19.9}Point{x=6.2, y=18.5}Point{x=5.3, y=19.4}Point{x=7.6, y=17.4}centroid is Point{x=6.275, y=18.800001}====================   classification 5   ====================Point{x=0.8, y=9.8}Point{x=1.2, y=11.6}Point{x=2.8, y=9.6}Point{x=3.8, y=9.9}centroid is Point{x=2.15, y=10.225}====================   classification 6   ====================Point{x=6.1, y=14.3}centroid is Point{x=6.1, y=14.3}

 

代码下载地址:

http://download.csdn.net/download/huangyueranbbc/10267041

github: 

https://github.com/huangyueranbbc/KmeansDemo 

 

转载于:https://my.oschina.net/u/4074730/blog/3007470

你可能感兴趣的文章
3D地图的定时高亮和点击事件(基于echarts)
查看>>
mysql开启binlog
查看>>
设置Eclipse编码方式
查看>>
分布式系统唯一ID生成方案汇总【转】
查看>>
并查集hdu1232
查看>>
Mysql 监视工具
查看>>
从前后端分离到GraphQL,携程如何用Node实现?\n
查看>>
Linux Namespace系列(09):利用Namespace创建一个简单可用的容器
查看>>
nginc+memcache
查看>>
linux下crontab实现定时服务详解
查看>>
Numpy中的random模块中的seed方法的作用
查看>>
用java数组模拟登录和注册功能
查看>>
关于jsb中js与c++的相互调用
查看>>
UVA 122 Trees on the level 二叉树 广搜
查看>>
POJ-2251 Dungeon Master
查看>>
tortoisesvn的安装
查看>>
URAL 1353 Milliard Vasya's Function DP
查看>>
速读《构建之法:现代软件工程》提问
查看>>
Android onclicklistener中使用外部类变量时为什么需要final修饰【转】
查看>>
django中聚合aggregate和annotate GROUP BY的使用方法
查看>>