10.7.2 用户自定会聚合函数

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数

继承UserDefinedAggregateFunction

package day05

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

object UDFDemo1 {
    def main(args: Array[String]): Unit = {
        // 测试自定义的聚合函数
        val spark: SparkSession = SparkSession
            .builder()
            .master("local[*]")
            .appName("UDFDemo1")
            .getOrCreate()
        // 注册自定义函数
        spark.udf.register("myAvg", new MyAvg)
        val df = spark.read.json("file://" + ClassLoader.getSystemResource("user.json").getPath)
        df.createTempView("user")
        spark.sql("select myAvg(age) age_avg from user").show

    }
}

object MyAvg extends UserDefinedAggregateFunction {
    /**
      * 返回聚合函数输入参数的数据类型
      *
      * @return
      */
    override def inputSchema: StructType = {
        StructType(StructField("inputColumn", DoubleType) :: Nil)
    }

    /**
      * 聚合缓冲区中值的类型
      *
      * @return
      */
    override def bufferSchema: StructType = {
        StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
    }

    /**
      * 最终的返回值的类型
      *
      * @return
      */
    override def dataType: DataType = DoubleType

    /**
      * 确定性: 比如同样的输入是否返回同样的输出
      *
      * @return
      */
    override def deterministic: Boolean = true

    /**
      * 初始化
      *
      * @param buffer
      */
    override def initialize(buffer: MutableAggregationBuffer): Unit = {
        // 存数据的总和
        buffer(0) = 0d
        // 储存数据的个数
        buffer(1) = 0L
    }

    /**
      * 相同 Executor间的合并
      *
      * @param buffer
      * @param input
      */
    override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        if (!input.isNullAt(0)) {
            buffer(0) = buffer.getDouble(0) + input.getDouble(0)
            buffer(1) = buffer.getLong(1) + 1
        }
    }


    /**
      * 不同 Executor间的合并
      *
      * @param buffer1
      * @param buffer2
      */
    override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        if (!buffer2.isNullAt(0)) {
            buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
            buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
        }
    }

    /**
      * 计算最终的结果.  因为是聚合函数, 所以最后只有一行了
      *
      * @param buffer
      * @return
      */
    override def evaluate(buffer: Row): Double = {
        println(buffer.getDouble(0), buffer.getLong(1))
        buffer.getDouble(0) / buffer.getLong(1)
    }

}
Copyright © 尚硅谷大数据 2019 all right reserved,powered by Gitbook
该文件最后修订时间: 2019-08-09 00:21:43

results matching ""

    No results matching ""