PySpark实战:决策树原理及应用
来自CloudWiki
介绍
决策树是...
代码
#import findspark #findspark.init() ############################################## from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.ml.feature import StringIndexer, VectorAssembler spark = SparkSession.builder \ .master("local[*]") \ .appName("PySpark ML") \ .getOrCreate() sc = spark.sparkContext ############################################# df_train = spark.read.csv('./data/titanic-train.csv',header=True,inferSchema=True) \ .cache() df_train = df_train.fillna({'Age': round(29.699,0)}) df_train = df_train.fillna({'Embarked': 'S'}) labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked") model = labelIndexer.fit(df_train) df_train = model.transform(df_train) labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex") model = labelIndexer.fit(df_train) df_train = model.transform(df_train) features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived'] train_features = df_train[features] df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked'], outputCol="features") train = df_assembler.transform(train_features) from pyspark.ml.classification import DecisionTreeClassifier #DecisionTree模型 dtree = DecisionTreeClassifier(labelCol="Survived", featuresCol="features") treeModel = dtree.fit(train) #打印treeModel print(treeModel.toDebugString) #对训练数据进行预测 dt_predictions = treeModel.transform(train) dt_predictions.select("prediction", "Survived", "features").show() from pyspark.ml.evaluation import MulticlassClassificationEvaluator multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'Survived', metricName = 'accuracy') print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions)) ############################################# sc.stop()
注意:由于测试集上没有是否幸存的值,为了可以对比实际值和预测值,这里的决策树直接在训练值上进行评估。
输出
DecisionTreeClassificationModel (uid=DecisionTreeClassifier_07a879378454) of depth 5 with 35 nodes If (feature 1 in {0.0}) If (feature 2 <= 3.5) If (feature 3 <= 2.5) Predict: 1.0 Else (feature 3 > 2.5) If (feature 4 <= 1.5) Predict: 0.0 Else (feature 4 > 1.5) If (feature 3 <= 4.5) Predict: 1.0 Else (feature 3 > 4.5) Predict: 0.0 Else (feature 2 > 3.5) If (feature 0 <= 1.5) If (feature 5 <= 26.125) Predict: 0.0 Else (feature 5 > 26.125) If (feature 5 <= 26.46875) Predict: 1.0 Else (feature 5 > 26.46875) Predict: 0.0 Else (feature 0 > 1.5) If (feature 2 <= 15.5) If (feature 3 <= 1.5) Predict: 1.0 Else (feature 3 > 1.5) Predict: 0.0 Else (feature 2 > 15.5) Predict: 0.0 Else (feature 1 not in {0.0}) If (feature 0 <= 2.5) If (feature 2 <= 3.5) If (feature 0 <= 1.5) Predict: 0.0 Else (feature 0 > 1.5) Predict: 1.0 Else (feature 2 > 3.5) Predict: 1.0 Else (feature 0 > 2.5) If (feature 5 <= 24.808349999999997) If (feature 6 in {1.0,2.0}) If (feature 2 <= 30.25) Predict: 1.0 Else (feature 2 > 30.25) Predict: 0.0 Else (feature 6 not in {1.0,2.0}) If (feature 5 <= 21.0375) Predict: 1.0 Else (feature 5 > 21.0375) Predict: 0.0 Else (feature 5 > 24.808349999999997) Predict: 0.0 +----------+--------+--------------------+ |prediction|Survived| features| +----------+--------+--------------------+ | 0.0| 0|[3.0,0.0,22.0,1.0...| | 1.0| 1|[1.0,1.0,38.0,1.0...| | 1.0| 1|[3.0,1.0,26.0,0.0...| | 1.0| 1|[1.0,1.0,35.0,1.0...| | 0.0| 0|(7,[0,2,5],[3.0,3...| | 0.0| 0|[3.0,0.0,30.0,0.0...| | 0.0| 0|(7,[0,2,5],[1.0,5...| | 0.0| 0|[3.0,0.0,2.0,3.0,...| | 1.0| 1|[3.0,1.0,27.0,0.0...| | 1.0| 1|[2.0,1.0,14.0,1.0...| | 1.0| 1|[3.0,1.0,4.0,1.0,...| | 1.0| 1|[1.0,1.0,58.0,0.0...| | 0.0| 0|(7,[0,2,5],[3.0,2...| | 0.0| 0|[3.0,0.0,39.0,1.0...| | 1.0| 0|[3.0,1.0,14.0,0.0...| | 1.0| 1|[2.0,1.0,55.0,0.0...| | 0.0| 0|[3.0,0.0,2.0,4.0,...| | 0.0| 1|(7,[0,2,5],[2.0,3...| | 1.0| 0|[3.0,1.0,31.0,1.0...| | 1.0| 1|[3.0,1.0,30.0,0.0...| +----------+--------+--------------------+ only showing top 20 rows Decision Tree Accu: 0.8417508417508418