本篇博客主要讲述如何利用 spark 的 mliib 构建机器学习模型并预测新的数据,具体的流程如下图所示:
对于数据的加载或保存,mllib 提供了 MLUtils 包,其作用是 Helper methods to load,save and pre-process data used in MLLib. 博客中的数据是采用 spark 中提供的数据 sample_libsvm_data.txt,其有一百个数据样本,658 个特征。具体的数据形式如图所示:
加载 libsvm
- JavaRDD lpdata = MLUtils.loadLibSVMFile(sc, this.libsvmFile).toJavaRDD();
LabeledPoint 数据类型是对应与 libsvmfile 格式文件, 具体格式为: Lable(double 类型),vector(Vector 类型)
转化 dataFrame 数据类型
- JavaRDD jrow = lpdata.map(newLabeledPointToRow());
- StructType schema =newStructType(newStructField[]{newStructField("label", DataTypes.DoubleType,false, Metadata.empty()),newStructField("features",newVectorUDT(),false, Metadata.empty()),
- });
- SQLContext jsql =newSQLContext(sc);
- DataFrame df = jsql.createDataFrame(jrow, schema);
DataFrame:DataFrame 是一个以命名列方式组织的分布式数据集。在概念上,它跟关系型数据库中的一张表或者 1 个 Python(或者 R) 中的 data frame 一样,但是比他们更优化。DataFrame 可以根据结构化的数据文件、hive 表、外部数据库或者已经存在的 RDD 构造。
SQLContext:spark sql 所有功能的入口是 SQLContext 类,或者 SQLContext 的子类。为了创建一个基本的 SQLContext,需要一个 SparkContext。
特征归一化处理
- StandardScaler scaler = new StandardScaler().setInputCol("features").setOutputCol("normFeatures").setWithStd(true);DataFrame scalerDF = scaler.fit(df).transform(df);scaler.save(this.scalerModelPath);
利用卡方统计做特征提取
- ChiSqSelector selector = new ChiSqSelector().setNumTopFeatures(500).setFeaturesCol("normFeatures").setLabelCol("label").setOutputCol("selectedFeatures");ChiSqSelectorModel chiModel = selector.fit(scalerDF);DataFrame selectedDF = chiModel.transform(scalerDF).select("label","selectedFeatures");chiModel.save(this.featureSelectedModelPath);
- //转化为LabeledPoint数据类型, 训练模型
- JavaRDD selectedrows = selectedDF.javaRDD();
- JavaRDD trainset = selectedrows.map(new RowToLabel());
- //训练SVM模型, 并保存
- int numIteration = 200;
- SVMModel model = SVMWithSGD.train(trainset.rdd(), numIteration);
- model.clearThreshold();
- model.save(sc, this.mlModelPath);
- // LabeledPoint数据类型转化为Row
- static class LabeledPointToRow implements Function < LabeledPoint,
- Row > {
- public Row call(LabeledPoint p) throws Exception {
- double label = p.label();
- Vector vector = p.features();
- return RowFactory.create(label, vector);
- }
- }
- //Rows数据类型转化为LabeledPoint
- static class RowToLabel implements Function < Row,
- LabeledPoint > {
- public LabeledPoint call(Row r) throws Exception {
- Vector features = r.getAs(1);
- double label = r.getDouble(0);
- return new LabeledPoint(label, features);
- }
- }
测试新的样本前,需要将样本做数据的转化和特征提取的工作,所有刚刚训练模型的过程中,除了保存机器学习模型,还需要保存特征提取的中间模型。具体代码如下:
- //初始化sparkSparkConf conf =newSparkConf().setAppName("SVM").setMaster("local");
- conf.set("spark.testing.memory","2147480000");
- SparkContext sc =newSparkContext(conf);//加载测试数据JavaRDD testData = MLUtils.loadLibSVMFile(sc, this.predictDataPath).toJavaRDD();//转化DataFrame数据类型JavaRDD jrow =testData.map(newLabeledPointToRow());
- StructType schema =newStructType(newStructField[]{newStructField("label", DataTypes.DoubleType,false, Metadata.empty()),newStructField("features",newVectorUDT(),false, Metadata.empty()),
- });
- SQLContext jsql =newSQLContext(sc);
- DataFrame df = jsql.createDataFrame(jrow, schema);//数据规范化StandardScaler scaler = StandardScaler.load(this.scalerModelPath);
- DataFrame scalerDF = scaler.fit(df).transform(df);//特征选取ChiSqSelectorModel chiModel = ChiSqSelectorModel.load(this.featureSelectedModelPath);
- DataFrame selectedDF = chiModel.transform(scalerDF).select("label","selectedFeatures");
测试数据集
- SVMModel svmmodel = SVMModel.load(sc, this.mlModelPath);
- JavaRDDnew Prediction(svmmodel));
- predictResult.collect();
- static class Prediction implements Function < LabeledPoint,
- Tuple2 < Double,
- Double >> {
- SVMModel model;
- public Prediction(SVMModel model) {
- this.model = model;
- }
- public Tuple2 call(LabeledPoint p) throws Exception {
- Double score = model.predict(p.features());
- return new Tuple2(score, p.label());
- }
- }
计算准确率
- double accuracy = predictResult.filter(new PredictAndScore()).count() * 1.0 / predictResult.count();
- System.out.println(accuracy);
- static class PredictAndScore implements Function < Tuple2 < Double,
- Double > ,
- Boolean > {
- public Boolean call(Tuple2 t) throws Exception {
- double score = t._1();
- double label = t._2();
- System.out.print("score:" + score + ", label:" + label);
- if (score >= 0.0 && label >= 0.0) return true;
- else if (score < 0.0 && label < 0.0) return true;
- else return false;
- }
- }
具体的代码,放在我的 github 上: https://github.com/Quincy1994/MachineLearning/
来源: http://blog.csdn.net/qq_30843221/article/details/70552775