一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。
二.图片解说
1. Ensemble trees原始打分过程
像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。
2.论文中提到的打分过程
A.为回归树中的每个分枝打上true和false标签
比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。
B.为每个branch分配一个bitvector
这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。
C.打分阶段
此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。
三.代码
1 import org.apache.lucene.index.LeafReaderContext;
2 import org.apache.lucene.search.Explanation;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.model.LTRScoringModel;
5 import org.apache.solr.ltr.model.ModelException;
6 import org.apache.solr.ltr.norm.Normalizer;
7 import org.apache.solr.util.SolrPluginUtils;
8
9 import java.util.*;
10
11 public class MultipleAdditiveTreesModel extends LTRScoringModel {
12
13 // 特征名:索引(从0开始)
14 private final HashMap<String, Integer> fname2index = new HashMap();
15 private List<RegressionTree> trees;
16
17 private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
18 MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree();
19 if(map != null) {
20 SolrPluginUtils.invokeSetters(rt, map.entrySet());
21 }
22
23 return rt;
24 }
25
26 private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
27 MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode();
28 if(map != null) {
29 SolrPluginUtils.invokeSetters(rtn, map.entrySet());
30 }
31
32 return rtn;
33 }
34
35 public void setTrees(Object trees) {
36 this.trees = new ArrayList();
37 Iterator var2 = ((List)trees).iterator();
38
39 while(var2.hasNext()) {
40 Object o = var2.next();
41 MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o);
42 this.trees.add(rt);
43 }
44 }
45
46 public void setTrees(List<RegressionTree> trees) {
47 this.trees = trees;
48 }
49
50 public List<RegressionTree> getTrees() {
51 return this.trees;
52 }
53
54 public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
55 super(name, features, norms, featureStoreName, allFeatures, params);
56
57 for(int i = 0; i < features.size(); ++i) {
58 String key = ((Feature)features.get(i)).getName();
59 this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
60 }
61
62 }
63
64 public void validate() throws ModelException {
65 super.validate();
66 if(this.trees == null) {
67 throw new ModelException("no trees declared for model " + this.name);
68 } else {
69 Iterator var1 = this.trees.iterator();
70
71 while(var1.hasNext()) {
72 MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next();
73 tree.validate();
74 }
75
76 }
77 }
78
79 public float score(float[] modelFeatureValuesNormalized) {
80 float score = 0.0F;
81
82 MultipleAdditiveTreesModel.RegressionTree t;
83 for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
84 t = (MultipleAdditiveTreesModel.RegressionTree)var3.next();
85 }
86
87 return score;
88 }
89
90 public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
91 float[] fv = new float[featureExplanations.size()];
92 int index = 0;
93
94 for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
95 Explanation featureExplain = (Explanation)details.next();
96 fv[index] = featureExplain.getValue();
97 }
98
99 ArrayList var12 = new ArrayList();
100 index = 0;
101
102 for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
103 MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next();
104 float score = t.score(fv);
105 Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
106 var12.add(p);
107 }
108
109 return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
110 }
111
112 public String toString() {
113 StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
114 sb.append("(name=").append(this.getName());
115 sb.append(",trees=[");
116
117 for(int ii = 0; ii < this.trees.size(); ++ii) {
118 if(ii > 0) {
119 sb.append(',');
120 }
121
122 sb.append(this.trees.get(ii));
123 }
124
125 sb.append("])");
126 return sb.toString();
127 }
128
129 public class RegressionTree {
130 private Float weight;
131 private MultipleAdditiveTreesModel.RegressionTreeNode root;
132
133 public void setWeight(float weight) {
134 this.weight = new Float(weight);
135 }
136
137 public void setWeight(String weight) {
138 this.weight = new Float(weight);
139 }
140
141 public float getWeight() {
142 return this.weight;
143 }
144
145 public void setRoot(Object root) {
146 this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
147 }
148
149 public RegressionTreeNode getRoot() {
150 return this.root;
151 }
152
153 public float score(float[] featureVector) {
154 return this.weight.floatValue() * this.root.score(featureVector);
155 }
156
157 public String explain(float[] featureVector) {
158 return this.root.explain(featureVector);
159 }
160
161 public String toString() {
162 StringBuilder sb = new StringBuilder();
163 sb.append("(weight=").append(this.weight);
164 sb.append(",root=").append(this.root);
165 sb.append(")");
166 return sb.toString();
167 }
168
169 public RegressionTree() {
170 }
171
172 public void validate() throws ModelException {
173 if(this.weight == null) {
174 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
175 } else if(this.root == null) {
176 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
177 } else {
178 this.root.validate();
179 }
180 }
181 }
182
183 public class RegressionTreeNode {
184 private static final float NODE_SPLIT_SLACK = 1.0E-6F;
185 private float value = 0.0F;
186 private String feature;
187 private int featureIndex = -1;
188 private Float threshold;
189 private MultipleAdditiveTreesModel.RegressionTreeNode left;
190 private MultipleAdditiveTreesModel.RegressionTreeNode right;
191
192 public void setValue(float value) {
193 this.value = value;
194 }
195
196 public void setValue(String value) {
197 this.value = Float.parseFloat(value);
198 }
199
200 public void setFeature(String feature) {
201 this.feature = feature;
202 Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
203 this.featureIndex = idx == null?-1:idx.intValue();
204 }
205
206 public int getFeatureIndex() {
207 return this.featureIndex;
208 }
209
210 public void setThreshold(float threshold) {
211 this.threshold = Float.valueOf(threshold + 1.0E-6F);
212 }
213
214 public void setThreshold(String threshold) {
215 this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
216 }
217
218 public float getThreshold() {
219 return this.threshold;
220 }
221
222 public void setLeft(Object left) {
223 this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
224 }
225
226 public RegressionTreeNode getLeft() {
227 return this.left;
228 }
229
230 public void setRight(Object right) {
231 this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
232 }
233
234 public RegressionTreeNode getRight() {
235 return this.right;
236 }
237
238 public boolean isLeaf() {
239 return this.feature == null;
240 }
241
242 public float score(float[] featureVector) {
243 return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
244 }
245
246 public String explain(float[] featureVector) {
247 if(this.isLeaf()) {
248 return "val: " + this.value;
249 } else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
250 String rval;
251 if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
252 rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
253 return rval + this.left.explain(featureVector);
254 } else {
255 rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
256 return rval + this.right.explain(featureVector);
257 }
258 } else {
259 return "\'" + this.feature + "\' does not exist in FV, Return Zero";
260 }
261 }
262
263 public String toString() {
264 StringBuilder sb = new StringBuilder();
265 if(this.isLeaf()) {
266 sb.append(this.value);
267 } else {
268 sb.append("(feature=").append(this.feature);
269 sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
270 sb.append(",left=").append(this.left);
271 sb.append(",right=").append(this.right);
272 sb.append(')');
273 }
274
275 return sb.toString();
276 }
277
278 public RegressionTreeNode() {
279 }
280
281 public void validate() throws ModelException {
282 if(this.isLeaf()) {
283 if(this.left != null || this.right != null) {
284 throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
285 }
286 } else if(null == this.threshold) {
287 throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
288 } else if(null == this.left) {
289 throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
290 } else {
291 this.left.validate();
292 if(null == this.right) {
293 throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
294 } else {
295 this.right.validate();
296 }
297 }
298 }
299 }
300
301 }
1 import org.apache.commons.lang.ArrayUtils;
2 import org.apache.lucene.util.CloseableThreadLocal;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.model.ModelException;
5 import org.apache.solr.ltr.norm.Normalizer;
6
7 import java.util.*;
8
9 public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{
10
11 private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL;
12
13 // 64bits De Bruijn Sequence
14 // see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
15 private static final long HASH_BITS = 0x022fdd63cc95386dL;
16 private static final int[] hashTable = new int[64];
17
18 static {
19 long hash = HASH_BITS;
20 for (int i = 0; i < 64; ++i) {
21 hashTable[(int) (hash >>> 58)] = i;
22 hash <<= 1;
23 }
24 }
25
26 /**
27 * Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
28 *
29 * @param bits target bits (64bits)
30 * @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
31 */
32 private static int findIndexOfRightMostBit(long bits) {
33 return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
34 }
35
36 /**
37 * The number of trees of this model.
38 */
39 private int treeNum;
40
41 /**
42 * Weights of each tree.
43 */
44 private float[] weights;
45
46 /**
47 * List of all leaves of this model.
48 * We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
49 */
50 private RegressionTreeNode[] leaves;
51
52 /**
53 * Offsets of each leaf block correspond to each tree.
54 */
55 private int[] leafOffsets;
56
57 /**
58 * The number of conditions of this model.
59 */
60 private int condNum;
61
62 /**
63 * Thresholds of each condition.
64 * These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
65 */
66 private float[] thresholds;
67
68 /**
69 * Corresponding featureIndex of each condition.
70 */
71 private int[] featureIndexes;
72
73 /**
74 * Offsets of each condition block correspond to each feature.
75 */
76 private int[] condOffsets;
77
78 /**
79 * Forward bitvectors of each condition which correspond to original additive trees.
80 */
81 private long[] forwardBitVectors;
82
83 /**
84 * Backward bitvectors of each condition which correspond to inverted additive trees.
85 */
86 private long[] backwardBitVectors;
87
88 /**
89 * Mappings from threasholdes index to tree indexes.
90 */
91 private int[] treeIds;
92
93 /**
94 * Bitvectors of each tree for calculating the score.
95 * We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
96 */
97 private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null;
98
99 /**
100 * Boolean statistical tendency of this model.
101 * If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
102 */
103 private volatile float falseRatio = 0.5f;
104
105 /**
106 * The decay factor for updating falseRatio in each evaluation step.
107 * This factor is used like "{@code ratio = preRatio * decay ratio * (1 - decay)}".
108 */
109 private float falseRatioDecay = 0.99f;
110
111 /**
112 * Comparable node cost for selecting leaf candidates.
113 */
114 private static class NodeCost implements Comparable<NodeCost> {
115 private final int id;
116 private final int cost;
117 private final int depth;
118 private final int left;
119 private final int right;
120
121 private NodeCost(int id, int cost, int depth, int left, int right) {
122 this.id = id;
123 this.cost = cost;
124 this.depth = depth;
125 this.left = left;
126 this.right = right;
127 }
128
129 public int getId() {
130 return id;
131 }
132
133 public int getLeft() {
134 return left;
135 }
136
137 public int getRight() {
138 return right;
139 }
140
141 /**
142 * Sorts by cost and depth.
143 * We prefer cheaper cost and deeper one.
144 */
145 @Override
146 public int compareTo(NodeCost n) {
147 if (cost != n.cost) {
148 return Integer.compare(cost, n.cost);
149 } else if (depth != n.depth) {
150 return Integer.compare(n.depth, depth); // revere order
151 } else {
152 return Integer.compare(id, n.id);
153 }
154 }
155 }
156
157 /**
158 * Comparable condition for constructing and sorting bitvectors.
159 */
160 private static class Condition implements Comparable<Condition> {
161 private final int featureIndex;
162 private final float threshold;
163 private final int treeId;
164 private final long forwardBitvector;
165 private final long backwardBitvector;
166
167 private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
168 this.featureIndex = featureIndex;
169 this.threshold = threshold;
170 this.treeId = treeId;
171 this.forwardBitvector = forwardBitvector;
172 this.backwardBitvector = backwardBitvector;
173 }
174
175 int getFeatureIndex() {
176 return featureIndex;
177 }
178
179 float getThreshold() {
180 return threshold;
181 }
182
183 int getTreeId() {
184 return treeId;
185 }
186
187 long getForwardBitvector() {
188 return forwardBitvector;
189 }
190
191 long getBackwardBitvector() {
192 return backwardBitvector;
193 }
194
195 /*
196 * Sort by featureIndex and threshold with ascent order.
197 */
198 @Override
199 public int compareTo(Condition c) {
200 if (featureIndex != c.featureIndex) {
201 return Integer.compare(featureIndex, c.featureIndex);
202 } else {
203 return Float.compare(threshold, c.threshold);
204 }
205 }
206 }
207
208 /**
209 * Base class for traversing node with depth first order.
210 */
211 private abstract static class Visitor {
212 private int nodeId = 0;
213
214 int getNodeId() {
215 return nodeId;
216 }
217
218 void visit(RegressionTree tree) {
219 nodeId = 0;
220 visit(tree.getRoot(), 0);
221 }
222
223 private void visit(RegressionTreeNode node, int depth) {
224 if (node.isLeaf()) {
225 doVisitLeaf(node, depth);
226 } else {
227 // visit children first
228 visit(node.getLeft(), depth + 1);
229 visit(node.getRight(), depth + 1);
230
231 doVisitBranch(node, depth);
232 }
233 ++nodeId;
234 }
235
236 protected abstract void doVisitLeaf(RegressionTreeNode node, int depth);
237
238 protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
239 }
240
241 /**
242 * {@link Visitor} implementation for calculating the cost of each node.
243 */
244 private static class NodeCostVisitor extends Visitor {
245
246 private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
247 private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>();
248
249 PriorityQueue<NodeCost> getNodeCostQueue() {
250 return nodeCostQueue;
251 }
252
253 @Override
254 protected void doVisitLeaf(RegressionTreeNode node, int depth) {
255 nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
256 idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
257 }
258
259 @Override
260 protected void doVisitBranch(RegressionTreeNode node, int depth) {
261 // calculate the cost of this node from children costs
262 final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
263 final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
264 final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue());
265
266 nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
267 idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
268 }
269 }
270
271 /**
272 * {@link Visitor} implementation for extracting leaves and bitvectors.
273 */
274 private static class QuickScorerVisitor extends Visitor {
275
276 private final int treeId;
277 private final int leafNum;
278 private final Set<Integer> leafIdSet;
279 private final Set<Integer> skipIdSet;
280
281 private final Stack<Long> bitsStack = new Stack<>();
282 private final List<RegressionTreeNode> leafList = new ArrayList<>();
283 private final List<Condition> conditionList = new ArrayList<>();
284
285 private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
286 this.treeId = treeId;
287 this.leafNum = leafNum;
288 this.leafIdSet = leafIdSet;
289 this.skipIdSet = skipIdSet;
290 }
291
292 List<RegressionTreeNode> getLeafList() {
293 return leafList;
294 }
295
296 List<Condition> getConditionList() {
297 return conditionList;
298 }
299
300 private long reverseBits(long bits) {
301 long revBits = 0L;
302 long mask = (1L << (leafNum - 1));
303 for (int i = 0; i < leafNum; ++i) {
304 if ((bits & mask) != 0L) revBits |= (1L << i);
305 mask >>>= 1;
306 }
307 return revBits;
308 }
309
310 @Override
311 protected void doVisitLeaf(RegressionTreeNode node, int depth) {
312 if (skipIdSet.contains(getNodeId())) return;
313
314 bitsStack.add(1L << leafList.size()); // we use rightmost bit for detecting leaf
315 leafList.add(node);
316 }
317
318 @Override
319 protected void doVisitBranch(RegressionTreeNode node, int depth) {
320 if (skipIdSet.contains(getNodeId())) return;
321
322 if (leafIdSet.contains(getNodeId())) {
323 // an endpoint of QuickScorer
324 doVisitLeaf(node, depth);
325 return;
326 }
327
328 final long rightBits = bitsStack.pop(); // bits of false branch
329 final long leftBits = bitsStack.pop(); // bits of true branch
330 /*
331 * NOTE:
332 * forwardBitvector = ~leftBits
333 * backwardBitvector = ~(reverse(rightBits))
334 */
335 conditionList.add(
336 new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
337 bitsStack.add(leftBits | rightBits);
338 }
339 }
340
341 public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
342 List<Feature> allFeatures, Map<String, Object> params) {
343 super(name, features, norms, featureStoreName, allFeatures, params);
344 }
345
346 /**
347 * Set falseRadioDecay parameter of this model.
348 *
349 * @param falseRatioDecay decay parameter for updating falseRatio
350 */
351 public void setFalseRatioDecay(float falseRatioDecay) {
352 this.falseRatioDecay = falseRatioDecay;
353 }
354
355 /**
356 * @see #setFalseRatioDecay(float)
357 */
358 public void setFalseRatioDecay(String falseRatioDecay) {
359 this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
360 }
361
362 /**
363 * {@inheritDoc}
364 */
365 @Override
366 public void validate() throws ModelException {
367 // validate trees before initializing QuickScorer
368 super.validate();
369
370 // initialize QuickScorer with validated trees
371 init(getTrees());
372 }
373
374 /**
375 * Initializes quick scorer with given trees.
376 * 利用给定的树集初始化快速打分模型
377 *
378 * @param trees base additive trees model
379 */
380 private void init(List<RegressionTree> trees) {
381 this.treeNum = trees.size();
382 this.weights = new float[trees.size()];
383 this.leafOffsets = new int[trees.size() + 1];
384 this.leafOffsets[0] = 0;
385
386 // re-create tree bitvectors
387 if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close();
388 this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() {
389 @Override
390 protected long[] initialValue() {
391 return new long[treeNum];
392 }
393 };
394
395 int treeId = 0;
396 List<RegressionTreeNode> leafList = new ArrayList<>();
397 List<Condition> conditionList = new ArrayList<>();
398 for (RegressionTree tree : trees) {
399 // select up to 64 leaves from given tree
400 QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree);
401
402 // extract leaves and conditions with selected leaf candidates
403 visitor.visit(tree);
404 leafList.addAll(visitor.getLeafList());
405 conditionList.addAll(visitor.getConditionList());
406
407 // update weight, offset and treeId
408 this.weights[treeId] = tree.getWeight();
409 this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
410 ++treeId;
411 }
412
413 // remap list to array for performance reason
414 this.leaves = leafList.toArray(new RegressionTreeNode[0]);
415
416 // sort conditions by ascent order of featureIndex and threshold
417 Collections.sort(conditionList);
418
419 // remap information of conditions
420 int idx = 0;
421 int preFeatureIndex = -1;
422 this.condNum = conditionList.size();
423 this.thresholds = new float[conditionList.size()];
424 this.forwardBitVectors = new long[conditionList.size()];
425 this.backwardBitVectors = new long[conditionList.size()];
426 this.treeIds = new int[conditionList.size()];
427 List<Integer> featureIndexList = new ArrayList<>();
428 List<Integer> condOffsetList = new ArrayList<>();
429 for (Condition condition : conditionList) {
430 this.thresholds[idx] = condition.threshold;
431 this.forwardBitVectors[idx] = condition.getForwardBitvector();
432 this.backwardBitVectors[idx] = condition.getBackwardBitvector();
433 this.treeIds[idx] = condition.getTreeId();
434
435 if (preFeatureIndex != condition.getFeatureIndex()) {
436 featureIndexList.add(condition.getFeatureIndex());
437 condOffsetList.add(idx);
438 preFeatureIndex = condition.getFeatureIndex();
439 }
440
441 ++idx;
442 }
443 condOffsetList.add(conditionList.size()); // guard
444
445 this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
446 this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
447 }
448
449 /**
450 * Checks costs of all nodes and select leaves up to 64.
451 *
452 * <p>NOTE:
453 * We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
454 * However, this modification caused performance degradation in our experiments, and we decided to use this form.
455 *
456 * @param treeId index of given regression tree
457 * @param tree target regression tree
458 * @return QuickScorerVisitor with proper id sets
459 */
460 private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
461 // calculate costs of all nodes
462 NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
463 nodeCostVisitor.visit(tree);
464
465 // poll zero cost nodes (i.e., real leaves)
466 Set<Integer> leafIdSet = new HashSet<>();
467 Set<Integer> skipIdSet = new HashSet<>();
468 while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
469 if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
470 NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
471 leafIdSet.add(nodeCost.id);
472 }
473
474 // merge leaves until the number of leaves reaches 64
475 while (leafIdSet.size() > 64) {
476 final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
477 assert nodeCost.left >= 0 && nodeCost.right >= 0;
478
479 // update leaves
480 leafIdSet.remove(nodeCost.left);
481 leafIdSet.remove(nodeCost.right);
482 leafIdSet.add(nodeCost.id);
483
484 // register previous leaves to skip ids
485 skipIdSet.add(nodeCost.left);
486 skipIdSet.add(nodeCost.right);
487 }
488
489 return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
490 }
491
492 /**
493 * {@inheritDoc}
494 */
495 @Override
496 public float score(float[] modelFeatureValuesNormalized) {
497 assert threadLocalTreeBitvectors != null;
498 long[] treeBitvectors = threadLocalTreeBitvectors.get();
499 Arrays.fill(treeBitvectors, MAX_BITS);
500
501 int falseNum = 0;
502 float score = 0.0f;
503 if (falseRatio <= 0.5) {
504 // use forward bitvectors
505 for (int i = 0; i < condOffsets.length - 1; ++i) {
506 final int featureIndex = featureIndexes[i];
507 for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
508 if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
509 treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
510 ++falseNum;
511 }
512 }
513
514 for (int i = 0; i < leafOffsets.length - 1; ++i) {
515 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
516 score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
517 }
518 } else {
519 // use backward bitvectors
520 falseNum = condNum;
521 for (int i = 0; i < condOffsets.length - 1; ++i) {
522 final int featureIndex = featureIndexes[i];
523 for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
524 if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
525 treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
526 --falseNum;
527 }
528 }
529
530 for (int i = 0; i < leafOffsets.length - 1; ++i) {
531 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
532 score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
533 }
534 }
535
536 // update false ratio
537 falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
538 return score;
539 }
540
541 }
1 import org.apache.lucene.search.IndexSearcher;
2 import org.apache.lucene.search.Query;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.feature.FeatureException;
5 import org.apache.solr.ltr.norm.IdentityNormalizer;
6 import org.apache.solr.ltr.norm.Normalizer;
7 import org.apache.solr.request.SolrQueryRequest;
8 import org.junit.Ignore;
9 import org.junit.Test;
10
11 import java.io.IOException;
12 import java.util.ArrayList;
13 import java.util.HashMap;
14 import java.util.LinkedHashMap;
15 import java.util.List;
16 import java.util.Map;
17 import java.util.Random;
18
19 import static org.hamcrest.CoreMatchers.is;
20 import static org.junit.Assert.assertThat;
21
22 public class TestQuickScorerTreesModelBenchmark {
23
24 /**
25 * 产生特征
26 * @param featureNum 特征个数
27 * @return
28 */
29 private List<Feature> createDummyFeatures(int featureNum) {
30 List<Feature> features = new ArrayList<>();
31 for (int i = 0; i < featureNum; ++i) {
32 features.add(new Feature("fv_" + i, null) {
33 @Override
34 protected void validate() throws FeatureException { }
35
36 @Override
37 public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
38 Query originalQuery, Map<String, String[]> efi) throws IOException {
39 return null;
40 }
41
42 @Override
43 public LinkedHashMap<String, Object> paramsToMap() {
44 return null;
45 }
46 });
47 }
48 return features;
49 }
50
51 private List<Normalizer> createDummyNormalizer(int featureNum) {
52 List<Normalizer> normalizers = new ArrayList<>();
53 for (int i = 0; i < featureNum; ++i) {
54 normalizers.add(new IdentityNormalizer());
55 }
56 return normalizers;
57 }
58
59 /**
60 * 创建单棵树
61 * 递归调用自己
62 * @param leafNum 叶子个数
63 * @param features 特征
64 * @param rand 产生随机数
65 * @return
66 */
67 private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
68 Map<String, Object> node = new HashMap<>();
69 if (leafNum == 1) {
70 // leaf
71 node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
72 return node;
73 }
74
75 // branch
76 node.put("feature", features.get(rand.nextInt(features.size())).getName());
77 node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
78 node.put("left", createRandomTree(leafNum / 2, features, rand));
79 node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
80 return node;
81 }
82
83 /**
84 * 这里随机创建多棵树作为model测试
85 * @param treeNum 树的个数
86 * @param leafNum 叶子个数
87 * @param features 特征
88 * @param rand 产生随机数
89 * @return
90 */
91 private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
92 Random rand) {
93 List<Object> trees = new ArrayList<>();
94 for (int i = 0; i < treeNum; ++i) {
95 Map<String, Object> tree = new HashMap<>();
96 tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
97 tree.put("root", createRandomTree(leafNum, features, rand));
98 trees.add(tree);
99 }
100 return trees;
101 }
102
103 /**
104 * 对比两个打分模型的分值是否一致
105 * @param featureNum 特征个数
106 * @param treeNum 树个数
107 * @param leafNum 叶子个数
108 * @param loopNum 样本个数
109 * @throws Exception
110 */
111 private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
112 Random rand = new Random(0);
113
114 List<Feature> features = createDummyFeatures(featureNum); //产生特征
115 List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化
116
117 for (int i = 0; i < loopNum; ++i) {
118 List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
119
120 MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
121 "dummy", features, null);
122 matModel.setTrees(trees);
123 matModel.validate();
124
125 QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
126 null);
127 qstModel.setTrees(trees);//设置提供的树模型
128 qstModel.validate();//对提供的树结构进行验证
129
130 float[] featureValues = new float[featureNum];
131 for (int j = 0; j < 100; ++j) {
132 for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5)
133
134 float expected = matModel.score(featureValues);
135 float actual = qstModel.score(featureValues);
136 assertThat(actual, is(expected));
137 //System.out.println("expected: " + expected + " actual: " + actual);
138 }
139 }
140 }
141
142 /**
143 * 两个模型是否得分一致
144 *
145 * @throws Exception thrown if testcase failed to initialize models
146 */
147 /*@Test
148 public void testAccuracy() throws Exception {
149 compareScore(25, 200, 32, 100);
150 //compareScore(19, 500, 31, 10000);
151 }*/
152
153
154 /**
155 * 对比两个打分模型打分的时间消耗
156 * @param featureNum 特征个数
157 * @param treeNum 树个数
158 * @param leafNum 叶子个数
159 * @param loopNum 样本个数
160 * @throws Exception
161 */
162 private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
163 Random rand = new Random(0);
164
165 //随机产生features
166 List<Feature> features = createDummyFeatures(featureNum);
167 //随机产生normalizer
168 List<Normalizer> norms = createDummyNormalizer(featureNum);
169 //随机创建trees
170 List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
171
172 //初始化multiple additive trees model
173 MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
174 "dummy", features, null);
175 matModel.setTrees(trees);
176 matModel.validate();
177
178 //初始化quick scorer trees model
179 QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
180 null);
181 qstModel.setTrees(trees);
182 qstModel.validate();
183
184 //随机产生样本, loopNum * featureNum
185 float[][] featureValues = new float[loopNum][featureNum];
186 for (int i = 0; i < loopNum; ++i) {
187 for (int k = 0; k < featureNum; ++k) {
188 featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
189 }
190 }
191
192 long start;
193 /*long matOpNsec = 0;
194 for (int i = 0; i < loopNum; ++i) {
195 start = System.nanoTime();
196 matModel.score(featureValues[i]);
197 matOpNsec += System.nanoTime() - start;
198 }
199 long qstOpNsec = 0;
200 for (int i = 0; i < loopNum; ++i) {
201 start = System.nanoTime();
202 qstModel.score(featureValues[i]);
203 qstOpNsec += System.nanoTime() - start;
204 }
205 System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
206 System.out.println("QuickScorerTreesModel : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/
207
208 long matOpNsec = 0;
209 start = System.currentTimeMillis();
210 for(int i = 0; i < loopNum; i++) {
211 matModel.score(featureValues[i]);
212 }
213 matOpNsec = System.currentTimeMillis() - start;
214
215 long qstOpNsec = 0;
216 start = System.currentTimeMillis();
217 for(int i = 0; i < loopNum; i++) {
218 qstModel.score(featureValues[i]);
219 }
220 qstOpNsec = System.currentTimeMillis() - start;
221
222 System.out.println("MultipleAdditiveTreesModel : " + matOpNsec);
223
224 System.out.println("QuickScorerTreesModel : " + qstOpNsec);
225
226 //assertThat(matOpNsec > qstOpNsec, is(true));
227 }
228
229 /**
230 * 测试性能
231 * @throws Exception thrown if testcase failed to initialize models
232 */
233
234 @Test
235 public void testPerformance() throws Exception {
236 //features,trees,leafs,samples
237 compareTime(20, 500, 61, 10000);
238 }
239
240 }