UDAF 是 Hive 中用户自定义的聚合函数,特点是输入多行输出一行,Hive 内置 UDAF 函数包括有 sum() 与 count() 等。UDAF 实现有简单与通用两种方式,简单 UDAF 因为使用 Java 反射导致性能损失,而且有些特性不能使用,已经被弃用了;在本文中我们将关注 Hive 中自定义聚合函数-GenericUDAF,即通用方式。

0x00 自定义 GenericUDAF 开发

编写 GenericUDAF 需要下面两个步骤:

  • 继承 org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver ,重写 getEvaluator 函数;
  • 依据 getEvaluator 函数返回值,编写内部静态类,继承 org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator,并重写其7个方法;

上述过程涉及两个重要抽象类 Resolver 与 Evaluator,其中 Resolver 官方不建议使用 GenericUDAFResolver2 接口,而使用 AbstractGenericUDAFResolver 接口;Evaluator 使用 GenericUDAFEvaluator 接口,并且必须是 public static 子类,需要重写的方法有 initgetNewAggregationBufferiterateterminatePartialmergeterminatereset 共7个;理解 Evaluator 之前,必须先理解 ObjectInspector 接口 与 Model 内部类;另外,UDAF 逻辑处理主要发生在 Evaluator 中。

ObjectInspector 作用主要是解耦数据使用与数据格式,使得数据流在输入输出端切换不同的输入输出格式,不同的 Operator 上使用不同的格式。

Model 代表了 UDAF 在 MapReduce 的各个阶段,具体如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
public static enum Mode {
/**
* PARTIAL1: 这个是mapreduce的map阶段:从原始数据到部分数据聚合
* 将会调用iterate()和terminatePartial()
*/
PARTIAL1,
/**
* PARTIAL2: 这个是mapreduce的map端的Combiner阶段,负责在map端合并map的数据::从部分数据聚合到部分数据聚合:
* 将会调用merge() 和 terminatePartial()
*/
PARTIAL2,
/**
* FINAL: mapreduce的reduce阶段:从部分数据的聚合到完全聚合
* 将会调用merge()和terminate()
*/
FINAL,
/**
* COMPLETE: 如果出现了这个阶段,表示mapreduce只有map,没有reduce,所以map端就直接出结果了:从原始数据直接到完全聚合
* 将会调用 iterate()和terminate()
*/
COMPLETE
};

一般情况下,完整的 UDAF 逻辑是一个 mapreduce 过程,如果有 mapper 和 reducer,就会经历 PARTIAL1(mapper),FINAL(reducer),如果还有 combiner,那就会经历 PARTIAL1(mapper),PARTIAL2(combiner),FINAL(reducer)。而有一些情况下的 mapreduce,只有 mapper 没有 reducer,所以就会只有 COMPLETE 阶段,这个阶段直接输入原始数据,出结果。

Model各阶段对应Evaluator方法调用

0x01 散度 DivergenceGenericUDAF 示例代码

散度数学公式

上述公式的参考代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package com.data.hive;

import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.util.JavaDataModel;

@Description(
name = "divergence",
value = "_FUNC_(px, qy) - Calculate divergence from px, qy ")
public class DivergenceGenericUDAF extends AbstractGenericUDAFResolver {

@Override
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
if (parameters.length != 2) {
throw new UDFArgumentTypeException(parameters.length - 1,
"Exactly two arguments are expected.");
}

if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(0,
"Only primitive type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}

if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
throw new UDFArgumentTypeException(1,
"Only primitive type arguments are accepted but "
+ parameters[1].getTypeName() + " is passed.");
}

switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case TIMESTAMP:
case DECIMAL:
switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) {
case BYTE:
case SHORT:
case INT:
case LONG:
case FLOAT:
case DOUBLE:
case TIMESTAMP:
case DECIMAL:
return new GenericUDAFDivergenceEvaluator();
case STRING:
case BOOLEAN:
case DATE:
default:
throw new UDFArgumentTypeException(1,
"Only numeric or string type arguments are accepted but "
+ parameters[1].getTypeName() + " is passed.");
}
case STRING:
case BOOLEAN:
case DATE:
default:
throw new UDFArgumentTypeException(0,
"Only numeric or string type arguments are accepted but "
+ parameters[0].getTypeName() + " is passed.");
}
}

public static class GenericUDAFDivergenceEvaluator extends GenericUDAFEvaluator {

private PrimitiveObjectInspector pInputOI;
private PrimitiveObjectInspector qInputOI;

private PrimitiveObjectInspector partialOI;

private Object partialResult;

private DoubleWritable result;

/*
* PARTIAL1 (input, partial)
* PARTIAL2 (partial, partial)
* FINAL (partial, output)
* COMPLETE (input, output)
*/
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
super.init(m, parameters);
result = new DoubleWritable(0);
// initialize input
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
assert (parameters.length == 2);
pInputOI = (PrimitiveObjectInspector) parameters[0];
qInputOI = (PrimitiveObjectInspector) parameters[1];
} else {
partialOI = (PrimitiveObjectInspector) parameters[0];
}
// initialize output
if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
partialOI = (PrimitiveObjectInspector) parameters[0];
partialResult = new DoubleWritable(0);
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
} else {
partialResult = new DoubleWritable(0);
return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
}

static class SumDoubleAgg extends AbstractAggregationBuffer {
boolean empty;
double sum;

@Override
public int estimate() { return JavaDataModel.PRIMITIVES1 + JavaDataModel.PRIMITIVES2; }
}

@Override
public void reset(AggregationBuffer agg) throws HiveException {
SumDoubleAgg divergenceAgg = (SumDoubleAgg) agg;
divergenceAgg.empty = true;
divergenceAgg.sum = 0.0;
}

@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
SumDoubleAgg divergenceAgg = new SumDoubleAgg();
reset(divergenceAgg);
return divergenceAgg;
}

@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
assert (parameters.length == 2);
SumDoubleAgg divergenceAgg = (SumDoubleAgg) agg;

double vp = PrimitiveObjectInspectorUtils.getDouble(parameters[0], pInputOI);
double vq = PrimitiveObjectInspectorUtils.getDouble(parameters[1], qInputOI);

divergenceAgg.sum += vp * Math.log(vp / vq);
}

@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
SumDoubleAgg divergenceAgg = (SumDoubleAgg) agg;
partialResult = new DoubleWritable(divergenceAgg.sum);
return partialResult;
}

@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial != null) {
SumDoubleAgg divergenceAgg = (SumDoubleAgg) agg;

double subSum = PrimitiveObjectInspectorUtils.getDouble(partial, partialOI);
divergenceAgg.sum += subSum;

divergenceAgg.empty = false;
}
}

@Override
public Object terminate(AggregationBuffer agg) throws HiveException {
SumDoubleAgg divergenceAgg = (SumDoubleAgg) agg;
if (divergenceAgg.empty) {
result = null;
} else {
result.set(divergenceAgg.sum);
}
return result;
}
}
}

0x02 代码走读

  1. getEvaluator() 被调用
    1.1 检查输入参数的个数;
    1.2 检查输入参数的类型;
    1.3 根据不同的输入参数类型,返回对应的 GenericUDAFEvaluator;

  2. 类 GenericUDAFDivergenceEvaluator 的方法说明
    2.1 init() 方法,定义 mapreduce 不同阶段的输入与输出;
    2.2 getNewAggregationBuffer() 方法,获取新的中间结果;
    2.3 iterate() 方法,读取输入行的 p,q 累加至中间结果;
    2.4 terminatePartial() 方法,输出中间结果;
    2.5 merge() 方法,聚合中间结果;
    2.6 terminate() 方法,输出聚合结果;
    2.7 reset() 方法,重置中间结果;

  3. 方法调用过程(参考 Model各阶段对应Evaluator方法调用)

参考文献

GenericUDAFCaseStudy
Hive UDAF开发详解