PySpark实战:逻辑回归原理与应用
来自CloudWiki
逻辑回归含义
逻辑回归为概率型非线性回归方法,是研究二分类观察结果与一些影响因素(x1,x2,x3,...,xn)之间关系的一种多变量分析方法
Logistic回归模型是有适用条件的;
- 因变量为二分类的分类变量或某事件的发生率,并且是数值型变量
- 各观测对象之间相互独立
- 自变量和Logostic概率是线性关系
Titanic幸存者预测:模型训练
#import findspark #findspark.init() ############################################## from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.ml.feature import StringIndexer, VectorAssembler spark = SparkSession.builder \ .master("local[*]") \ .appName("PySpark ML") \ .getOrCreate() sc = spark.sparkContext ############################################# df_train = spark.read.csv('./data/titanic-train.csv',header=True,inferSchema=True) \ .cache() df_train = df_train.fillna({'Age': round(29.699,0)}) df_train = df_train.fillna({'Embarked': 'S'}) df_train = df_train.drop("Cabin") df_train = df_train.drop("Ticket") labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked") model = labelIndexer.fit(df_train) df_train = model.transform(df_train) labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex") model = labelIndexer.fit(df_train) df_train = model.transform(df_train) features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived'] train_features = df_train[features] df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked'], outputCol="features") train = df_assembler.transform(train_features) from pyspark.ml.classification import LogisticRegression #LogisticRegression模型 lg = LogisticRegression(labelCol='Survived') lgModel = lg.fit(train) #保存模型 lgModel.write().overwrite().save("./model/logistic-titanic") print("save model to ./model/logistic-titanic") trainingSummary = lgModel.summary trainingSummary.roc.show() #ROC是Receiver Characterisitic Operator的缩写,它可以反映机器学习模型的预测效果 print("areaUnderROC: " + str(trainingSummary.areaUnderROC)) #一般来说,ROC曲线下面的面积越大,则模型对于训练集的评估效果越好,但是要注意过拟合的问题 #TPR = TP / (TP+FN); #表示当前分到正样本中真实的正样本所占所有正样本的比例; #FPR = FP / (FP + TN); #表示当前被错误分到正样本类别中真实的负样本所占所有负样本总数的比例 #ROC curve is a plot of FPR against TPR import matplotlib.pyplot as plt plt.figure(figsize=(5,5)) plt.plot([0, 1], [0, 1], 'r--') plt.plot(lgModel.summary.roc.select('FPR').collect(), lgModel.summary.roc.select('TPR').collect()) plt.xlabel('FPR') plt.ylabel('TPR') plt.show() ############################################# sc.stop()
输出
save model to ./model/logistic-titanic +--------------------+--------------------+ | FPR| TPR| +--------------------+--------------------+ | 0.0| 0.0| |0.001821493624772...|0.017543859649122806| |0.001821493624772...| 0.04093567251461988| |0.001821493624772...| 0.06140350877192982| |0.001821493624772...| 0.08187134502923976| |0.001821493624772...| 0.1023391812865497| |0.001821493624772...| 0.12280701754385964| |0.003642987249544...| 0.14035087719298245| |0.003642987249544...| 0.1608187134502924| |0.003642987249544...| 0.18128654970760233| |0.003642987249544...| 0.20175438596491227| |0.003642987249544...| 0.2222222222222222| | 0.00546448087431694| 0.23976608187134502| | 0.00546448087431694| 0.2631578947368421| | 0.00546448087431694| 0.28362573099415206| |0.007285974499089253| 0.30409356725146197| |0.007285974499089253| 0.32456140350877194| |0.007285974499089253| 0.347953216374269| |0.007285974499089253| 0.3684210526315789| |0.009107468123861567| 0.38596491228070173| +--------------------+--------------------+ only showing top 20 rows areaUnderROC: 0.8569355233864868
注:一般对于训练好的模型,为了评估效果,ROC曲线需要用测试集进行评估。这里用训练集绘制ROC,可能会出现过拟合的问题
Titanic幸存者预测:模型预测
一旦模型训练完毕,且通过效果评估,那么就可以利用模型对新的数据进行预测
#import findspark #findspark.init() ############################################## from pyspark.sql import SparkSession from pyspark.sql.context import SQLContext from pyspark.ml.feature import StringIndexer, VectorAssembler spark = SparkSession.builder \ .master("local[*]") \ .appName("PySpark ML") \ .getOrCreate() sc = spark.sparkContext ############################################# df_test = spark.read.csv('./data/titanic-test.csv',header=True,inferSchema=True) \ .cache() df_test = df_test.fillna({'Age': round(29.699,0)}) df_test = df_test.fillna({'Embarked': 'S'}) #有一个null df_test = df_test.fillna({'Fare': 36.0}) df_test = df_test.drop("Cabin") df_test = df_test.drop("Ticket") #新增Survived列,默认值为0 df_test = df_test.withColumn("Survived",0 * df_test["Age"]) labelIndexer = StringIndexer(inputCol="Embarked", outputCol="iEmbarked") model = labelIndexer.fit(df_test) df_test = model.transform(df_test) labelIndexer = StringIndexer(inputCol="Sex", outputCol="iSex") model = labelIndexer.fit(df_test) df_test = model.transform(df_test) features = ['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked','Survived'] test_features = df_test[features] df_assembler = VectorAssembler(inputCols=['Pclass', 'iSex', 'Age', 'SibSp', 'Parch', 'Fare', 'iEmbarked'], outputCol="features") test = df_assembler.transform(test_features) from pyspark.ml.classification import LogisticRegressionModel lgModel = LogisticRegressionModel.load("./model/logistic-titanic") testSummary =lgModel.evaluate(test) results=testSummary.predictions results["features","rawPrediction","probability","prediction"].show() ############################################# sc.stop()
输出:
D:\Tech\PySpark实战\PySpark源代码\ch06\test>spark-submit 05Predict.py +--------------------+--------------------+--------------------+----------+ | features| rawPrediction| probability|prediction| +--------------------+--------------------+--------------------+----------+ |[3.0,0.0,34.5,0.0...|[1.99328605097899...|[0.88009035220072...| 0.0| |[3.0,1.0,47.0,1.0...|[0.63374031844971...|[0.65333708360849...| 0.0| |[2.0,0.0,62.0,0.0...|[1.97058477648159...|[0.87767391006101...| 0.0| |(7,[0,2,5],[3.0,2...|[2.21170839644084...|[0.90129601257823...| 0.0| |[3.0,1.0,22.0,1.0...|[-0.2919725495559...|[0.42752102300610...| 1.0| |(7,[0,2,5],[3.0,1...|[1.68822917787023...|[0.84399113755992...| 0.0| |[3.0,1.0,30.0,0.0...|[-0.9032166903750...|[0.28838991532794...| 1.0| |[2.0,0.0,26.0,1.0...|[1.42490075002778...|[0.80610554993708...| 0.0| |[3.0,1.0,18.0,0.0...|[-1.1236436862496...|[0.24533604281752...| 1.0| |[3.0,0.0,21.0,2.0...|[2.59895227540995...|[0.93079411943702...| 0.0| |(7,[0,2,5],[3.0,3...|[2.33390585204715...|[0.91164644844255...| 0.0| |(7,[0,2,5],[1.0,4...|[0.69025711721974...|[0.66602412131662...| 0.0| |[1.0,1.0,23.0,1.0...|[-2.7419887292668...|[0.06054069440361...| 1.0| |[2.0,0.0,63.0,1.0...|[2.82767950026722...|[0.94415337330052...| 0.0| |[1.0,1.0,47.0,1.0...|[-1.7316563679495...|[0.15037583472736...| 1.0| |[2.0,1.0,24.0,1.0...|[-1.7197655536498...|[0.15190136429145...| 1.0| |[2.0,0.0,35.0,0.0...|[0.88008689342827...|[0.70684022722317...| 0.0| |[3.0,0.0,21.0,0.0...|[1.71304684487762...|[0.84723105652294...| 0.0| |[3.0,1.0,27.0,1.0...|[-0.1717428611873...|[0.45716950894284...| 1.0| |[3.0,1.0,45.0,0.0...|[-0.0389664987514...|[0.49025960775551...| 1.0| +--------------------+--------------------+--------------------+----------+ only showing top 20 rows