Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 73 additions & 63 deletions machine_learning/apriori_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
Examples: https://www.kaggle.com/code/earthian/apriori-association-rules-mining
"""

from collections import Counter
from collections import defaultdict
from itertools import combinations


Expand All @@ -25,78 +25,88 @@ def load_data() -> list[list[str]]:
return [["milk"], ["milk", "butter"], ["milk", "bread"], ["milk", "bread", "chips"]]


def prune(itemset: list, candidates: list, length: int) -> list:
"""
Prune candidate itemsets that are not frequent.
The goal of pruning is to filter out candidate itemsets that are not frequent. This
is done by checking if all the (k-1) subsets of a candidate itemset are present in
the frequent itemsets of the previous iteration (valid subsequences of the frequent
itemsets from the previous iteration).

Prunes candidate itemsets that are not frequent.

>>> itemset = ['X', 'Y', 'Z']
>>> candidates = [['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]
>>> prune(itemset, candidates, 2)
[['X', 'Y'], ['X', 'Z'], ['Y', 'Z']]

>>> itemset = ['1', '2', '3', '4']
>>> candidates = ['1', '2', '4']
>>> prune(itemset, candidates, 3)
[]
# ---------- Helpers ----------


def get_support(itemset, transactions):
"""Compute support count of an itemset efficiently."""
return sum(1 for t in transactions if itemset.issubset(t))


def generate_candidates(prev_frequent, k):
"""
itemset_counter = Counter(tuple(item) for item in itemset)
pruned = []
for candidate in candidates:
is_subsequence = True
for item in candidate:
item_tuple = tuple(item)
if (
item_tuple not in itemset_counter
or itemset_counter[item_tuple] < length - 1
):
is_subsequence = False
break
if is_subsequence:
pruned.append(candidate)
return pruned


def apriori(data: list[list[str]], min_support: int) -> list[tuple[list[str], int]]:
Generate candidate itemsets of size k from frequent itemsets of size k-1.
"""
Returns a list of frequent itemsets and their support counts.
prev_list = list(prev_frequent)
candidates = set()

>>> data = [['A', 'B', 'C'], ['A', 'B'], ['A', 'C'], ['A', 'D'], ['B', 'C']]
>>> apriori(data, 2)
[(['A', 'B'], 1), (['A', 'C'], 2), (['B', 'C'], 2)]
for i in range(len(prev_list)):
for j in range(i + 1, len(prev_list)):
union = prev_list[i] | prev_list[j]
if len(union) == k:
candidates.add(union)

>>> data = [['1', '2', '3'], ['1', '2'], ['1', '3'], ['1', '4'], ['2', '3']]
>>> apriori(data, 3)
[]
return candidates


def has_infrequent_subset(candidate, prev_frequent):
"""
itemset = [list(transaction) for transaction in data]
frequent_itemsets = []
length = 1
Apriori pruning: all (k-1)-subsets must be frequent.
"""
for subset in combinations(candidate, len(candidate) - 1):
if frozenset(subset) not in prev_frequent:
return True
return False


# ---------- Main Apriori ----------


def apriori(data: list[list[str]], min_support: int):
transactions = [set(t) for t in data]

# 1. initial 1-itemsets
item_counts = defaultdict(int)
for t in transactions:
for item in t:
item_counts[frozenset([item])] += 1

frequent = {
itemset for itemset, count in item_counts.items() if count >= min_support
}

all_frequents = [
(next(iter(i)), c) for i, c in item_counts.items() if c >= min_support
]

k = 2

while frequent:
# 2. generate candidates
candidates = generate_candidates(frequent, k)

# 3. prune
candidates = {c for c in candidates if not has_infrequent_subset(c, frequent)}

while itemset:
# Count itemset support
counts = [0] * len(itemset)
for transaction in data:
for j, candidate in enumerate(itemset):
if all(item in transaction for item in candidate):
counts[j] += 1
# 4. count support
candidate_counts = defaultdict(int)
for t in transactions:
for c in candidates:
if c.issubset(t):
candidate_counts[c] += 1

# Prune infrequent itemsets
itemset = [item for i, item in enumerate(itemset) if counts[i] >= min_support]
# 5. filter frequent
frequent = {c for c, count in candidate_counts.items() if count >= min_support}

# Append frequent itemsets (as a list to maintain order)
for i, item in enumerate(itemset):
frequent_itemsets.append((sorted(item), counts[i]))
all_frequents.extend(
(sorted(c), count)
for c, count in candidate_counts.items()
if count >= min_support
)

length += 1
itemset = prune(itemset, list(combinations(itemset, length)), length)
k += 1

return frequent_itemsets
return all_frequents


if __name__ == "__main__":
Expand Down