# 关于Additive Ensembles of Regression Trees模型的快速打分预测

2019/07/24 10:47

1. Ensemble trees原始打分过程

2.论文中提到的打分过程

A.为回归树中的每个分枝打上true和false标签

B.为每个branch分配一个bitvector

C.打分阶段

  1 import org.apache.lucene.index.LeafReaderContext;
2 import org.apache.lucene.search.Explanation;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.model.LTRScoringModel;
5 import org.apache.solr.ltr.model.ModelException;
6 import org.apache.solr.ltr.norm.Normalizer;
7 import org.apache.solr.util.SolrPluginUtils;
8
9 import java.util.*;
10
11 public class MultipleAdditiveTreesModel extends LTRScoringModel {
12
13     // 特征名：索引(从0开始)
14     private final HashMap<String, Integer> fname2index = new HashMap();
15     private List<RegressionTree> trees;
16
17     private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
19         if(map != null) {
20             SolrPluginUtils.invokeSetters(rt, map.entrySet());
21         }
22
23         return rt;
24     }
25
26     private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
28         if(map != null) {
29             SolrPluginUtils.invokeSetters(rtn, map.entrySet());
30         }
31
32         return rtn;
33     }
34
35     public void setTrees(Object trees) {
36         this.trees = new ArrayList();
37         Iterator var2 = ((List)trees).iterator();
38
39         while(var2.hasNext()) {
40             Object o = var2.next();
43         }
44     }
45
46     public void setTrees(List<RegressionTree> trees) {
47         this.trees = trees;
48     }
49
50     public List<RegressionTree> getTrees() {
51         return this.trees;
52     }
53
54     public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
55         super(name, features, norms, featureStoreName, allFeatures, params);
56
57         for(int i = 0; i < features.size(); ++i) {
58             String key = ((Feature)features.get(i)).getName();
59             this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
60         }
61
62     }
63
64     public void validate() throws ModelException {
65         super.validate();
66         if(this.trees == null) {
67             throw new ModelException("no trees declared for model " + this.name);
68         } else {
69             Iterator var1 = this.trees.iterator();
70
71             while(var1.hasNext()) {
73                 tree.validate();
74             }
75
76         }
77     }
78
79     public float score(float[] modelFeatureValuesNormalized) {
80         float score = 0.0F;
81
83         for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
85         }
86
87         return score;
88     }
89
90     public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
91         float[] fv = new float[featureExplanations.size()];
92         int index = 0;
93
94         for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
95             Explanation featureExplain = (Explanation)details.next();
96             fv[index] = featureExplain.getValue();
97         }
98
99         ArrayList var12 = new ArrayList();
100         index = 0;
101
102         for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
104             float score = t.score(fv);
105             Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
107         }
108
109         return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
110     }
111
112     public String toString() {
113         StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
114         sb.append("(name=").append(this.getName());
115         sb.append(",trees=[");
116
117         for(int ii = 0; ii < this.trees.size(); ++ii) {
118             if(ii > 0) {
119                 sb.append(',');
120             }
121
122             sb.append(this.trees.get(ii));
123         }
124
125         sb.append("])");
126         return sb.toString();
127     }
128
129     public class RegressionTree {
130         private Float weight;
132
133         public void setWeight(float weight) {
134             this.weight = new Float(weight);
135         }
136
137         public void setWeight(String weight) {
138             this.weight = new Float(weight);
139         }
140
141         public float getWeight() {
142             return this.weight;
143         }
144
145         public void setRoot(Object root) {
147         }
148
149         public RegressionTreeNode getRoot() {
150             return this.root;
151         }
152
153         public float score(float[] featureVector) {
154             return this.weight.floatValue() * this.root.score(featureVector);
155         }
156
157         public String explain(float[] featureVector) {
158             return this.root.explain(featureVector);
159         }
160
161         public String toString() {
162             StringBuilder sb = new StringBuilder();
163             sb.append("(weight=").append(this.weight);
164             sb.append(",root=").append(this.root);
165             sb.append(")");
166             return sb.toString();
167         }
168
169         public RegressionTree() {
170         }
171
172         public void validate() throws ModelException {
173             if(this.weight == null) {
174                 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
175             } else if(this.root == null) {
176                 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
177             } else {
178                 this.root.validate();
179             }
180         }
181     }
182
183     public class RegressionTreeNode {
184         private static final float NODE_SPLIT_SLACK = 1.0E-6F;
185         private float value = 0.0F;
186         private String feature;
187         private int featureIndex = -1;
188         private Float threshold;
191
192         public void setValue(float value) {
193             this.value = value;
194         }
195
196         public void setValue(String value) {
197             this.value = Float.parseFloat(value);
198         }
199
200         public void setFeature(String feature) {
201             this.feature = feature;
203             this.featureIndex = idx == null?-1:idx.intValue();
204         }
205
206         public int getFeatureIndex() {
207             return this.featureIndex;
208         }
209
210         public void setThreshold(float threshold) {
211             this.threshold = Float.valueOf(threshold + 1.0E-6F);
212         }
213
214         public void setThreshold(String threshold) {
215             this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
216         }
217
218         public float getThreshold() {
219             return this.threshold;
220         }
221
222         public void setLeft(Object left) {
224         }
225
226         public RegressionTreeNode getLeft() {
227             return this.left;
228         }
229
230         public void setRight(Object right) {
232         }
233
234         public RegressionTreeNode getRight() {
235             return this.right;
236         }
237
238         public boolean isLeaf() {
239             return this.feature == null;
240         }
241
242         public float score(float[] featureVector) {
243             return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
244         }
245
246         public String explain(float[] featureVector) {
247             if(this.isLeaf()) {
248                 return "val: " + this.value;
249             } else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
250                 String rval;
251                 if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
252                     rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
253                     return rval + this.left.explain(featureVector);
254                 } else {
255                     rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
256                     return rval + this.right.explain(featureVector);
257                 }
258             } else {
259                 return "\'" + this.feature + "\' does not exist in FV, Return Zero";
260             }
261         }
262
263         public String toString() {
264             StringBuilder sb = new StringBuilder();
265             if(this.isLeaf()) {
266                 sb.append(this.value);
267             } else {
268                 sb.append("(feature=").append(this.feature);
269                 sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
270                 sb.append(",left=").append(this.left);
271                 sb.append(",right=").append(this.right);
272                 sb.append(')');
273             }
274
275             return sb.toString();
276         }
277
278         public RegressionTreeNode() {
279         }
280
281         public void validate() throws ModelException {
282             if(this.isLeaf()) {
283                 if(this.left != null || this.right != null) {
284                     throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
285                 }
286             } else if(null == this.threshold) {
287                 throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
288             } else if(null == this.left) {
289                 throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
290             } else {
291                 this.left.validate();
292                 if(null == this.right) {
293                     throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
294                 } else {
295                     this.right.validate();
296                 }
297             }
298         }
299     }
300
301 }
  1 import org.apache.commons.lang.ArrayUtils;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.model.ModelException;
5 import org.apache.solr.ltr.norm.Normalizer;
6
7 import java.util.*;
8
9 public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{
10
11     private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL;
12
13     // 64bits De Bruijn Sequence
14     // see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
15     private static final long HASH_BITS = 0x022fdd63cc95386dL;
16     private static final int[] hashTable = new int[64];
17
18     static {
19         long hash = HASH_BITS;
20         for (int i = 0; i < 64; ++i) {
21             hashTable[(int) (hash >>> 58)] = i;
22             hash <<= 1;
23         }
24     }
25
26     /**
27      * Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
28      *
29      * @param bits target bits (64bits)
30      * @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
31      */
32     private static int findIndexOfRightMostBit(long bits) {
33         return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
34     }
35
36     /**
37      * The number of trees of this model.
38      */
39     private int treeNum;
40
41     /**
42      * Weights of each tree.
43      */
44     private float[] weights;
45
46     /**
47      * List of all leaves of this model.
48      * We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
49      */
50     private RegressionTreeNode[] leaves;
51
52     /**
53      * Offsets of each leaf block correspond to each tree.
54      */
55     private int[] leafOffsets;
56
57     /**
58      * The number of conditions of this model.
59      */
60     private int condNum;
61
62     /**
63      * Thresholds of each condition.
64      * These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
65      */
66     private float[] thresholds;
67
68     /**
69      * Corresponding featureIndex of each condition.
70      */
71     private int[] featureIndexes;
72
73     /**
74      * Offsets of each condition block correspond to each feature.
75      */
76     private int[] condOffsets;
77
78     /**
79      * Forward bitvectors of each condition which correspond to original additive trees.
80      */
81     private long[] forwardBitVectors;
82
83     /**
84      * Backward bitvectors of each condition which correspond to inverted additive trees.
85      */
86     private long[] backwardBitVectors;
87
88     /**
89      * Mappings from threasholdes index to tree indexes.
90      */
91     private int[] treeIds;
92
93     /**
94      * Bitvectors of each tree for calculating the score.
95      * We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
96      */
98
99     /**
100      * Boolean statistical tendency of this model.
101      * If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
102      */
103     private volatile float falseRatio = 0.5f;
104
105     /**
106      * The decay factor for updating falseRatio in each evaluation step.
107      * This factor is used like "{@code ratio = preRatio * decay  ratio * (1 - decay)}".
108      */
109     private float falseRatioDecay = 0.99f;
110
111     /**
112      * Comparable node cost for selecting leaf candidates.
113      */
114     private static class NodeCost implements Comparable<NodeCost> {
115         private final int id;
116         private final int cost;
117         private final int depth;
118         private final int left;
119         private final int right;
120
121         private NodeCost(int id, int cost, int depth, int left, int right) {
122             this.id = id;
123             this.cost = cost;
124             this.depth = depth;
125             this.left = left;
126             this.right = right;
127         }
128
129         public int getId() {
130             return id;
131         }
132
133         public int getLeft() {
134             return left;
135         }
136
137         public int getRight() {
138             return right;
139         }
140
141         /**
142          * Sorts by cost and depth.
143          * We prefer cheaper cost and deeper one.
144          */
145         @Override
146         public int compareTo(NodeCost n) {
147             if (cost != n.cost) {
148                 return Integer.compare(cost, n.cost);
149             } else if (depth != n.depth) {
150                 return Integer.compare(n.depth, depth);  // revere order
151             } else {
152                 return Integer.compare(id, n.id);
153             }
154         }
155     }
156
157     /**
158      * Comparable condition for constructing and sorting bitvectors.
159      */
160     private static class Condition implements Comparable<Condition> {
161         private final int featureIndex;
162         private final float threshold;
163         private final int treeId;
164         private final long forwardBitvector;
165         private final long backwardBitvector;
166
167         private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
168             this.featureIndex = featureIndex;
169             this.threshold = threshold;
170             this.treeId = treeId;
171             this.forwardBitvector = forwardBitvector;
172             this.backwardBitvector = backwardBitvector;
173         }
174
175         int getFeatureIndex() {
176             return featureIndex;
177         }
178
179         float getThreshold() {
180             return threshold;
181         }
182
183         int getTreeId() {
184             return treeId;
185         }
186
187         long getForwardBitvector() {
188             return forwardBitvector;
189         }
190
191         long getBackwardBitvector() {
192             return backwardBitvector;
193         }
194
195         /*
196          * Sort by featureIndex and threshold with ascent order.
197          */
198         @Override
199         public int compareTo(Condition c) {
200             if (featureIndex != c.featureIndex) {
201                 return Integer.compare(featureIndex, c.featureIndex);
202             } else {
203                 return Float.compare(threshold, c.threshold);
204             }
205         }
206     }
207
208     /**
209      * Base class for traversing node with depth first order.
210      */
211     private abstract static class Visitor {
212         private int nodeId = 0;
213
214         int getNodeId() {
215             return nodeId;
216         }
217
218         void visit(RegressionTree tree) {
219             nodeId = 0;
220             visit(tree.getRoot(), 0);
221         }
222
223         private void visit(RegressionTreeNode node, int depth) {
224             if (node.isLeaf()) {
225                 doVisitLeaf(node, depth);
226             } else {
227                 // visit children first
228                 visit(node.getLeft(), depth + 1);
229                 visit(node.getRight(), depth + 1);
230
231                 doVisitBranch(node, depth);
232             }
233             ++nodeId;
234         }
235
236         protected abstract void doVisitLeaf(RegressionTreeNode node, int depth);
237
238         protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
239     }
240
241     /**
242      * {@link Visitor} implementation for calculating the cost of each node.
243      */
244     private static class NodeCostVisitor extends Visitor {
245
246         private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
247         private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>();
248
249         PriorityQueue<NodeCost> getNodeCostQueue() {
250             return nodeCostQueue;
251         }
252
253         @Override
254         protected void doVisitLeaf(RegressionTreeNode node, int depth) {
255             nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
256             idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
257         }
258
259         @Override
260         protected void doVisitBranch(RegressionTreeNode node, int depth) {
261             // calculate the cost of this node from children costs
262             final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
263             final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
264             final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue());
265
266             nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
267             idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
268         }
269     }
270
271     /**
272      * {@link Visitor} implementation for extracting leaves and bitvectors.
273      */
274     private static class QuickScorerVisitor extends Visitor {
275
276         private final int treeId;
277         private final int leafNum;
278         private final Set<Integer> leafIdSet;
279         private final Set<Integer> skipIdSet;
280
281         private final Stack<Long> bitsStack = new Stack<>();
282         private final List<RegressionTreeNode> leafList = new ArrayList<>();
283         private final List<Condition> conditionList = new ArrayList<>();
284
285         private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
286             this.treeId = treeId;
287             this.leafNum = leafNum;
288             this.leafIdSet = leafIdSet;
289             this.skipIdSet = skipIdSet;
290         }
291
292         List<RegressionTreeNode> getLeafList() {
293             return leafList;
294         }
295
296         List<Condition> getConditionList() {
297             return conditionList;
298         }
299
300         private long reverseBits(long bits) {
301             long revBits = 0L;
302             long mask = (1L << (leafNum - 1));
303             for (int i = 0; i < leafNum; ++i) {
304                 if ((bits & mask) != 0L) revBits |= (1L << i);
306             }
307             return revBits;
308         }
309
310         @Override
311         protected void doVisitLeaf(RegressionTreeNode node, int depth) {
312             if (skipIdSet.contains(getNodeId())) return;
313
314             bitsStack.add(1L << leafList.size());  // we use rightmost bit for detecting leaf
316         }
317
318         @Override
319         protected void doVisitBranch(RegressionTreeNode node, int depth) {
320             if (skipIdSet.contains(getNodeId())) return;
321
322             if (leafIdSet.contains(getNodeId())) {
323                 // an endpoint of QuickScorer
324                 doVisitLeaf(node, depth);
325                 return;
326             }
327
328             final long rightBits = bitsStack.pop();  // bits of false branch
329             final long leftBits = bitsStack.pop();   // bits of true branch
330       /*
331        * NOTE:
332        *   forwardBitvector  = ~leftBits
333        *   backwardBitvector = ~(reverse(rightBits))
334        */
336                     new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
338         }
339     }
340
341     public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
342                                  List<Feature> allFeatures, Map<String, Object> params) {
343         super(name, features, norms, featureStoreName, allFeatures, params);
344     }
345
346     /**
347      * Set falseRadioDecay parameter of this model.
348      *
349      * @param falseRatioDecay decay parameter for updating falseRatio
350      */
351     public void setFalseRatioDecay(float falseRatioDecay) {
352         this.falseRatioDecay = falseRatioDecay;
353     }
354
355     /**
356      * @see #setFalseRatioDecay(float)
357      */
358     public void setFalseRatioDecay(String falseRatioDecay) {
359         this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
360     }
361
362     /**
363      * {@inheritDoc}
364      */
365     @Override
366     public void validate() throws ModelException {
367         // validate trees before initializing QuickScorer
368         super.validate();
369
370         // initialize QuickScorer with validated trees
371         init(getTrees());
372     }
373
374     /**
375      * Initializes quick scorer with given trees.
376      * 利用给定的树集初始化快速打分模型
377      *
378      * @param trees base additive trees model
379      */
380     private void init(List<RegressionTree> trees) {
381         this.treeNum = trees.size();
382         this.weights = new float[trees.size()];
383         this.leafOffsets = new int[trees.size() + 1];
384         this.leafOffsets[0] = 0;
385
386         // re-create tree bitvectors
389             @Override
390             protected long[] initialValue() {
391                 return new long[treeNum];
392             }
393         };
394
395         int treeId = 0;
396         List<RegressionTreeNode> leafList = new ArrayList<>();
397         List<Condition> conditionList = new ArrayList<>();
398         for (RegressionTree tree : trees) {
399             // select up to 64 leaves from given tree
400             QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree);
401
402             // extract leaves and conditions with selected leaf candidates
403             visitor.visit(tree);
406
407             // update weight, offset and treeId
408             this.weights[treeId] = tree.getWeight();
409             this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
410             ++treeId;
411         }
412
413         // remap list to array for performance reason
414         this.leaves = leafList.toArray(new RegressionTreeNode[0]);
415
416         // sort conditions by ascent order of featureIndex and threshold
417         Collections.sort(conditionList);
418
419         // remap information of conditions
420         int idx = 0;
421         int preFeatureIndex = -1;
422         this.condNum = conditionList.size();
423         this.thresholds = new float[conditionList.size()];
424         this.forwardBitVectors = new long[conditionList.size()];
425         this.backwardBitVectors = new long[conditionList.size()];
426         this.treeIds = new int[conditionList.size()];
427         List<Integer> featureIndexList = new ArrayList<>();
428         List<Integer> condOffsetList = new ArrayList<>();
429         for (Condition condition : conditionList) {
430             this.thresholds[idx] = condition.threshold;
431             this.forwardBitVectors[idx] = condition.getForwardBitvector();
432             this.backwardBitVectors[idx] = condition.getBackwardBitvector();
433             this.treeIds[idx] = condition.getTreeId();
434
435             if (preFeatureIndex != condition.getFeatureIndex()) {
438                 preFeatureIndex = condition.getFeatureIndex();
439             }
440
441             ++idx;
442         }
444
445         this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
446         this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
447     }
448
449     /**
450      * Checks costs of all nodes and select leaves up to 64.
451      *
452      * <p>NOTE:
453      * We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
454      * However, this modification caused performance degradation in our experiments, and we decided to use this form.
455      *
456      * @param treeId index of given regression tree
457      * @param tree target regression tree
458      * @return QuickScorerVisitor with proper id sets
459      */
460     private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
461         // calculate costs of all nodes
462         NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
463         nodeCostVisitor.visit(tree);
464
465         // poll zero cost nodes (i.e., real leaves)
466         Set<Integer> leafIdSet = new HashSet<>();
467         Set<Integer> skipIdSet = new HashSet<>();
468         while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
469             if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
470             NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
472         }
473
474         // merge leaves until the number of leaves reaches 64
475         while (leafIdSet.size() > 64) {
476             final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
477             assert nodeCost.left >= 0 && nodeCost.right >= 0;
478
479             // update leaves
480             leafIdSet.remove(nodeCost.left);
481             leafIdSet.remove(nodeCost.right);
483
484             // register previous leaves to skip ids
487         }
488
489         return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
490     }
491
492     /**
493      * {@inheritDoc}
494      */
495     @Override
496     public float score(float[] modelFeatureValuesNormalized) {
499         Arrays.fill(treeBitvectors, MAX_BITS);
500
501         int falseNum = 0;
502         float score = 0.0f;
503         if (falseRatio <= 0.5) {
504             // use forward bitvectors
505             for (int i = 0; i < condOffsets.length - 1; ++i) {
506                 final int featureIndex = featureIndexes[i];
507                 for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
508                     if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
509                     treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
510                     ++falseNum;
511                 }
512             }
513
514             for (int i = 0; i < leafOffsets.length - 1; ++i) {
515                 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
516                 score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
517             }
518         } else {
519             // use backward bitvectors
520             falseNum = condNum;
521             for (int i = 0; i < condOffsets.length - 1; ++i) {
522                 final int featureIndex = featureIndexes[i];
523                 for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
524                     if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
525                     treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
526                     --falseNum;
527                 }
528             }
529
530             for (int i = 0; i < leafOffsets.length - 1; ++i) {
531                 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
532                 score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
533             }
534         }
535
536         // update false ratio
537         falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
538         return score;
539     }
540
541 }
  1 import org.apache.lucene.search.IndexSearcher;
2 import org.apache.lucene.search.Query;
3 import org.apache.solr.ltr.feature.Feature;
4 import org.apache.solr.ltr.feature.FeatureException;
5 import org.apache.solr.ltr.norm.IdentityNormalizer;
6 import org.apache.solr.ltr.norm.Normalizer;
7 import org.apache.solr.request.SolrQueryRequest;
8 import org.junit.Ignore;
9 import org.junit.Test;
10
11 import java.io.IOException;
12 import java.util.ArrayList;
13 import java.util.HashMap;
15 import java.util.List;
16 import java.util.Map;
17 import java.util.Random;
18
19 import static org.hamcrest.CoreMatchers.is;
20 import static org.junit.Assert.assertThat;
21
22 public class TestQuickScorerTreesModelBenchmark {
23
24     /**
25      * 产生特征
26      * @param featureNum 特征个数
27      * @return
28      */
29     private List<Feature> createDummyFeatures(int featureNum) {
30         List<Feature> features = new ArrayList<>();
31         for (int i = 0; i < featureNum; ++i) {
32             features.add(new Feature("fv_" + i, null) {
33                 @Override
34                 protected void validate() throws FeatureException { }
35
36                 @Override
37                 public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
38                                                   Query originalQuery, Map<String, String[]> efi) throws IOException {
39                     return null;
40                 }
41
42                 @Override
43                 public LinkedHashMap<String, Object> paramsToMap() {
44                     return null;
45                 }
46             });
47         }
48         return features;
49     }
50
51     private List<Normalizer> createDummyNormalizer(int featureNum) {
52         List<Normalizer> normalizers = new ArrayList<>();
53         for (int i = 0; i < featureNum; ++i) {
55         }
56         return normalizers;
57     }
58
59     /**
60      * 创建单棵树
61      * 递归调用自己
62      * @param leafNum 叶子个数
63      * @param features 特征
64      * @param rand 产生随机数
65      * @return
66      */
67     private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
68         Map<String, Object> node = new HashMap<>();
69         if (leafNum == 1) {
70             // leaf
71             node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
72             return node;
73         }
74
75         // branch
76         node.put("feature", features.get(rand.nextInt(features.size())).getName());
77         node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
78         node.put("left", createRandomTree(leafNum / 2, features, rand));
79         node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
80         return node;
81     }
82
83     /**
84      *　这里随机创建多棵树作为model测试
85      * @param treeNum 树的个数
86      * @param leafNum 叶子个数
87      * @param features 特征
88      * @param rand 产生随机数
89      * @return
90      */
91     private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
92                                                            Random rand) {
93         List<Object> trees = new ArrayList<>();
94         for (int i = 0; i < treeNum; ++i) {
95             Map<String, Object> tree = new HashMap<>();
96             tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
97             tree.put("root", createRandomTree(leafNum, features, rand));
99         }
100         return trees;
101     }
102
103     /**
104      * 对比两个打分模型的分值是否一致
105      * @param featureNum 特征个数
106      * @param treeNum 树个数
107      * @param leafNum 叶子个数
108      * @param loopNum 样本个数
109      * @throws Exception
110      */
111     private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
112         Random rand = new Random(0);
113
114         List<Feature> features = createDummyFeatures(featureNum); //产生特征
115         List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化
116
117         for (int i = 0; i < loopNum; ++i) {
118             List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
119
121                     "dummy", features, null);
122             matModel.setTrees(trees);
123             matModel.validate();
124
125             QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
126                     null);
127             qstModel.setTrees(trees);//设置提供的树模型
128             qstModel.validate();//对提供的树结构进行验证
129
130             float[] featureValues = new float[featureNum];
131             for (int j = 0; j < 100; ++j) {
132                 for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5)
133
134                 float expected = matModel.score(featureValues);
135                 float actual = qstModel.score(featureValues);
136                 assertThat(actual, is(expected));
137                 //System.out.println("expected: " + expected + " actual: " + actual);
138             }
139         }
140     }
141
142     /**
143      * 两个模型是否得分一致
144      *
145      * @throws Exception thrown if testcase failed to initialize models
146      */
147     /*@Test
148     public void testAccuracy() throws Exception {
149         compareScore(25, 200, 32, 100);
150         //compareScore(19, 500, 31, 10000);
151     }*/
152
153
154     /**
155      * 对比两个打分模型打分的时间消耗
156      * @param featureNum 特征个数
157      * @param treeNum 树个数
158      * @param leafNum 叶子个数
159      * @param loopNum 样本个数
160      * @throws Exception
161      */
162     private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
163         Random rand = new Random(0);
164
165         //随机产生features
166         List<Feature> features = createDummyFeatures(featureNum);
167         //随机产生normalizer
168         List<Normalizer> norms = createDummyNormalizer(featureNum);
169         //随机创建trees
170         List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
171
174                 "dummy", features, null);
175         matModel.setTrees(trees);
176         matModel.validate();
177
178         //初始化quick scorer trees model
179         QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
180                 null);
181         qstModel.setTrees(trees);
182         qstModel.validate();
183
184         //随机产生样本, loopNum * featureNum
185         float[][] featureValues = new float[loopNum][featureNum];
186         for (int i = 0; i < loopNum; ++i) {
187             for (int k = 0; k < featureNum; ++k) {
188                 featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
189             }
190         }
191
192         long start;
193         /*long matOpNsec = 0;
194         for (int i = 0; i < loopNum; ++i) {
195             start = System.nanoTime();
196             matModel.score(featureValues[i]);
197             matOpNsec += System.nanoTime() - start;
198         }
199         long qstOpNsec = 0;
200         for (int i = 0; i < loopNum; ++i) {
201             start = System.nanoTime();
202             qstModel.score(featureValues[i]);
203             qstOpNsec += System.nanoTime() - start;
204         }
205         System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
206         System.out.println("QuickScorerTreesModel      : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/
207
208         long matOpNsec = 0;
209         start = System.currentTimeMillis();
210         for(int i = 0; i < loopNum; i++) {
211             matModel.score(featureValues[i]);
212         }
213         matOpNsec = System.currentTimeMillis() - start;
214
215         long qstOpNsec = 0;
216         start = System.currentTimeMillis();
217         for(int i = 0; i < loopNum; i++) {
218             qstModel.score(featureValues[i]);
219         }
220         qstOpNsec = System.currentTimeMillis() - start;
221
222         System.out.println("MultipleAdditiveTreesModel : " + matOpNsec);
223
224         System.out.println("QuickScorerTreesModel : " + qstOpNsec);
225
226         //assertThat(matOpNsec > qstOpNsec, is(true));
227     }
228
229     /**
230      * 测试性能
231      * @throws Exception thrown if testcase failed to initialize models
232      */
233
234     @Test
235     public void testPerformance() throws Exception {
236         //features,trees,leafs,samples
237         compareTime(20, 500, 61, 10000);
238     }
239
240 }

0
0 收藏

0 评论
0 收藏
0