PySpark实战:决策树原理及应用

来自CloudWiki
跳转至: 导航搜索

介绍

决策树是...

代码

#import findspark
#findspark.init()
##############################################
from pyspark.sql import SparkSession
from pyspark.sql.context import SQLContext
from pyspark.ml.feature import StringIndexer, VectorAssembler
spark = SparkSession.builder \
        .master("local[*]") \
        .appName("PySpark ML") \
        .getOrCreate()
sc = spark.sparkContext
#############################################
df_train = spark.read.csv('./data/titanic-train.csv',header=True,inferSchema=True) \
  .cache()
df_train = df_train.fillna({'Age': round(29.699,0)})
df_train = df_train.fillna({'Embarked': 'S'})
labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked")
model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex")
model = labelIndexer.fit(df_train)
df_train = model.transform(df_train)
features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived']
train_features = df_train[features]
df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp',
   'Parch', 'Fare', 'iEmbarked'], outputCol="features")
train = df_assembler.transform(train_features)

from pyspark.ml.classification import DecisionTreeClassifier
#DecisionTree模型
dtree = DecisionTreeClassifier(labelCol="Survived", featuresCol="features")
treeModel = dtree.fit(train)
#打印treeModel
print(treeModel.toDebugString)
#对训练数据进行预测
dt_predictions = treeModel.transform(train)
dt_predictions.select("prediction", "Survived", "features").show()

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
multi_evaluator = MulticlassClassificationEvaluator(labelCol = 'Survived', metricName = 'accuracy')
print('Decision Tree Accu:', multi_evaluator.evaluate(dt_predictions))
#############################################
sc.stop()

注意:由于测试集上没有是否幸存的值,为了可以对比实际值和预测值,这里的决策树直接在训练值上进行评估。

输出

DecisionTreeClassificationModel (uid=DecisionTreeClassifier_07a879378454) of depth 5 with 35 nodes
  If (feature 1 in {0.0})
   If (feature 2 <= 3.5)
    If (feature 3 <= 2.5)
     Predict: 1.0
    Else (feature 3 > 2.5)
     If (feature 4 <= 1.5)
      Predict: 0.0
     Else (feature 4 > 1.5)
      If (feature 3 <= 4.5)
       Predict: 1.0
      Else (feature 3 > 4.5)
       Predict: 0.0
   Else (feature 2 > 3.5)
    If (feature 0 <= 1.5)
     If (feature 5 <= 26.125)
      Predict: 0.0
     Else (feature 5 > 26.125)
      If (feature 5 <= 26.46875)
       Predict: 1.0
      Else (feature 5 > 26.46875)
       Predict: 0.0
    Else (feature 0 > 1.5)
     If (feature 2 <= 15.5)
      If (feature 3 <= 1.5)
       Predict: 1.0
      Else (feature 3 > 1.5)
       Predict: 0.0
     Else (feature 2 > 15.5)
      Predict: 0.0
  Else (feature 1 not in {0.0})
   If (feature 0 <= 2.5)
    If (feature 2 <= 3.5)
     If (feature 0 <= 1.5)
      Predict: 0.0
     Else (feature 0 > 1.5)
      Predict: 1.0
    Else (feature 2 > 3.5)
     Predict: 1.0
   Else (feature 0 > 2.5)
    If (feature 5 <= 24.808349999999997)
     If (feature 6 in {1.0,2.0})
      If (feature 2 <= 30.25)
       Predict: 1.0
      Else (feature 2 > 30.25)
       Predict: 0.0
     Else (feature 6 not in {1.0,2.0})
      If (feature 5 <= 21.0375)
       Predict: 1.0
      Else (feature 5 > 21.0375)
       Predict: 0.0
    Else (feature 5 > 24.808349999999997)
     Predict: 0.0

+----------+--------+--------------------+
|prediction|Survived|            features|
+----------+--------+--------------------+
|       0.0|       0|[3.0,0.0,22.0,1.0...|
|       1.0|       1|[1.0,1.0,38.0,1.0...|
|       1.0|       1|[3.0,1.0,26.0,0.0...|
|       1.0|       1|[1.0,1.0,35.0,1.0...|
|       0.0|       0|(7,[0,2,5],[3.0,3...|
|       0.0|       0|[3.0,0.0,30.0,0.0...|
|       0.0|       0|(7,[0,2,5],[1.0,5...|
|       0.0|       0|[3.0,0.0,2.0,3.0,...|
|       1.0|       1|[3.0,1.0,27.0,0.0...|
|       1.0|       1|[2.0,1.0,14.0,1.0...|
|       1.0|       1|[3.0,1.0,4.0,1.0,...|
|       1.0|       1|[1.0,1.0,58.0,0.0...|
|       0.0|       0|(7,[0,2,5],[3.0,2...|
|       0.0|       0|[3.0,0.0,39.0,1.0...|
|       1.0|       0|[3.0,1.0,14.0,0.0...|
|       1.0|       1|[2.0,1.0,55.0,0.0...|
|       0.0|       0|[3.0,0.0,2.0,4.0,...|
|       0.0|       1|(7,[0,2,5],[2.0,3...|
|       1.0|       0|[3.0,1.0,31.0,1.0...|
|       1.0|       1|[3.0,1.0,30.0,0.0...|
+----------+--------+--------------------+
only showing top 20 rows

Decision Tree Accu: 0.8417508417508418