PySpark实战:mapPartitionsWithIndex操作

来自CloudWiki
跳转至: 导航搜索

介绍

mapPartitionsWithIndex操作是一个变换算子,

它的作用是对RDD每个分区上应用函数func,同时跟踪原始分区的索引

并返回新的RDD

代码


import findspark
findspark.init()
##############################################
from pyspark.sql import SparkSession
spark = SparkSession.builder \
        .master("local[1]") \
        .appName("RDD Demo") \
        .getOrCreate();
sc = spark.sparkContext
#############################################
rdd = sc.parallelize([1, 2, 3, 4 ,5 ,6], 3)
def f(index, iter): 
        #分区索引 0,1,2
        print(index)
        for x in iter:
                #1,2;3,4;5,6
                print(x)
        yield index
ret = rdd.mapPartitionsWithIndex(f).sum()
#3=0+1+2
print(ret)
##############################################
sc.stop()
  • rdd.mapPartitionsWithIndex按照分区进行遍历,这个过程中创建一个[0,1,2]的新RDD对象。
  • 然后在RDD对象上调用sum(),将元素值进行累加,值为0+1+2 =3

输出

0
1
2
1
3
4
2
5
6
3