关于Additive Ensembles of Regression Trees模型的快速打分预测

2019/07/24 10:47
阅读数 15

一.论文《QuickScorer:a Fast Algorithm to Rank Documents with Additive Ensembles of Regression Trees》是为了解决LTR模型的预测问题,如果LTR中的LambdaMart在生成模型时产生的树数和叶结点过多,在对样本打分预测时会遍历每棵树,这样在线上使用时效率较慢,这篇文章主要就是利用了bitvector方法加速打分预测。代码我找了很久没找到开源的,后来无意中在Solr ltr中看到被改动过了的源码,不过这个源码集成在solr中,这里暂时贴出来,后期再剥离出,集成到ranklib中,以便使用。

二.图片解说

1. Ensemble trees原始打分过程

 

像gbdt,lambdamart,xgboost或lightgbm等这样的集成树模型在打分预测阶段,比如来了一个样本,这个样本是vector形式输入到每一棵树中,然后在每棵树中像if else这样的过程走到或映射到每棵树的一个节点中,这个节点就是每棵树的打分,然后将每棵树的打分乘上学习率(shrinkage)加和就是此样本的预测分。

2.论文中提到的打分过程

A.为回归树中的每个分枝打上true和false标签

比如图中样本X=[0.2,1.1,0.2],在回归树的branch中判断X[0],X[1],X[2]的true和false,比如图中根结点X[1]<=1.0,但样本X[1]=1.1,所以是false(走左边是true,右边是false),这样将所有branch打上true和false标签(可以直接打上false标志,不用考虑true),后面需要用到所有的false branch。

B.为每个branch分配一个bitvector

 

这个bitvector中的"0"表示true leaves,比如"001111"表示6个叶结点中的最左边两个叶结点是候选节点。“110011”表示在右子树中true的结点只有中间两个,作为候选结点。

C.打分阶段

此阶段是最后的打分预测阶段,根据前几个图的过程,将所有branch为false的bitvector按位与操作,就会得出样本落在哪个叶结点上。比如图中的结果是"001101",最左边为1的便是最终的叶结点的编号,每个回归树都会这样操作得到预测值,乘上学习率(shrinkage)然后加和就会得到一个样本的预测值。

三.代码

  1 import org.apache.lucene.index.LeafReaderContext;
  2 import org.apache.lucene.search.Explanation;
  3 import org.apache.solr.ltr.feature.Feature;
  4 import org.apache.solr.ltr.model.LTRScoringModel;
  5 import org.apache.solr.ltr.model.ModelException;
  6 import org.apache.solr.ltr.norm.Normalizer;
  7 import org.apache.solr.util.SolrPluginUtils;
  8 
  9 import java.util.*;
 10 
 11 public class MultipleAdditiveTreesModel extends LTRScoringModel {
 12 
 13     // 特征名:索引(从0开始)
 14     private final HashMap<String, Integer> fname2index = new HashMap();
 15     private List<RegressionTree> trees;
 16 
 17     private MultipleAdditiveTreesModel.RegressionTree createRegressionTree(Map<String, Object> map) {
 18         MultipleAdditiveTreesModel.RegressionTree rt = new MultipleAdditiveTreesModel.RegressionTree();
 19         if(map != null) {
 20             SolrPluginUtils.invokeSetters(rt, map.entrySet());
 21         }
 22 
 23         return rt;
 24     }
 25 
 26     private MultipleAdditiveTreesModel.RegressionTreeNode createRegressionTreeNode(Map<String, Object> map) {
 27         MultipleAdditiveTreesModel.RegressionTreeNode rtn = new MultipleAdditiveTreesModel.RegressionTreeNode();
 28         if(map != null) {
 29             SolrPluginUtils.invokeSetters(rtn, map.entrySet());
 30         }
 31 
 32         return rtn;
 33     }
 34 
 35     public void setTrees(Object trees) {
 36         this.trees = new ArrayList();
 37         Iterator var2 = ((List)trees).iterator();
 38 
 39         while(var2.hasNext()) {
 40             Object o = var2.next();
 41             MultipleAdditiveTreesModel.RegressionTree rt = this.createRegressionTree((Map)o);
 42             this.trees.add(rt);
 43         }
 44     }
 45 
 46     public void setTrees(List<RegressionTree> trees) {
 47         this.trees = trees;
 48     }
 49 
 50     public List<RegressionTree> getTrees() {
 51         return this.trees;
 52     }
 53 
 54     public MultipleAdditiveTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName, List<Feature> allFeatures, Map<String, Object> params) {
 55         super(name, features, norms, featureStoreName, allFeatures, params);
 56 
 57         for(int i = 0; i < features.size(); ++i) {
 58             String key = ((Feature)features.get(i)).getName();
 59             this.fname2index.put(key, Integer.valueOf(i));//特征名:索引
 60         }
 61 
 62     }
 63 
 64     public void validate() throws ModelException {
 65         super.validate();
 66         if(this.trees == null) {
 67             throw new ModelException("no trees declared for model " + this.name);
 68         } else {
 69             Iterator var1 = this.trees.iterator();
 70 
 71             while(var1.hasNext()) {
 72                 MultipleAdditiveTreesModel.RegressionTree tree = (MultipleAdditiveTreesModel.RegressionTree)var1.next();
 73                 tree.validate();
 74             }
 75 
 76         }
 77     }
 78 
 79     public float score(float[] modelFeatureValuesNormalized) {
 80         float score = 0.0F;
 81 
 82         MultipleAdditiveTreesModel.RegressionTree t;
 83         for(Iterator var3 = this.trees.iterator(); var3.hasNext(); score += t.score(modelFeatureValuesNormalized)) {
 84             t = (MultipleAdditiveTreesModel.RegressionTree)var3.next();
 85         }
 86 
 87         return score;
 88     }
 89 
 90     public Explanation explain(LeafReaderContext context, int doc, float finalScore, List<Explanation> featureExplanations) {
 91         float[] fv = new float[featureExplanations.size()];
 92         int index = 0;
 93 
 94         for(Iterator details = featureExplanations.iterator(); details.hasNext(); ++index) {
 95             Explanation featureExplain = (Explanation)details.next();
 96             fv[index] = featureExplain.getValue();
 97         }
 98 
 99         ArrayList var12 = new ArrayList();
100         index = 0;
101 
102         for(Iterator var13 = this.trees.iterator(); var13.hasNext(); ++index) {
103             MultipleAdditiveTreesModel.RegressionTree t = (MultipleAdditiveTreesModel.RegressionTree)var13.next();
104             float score = t.score(fv);
105             Explanation p = Explanation.match(score, "tree " + index + " | " + t.explain(fv), new Explanation[0]);
106             var12.add(p);
107         }
108 
109         return Explanation.match(finalScore, this.toString() + " model applied to features, sum of:", var12);
110     }
111 
112     public String toString() {
113         StringBuilder sb = new StringBuilder(this.getClass().getSimpleName());
114         sb.append("(name=").append(this.getName());
115         sb.append(",trees=[");
116 
117         for(int ii = 0; ii < this.trees.size(); ++ii) {
118             if(ii > 0) {
119                 sb.append(',');
120             }
121 
122             sb.append(this.trees.get(ii));
123         }
124 
125         sb.append("])");
126         return sb.toString();
127     }
128 
129     public class RegressionTree {
130         private Float weight;
131         private MultipleAdditiveTreesModel.RegressionTreeNode root;
132 
133         public void setWeight(float weight) {
134             this.weight = new Float(weight);
135         }
136 
137         public void setWeight(String weight) {
138             this.weight = new Float(weight);
139         }
140 
141         public float getWeight() {
142             return this.weight;
143         }
144 
145         public void setRoot(Object root) {
146             this.root = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)root);
147         }
148 
149         public RegressionTreeNode getRoot() {
150             return this.root;
151         }
152 
153         public float score(float[] featureVector) {
154             return this.weight.floatValue() * this.root.score(featureVector);
155         }
156 
157         public String explain(float[] featureVector) {
158             return this.root.explain(featureVector);
159         }
160 
161         public String toString() {
162             StringBuilder sb = new StringBuilder();
163             sb.append("(weight=").append(this.weight);
164             sb.append(",root=").append(this.root);
165             sb.append(")");
166             return sb.toString();
167         }
168 
169         public RegressionTree() {
170         }
171 
172         public void validate() throws ModelException {
173             if(this.weight == null) {
174                 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a weight");
175             } else if(this.root == null) {
176                 throw new ModelException("MultipleAdditiveTreesModel tree doesn\'t contain a tree");
177             } else {
178                 this.root.validate();
179             }
180         }
181     }
182 
183     public class RegressionTreeNode {
184         private static final float NODE_SPLIT_SLACK = 1.0E-6F;
185         private float value = 0.0F;
186         private String feature;
187         private int featureIndex = -1;
188         private Float threshold;
189         private MultipleAdditiveTreesModel.RegressionTreeNode left;
190         private MultipleAdditiveTreesModel.RegressionTreeNode right;
191 
192         public void setValue(float value) {
193             this.value = value;
194         }
195 
196         public void setValue(String value) {
197             this.value = Float.parseFloat(value);
198         }
199 
200         public void setFeature(String feature) {
201             this.feature = feature;
202             Integer idx = (Integer)MultipleAdditiveTreesModel.this.fname2index.get(this.feature);
203             this.featureIndex = idx == null?-1:idx.intValue();
204         }
205 
206         public int getFeatureIndex() {
207             return this.featureIndex;
208         }
209 
210         public void setThreshold(float threshold) {
211             this.threshold = Float.valueOf(threshold + 1.0E-6F);
212         }
213 
214         public void setThreshold(String threshold) {
215             this.threshold = Float.valueOf(Float.parseFloat(threshold) + 1.0E-6F);
216         }
217 
218         public float getThreshold() {
219             return this.threshold;
220         }
221 
222         public void setLeft(Object left) {
223             this.left = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)left);
224         }
225 
226         public RegressionTreeNode getLeft() {
227             return this.left;
228         }
229 
230         public void setRight(Object right) {
231             this.right = MultipleAdditiveTreesModel.this.createRegressionTreeNode((Map)right);
232         }
233 
234         public RegressionTreeNode getRight() {
235             return this.right;
236         }
237 
238         public boolean isLeaf() {
239             return this.feature == null;
240         }
241 
242         public float score(float[] featureVector) {
243             return this.isLeaf()?this.value:(this.featureIndex >= 0 && this.featureIndex < featureVector.length?(featureVector[this.featureIndex] <= this.threshold.floatValue()?this.left.score(featureVector):this.right.score(featureVector)):0.0F);
244         }
245 
246         public String explain(float[] featureVector) {
247             if(this.isLeaf()) {
248                 return "val: " + this.value;
249             } else if(this.featureIndex >= 0 && this.featureIndex < featureVector.length) {
250                 String rval;
251                 if(featureVector[this.featureIndex] <= this.threshold.floatValue()) {
252                     rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " <= " + this.threshold + ", Go Left | ";
253                     return rval + this.left.explain(featureVector);
254                 } else {
255                     rval = "\'" + this.feature + "\':" + featureVector[this.featureIndex] + " > " + this.threshold + ", Go Right | ";
256                     return rval + this.right.explain(featureVector);
257                 }
258             } else {
259                 return "\'" + this.feature + "\' does not exist in FV, Return Zero";
260             }
261         }
262 
263         public String toString() {
264             StringBuilder sb = new StringBuilder();
265             if(this.isLeaf()) {
266                 sb.append(this.value);
267             } else {
268                 sb.append("(feature=").append(this.feature);
269                 sb.append(",threshold=").append(this.threshold.floatValue() - 1.0E-6F);
270                 sb.append(",left=").append(this.left);
271                 sb.append(",right=").append(this.right);
272                 sb.append(')');
273             }
274 
275             return sb.toString();
276         }
277 
278         public RegressionTreeNode() {
279         }
280 
281         public void validate() throws ModelException {
282             if(this.isLeaf()) {
283                 if(this.left != null || this.right != null) {
284                     throw new ModelException("MultipleAdditiveTreesModel tree node is leaf with left=" + this.left + " and right=" + this.right);
285                 }
286             } else if(null == this.threshold) {
287                 throw new ModelException("MultipleAdditiveTreesModel tree node is missing threshold");
288             } else if(null == this.left) {
289                 throw new ModelException("MultipleAdditiveTreesModel tree node is missing left");
290             } else {
291                 this.left.validate();
292                 if(null == this.right) {
293                     throw new ModelException("MultipleAdditiveTreesModel tree node is missing right");
294                 } else {
295                     this.right.validate();
296                 }
297             }
298         }
299     }
300 
301 }
  1 import org.apache.commons.lang.ArrayUtils;
  2 import org.apache.lucene.util.CloseableThreadLocal;
  3 import org.apache.solr.ltr.feature.Feature;
  4 import org.apache.solr.ltr.model.ModelException;
  5 import org.apache.solr.ltr.norm.Normalizer;
  6 
  7 import java.util.*;
  8 
  9 public class QuickScorerTreesModel extends MultipleAdditiveTreesModel{
 10 
 11     private static final long MAX_BITS = 0xFFFFFFFFFFFFFFFFL;
 12 
 13     // 64bits De Bruijn Sequence
 14     // see: http://chessprogramming.wikispaces.com/DeBruijnsequence#Binary alphabet-B(2, 6)
 15     private static final long HASH_BITS = 0x022fdd63cc95386dL;
 16     private static final int[] hashTable = new int[64];
 17 
 18     static {
 19         long hash = HASH_BITS;
 20         for (int i = 0; i < 64; ++i) {
 21             hashTable[(int) (hash >>> 58)] = i;
 22             hash <<= 1;
 23         }
 24     }
 25 
 26     /**
 27      * Finds the index of rightmost bit with O(1) by using De Bruijn strategy.
 28      *
 29      * @param bits target bits (64bits)
 30      * @see <a href="http://supertech.csail.mit.edu/papers/debruijn.pdf">http://supertech.csail.mit.edu/papers/debruijn.pdf</a>
 31      */
 32     private static int findIndexOfRightMostBit(long bits) {
 33         return hashTable[(int) (((bits & -bits) * HASH_BITS) >>> 58)];
 34     }
 35 
 36     /**
 37      * The number of trees of this model.
 38      */
 39     private int treeNum;
 40 
 41     /**
 42      * Weights of each tree.
 43      */
 44     private float[] weights;
 45 
 46     /**
 47      * List of all leaves of this model.
 48      * We use tree instead of value to manage wide (i.e., more than 64 leaves) trees.
 49      */
 50     private RegressionTreeNode[] leaves;
 51 
 52     /**
 53      * Offsets of each leaf block correspond to each tree.
 54      */
 55     private int[] leafOffsets;
 56 
 57     /**
 58      * The number of conditions of this model.
 59      */
 60     private int condNum;
 61 
 62     /**
 63      * Thresholds of each condition.
 64      * These thresholds are grouped by corresponding feature and each block is sorted by threshold values.
 65      */
 66     private float[] thresholds;
 67 
 68     /**
 69      * Corresponding featureIndex of each condition.
 70      */
 71     private int[] featureIndexes;
 72 
 73     /**
 74      * Offsets of each condition block correspond to each feature.
 75      */
 76     private int[] condOffsets;
 77 
 78     /**
 79      * Forward bitvectors of each condition which correspond to original additive trees.
 80      */
 81     private long[] forwardBitVectors;
 82 
 83     /**
 84      * Backward bitvectors of each condition which correspond to inverted additive trees.
 85      */
 86     private long[] backwardBitVectors;
 87 
 88     /**
 89      * Mappings from threasholdes index to tree indexes.
 90      */
 91     private int[] treeIds;
 92 
 93     /**
 94      * Bitvectors of each tree for calculating the score.
 95      * We reuse bitvectors instance in each thread to prevent from re-allocating arrays.
 96      */
 97     private CloseableThreadLocal<long[]> threadLocalTreeBitvectors = null;
 98 
 99     /**
100      * Boolean statistical tendency of this model.
101      * If conditions of the model tend to be false, we use inverted bitvectors for speeding up.
102      */
103     private volatile float falseRatio = 0.5f;
104 
105     /**
106      * The decay factor for updating falseRatio in each evaluation step.
107      * This factor is used like "{@code ratio = preRatio * decay  ratio * (1 - decay)}".
108      */
109     private float falseRatioDecay = 0.99f;
110 
111     /**
112      * Comparable node cost for selecting leaf candidates.
113      */
114     private static class NodeCost implements Comparable<NodeCost> {
115         private final int id;
116         private final int cost;
117         private final int depth;
118         private final int left;
119         private final int right;
120 
121         private NodeCost(int id, int cost, int depth, int left, int right) {
122             this.id = id;
123             this.cost = cost;
124             this.depth = depth;
125             this.left = left;
126             this.right = right;
127         }
128 
129         public int getId() {
130             return id;
131         }
132 
133         public int getLeft() {
134             return left;
135         }
136 
137         public int getRight() {
138             return right;
139         }
140 
141         /**
142          * Sorts by cost and depth.
143          * We prefer cheaper cost and deeper one.
144          */
145         @Override
146         public int compareTo(NodeCost n) {
147             if (cost != n.cost) {
148                 return Integer.compare(cost, n.cost);
149             } else if (depth != n.depth) {
150                 return Integer.compare(n.depth, depth);  // revere order
151             } else {
152                 return Integer.compare(id, n.id);
153             }
154         }
155     }
156 
157     /**
158      * Comparable condition for constructing and sorting bitvectors.
159      */
160     private static class Condition implements Comparable<Condition> {
161         private final int featureIndex;
162         private final float threshold;
163         private final int treeId;
164         private final long forwardBitvector;
165         private final long backwardBitvector;
166 
167         private Condition(int featureIndex, float threshold, int treeId, long forwardBitvector, long backwardBitvector) {
168             this.featureIndex = featureIndex;
169             this.threshold = threshold;
170             this.treeId = treeId;
171             this.forwardBitvector = forwardBitvector;
172             this.backwardBitvector = backwardBitvector;
173         }
174 
175         int getFeatureIndex() {
176             return featureIndex;
177         }
178 
179         float getThreshold() {
180             return threshold;
181         }
182 
183         int getTreeId() {
184             return treeId;
185         }
186 
187         long getForwardBitvector() {
188             return forwardBitvector;
189         }
190 
191         long getBackwardBitvector() {
192             return backwardBitvector;
193         }
194 
195         /*
196          * Sort by featureIndex and threshold with ascent order.
197          */
198         @Override
199         public int compareTo(Condition c) {
200             if (featureIndex != c.featureIndex) {
201                 return Integer.compare(featureIndex, c.featureIndex);
202             } else {
203                 return Float.compare(threshold, c.threshold);
204             }
205         }
206     }
207 
208     /**
209      * Base class for traversing node with depth first order.
210      */
211     private abstract static class Visitor {
212         private int nodeId = 0;
213 
214         int getNodeId() {
215             return nodeId;
216         }
217 
218         void visit(RegressionTree tree) {
219             nodeId = 0;
220             visit(tree.getRoot(), 0);
221         }
222 
223         private void visit(RegressionTreeNode node, int depth) {
224             if (node.isLeaf()) {
225                 doVisitLeaf(node, depth);
226             } else {
227                 // visit children first
228                 visit(node.getLeft(), depth + 1);
229                 visit(node.getRight(), depth + 1);
230 
231                 doVisitBranch(node, depth);
232             }
233             ++nodeId;
234         }
235 
236         protected abstract void doVisitLeaf(RegressionTreeNode node, int depth);
237 
238         protected abstract void doVisitBranch(RegressionTreeNode node, int depth);
239     }
240 
241     /**
242      * {@link Visitor} implementation for calculating the cost of each node.
243      */
244     private static class NodeCostVisitor extends Visitor {
245 
246         private final Stack<AbstractMap.SimpleEntry<Integer, Integer>> idCostStack = new Stack<>();
247         private final PriorityQueue<NodeCost> nodeCostQueue = new PriorityQueue<>();
248 
249         PriorityQueue<NodeCost> getNodeCostQueue() {
250             return nodeCostQueue;
251         }
252 
253         @Override
254         protected void doVisitLeaf(RegressionTreeNode node, int depth) {
255             nodeCostQueue.add(new NodeCost(getNodeId(), 0, depth, -1, -1));
256             idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), 1));
257         }
258 
259         @Override
260         protected void doVisitBranch(RegressionTreeNode node, int depth) {
261             // calculate the cost of this node from children costs
262             final AbstractMap.SimpleEntry<Integer, Integer> rightIdCost = idCostStack.pop();
263             final AbstractMap.SimpleEntry<Integer, Integer> leftIdCost = idCostStack.pop();
264             final int cost = Math.max(leftIdCost.getValue(), rightIdCost.getValue());
265 
266             nodeCostQueue.add(new NodeCost(getNodeId(), cost, depth, leftIdCost.getKey(), rightIdCost.getKey()));
267             idCostStack.push(new AbstractMap.SimpleEntry<>(getNodeId(), cost + 1));
268         }
269     }
270 
271     /**
272      * {@link Visitor} implementation for extracting leaves and bitvectors.
273      */
274     private static class QuickScorerVisitor extends Visitor {
275 
276         private final int treeId;
277         private final int leafNum;
278         private final Set<Integer> leafIdSet;
279         private final Set<Integer> skipIdSet;
280 
281         private final Stack<Long> bitsStack = new Stack<>();
282         private final List<RegressionTreeNode> leafList = new ArrayList<>();
283         private final List<Condition> conditionList = new ArrayList<>();
284 
285         private QuickScorerVisitor(int treeId, int leafNum, Set<Integer> leafIdSet, Set<Integer> skipIdSet) {
286             this.treeId = treeId;
287             this.leafNum = leafNum;
288             this.leafIdSet = leafIdSet;
289             this.skipIdSet = skipIdSet;
290         }
291 
292         List<RegressionTreeNode> getLeafList() {
293             return leafList;
294         }
295 
296         List<Condition> getConditionList() {
297             return conditionList;
298         }
299 
300         private long reverseBits(long bits) {
301             long revBits = 0L;
302             long mask = (1L << (leafNum - 1));
303             for (int i = 0; i < leafNum; ++i) {
304                 if ((bits & mask) != 0L) revBits |= (1L << i);
305                 mask >>>= 1;
306             }
307             return revBits;
308         }
309 
310         @Override
311         protected void doVisitLeaf(RegressionTreeNode node, int depth) {
312             if (skipIdSet.contains(getNodeId())) return;
313 
314             bitsStack.add(1L << leafList.size());  // we use rightmost bit for detecting leaf
315             leafList.add(node);
316         }
317 
318         @Override
319         protected void doVisitBranch(RegressionTreeNode node, int depth) {
320             if (skipIdSet.contains(getNodeId())) return;
321 
322             if (leafIdSet.contains(getNodeId())) {
323                 // an endpoint of QuickScorer
324                 doVisitLeaf(node, depth);
325                 return;
326             }
327 
328             final long rightBits = bitsStack.pop();  // bits of false branch
329             final long leftBits = bitsStack.pop();   // bits of true branch
330       /*
331        * NOTE:
332        *   forwardBitvector  = ~leftBits
333        *   backwardBitvector = ~(reverse(rightBits))
334        */
335             conditionList.add(
336                     new Condition(node.getFeatureIndex(), node.getThreshold(), treeId, ~leftBits, ~reverseBits(rightBits)));
337             bitsStack.add(leftBits | rightBits);
338         }
339     }
340 
341     public QuickScorerTreesModel(String name, List<Feature> features, List<Normalizer> norms, String featureStoreName,
342                                  List<Feature> allFeatures, Map<String, Object> params) {
343         super(name, features, norms, featureStoreName, allFeatures, params);
344     }
345 
346     /**
347      * Set falseRadioDecay parameter of this model.
348      *
349      * @param falseRatioDecay decay parameter for updating falseRatio
350      */
351     public void setFalseRatioDecay(float falseRatioDecay) {
352         this.falseRatioDecay = falseRatioDecay;
353     }
354 
355     /**
356      * @see #setFalseRatioDecay(float)
357      */
358     public void setFalseRatioDecay(String falseRatioDecay) {
359         this.falseRatioDecay = Float.parseFloat(falseRatioDecay);
360     }
361 
362     /**
363      * {@inheritDoc}
364      */
365     @Override
366     public void validate() throws ModelException {
367         // validate trees before initializing QuickScorer
368         super.validate();
369 
370         // initialize QuickScorer with validated trees
371         init(getTrees());
372     }
373 
374     /**
375      * Initializes quick scorer with given trees.
376      * 利用给定的树集初始化快速打分模型
377      *
378      * @param trees base additive trees model
379      */
380     private void init(List<RegressionTree> trees) {
381         this.treeNum = trees.size();
382         this.weights = new float[trees.size()];
383         this.leafOffsets = new int[trees.size() + 1];
384         this.leafOffsets[0] = 0;
385 
386         // re-create tree bitvectors
387         if (this.threadLocalTreeBitvectors != null) this.threadLocalTreeBitvectors.close();
388         this.threadLocalTreeBitvectors = new CloseableThreadLocal<long[]>() {
389             @Override
390             protected long[] initialValue() {
391                 return new long[treeNum];
392             }
393         };
394 
395         int treeId = 0;
396         List<RegressionTreeNode> leafList = new ArrayList<>();
397         List<Condition> conditionList = new ArrayList<>();
398         for (RegressionTree tree : trees) {
399             // select up to 64 leaves from given tree
400             QuickScorerVisitor visitor = fitLeavesTo64bits(treeId, tree);
401 
402             // extract leaves and conditions with selected leaf candidates
403             visitor.visit(tree);
404             leafList.addAll(visitor.getLeafList());
405             conditionList.addAll(visitor.getConditionList());
406 
407             // update weight, offset and treeId
408             this.weights[treeId] = tree.getWeight();
409             this.leafOffsets[treeId + 1] = this.leafOffsets[treeId] + visitor.getLeafList().size();
410             ++treeId;
411         }
412 
413         // remap list to array for performance reason
414         this.leaves = leafList.toArray(new RegressionTreeNode[0]);
415 
416         // sort conditions by ascent order of featureIndex and threshold
417         Collections.sort(conditionList);
418 
419         // remap information of conditions
420         int idx = 0;
421         int preFeatureIndex = -1;
422         this.condNum = conditionList.size();
423         this.thresholds = new float[conditionList.size()];
424         this.forwardBitVectors = new long[conditionList.size()];
425         this.backwardBitVectors = new long[conditionList.size()];
426         this.treeIds = new int[conditionList.size()];
427         List<Integer> featureIndexList = new ArrayList<>();
428         List<Integer> condOffsetList = new ArrayList<>();
429         for (Condition condition : conditionList) {
430             this.thresholds[idx] = condition.threshold;
431             this.forwardBitVectors[idx] = condition.getForwardBitvector();
432             this.backwardBitVectors[idx] = condition.getBackwardBitvector();
433             this.treeIds[idx] = condition.getTreeId();
434 
435             if (preFeatureIndex != condition.getFeatureIndex()) {
436                 featureIndexList.add(condition.getFeatureIndex());
437                 condOffsetList.add(idx);
438                 preFeatureIndex = condition.getFeatureIndex();
439             }
440 
441             ++idx;
442         }
443         condOffsetList.add(conditionList.size()); // guard
444 
445         this.featureIndexes = ArrayUtils.toPrimitive(featureIndexList.toArray(new Integer[0]));
446         this.condOffsets = ArrayUtils.toPrimitive(condOffsetList.toArray(new Integer[0]));
447     }
448 
449     /**
450      * Checks costs of all nodes and select leaves up to 64.
451      *
452      * <p>NOTE:
453      * We can use {@link java.util.BitSet} instead of {@code long} to represent bitvectors longer than 64bits.
454      * However, this modification caused performance degradation in our experiments, and we decided to use this form.
455      *
456      * @param treeId index of given regression tree
457      * @param tree target regression tree
458      * @return QuickScorerVisitor with proper id sets
459      */
460     private QuickScorerVisitor fitLeavesTo64bits(int treeId, RegressionTree tree) {
461         // calculate costs of all nodes
462         NodeCostVisitor nodeCostVisitor = new NodeCostVisitor();
463         nodeCostVisitor.visit(tree);
464 
465         // poll zero cost nodes (i.e., real leaves)
466         Set<Integer> leafIdSet = new HashSet<>();
467         Set<Integer> skipIdSet = new HashSet<>();
468         while (!nodeCostVisitor.getNodeCostQueue().isEmpty()) {
469             if (nodeCostVisitor.getNodeCostQueue().peek().cost > 0) break;
470             NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
471             leafIdSet.add(nodeCost.id);
472         }
473 
474         // merge leaves until the number of leaves reaches 64
475         while (leafIdSet.size() > 64) {
476             final NodeCost nodeCost = nodeCostVisitor.getNodeCostQueue().poll();
477             assert nodeCost.left >= 0 && nodeCost.right >= 0;
478 
479             // update leaves
480             leafIdSet.remove(nodeCost.left);
481             leafIdSet.remove(nodeCost.right);
482             leafIdSet.add(nodeCost.id);
483 
484             // register previous leaves to skip ids
485             skipIdSet.add(nodeCost.left);
486             skipIdSet.add(nodeCost.right);
487         }
488 
489         return new QuickScorerVisitor(treeId, leafIdSet.size(), leafIdSet, skipIdSet);
490     }
491 
492     /**
493      * {@inheritDoc}
494      */
495     @Override
496     public float score(float[] modelFeatureValuesNormalized) {
497         assert threadLocalTreeBitvectors != null;
498         long[] treeBitvectors = threadLocalTreeBitvectors.get();
499         Arrays.fill(treeBitvectors, MAX_BITS);
500 
501         int falseNum = 0;
502         float score = 0.0f;
503         if (falseRatio <= 0.5) {
504             // use forward bitvectors
505             for (int i = 0; i < condOffsets.length - 1; ++i) {
506                 final int featureIndex = featureIndexes[i];
507                 for (int j = condOffsets[i]; j < condOffsets[i + 1]; ++j) {
508                     if (modelFeatureValuesNormalized[featureIndex] <= thresholds[j]) break;
509                     treeBitvectors[treeIds[j]] &= forwardBitVectors[j];
510                     ++falseNum;
511                 }
512             }
513 
514             for (int i = 0; i < leafOffsets.length - 1; ++i) {
515                 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
516                 score += weights[i] * leaves[leafOffsets[i] + leafIdx].score(modelFeatureValuesNormalized);
517             }
518         } else {
519             // use backward bitvectors
520             falseNum = condNum;
521             for (int i = 0; i < condOffsets.length - 1; ++i) {
522                 final int featureIndex = featureIndexes[i];
523                 for (int j = condOffsets[i + 1] - 1; j >= condOffsets[i]; --j) {
524                     if (modelFeatureValuesNormalized[featureIndex] > thresholds[j]) break;
525                     treeBitvectors[treeIds[j]] &= backwardBitVectors[j];
526                     --falseNum;
527                 }
528             }
529 
530             for (int i = 0; i < leafOffsets.length - 1; ++i) {
531                 final int leafIdx = findIndexOfRightMostBit(treeBitvectors[i]);
532                 score += weights[i] * leaves[leafOffsets[i + 1] - 1 - leafIdx].score(modelFeatureValuesNormalized);
533             }
534         }
535 
536         // update false ratio
537         falseRatio = falseRatio * falseRatioDecay + (falseNum * 1.0f / condNum) * (1.0f - falseRatioDecay);
538         return score;
539     }
540 
541 }
  1 import org.apache.lucene.search.IndexSearcher;
  2 import org.apache.lucene.search.Query;
  3 import org.apache.solr.ltr.feature.Feature;
  4 import org.apache.solr.ltr.feature.FeatureException;
  5 import org.apache.solr.ltr.norm.IdentityNormalizer;
  6 import org.apache.solr.ltr.norm.Normalizer;
  7 import org.apache.solr.request.SolrQueryRequest;
  8 import org.junit.Ignore;
  9 import org.junit.Test;
 10 
 11 import java.io.IOException;
 12 import java.util.ArrayList;
 13 import java.util.HashMap;
 14 import java.util.LinkedHashMap;
 15 import java.util.List;
 16 import java.util.Map;
 17 import java.util.Random;
 18 
 19 import static org.hamcrest.CoreMatchers.is;
 20 import static org.junit.Assert.assertThat;
 21 
 22 public class TestQuickScorerTreesModelBenchmark {
 23 
 24     /**
 25      * 产生特征
 26      * @param featureNum 特征个数
 27      * @return
 28      */
 29     private List<Feature> createDummyFeatures(int featureNum) {
 30         List<Feature> features = new ArrayList<>();
 31         for (int i = 0; i < featureNum; ++i) {
 32             features.add(new Feature("fv_" + i, null) {
 33                 @Override
 34                 protected void validate() throws FeatureException { }
 35 
 36                 @Override
 37                 public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores, SolrQueryRequest request,
 38                                                   Query originalQuery, Map<String, String[]> efi) throws IOException {
 39                     return null;
 40                 }
 41 
 42                 @Override
 43                 public LinkedHashMap<String, Object> paramsToMap() {
 44                     return null;
 45                 }
 46             });
 47         }
 48         return features;
 49     }
 50 
 51     private List<Normalizer> createDummyNormalizer(int featureNum) {
 52         List<Normalizer> normalizers = new ArrayList<>();
 53         for (int i = 0; i < featureNum; ++i) {
 54             normalizers.add(new IdentityNormalizer());
 55         }
 56         return normalizers;
 57     }
 58 
 59     /**
 60      * 创建单棵树
 61      * 递归调用自己
 62      * @param leafNum 叶子个数
 63      * @param features 特征
 64      * @param rand 产生随机数
 65      * @return
 66      */
 67     private Map<String, Object> createRandomTree(int leafNum, List<Feature> features, Random rand) {
 68         Map<String, Object> node = new HashMap<>();
 69         if (leafNum == 1) {
 70             // leaf
 71             node.put("value", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
 72             return node;
 73         }
 74 
 75         // branch
 76         node.put("feature", features.get(rand.nextInt(features.size())).getName());
 77         node.put("threshold", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5)
 78         node.put("left", createRandomTree(leafNum / 2, features, rand));
 79         node.put("right", createRandomTree(leafNum - leafNum / 2, features, rand));
 80         return node;
 81     }
 82 
 83     /**
 84      * 这里随机创建多棵树作为model测试
 85      * @param treeNum 树的个数
 86      * @param leafNum 叶子个数
 87      * @param features 特征
 88      * @param rand 产生随机数
 89      * @return
 90      */
 91     private List<Object> createRandomMultipleAdditiveTrees(int treeNum, int leafNum, List<Feature> features,
 92                                                            Random rand) {
 93         List<Object> trees = new ArrayList<>();
 94         for (int i = 0; i < treeNum; ++i) {
 95             Map<String, Object> tree = new HashMap<>();
 96             tree.put("weight", Float.toString(rand.nextFloat() - 0.5f)); // [-0.5, 0.5) 设置每棵树的学习率
 97             tree.put("root", createRandomTree(leafNum, features, rand));
 98             trees.add(tree);
 99         }
100         return trees;
101     }
102 
103     /**
104      * 对比两个打分模型的分值是否一致
105      * @param featureNum 特征个数
106      * @param treeNum 树个数
107      * @param leafNum 叶子个数
108      * @param loopNum 样本个数
109      * @throws Exception
110      */
111     private void compareScore(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
112         Random rand = new Random(0);
113 
114         List<Feature> features = createDummyFeatures(featureNum); //产生特征
115         List<Normalizer> norms = createDummyNormalizer(featureNum); //标准化
116 
117         for (int i = 0; i < loopNum; ++i) {
118             List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
119 
120             MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
121                     "dummy", features, null);
122             matModel.setTrees(trees);
123             matModel.validate();
124 
125             QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
126                     null);
127             qstModel.setTrees(trees);//设置提供的树模型
128             qstModel.validate();//对提供的树结构进行验证
129 
130             float[] featureValues = new float[featureNum];
131             for (int j = 0; j < 100; ++j) {
132                 for (int k = 0; k < featureNum; ++k) featureValues[k] = rand.nextFloat() - 0.5f; // [-0.5, 0.5)
133 
134                 float expected = matModel.score(featureValues);
135                 float actual = qstModel.score(featureValues);
136                 assertThat(actual, is(expected));
137                 //System.out.println("expected: " + expected + " actual: " + actual);
138             }
139         }
140     }
141 
142     /**
143      * 两个模型是否得分一致
144      *
145      * @throws Exception thrown if testcase failed to initialize models
146      */
147     /*@Test
148     public void testAccuracy() throws Exception {
149         compareScore(25, 200, 32, 100);
150         //compareScore(19, 500, 31, 10000);
151     }*/
152 
153 
154     /**
155      * 对比两个打分模型打分的时间消耗
156      * @param featureNum 特征个数
157      * @param treeNum 树个数
158      * @param leafNum 叶子个数
159      * @param loopNum 样本个数
160      * @throws Exception
161      */
162     private void compareTime(int featureNum, int treeNum, int leafNum, int loopNum) throws Exception {
163         Random rand = new Random(0);
164 
165         //随机产生features
166         List<Feature> features = createDummyFeatures(featureNum);
167         //随机产生normalizer
168         List<Normalizer> norms = createDummyNormalizer(featureNum);
169         //随机创建trees
170         List<Object> trees = createRandomMultipleAdditiveTrees(treeNum, leafNum, features, rand);
171 
172         //初始化multiple additive trees model
173         MultipleAdditiveTreesModel matModel = new MultipleAdditiveTreesModel("multipleadditivetrees", features, norms,
174                 "dummy", features, null);
175         matModel.setTrees(trees);
176         matModel.validate();
177 
178         //初始化quick scorer trees model
179         QuickScorerTreesModel qstModel = new QuickScorerTreesModel("quickscorertrees", features, norms, "dummy", features,
180                 null);
181         qstModel.setTrees(trees);
182         qstModel.validate();
183 
184         //随机产生样本, loopNum * featureNum
185         float[][] featureValues = new float[loopNum][featureNum];
186         for (int i = 0; i < loopNum; ++i) {
187             for (int k = 0; k < featureNum; ++k) {
188                 featureValues[i][k] = rand.nextFloat() * 2.0f - 1.0f; // [-1.0, 1.0)
189             }
190         }
191 
192         long start;
193         /*long matOpNsec = 0;
194         for (int i = 0; i < loopNum; ++i) {
195             start = System.nanoTime();
196             matModel.score(featureValues[i]);
197             matOpNsec += System.nanoTime() - start;
198         }
199         long qstOpNsec = 0;
200         for (int i = 0; i < loopNum; ++i) {
201             start = System.nanoTime();
202             qstModel.score(featureValues[i]);
203             qstOpNsec += System.nanoTime() - start;
204         }
205         System.out.println("MultipleAdditiveTreesModel : " + matOpNsec / 1000.0 / loopNum + " usec/op");
206         System.out.println("QuickScorerTreesModel      : " + qstOpNsec / 1000.0 / loopNum + " usec/op");*/
207 
208         long matOpNsec = 0;
209         start = System.currentTimeMillis();
210         for(int i = 0; i < loopNum; i++) {
211             matModel.score(featureValues[i]);
212         }
213         matOpNsec = System.currentTimeMillis() - start;
214 
215         long qstOpNsec = 0;
216         start = System.currentTimeMillis();
217         for(int i = 0; i < loopNum; i++) {
218             qstModel.score(featureValues[i]);
219         }
220         qstOpNsec = System.currentTimeMillis() - start;
221 
222         System.out.println("MultipleAdditiveTreesModel : " + matOpNsec);
223 
224         System.out.println("QuickScorerTreesModel : " + qstOpNsec);
225 
226         //assertThat(matOpNsec > qstOpNsec, is(true));
227     }
228 
229     /**
230      * 测试性能
231      * @throws Exception thrown if testcase failed to initialize models
232      */
233 
234     @Test
235     public void testPerformance() throws Exception {
236         //features,trees,leafs,samples
237         compareTime(20, 500, 61, 10000);
238     }
239 
240 }

 

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部