#!/usr/bin/env python3

import argparse
import sys

def _load_source(source):
    if source == '-':
        return sys.stdin.read().splitlines()
    else:
        with open(source, 'r') as f:
            return f.read().splitlines()

def cli():
    parser = argparse.ArgumentParser()
    parser.add_argument('sets', nargs="+", help='Each file is a set and each line in the file is a member of the set')
    parser.add_argument('-n', '--non-empty', action='store_true', default=False, help='non-empty values only')
    parser.add_argument('-s', '--strip', action='store_true', default=False, help='strip surrounding whitespace')
    parser.add_argument('-f', '--filter', action='store_true', default=False, help='strip and filter to non-empty')
    parser.add_argument('-o', '--operation',
        choices=['+', '-', 'x', 'd', 'union', 'difference', 'intersection', 'unique'],
        default='+',
        help="""
            Operation to perform on the sets
            [-] Subtract sets 1...N from set 0
            [+] Get the union of sets 0...N
            [x] Get the intersection of sets 0...N
            [d] Symmetric difference (disjunctive union). Elements from all sets which are not in any others.
        """
    )
    args = parser.parse_args()

    sets = []
    should_strip = args.filter or args.strip
    should_drop_empty = args.filter or args.non_empty
    for source in args.sets:
        lines = _load_source(source)
        if should_strip:
            lines = [l.strip() for l in lines]
        if should_drop_empty:
            lines = [l for l in lines if l]
        sets.append(set(lines))

    if args.operation in ('-', 'difference'):
        result = sets[0]
        result.difference_update(*sets[1:])
    elif args.operation in ('+', 'union'):
        result = sets[0]
        result.update(*sets[1:])
    elif args.operation in ('x', 'intersection'):
        result = sets[0]
        result.intersection_update(*sets[1:])
    elif args.operation in ('d', 'unique'):
        result = set()
        for index, primary in enumerate(sets):
            others = sets[0:index] + sets[index + 1:]
            difference = primary.copy()
            difference.difference_update(*others)
            result.update(difference)
    else:
        print('Unsupported operation "{}"'.format(args.operation))


    sorted_result = sorted(list(result))
    print('\n'.join(sorted_result))

if __name__ == '__main__':
    cli()
