10.7.2 用户自定会聚合函数
强类型的Dataset
和弱类型的DataFrame
都提供了相关的聚合函数, 如 count(),countDistinct(),avg(),max(),min()
。除此之外,用户可以设定自己的自定义聚合函数
继承UserDefinedAggregateFunction
package day05
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
object UDFDemo1 {
def main(args: Array[String]): Unit = {
// 测试自定义的聚合函数
val spark: SparkSession = SparkSession
.builder()
.master("local[*]")
.appName("UDFDemo1")
.getOrCreate()
// 注册自定义函数
spark.udf.register("myAvg", new MyAvg)
val df = spark.read.json("file://" + ClassLoader.getSystemResource("user.json").getPath)
df.createTempView("user")
spark.sql("select myAvg(age) age_avg from user").show
}
}
object MyAvg extends UserDefinedAggregateFunction {
/**
* 返回聚合函数输入参数的数据类型
*
* @return
*/
override def inputSchema: StructType = {
StructType(StructField("inputColumn", DoubleType) :: Nil)
}
/**
* 聚合缓冲区中值的类型
*
* @return
*/
override def bufferSchema: StructType = {
StructType(StructField("sum", DoubleType) :: StructField("count", LongType) :: Nil)
}
/**
* 最终的返回值的类型
*
* @return
*/
override def dataType: DataType = DoubleType
/**
* 确定性: 比如同样的输入是否返回同样的输出
*
* @return
*/
override def deterministic: Boolean = true
/**
* 初始化
*
* @param buffer
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存数据的总和
buffer(0) = 0d
// 储存数据的个数
buffer(1) = 0L
}
/**
* 相同 Executor间的合并
*
* @param buffer
* @param input
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getDouble(0) + input.getDouble(0)
buffer(1) = buffer.getLong(1) + 1
}
}
/**
* 不同 Executor间的合并
*
* @param buffer1
* @param buffer2
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
if (!buffer2.isNullAt(0)) {
buffer1(0) = buffer1.getDouble(0) + buffer2.getDouble(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
}
/**
* 计算最终的结果. 因为是聚合函数, 所以最后只有一行了
*
* @param buffer
* @return
*/
override def evaluate(buffer: Row): Double = {
println(buffer.getDouble(0), buffer.getLong(1))
buffer.getDouble(0) / buffer.getLong(1)
}
}