6
\$\begingroup\$

This takes in a dataset, the minimum support and the minimum confidence values as its options, and returns the association rules.

I'm looking for pointers towards better optimization, documentation and code quality.

"""
Description     : A Python implementation of the Apriori Algorithm

Usage:
    $python apriori.py -f DATASET.csv -s minSupport  -c minConfidence

    $python apriori.py -f DATASET.csv -s 0.15 -c 0.6
"""

import sys

from itertools import chain, combinations
from collections import defaultdict
from optparse import OptionParser


def subsets(arr):

    """ 
    Returns non empty subsets of arr

    enumerate(arr)       <= returns the following format "<index>, <array element>"
    combinations(arr, i) <= returns all i-length combinations of the array.
    chain(arr)           <= unpackas a list of lists

    """
    return chain(*[combinations(arr, i + 1) for i, a in enumerate(arr)])


def returnItemsWithMinSupport(itemSet, transactionList, minSupport, freqSet):
        """calculates the support for items in the itemSet and returns a subset
       of the itemSet each of whose elements satisfies the minimum support


       """
        _itemSet = set()
        localSet = defaultdict(int)

        for item in itemSet:
                for transaction in transactionList:
                        if item.issubset(transaction):
                                freqSet[item] += 1
                                localSet[item] += 1

        for item, count in localSet.items():
                support = float(count)/len(transactionList)

                if support >= minSupport:
                        _itemSet.add(item)

        return _itemSet


def joinSet(itemSet, length):
        """Join a set with itself and returns the n-element itemsets"""
        return set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length])


def getItemSetTransactionList(data_iterator):

    """

    Takes data from dataFromFile() and returns list of items and a list of transactions
    and generate two seperate sets of items and transactions.

    The item list would be: 
    ([frozenset(['apple']), frozenset(['beer']), frozenset(['chicken']), etc

    The transaction list would be:
    frozenset(['beer', 'rice', 'apple', 'chicken']), frozenset(['beer', 'rice', 'apple']), etc

    """
    transactionList = list()
    itemSet = set()
    for record in data_iterator:
        transaction = frozenset(record)
        transactionList.append(transaction)
        for item in transaction:
            itemSet.add(frozenset([item]))              # Generate 1-itemSets
    return itemSet, transactionList


def runApriori(data_iter, minSupport, minConfidence):
    """
    run the apriori algorithm. data_iter is a record iterator
    Return both:
     - items (tuple, support)
     - rules ((pretuple, posttuple), confidence)
    """
    itemSet, transactionList = getItemSetTransactionList(data_iter)

    freqSet = defaultdict(int)
    largeSet = dict()
    # Global dictionary which stores (key=n-itemSets,value=support)
    # which satisfy minSupport

    assocRules = dict()
    # Dictionary which stores Association Rules

    oneCSet = returnItemsWithMinSupport(itemSet,
                                        transactionList,
                                        minSupport,
                                        freqSet)

    currentLSet = oneCSet
    k = 2
    while(currentLSet != set([])):
        largeSet[k-1] = currentLSet
        currentLSet = joinSet(currentLSet, k)
        currentCSet = returnItemsWithMinSupport(currentLSet,
                                                transactionList,
                                                minSupport,
                                                freqSet)
        currentLSet = currentCSet
        k = k + 1

    def getSupport(item):
            """local function which Returns the support of an item"""
            return float(freqSet[item])/len(transactionList)

    toRetItems = []
    for key, value in largeSet.items():
        toRetItems.extend([(tuple(item), getSupport(item))
                           for item in value])

    toRetRules = []
    for key, value in largeSet.items()[1:]:
        for item in value:
            _subsets = map(frozenset, [x for x in subsets(item)])
            for element in _subsets:
                remain = item.difference(element)
                if len(remain) > 0:
                    confidence = getSupport(item)/getSupport(element)
                    if confidence >= minConfidence:
                        toRetRules.append(((tuple(element), tuple(remain)),
                                           confidence))
    return toRetItems, toRetRules


def printResults(items, rules):
    """prints the generated itemsets sorted by support and the confidence rules sorted by confidence"""
    for item, support in sorted(items, key=lambda (item, support): support):
        print "item: %s , %.3f" % (str(item), support)
    print "\n------------------------ RULES:"
    for rule, confidence in sorted(rules, key=lambda (rule, confidence): confidence):
        pre, post = rule
        print "Rule: %s ==> %s , %.3f" % (str(pre), str(post), confidence)


def dataFromFile(fname):
        """
        Function which reads from the file and yields a generator of frozen sets of each line in the csv

        The first line of tesco.csv file returns the following output:
        frozenset(['beer', 'rice', 'apple', 'chicken'])
        """
        file_iter = open(fname, 'rU')
        for line in file_iter:
                line = line.strip().rstrip(',')                         # Remove trailing comma
                record = frozenset(line.split(','))
                yield record


if __name__ == "__main__":

    optparser = OptionParser()
    optparser.add_option('-f', '--inputFile',
                         dest='input',
                         help='filename containing csv',
                         default=None)
    optparser.add_option('-s', '--minSupport',
                         dest='minS',
                         help='minimum support value',
                         default=0.15,
                         type='float')
    optparser.add_option('-c', '--minConfidence',
                         dest='minC',
                         help='minimum confidence value',
                         default=0.6,
                         type='float')

    (options, args) = optparser.parse_args()

    inFile = None
    if options.input is None:
            inFile = sys.stdin
    elif options.input is not None:
            inFile = dataFromFile(options.input)
    else:
            print 'No dataset filename specified, system with exit\n'
            sys.exit('System will exit')

    minSupport = options.minS
    minConfidence = options.minC

    items, rules = runApriori(inFile, minSupport, minConfidence)

    printResults(items, rules)

Same data is the following csv file:

apple,beer,beer,rice,chicken
apple,beer,beer,rice
apple,beer,beer
apple,mango
milk,beer,beer,rice,chicken
milk,beer,rice
milk,beer
milk,mango
\$\endgroup\$
6
  • 3
    \$\begingroup\$ Brace yourselves, antiCamelCaseRecommendationsAreComing :) \$\endgroup\$
    – BusyAnt
    Commented Jul 28, 2016 at 11:05
  • \$\begingroup\$ It would be really useful if you provided a sample of the data you are reading from. \$\endgroup\$ Commented Jul 28, 2016 at 13:04
  • \$\begingroup\$ @OscarSmith Added the sample data (csv file in the question) :) \$\endgroup\$
    – Dawny33
    Commented Jul 28, 2016 at 14:33
  • 1
    \$\begingroup\$ @MathiasEttinger No. Thanks for pointing it out. Would use argparse then :) \$\endgroup\$
    – Dawny33
    Commented Jul 28, 2016 at 16:46
  • 1
    \$\begingroup\$ @Dawny33 As regard to your module docstring, I can also suggest docopt \$\endgroup\$ Commented Jul 28, 2016 at 16:48

1 Answer 1

5
\$\begingroup\$

My biggest piece of advice would be to replace freqSet = defaultdict(int) with a Counter. Counters are a datatype designed to do exactly what you are doing with defaultdicts, and they have some specialized methods.

for item in itemSet:
            for transaction in transactionList:
                    if item.issubset(transaction):
                        freqSet[item] += 1

Could be replaced with

freqSet.update(item for item in itemSet for transaction in TransactionList if item.issubset(transaction))

This should be a pretty big speed increase. Also, set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length]) could be written using a set comprehension, which would lower memory usage, and increase speed.

\$\endgroup\$
2
  • \$\begingroup\$ I really should learn how to spell comprehension. \$\endgroup\$ Commented Jul 28, 2016 at 14:41
  • \$\begingroup\$ Possible to explain how do I write the line set([i.union(j) for i in itemSet for j in itemSet if len(i.union(j)) == length]) as a set comprehension? And the freqSet.update() is returning some weird error :( \$\endgroup\$
    – Dawny33
    Commented Aug 1, 2016 at 8:37

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.