Spark SQL 中 UDF 和 UDAF 的使用

Spark SQL 支持 Hive 的 UDF(User defined functions) 和 UDAF(User defined aggregation functions)

UDF 传入参数只能是表中的 1 行数据(可以是多列字段),传出参数也是 1 行,具体使用如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
/**
* 拼接一行中两列字段,数据类型一个为长整型,一个为字符串
* Created by zhulei on 2017/6/20.
*/
public class ConcatLongStringUDF implements UDF3<Long, String, String, String> {
@Override
public String call(Long v1, String v2, String split) throws Exception {
return String.valueOf(v1) + split + v2;
}
}

//然后在 main 方法中注册
sqlContext.udf().register("concat_long_string", new ConcatLongStringUDF(), DataTypes.StringType);

UDAF 传入参数是多行的数据,然后通过聚合运算输出一行数据, 具体使用如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
/**
* <p>
* 组内拼接去重函数
* 多行输入,聚合成一行输出
* Created by zhulei on 2017/6/20.
*/
public class GroupConcatDistinctUDAF extends UserDefinedAggregateFunction {

/**
* 定义输入数据的 schema
* 比如你要将多行多列的数据合并,可以理解成输入多行多列的数据所对应的 schema
* 这里输入的只有一列数据,所以 schema 也就只有一个字段
*/
@Override
public StructType inputSchema() {
return DataTypes.createStructType(Collections.singletonList(
DataTypes.createStructField("cityInfo", DataTypes.StringType, true)));
}

/**
* 定义用来存储中间计算结果的 buffer 对应的 schema
* 这个值是根据你的计算过程来定的
*/
@Override
public StructType bufferSchema() {
return DataTypes.createStructType(Collections.singletonList(
DataTypes.createStructField("bufferCityInfo", DataTypes.StringType, true)
));
}

/**
* 输出值的数据类型
*/
@Override
public DataType dataType() {
return DataTypes.StringType;
}

/**
* 输入值和输出值是不是确定的
*/
@Override
public boolean deterministic() {
return true;
}

/**
* 初始化中间计算结果变量
*/
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, "");
}

/**
* 更新计算结果
* 不断的将每个输入值通过你的计算方法去计算
*/
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
String bufferCityIno = buffer.getString(0);
String inputCityInfo = input.getString(0);

if (!bufferCityIno.contains(inputCityInfo)) {
if ("".equals(bufferCityIno)) {
bufferCityIno += inputCityInfo;
} else {
bufferCityIno += "," + inputCityInfo;
}
buffer.update(0, bufferCityIno);
}
}

/**
* update 操作是某个节点上的计算
* merge 是将多个节点的结果进行合并
*/
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
String aggBuffer1 = buffer1.getString(0);
String aggBuffer2 = buffer2.getString(0);

for (String ele : aggBuffer2.split(",")) {
if (!aggBuffer1.contains(ele)) {
if ("".equals(aggBuffer1)) {
aggBuffer1 += ele;
} else {
aggBuffer1 += "," + ele;
}
}
}

buffer1.update(0, aggBuffer1);
}

/**
* 输出最终计算结果
*/
@Override
public Object evaluate(Row buffer) {
return buffer.getString(0);
}

}

//然后在 main 方法中注册
sqlContext.udf().register("group_concat_distinct", new GroupConcatDistinctUDAF());