weka实战005:基于HashSet实现的apriori关联规则算法

原创
2017/01/17 09:47
阅读数 13

这个一个apriori算法的演示版本,所有的代码都在一个类。仅供研究算法参考


package test;

import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Vector;

//用set写的apriori算法
public class AprioriSetBasedDemo {

	class Transaction {
		/*
		 * 购物记录,用set保存多个货物名
		 */
		private HashSet<String> pnSet = new HashSet<String>();

		public Transaction() {
			pnSet.clear();
		}

		public Transaction(String[] names) {
			pnSet.clear();
			for (String s : names) {
				pnSet.add(s);
			}
		}

		public HashSet<String> getPnSet() {
			return pnSet;
		}

		public void addPname(String s) {
			pnSet.add(s);
		}

		public boolean containSubSet(HashSet<String> subSet) {
			return pnSet.containsAll(subSet);
		}

		@Override
		public String toString() {
			StringBuilder sb = new StringBuilder();
			Iterator<String> iter = pnSet.iterator();
			while (iter.hasNext()) {
				sb.append(iter.next() + ",");
			}
			return "Transaction = [" + sb.toString() + "]";
		}

	}

	class TransactionDB {
		// 记录所有的Transaction
		private Vector<Transaction> vt = new Vector<Transaction>();

		public TransactionDB() {
			vt.clear();
		}

		public int getSize() {
			return vt.size();
		}

		public void addTransaction(Transaction t) {
			vt.addElement(t);
		}

		public Transaction getTransaction(int idx) {
			return vt.elementAt(idx);
		}

	}

	public class AssoRule implements Comparable<AssoRule> {
		private String ruleContent;
		private double confidence;

		public void setRuleContent(String ruleContent) {
			this.ruleContent = ruleContent;
		}

		public void setConfidence(double confidence) {
			this.confidence = confidence;
		}

		public AssoRule(String ruleContent, double confidence) {
			this.ruleContent = ruleContent;
			this.confidence = confidence;
		}

		@Override
		public int compareTo(AssoRule o) {
			if (o.confidence > this.confidence) {
				return 1;
			} else if (o.confidence == this.confidence) {
				return 0;
			} else {
				return -1;
			}
		}

		@Override
		public String toString() {
			return ruleContent + ", confidence=" + confidence * 100 + "%";
		}

	}

	public static String getStringFromSet(HashSet<String> set) {
		StringBuilder sb = new StringBuilder();
		Iterator<String> iter = set.iterator();
		while (iter.hasNext()) {
			sb.append(iter.next() + ", ");
		}
		if (sb.length() > 2) {
			sb.delete(sb.length() - 2, sb.length() - 1);
		}
		return sb.toString();
	}

	// 计算具有最小支持度的一项频繁集 >= minSupport
	public static HashMap<String, Integer> buildMinSupportFrequenceSet(
			TransactionDB tdb, int minSupport) {
		HashMap<String, Integer> minSupportMap = new HashMap<String, Integer>();

		for (int i = 0; i < tdb.getSize(); i++) {
			Transaction t = tdb.getTransaction(i);
			Iterator<String> it = t.getPnSet().iterator();
			while (it.hasNext()) {
				String key = it.next();
				if (minSupportMap.containsKey(key)) {
					minSupportMap.put(key, minSupportMap.get(key) + 1);
				} else {
					minSupportMap.put(key, new Integer(1));
				}
			}
		}

		Iterator<String> iter = minSupportMap.keySet().iterator();
		Vector<String> toBeRemoved = new Vector<String>();
		while (iter.hasNext()) {
			String key = iter.next();
			if (minSupportMap.get(key) < minSupport) {
				toBeRemoved.add(key);
			}
		}

		for (int i = 0; i < toBeRemoved.size(); i++) {
			minSupportMap.remove(toBeRemoved.get(i));
		}

		return minSupportMap;
	}

	public void buildRules(TransactionDB tdb,
			HashMap<HashSet<String>, Integer> kItemFS, Vector<AssoRule> var,
			double ruleMinSupportPer) {

		// 如果kItemFS的成员数量不超过1不需要计算
		if (kItemFS.size() <= 1) {
			return;
		}

		// k+1项频项集
		HashMap<HashSet<String>, Integer> kNextItemFS = new HashMap<HashSet<String>, Integer>();

		// 获得第k项频项集
		@SuppressWarnings("unchecked")
		HashSet<String>[] kItemSets = new HashSet[kItemFS.size()];
		kItemFS.keySet().toArray(kItemSets);

		/*
		 * 根据k项频项集,用两重循环获得k+1项频项集 然后计算有多少个tranction包含这个k+1项频项集
		 * 然后支持比超过ruleMinSupportPer,就可以生成规则,放入规则向量
		 * 然后,将k+1项频项集及其支持度放入kNextItemFS,进入下一轮计算
		 */
		for (int i = 0; i < kItemSets.length - 1; i++) {
			HashSet<String> set_i = kItemSets[i];
			for (int j = i + 1; j < kItemSets.length; j++) {
				HashSet<String> set_j = kItemSets[j];
				// k+1 item set
				HashSet<String> kNextSet = new HashSet<String>();
				kNextSet.addAll(set_i);
				kNextSet.addAll(set_j);
				if (kNextSet.size() <= set_i.size()
						|| kNextSet.size() <= set_j.size()) {
					continue;
				}

				// 计算k+1 item set在所有transaction出现了几次
				int count = 0;
				for (int k = 0; k < tdb.getSize(); k++) {
					if (tdb.getTransaction(k).containSubSet(kNextSet)) {
						count++;
					}
				}
				if (count <= 0) {
					continue;
				}

				Integer n_i = kItemFS.get(set_i);
				double per = 1.0 * count / n_i.intValue();
				if (per >= ruleMinSupportPer) {
					kNextItemFS.put(kNextSet, new Integer(count));
					HashSet<String> tmp = new HashSet<String>();
					tmp.addAll(kNextSet);
					tmp.removeAll(set_i);
					String s1 = "{" + getStringFromSet(set_i) + "}" + "(" + n_i
							+ ")" + "==>" + getStringFromSet(tmp).toString()
							+ "(" + count + ")";
					var.addElement(new AssoRule(s1, per));
				}
			}
		}

		// 进入下一轮计算
		buildRules(tdb, kNextItemFS, var, ruleMinSupportPer);
	}

	public void test() {
		// Transaction数据集
		TransactionDB tdb = new TransactionDB();

		// 添加Transaction交易记录
		tdb.addTransaction(new Transaction(new String[] { "a", "b", "c", "d" }));
		tdb.addTransaction(new Transaction(new String[] { "a", "b" }));
		tdb.addTransaction(new Transaction(new String[] { "b", "c" }));
		tdb.addTransaction(new Transaction(new String[] { "b", "c", "d", "e" }));

		// 规则最小支持度
		double minRuleConfidence = 0.5;
		Vector<AssoRule> vr = computeAssociationRules(tdb, minRuleConfidence);
		// 输出规则
		int i = 0;
		for (AssoRule ar : vr) {
			System.out.println("rule[" + (i++) + "]: " + ar);
		}
	}

	public Vector<AssoRule> computeAssociationRules(TransactionDB tdb,
			double ruleMinSupportPer) {
		// 输出关联规则
		Vector<AssoRule> var = new Vector<AssoRule>();

		// 计算最小支持度频项
		HashMap<String, Integer> minSupportMap = buildMinSupportFrequenceSet(
				tdb, 2);

		// 计算一项频项集
		HashMap<HashSet<String>, Integer> oneItemFS = new HashMap<HashSet<String>, Integer>();
		for (String key : minSupportMap.keySet()) {
			HashSet<String> oneItemSet = new HashSet<String>();
			oneItemSet.add(key);
			oneItemFS.put(oneItemSet, minSupportMap.get(key));
		}

		// 根据一项频项集合,递归计算规则
		buildRules(tdb, oneItemFS, var, ruleMinSupportPer);
		// 将规则按照可信度排序
		Collections.sort(var);
		return var;
	}

	public static void main(String[] args) {
		AprioriSetBasedDemo asbd = new AprioriSetBasedDemo();
		asbd.test();
	}

}

运行结果如下:


rule[0]: {d }(2)==>b (2), confidence=100.0%
rule[1]: {d }(2)==>c (2), confidence=100.0%
rule[2]: {d, a }(1)==>c (1), confidence=100.0%
rule[3]: {d, a }(1)==>b (1), confidence=100.0%
rule[4]: {d, a }(1)==>b (1), confidence=100.0%
rule[5]: {d, c }(2)==>b (2), confidence=100.0%
rule[6]: {d, b, a }(1)==>c (1), confidence=100.0%
rule[7]: {d, b, a }(1)==>c (1), confidence=100.0%
rule[8]: {d, c, a }(1)==>b (1), confidence=100.0%
rule[9]: {b }(4)==>c (3), confidence=75.0%
rule[10]: {b, c }(3)==>d (2), confidence=66.66666666666666%
rule[11]: {b, c }(3)==>d (2), confidence=66.66666666666666%
rule[12]: {d }(2)==>a (1), confidence=50.0%
rule[13]: {b }(4)==>a (2), confidence=50.0%
rule[14]: {d, c }(2)==>b, a (1), confidence=50.0%
rule[15]: {d, b }(2)==>a (1), confidence=50.0%

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部