8.1 累加器(Accumulator)

累加器用来对信息进行聚合,通常在向 Spark 传递函数时,比如使用 map() 函数或者用 filter() 传条件时,可以使用驱动器程序中定义的变量,但是集群中运行的每个任务都会得到这些变量的一份新的副本,所以更新这些副本的值不会影响驱动器中的对应变量。

如果我们想实现所有分片处理时更新共享变量的功能,那么累加器可以实现我们想要的效果。

累加器是一种变量, 仅仅支持"add", 支持并发. 累加器用于去实现计数器或者求和. Spark 内部已经支持数字类型的累加器, 开发者可以添加其他类型的支持.

内置累加器

需求:计算文件中空行的数量

package day04

import org.apache.spark.rdd.RDD
import org.apache.spark.util.LongAccumulator
import org.apache.spark.{SparkConf, SparkContext}

object AccDemo1 {
    def main(args: Array[String]): Unit = {
        val conf = new SparkConf().setAppName("Practice").setMaster("local[2]")
        val sc = new SparkContext(conf)
        val rdd: RDD[String] = sc.textFile("file://" + ClassLoader.getSystemResource("words.txt").getPath)
        // 得到一个 Long 类型的累加器.  将从 0 开始累加
        val emptyLineCount: LongAccumulator = sc.longAccumulator
        rdd.foreach(s => if (s.trim.length == 0) emptyLineCount.add(1))
        println(emptyLineCount.value)
    }
}

说明:

  • 在驱动程序中通过sc.longAccumulator得到Long类型的累加器, 还有Double类型的

  • 可以通过value来访问累加器的值.(与sum等价). avg得到平均值

  • 只能通过add来添加值.

  • 累加器的更新操作最好放在action中, Spark 可以保证每个 task 只执行一次. 如果放在 transformations 操作中则不能保证只更新一次.有可能会被重复执行.


自定义累加器

通过继承类AccumulatorV2来自定义累加器.

下面这个累加器可以用于在程序运行过程中收集一些文本类信息,最终以List[String]的形式返回。

package day04

import java.util
import java.util.{ArrayList, Collections}

import org.apache.spark.util.AccumulatorV2

object MyAccDemo {
    def main(args: Array[String]): Unit = {

    }
}

class MyAcc extends AccumulatorV2[String, java.util.List[String]] {
    private val _list: java.util.List[String] = Collections.synchronizedList(new ArrayList[String]())
    override def isZero: Boolean = _list.isEmpty

    override def copy(): AccumulatorV2[String, util.List[String]] = {
        val newAcc = new MyAcc
        _list.synchronized {
            newAcc._list.addAll(_list)
        }
        newAcc
    }

    override def reset(): Unit = _list.clear()

    override def add(v: String): Unit = _list.add(v)

    override def merge(other: AccumulatorV2[String, util.List[String]]): Unit =other match {
        case o: MyAcc => _list.addAll(o.value)
        case _ => throw new UnsupportedOperationException(
            s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
    }

    override def value: util.List[String] = java.util.Collections.unmodifiableList(new util.ArrayList[String](_list))
}

测试:

object MyAccDemo {
    def main(args: Array[String]): Unit = {
        val pattern = """^\d+$"""
        val conf = new SparkConf().setAppName("Practice").setMaster("local[2]")
        val sc = new SparkContext(conf)
        // 统计出来非纯数字, 并计算纯数字元素的和
        val rdd1 = sc.parallelize(Array("abc", "a30b", "aaabb2", "60", "20"))

        val acc = new MyAcc
        sc.register(acc)
        val rdd2: RDD[Int] = rdd1.filter(x => {
            val flag: Boolean = x.matches(pattern)
            if (!flag) acc.add(x)
            flag
        }).map(_.toInt)
        println(rdd2.reduce(_ + _))
        println(acc.value)
    }
}

注意:

  • 在使用自定义累加器的不要忘记注册sc.register(acc)

Copyright © 尚硅谷大数据 2019 all right reserved,powered by Gitbook
该文件最后修订时间: 2019-08-09 00:21:43

results matching ""

    No results matching ""