数据挖掘作为当前比较热门的领域之一,已经有较为权威的书籍指导入门了,如陈封能著,范明译的《数据挖掘导论》以及力量黑书(机械工业出版社)出版的《数据挖掘概念与技术》。
在介绍完基本的数据挖掘概念之后,第一个动手写代码实现的算法就是Apriori
算法了。Apriori
算法是用于挖掘关联规则的频繁项集算法。
环境支持
Apriori中的剪枝步骤 在介绍伪代码前要介绍Apriori
算法中的剪枝
步骤。在产生K=1,...
频繁项集的过程中,一共有 个候选集,不用百万级的数据,光是n
大于1000的时候都可以产生组合爆炸,更别说对产生的组合进行统计。所以Apriori
算法在统计候选集之前先要把产生的K
候选集作一个剪枝,删除不频繁的K
候选集再开始统计。
具体的做法是检验该K
项集的K-1
子项集是否为频繁项集,如果该K
项集的所有子项集都是频繁项集,那么该K
项集才有可能是频繁项集。
假设有{A,B,C,D}
这个全集,在K=3
有{A,B,C}
和{A,B,D}
等组合(其它组合情况省略)。设频繁阈值为1,而K-1
时存在{\{A,B},{A,C\}}
项集而没有{A,D}
项集,那么{A,B,D}
肯定不是频繁项集,因而在下一轮统计前就可以删除了。
Apriori算法伪代码 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 Input: D: 事务数据库 min_sup: 最小支持度阈值 Ouput: L,D中的频繁项集 Procedure: L1=find_frequent_1_itemsets(D); for(k=2; Lk-1!=EmptySet; k++) { Ck=apriori_gen(Lk-1); for each transaction t in D // 扫描D,进行计数 { Ct=subset(Ck,t); // 得到t的子集,候选集 for each candidate c in Ct { c.count++; } } Lk={c(Ck|c.count>=min_sup)} } return L union_k Lk; Procedure apriori_gen(Lk-1: frequent(k-1)itemset) for each itemset l1 in Lk-1 { for each itemset l2 in Lk-1 { if(l1[1]=l2[1])and...and(l1[k-2]=l2[k-2])and(l1[k-1]<l2[k-1]) then { c=l1 × l2 (Cartesian product) if has_infrequent_subset(c, Lk-1) then delete c; // 删除非频繁的候选 else add c to Ck; } } } return Ck; Procedure has_infrequent_subset(c: candidate k itemset; Lk-1: frequent(k-1)itemset) for each (k-1) subset s of c { if s not in Lk-1 then return true; } return false;
简单介绍了基本概念以及算法之后,一般读者都会迷迷糊糊的,这很正常,毕竟这篇博客是用来讲实现了,伪代码也因为Markdown的代码块里面不允许加载特殊html语法写得难看。建议真的想搞懂的话,就去看上述的书,然后手解模拟一下过程,一般懂了之后,实现的问题就不大了。
如上述环境支持,在这里实现的时候用的是美国国会84年的投票记录,实际上是什么数据集并没有关系,无非就是数据预处理的过程不一样了而已。针对这个数据集,笔者实现了如下代码用于加载与处理一行记录:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 def translate_record (record ): items = [] if record[0 ] == 'republican' : items.append('rep' ) elif record[0 ] == 'democrat' : items.append('demo' ) if record[1 ] == 'y' : items.append('hci' ) elif record[1 ] == 'n' : items.append('hci-n' ) if record[2 ] == 'y' : items.append('wpcs' ) elif record[2 ] == 'n' : items.append('wpcs-n' ) if record[3 ] == 'y' : items.append('aotbr' ) elif record[3 ] == 'n' : items.append('aotbr-n' ) if record[4 ] == 'y' : items.append('pff' ) elif record[4 ] == 'n' : items.append('pff-n' ) if record[5 ] == 'y' : items.append('esa' ) elif record[5 ] == 'n' : items.append('esa-n' ) if record[6 ] == 'y' : items.append('rgis' ) elif record[6 ] == 'n' : items.append('rgis-n' ) if record[7 ] == 'y' : items.append('astb' ) elif record[7 ] == 'n' : items.append('astb-n' ) if record[8 ] == 'y' : items.append('atnc' ) elif record[8 ] == 'n' : items.append('atnc-n' ) if record[9 ] == 'y' : items.append('mxm' ) elif record[9 ] == 'n' : items.append('mxm-n' ) if record[10 ] == 'y' : items.append('imm' ) elif record[10 ] == 'n' : items.append('imm-n' ) if record[11 ] == 'y' : items.append('scc' ) elif record[11 ] == 'n' : items.append('scc-n' ) if record[12 ] == 'y' : items.append('es' ) elif record[12 ] == 'n' : items.append('es-n' ) if record[13 ] == 'y' : items.append('srts' ) elif record[13 ] == 'n' : items.append('srts-n' ) if record[14 ] == 'y' : items.append('cri' ) elif record[14 ] == 'n' : items.append('cri-n' ) if record[15 ] == 'y' : items.append('dfe' ) elif record[15 ] == 'n' : items.append('dfe-n' ) if record[16 ] == 'y' : items.append('eaasa' ) elif record[16 ] == 'n' : items.append('eaasa-n' ) return items def load_data (file_path ): src_data = open (file_path) votes_records = [] for line in src_data: votes_records.append(translate_record(line.strip('\n' ).split(',' ))) src_data.close() return votes_records
其中,load_data
用于读入数据集并返回一个处理好的数据列表。
然后就是伪代码第一行,就是产生1项集,并统计1频繁项集:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 def find_frequent_1_itemsets (vote_records, min_sup=0 ): itemsets = [] for record in vote_records: for element in record: if [element] not in itemsets: itemsets.append([element]) itemsets = map (frozenset , itemsets) itemsets_with_count = {} for candidate in itemsets: for record in vote_records: if candidate.issubset(record): if candidate in itemsets_with_count: itemsets_with_count[candidate] += 1 else : itemsets_with_count[candidate] = 1 qualified_itemsets = {} for key, val in itemsets_with_count.items(): if val >= min_sup: qualified_itemsets[key] = val return qualified_itemsets
注意,为了能够用dict
来统计频繁项集的个数,其中的key
的类型是一个集合,而一般的集合是不能来当作key
的,幸好python提供了frozenset
使得集合作为dict
的key
。不过这样的实现也有很大的性能缺陷,频繁地在普通的集合以及frozenset
中转化,深度复制,性能损失可想而知。
当时实现的时候一个优化的想法就是用一个4Byte的数据结构来表示一个集合,其中的0
和1
分别代表yes
和no
,然而数据还存在第三个状态,用扩展位来解决这个问题,可以通过拼凑的方法使得这样的数据结构可以兼容任意个选项。这样的方法通用性不好,加上这里是第一次实现,所以就还是采取frozenset
这个平凡的办法来处理了。
然后在K>=2
开始,需要用到的has_infrequent_set
,用于检测该候选项是否存在非频繁子项集:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 def has_infrequent_set (un_set, frozen_set ): unfrozen_un = set (un_set) for ele in unfrozen_un: sub_k_1_set = unfrozen_un.copy() sub_k_1_set.remove(ele) sub_k_1_set = frozenset (sub_k_1_set) if sub_k_1_set not in frozen_set: return True return False
这里就体现出了普通的集合和frozenset
的频繁转化了。
接下来是产生K(K>1)
项集的apriori_gen
:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 def apriori_gen (lk ): frozen_set = lk.keys() if len (frozen_set) <= 0 : return [] for ele in frozen_set: set_size = len (ele) + 1 ; break gen_set = [] for ele1 in frozen_set: for ele2 in frozen_set: if not (ele1.issubset(ele2) and ele1.issuperset(ele2)): un_set = ele1.union(ele2) if len (un_set) == set_size and un_set not in gen_set and not has_infrequent_set(un_set, frozen_set): gen_set.append(un_set) return gen_set
准备好Apriori
用到的主要过程之后,就是Apriori
算法的主体了:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 def apriori (record_set, min_sup ): l1 = find_frequent_1_itemsets(record_set, min_sup) L = [l1] k = 1 while len (L[k - 1 ]) > 0 : ck = apriori_gen(L[k - 1 ]) count_set = {} verified_count_set = {} for candidate in ck: count_set[candidate] = 0 for record in record_set: if candidate.issubset(record): count_set[candidate] += 1 for key, val in count_set.items(): if val >= min_sup: verified_count_set[key] = val L.append(verified_count_set.copy()) count_set.clear() verified_count_set.clear() k += 1 return L
经典的Apriori
算法的主体代码实现就到这里就结束了。
算法实现并不难,难得是弄懂这个过程,所以还是看书,自己手动模拟过程更快一些。