Source code for Goulib.itertools2

"""
additions to :mod:`itertools` standard library
"""
from itertools import repeat
__author__ = "Philippe Guglielmetti"
__copyright__ = "Copyright 2012, Philippe Guglielmetti"
__credits__ = ["functional toolset from http://pyeuler.wikidot.com/toolset",
               "algos from https://github.com/tokland/pyeuler/blob/master/pyeuler/toolset.py",
               "tools from http://docs.python.org/dev/py3k/library/html",
               "https://github.com/erikrose/more-itertools"
               ]
__license__ = "LGPL"

import itertools
import random
import operator
import collections
import heapq
import logging

# pure logic


[docs]def identity(x): """Do nothing and return the variable untouched""" return x
[docs]def isiterable(obj): """ :result: bool True if obj is iterable (but not a string) """ # http://stackoverflow.com/questions/1055360/how-to-tell-a-variable-is-iterable-but-not-a-string if isinstance(obj, str): return False # required since Python 3.5 return isinstance(obj, collections.Iterable)
[docs]def iscallable(f): return isinstance(f, collections.Callable)
[docs]def anyf(seq, pred=bool): """ :result: bool True if pred(x) is True for at least one element in the iterable """ return (True in map(pred, seq))
[docs]def allf(seq, pred=bool): """ :result: bool True if pred(x) is True for all elements in the iterable """ return (False not in map(pred, seq))
[docs]def no(seq, pred=bool): """ :result: bool True if pred(x) is False for every element in the iterable """ return (True not in map(pred, seq))
[docs]def index(value, iterable): """ :result: integer index of value in iterable """ for i, v in enumerate(iterable): if v == value: return i raise IndexError
# accessors
[docs]def ith(iterable, i): """ :result: i-th element in the iterable """ for j, x in enumerate(iterable): if i == j: return x # works in all cases by definition of iterable raise IndexError
[docs]def first(iterable): """ :result: first element in the iterable """ return ith(iterable, 0)
[docs]def last(iterable): """ :result: last element in the iterable """ for x in iterable: res = x try: return res except Exception: raise IndexError
[docs]def takeevery(n, iterable, start=0): """Take an element from iterator every n elements""" return itertools.islice(iterable, start, None, n)
every = takeevery
[docs]def take(n, iterable): """ :result: first n items from iterable """ return itertools.islice(iterable, n)
[docs]def drop(n, iterable): """Drop n elements from iterable and return the rest""" return itertools.islice(iterable, n, None)
[docs]def enumerates(iterable): """ generalizes enumerate to dicts :result: key,value pair for whatever iterable type """ if isinstance(iterable, dict): return iterable.items() return enumerate(iterable)
[docs]def ilen(it): """ :result: int length exhausting an iterator """ try: return len(it) # much faster if defined... except TypeError: return sum(1 for _ in it)
[docs]def irange(start_or_end, optional_end=None): """ :result: iterable that counts from start to end (both included). """ if optional_end is None: start, end = 0, start_or_end else: start, end = start_or_end, optional_end return take(max(end - start + 1, 0), itertools.count(start))
[docs]def arange(start, stop=None, step=1): """ range for floats or other types (`numpy.arange` without numpy) :param start: optional number. Start of interval. The interval includes this value. The default start value is 0. :param stop: number. End of interval. The interval does not include this value, except in some cases where step is not an integer and floating point round-off affects the length of out. :param step: optional number. Spacing between values. For any output out, this is the distance between two adjacent values, out[i+1] - out[i]. The default step size is 1. :result: iterator """ if stop is None: stop = start start = 0 r = start step = abs(step) if stop < start: while r > stop: yield r r -= step else: while r < stop: yield r r += step
[docs]def linspace(start, end, n=100): """ iterator over n values linearly interpolated between (and including) start and end `numpy.linspace` without numpy :param start: number, or iterable vector :param end: number, or iterable vector :param n: int number of interpolated values :result: iterator """ # try: #suppose start and end are tuples or lists of the same size if isiterable(start): res = (linspace(s, e, n) for s, e in zip(start, end)) return zip(*res) # like http://www.mathworks.com/help/matlab/ref/linspace.html # http://docs.scipy.org/doc/numpy/reference/generated/numpy.linspace.html has more options if start == end: # generate n times the same value for consistency return itertools.repeat(start, n) else: # make sure we generate n values including start and end step = float(end - start) / (n - 1) return arange(start, end + step / 2, step)
[docs]def flatten(it, donotrecursein=str): """iterator to flatten (depth-first) structure :param it: iterable structure :param donotrecursein: iterable types in which algo doesn't recurse string type by default """ # http://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists-in-python if isinstance(it, dict): it = it.values() for el in it: if not isinstance(el, collections.Iterable): yield el elif isinstance(el, donotrecursein): yield el else: for sub in flatten(el, donotrecursein): yield sub
[docs]def itemgetter(iterable, i): for item in iterable: yield item[i]
[docs]def recurse(f, x): while True: yield x x = f(x)
[docs]def swap(iterable): for x in iterable: yield reversed(list(x))
[docs]def tee(iterable, n=2, copy=None): """tee or copy depending on type and goal :param iterable: any iterable :param n: int number of tees/copies to return :param copy: optional copy function, for exemple copy.copy or copy.deepcopy :result: tee of iterable if it's an iterator or generator, or (deep)copies for other types this function is useful to avoid side effects at a lower memory cost depending on the case """ if isinstance(iterable, (list, tuple, set, dict, str)): if copy is None: # same object replicated n times res = [iterable] * n else: res = [copy(iterable) for _ in range(n)] res = tuple(res) else: res = itertools.tee(iterable, n) # make independent align_iterators if isinstance(iterable, keep): res = tuple(map(iterable.__class__, res)) return res
[docs]def groups(iterable, n, step=None): """Make groups of 'n' elements from the iterable advancing 'step' elements on each iteration""" itlist = tee(iterable, n=n, copy=None) onestepit = zip(*(itertools.starmap(drop, enumerate(itlist)))) return every(step or n, onestepit)
[docs]def pairwise(iterable, op=None, loop=False): """ iterates through consecutive pairs :param iterable: input iterable s1,s2,s3, .... sn :param op: optional operator to apply to each pair :param loop: boolean True if last pair should be (sn,s1) to close the loop :result: pairs iterator (s1,s2), (s2,s3) ... (si,si+1), ... (sn-1,sn) + optional pair to close the loop """ i = itertools.chain(iterable, [first(iterable)]) if loop else iterable for x in groups(i, 2, 1): if op: yield op(x[1], x[0]) # reversed ! (for sub or div) else: yield x[0], x[1]
[docs]def select(it1, it2, op): return (x[0] for x in zip(it1, it2) if op(*x))
[docs]def shape(iterable): """ shape of a mutidimensional array, without numpy :param iterable: iterable of iterable ... of iterable or numpy arrays... :result: list of n ints corresponding to iterable's len of each dimension :warning: if iterable is not a (hyper) rect matrix, shape is evaluated from the [0,0,...0] element ... :see: http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.ndarray.shape.html """ res = [] try: while True: res.append(ilen(iterable)) iterable = first(iterable) except TypeError: return res
[docs]def ndim(iterable): """ number of dimensions of a mutidimensional array, without numpy :param iterable: iterable of iterable ... of iterable or numpy arrays... :result: int number of dimensions """ return len(shape(iterable))
[docs]def reshape(data, dims): """ :result: data as a n-dim matrix """ data = list(flatten(data)) for d in dims[::-1]: # reversed if d: data = [data[i:i + d] for i in range(0, len(data), d)] else: data = [data] return data[0]
[docs]def compose(f, g): """Compose two functions -> compose(f, g)(x) -> f(g(x))""" def _wrapper(*args, **kwargs): return f(g(*args, **kwargs)) return _wrapper
[docs]def iterate(func, arg): """After Haskell's iterate: apply function repeatedly.""" # not functional while True: yield arg arg = func(arg)
[docs]def accumulate(iterable, func=operator.add, skip_first=False, modulo=0): """Return running totals. extends `python.accumulate` # accumulate([1,2,3,4,5]) --> 1 3 6 10 15 # accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120 """ first = True for x in iterable: if first: total = x first = False if skip_first: continue else: total = func(total, x) if modulo: total = total % modulo yield total
[docs]def record(iterable, it=itertools.count(), max=0): """return the index and value of iterable which exceed previous max""" for i, v in zip(it, iterable): if v > max: yield i, v max = v
[docs]def record_index(iterable, it=itertools.count(), max=0): return itemgetter(record(iterable, it, max), 0)
[docs]def record_value(iterable, it=itertools.count(), max=0): return itemgetter(record(iterable, it, max), 1)
[docs]def tails(seq): """Get tails of a sequence tails([1,2,3]) -> [1,2,3], [2,3], [3], []. """ for idx in range(len(seq) + 1): yield seq[idx:]
[docs]def ireduce(func, iterable, init=None): """Like `python.reduce` but using iterators (a.k.a scanl)""" # not functional if init is None: iterable = iter(iterable) curr = next(iterable) else: curr = init yield init for x in iterable: curr = func(curr, x) yield curr
[docs]def occurences(iterable): """ count number of occurences of each item in a finite iterable :param iterable: finite iterable :return: dict of int count indexed by item """ from sortedcontainers import SortedDict occur = SortedDict() for x in iterable: occur[x] = occur.get(x, 0) + 1 return occur
[docs]def compress(iterable, key=identity, buffer=None): """ generates (item,count) pairs by counting the number of consecutive items in iterable) :param iterable: iterable, possibly infinite :param key: optional function defining which elements are considered equal :param buffer: optional integer. if defined, iterable is sorted with this buffer """ key = key or identity if buffer: iterable = sorted_iterable(iterable, key, buffer) prev, count = None, 0 for item in iterable: if count and key(item) == key(prev): count += 1 else: if prev is not None: # to skip initial junk yield prev, count count = 1 prev = item if count: yield prev, count
[docs]def decompress(iterable): return flatten(itertools.chain((repeat(item, count) for (item, count) in iterable)))
[docs]def unique(iterable, key=None, buffer=100): """generate unique elements, preserving order. :param iterable: iterable, possibly infinite :param key: optional function defining which elements are considered equal :param buffer: optional integer defining how many of the last unique elements to keep in memory mandatory if iterable is infinite # unique('AAAABBBCCDAABBB') --> A B C D # unique('ABBCcAD', str.lower) --> A B C D """ return itemgetter(compress(iterable, key, buffer), 0)
[docs]def count_unique(iterable, key=None): """Count unique elements # count_unique('AAAABBBCCDAABBB') --> 4 # count_unique('ABBCcAD', str.lower) --> 4 """ seen = set() for element in iterable: seen.add(key(element) if key else element) return len(seen)
[docs]def combinations_with_replacement(iterable, r): """combinations_with_replacement('ABC', 2) --> AA AB AC BB BC CC same as combinations_with_replacement except it doesn't generate duplicates """ pool = tuple(iterable) n = len(pool) for indices in product(range(n), repeat=r): if sorted(indices) == list(indices): yield tuple(pool[i] for i in indices)
# my functions added
[docs]def takenth(n, iterable, default=None): """ :result: nth item of iterable """ # https://docs.python.org/2/library/html#recipes return next(itertools.islice(iterable, n, n + 1), default)
nth = takenth
[docs]def icross(*sequences): """Cartesian product of sequences (recursive version)""" # http://stackoverflow.com/questions/15099647/cross-product-of-sets-using-recursion if sequences: for x in sequences[0]: for y in icross(*sequences[1:]): yield (x,) + y else: yield ()
[docs]def quantify(iterable, pred=bool): """ :result: int count how many times the predicate is true """ return sum(map(pred, iterable), 0)
[docs]def interleave(l1, l2): """ :param l1: iterable :param l2: iterable of same length, or 1 less than l1 :result: iterable interleaving elements from l1 and l2, starting by l1[0] """ # http://stackoverflow.com/questions/7946798/interleaving-two-lists-in-python-2-2 res = l1 + l2 res[::2] = l1 res[1::2] = l2 return res
[docs]def shuffle(ary): """ :param: array to shuffle by Fisher-Yates algorithm :result: shuffled array (IN PLACE!) :see: http://www.drgoulu.com/2013/01/19/comment-bien-brasser-les-cartes/ """ for i in range(len(ary) - 1, 0, -1): j = random.randint(0, i) ary[i], ary[j] = ary[j], ary[i] return ary
[docs]def rand_seq(size): """ :result: range(size) shuffled """ return shuffle(list(range(size)))
[docs]def all_pairs(size): """generates all i,j pairs for i,j from 0-size""" for i in rand_seq(size): for j in rand_seq(size): yield (i, j)
[docs]def index_min(values, key=identity): """ :result: min_index, min_value """ return min(enumerates(values), key=lambda v: key(v[1]))
[docs]def index_max(values, key=identity): """ :result: max_index, max_value """ return max(enumerates(values), key=lambda v: key(v[1]))
[docs]def best(iterable, key=None, n=1, reverse=False): """ generate items corresponding to the n best values of key sort order""" v = sorted(iterable, key=key, reverse=reverse) if key is None: key = identity i, k = 0, None for x in v: k2 = key(x) if k2 == k: yield x else: k = k2 i += 1 if i > n: break # end yield x
[docs]def sort_indexes(iterable, key=identity, reverse=False): """ :return: iterator over indexes of iterable that correspond to the sorted iterable """ # http://stackoverflow.com/questions/6422700/how-to-get-indices-of-a-sorted-array-in-python return [i[0] for i in sorted(enumerate(iterable), key=lambda x:key(x[1]))]
# WARNING : filter2 has been renamed from "split" at v.1.7.0 for coherency
[docs]def filter2(iterable, condition): """ like `python.filter` but returns 2 lists : - list of elements in iterable that satisfy condition - list of those that don't """ yes, no = [], [] for x in iterable: if condition(x): yes.append(x) else: no.append(x) return yes, no
[docs]def ifind(iterable, f, reverse=False): """iterates through items in iterable where f(item) == True.""" if not reverse: for i, item in enumerate(iterable): if f(item): yield (i, item) else: s = len(iterable) - 1 for i, item in enumerate(reversed(iterable)): if f(item): yield (s - i, item)
[docs]def iremove(iterable, f): """ removes items from an iterable based on condition :param iterable: iterable . will be modified in place :param f: function of the form lambda line:bool returning True if item should be removed :yield: removed items backwards """ for i, _ in ifind(iterable, f, reverse=True): yield iterable.pop(i)
[docs]def removef(iterable, f): """ removes items from an iterable based on condition :param iterable: iterable . will be modified in place :param f: function of the form lambda line:bool returning True if item should be removed :result: list of removed items. """ res = list(iremove(iterable, f)) res.reverse() return res
[docs]def find(iterable, f): """Return first item in iterable where f(item) == True.""" return next(ifind(iterable, f))
[docs]def isplit(iterable, sep, include_sep=False): """ split iterable by separators or condition :param sep: value or function(item) returning True for items that separate :param include_sep: bool. If True the separators items are included in output, at beginning of each sub-iterator :result: iterates through slices before, between, and after separators """ indexes = (i for i, _ in ifind(iterable, sep)) indexes = itertools.chain( [0 if include_sep else -1], indexes, [None]) # will be the last j for i, j in pairwise(indexes): yield itertools.islice(iterable, i if include_sep else i + 1, j)
[docs]def split(iterable, sep, include_sep=False): """ like https://docs.python.org/2/library/stdtypes.html#str.split, but for iterable :param sep: value or function(item) returning True for items that separate :param include_sep: bool. If True the separators items are included in output, at beginning of each sub-iterator :result: list of iterable slices before, between, and after separators """ return [list(x) for x in isplit(iterable, sep, include_sep)]
[docs]def dictsplit(dic, keys): """ extract keys from dic :param dic: dict source :param keys: iterable of dict keys :result: dict,dict : the first contains entries present in source, the second the remaining entries """ yes, no = {}, dic.copy() for k in keys: if k in no: yes[k] = no.pop(k) return yes, no
[docs]def next_permutation(seq, pred=lambda x, y: -1 if x < y else 0): """Like C++ std::next_permutation() but implemented as generator. see http://blog.bjrn.se/2008/04/lexicographic-permutations-using.html :param seq: iterable :param pred: a function (a,b) that returns a negative number if a<b, like cmp(a,b) in Python 2.7 """ def reverse(seq, start, end): # seq = seq[:start] + reversed(seq[start:end]) + \ # seq[end:] end -= 1 if end <= start: return while True: seq[start], seq[end] = seq[end], seq[start] if start == end or start + 1 == end: return start += 1 end -= 1 if not seq: raise StopIteration try: seq[0] except TypeError: raise TypeError("seq must allow random access.") first = 0 last = len(seq) seq = seq[:] # Yield input sequence as the STL version is often # used inside do {} while. yield seq if last == 1: raise StopIteration while True: next = last - 1 while True: # Step 1. next1 = next next -= 1 if pred(seq[next], seq[next1]) < 0: # Step 2. mid = last - 1 while not (pred(seq[next], seq[mid]) < 0): mid -= 1 seq[next], seq[mid] = seq[mid], seq[next] # Step 3. reverse(seq, next1, last) # Change to yield references to get rid of # (at worst) |seq|! copy operations. yield seq[:] break if next == first: raise StopIteration raise StopIteration
[docs]class iter2(object): """Takes in an object that is iterable. http://code.activestate.com/recipes/578092-flattening-an-arbitrarily-deep-list-or-any-iterato/ Allows for the following method calls (that should be built into iterators anyway...) calls: - append - appends another iterable onto the iterator. - insert - only accepts inserting at the 0 place, inserts an iterable before other iterables. - adding. an iter2 object can be added to another object that is iterable. i.e. iter2 + iter (not iter + iter2). It's best to make all objects iter2 objects to avoid syntax errors. :D """
[docs] def __init__(self, iterable): self._iter = iter(iterable)
[docs] def append(self, iterable): self._iter = itertools.chain(self._iter, iter(iterable))
[docs] def insert(self, place, iterable): if place != 0: raise ValueError('Can only insert at index of 0') self._iter = itertools.chain(iter(iterable), self._iter)
[docs] def __add__(self, iterable): return itertools.chain(self._iter, iter(iterable))
[docs] def __next__(self): return next(self._iter)
next = __next__ # Python2-3 compatibility
[docs] def __iter__(self): return self
[docs]def subdict(d, keys): """extract "sub-dictionary" :param d: dict :param keys: container of keys to extract: :result: dict: :see: http://stackoverflow.com/questions/5352546/best-way-to-extract-subset-of-key-value-pairs-from-python-dictionary-object/5352649#5352649 """ return dict([(i, d[i]) for i in keys if i in d])
[docs]class SortingError(Exception): pass
[docs]def ensure_sorted(iterable, key=None): """ makes sure iterable is sorted according to key :yields: items of iterable :raise: SortingError if not """ key = key or identity prev, n = None, 0 for x in iterable: if prev is not None and key(x) < key(prev): raise SortingError("%d: %s < %s" % (n, x, prev)) prev = x yield x n += 1
[docs]def sorted_iterable(iterable, key=None, buffer=100): """sorts an "almost sorted" (infinite) iterable :param iterable: iterable :param key: function used as sort key :param buffer: int size of buffer. elements to swap should not be further than that """ key = key or identity from sortedcontainers import SortedListWithKey b = SortedListWithKey(key=key) for x in iterable: if buffer and len(b) >= buffer: res = b.pop(0) yield res b.add(x) for x in b: # this never happens if iterable is infinite yield x
# operations on sorted iterators
[docs]def diff(iterable1, iterable2): """generate items in sorted iterable1 that are not in sorted iterable2""" b = next(iterable2) for a in iterable1: while b < a: b = next(iterable2) if a == b: continue yield a
merge = heapq.merge
[docs]def intersect(*iterables): """ generates itersection of N iterables :param iterables: any number of SORTED iterables :yields: elements that belong to all iterables :see: http://stackoverflow.com/questions/969709/joining-a-set-of-ordered-integer-yielding-python-iterators """ for key, values in itertools.groupby(heapq.merge(*iterables)): if len(list(values)) == len(iterables): yield key
[docs]def product(*iterables, **kwargs): """ Cartesian product of (infinite) input iterables. :param iterables: any number of iterables :param repeat: integer optional number of repetitions :see: http://stackoverflow.com/questions/12093364/cartesian-product-of-large-iterators-itertools """ # https://github.com/enricobacis/infinite/blob/master/infinite/product.py # is not general enough def empty(): yield () if len(iterables) == 0: return empty() r = kwargs.get('repeat', 1) if r > 1: n = len(iterables) res = [] for i, it in enumerate(iterables): t = tee(it, r) res[i::n] = t iterables = res if len(iterables) == 1: return iterables[0] def gen2(it1, it2, concat): def _(x): if concat: try: return tuple(x) except TypeError: pass return (x,) it1, it1t = tee(it1) # do not touch it1, so we can reiterate it for n, x in enumerate(it1t): # may be infinite x = _(x) it2, it2t = tee(it2) # do not touch it2, so we can reiterate it for i, y in enumerate(it2t): y = _(y) if i <= n: yield x + y else: # do not touch it1, so we can reiterate it it1, it1tt = tee(it1) for z in take(n + 1, it1tt): yield _(z) + y break res = gen2(iterables[0], iterables[1], concat=False) for g in iterables[2:]: res = gen2(res, g, concat=True) return res
# cycle detection (Floyd "tortue hand hare" algorithm" # taken from https://codereview.stackexchange.com/questions/7847/tortoise-and-hare-cycle-detection-algorithm-using-iterators-in-python # http://ideone.com/fgrwM
[docs]class keep(collections.Iterator): """iterator that keeps the last value"""
[docs] def __init__(self, iterable): self.it = iter(iterable) self.stop = False self.val = next(self.it)
[docs] def __next__(self): if self.stop: raise StopIteration prev = self.val try: self.val = next(self.it) except StopIteration: self.stop = True return prev
next = __next__ # 2.7 compatibility
[docs]def first_match(iter1, iter2, limit=None): """" :param limit: int max number of loops :return: integer i first index where iter1[i]==iter2[i] """ for n, (i1, i2) in enumerate(zip(iter1, iter2)): logging.debug((i1, i2)) if i1 == i2: return n if limit and n > limit: break return None
[docs]def floyd(iterable, limit=1e6): """Detect a cycle in iterable using Floyd "tortue hand hare" algorithm :see: https://en.wikipedia.org/wiki/Cycle_detection#Floyd's_Tortoise_and_Hare :param iterable: iterable :param limit: int limit to prevent infinite loop. no limit if None :result: (i,l) tuple of integers where i=index of cycle start, l=length if no cycle is found, return (None,None) """ iterable, tortoise, hare = tee(iterable, 3) tortoise = keep(tortoise) hare = keep(takeevery(2, hare, 1)) # it will start from the first value and only then will be advancing 2 values at a time first_match(tortoise, hare, limit=limit) hare = tortoise # put hare in the place of tortoise tortoise = keep(iterable) # start tortoise from the very beginning i = first_match(tortoise, hare, limit=limit) if i is None: return (None, None) # begin with the current val of hare.val and the value of tortoise which is in the first position hare = tortoise tortoise_val = tortoise.val hare.next() j = first_match(itertools.repeat(tortoise_val), hare) return i, j + 1
[docs]def brent(iterable, limit=1e6): """Detect a cycle in iterable using Floyd "tortue hand hare" algorithm :see: https://en.wikipedia.org/wiki/Cycle_detection#Brent's_algorithm :param iterable: iterable :param limit: int limit to prevent infinite loop. no limit if None :result: (i,l) tuple of integers where i=index of cycle start, l=length if no cycle is found, return (None,None) """ # main phase: search successive powers of two power = lam = 1 import copy iterable, tortoise, hare = tee(keep(iterable), 3) next(hare) while tortoise.val != hare.val: if power == lam: # time to start a new power of two? tortoise, hare = tee(hare) power *= 2 lam = 0 hare.next() lam += 1 # Find the position of the first repetition of length λ mu = 0 iterable, tortoise, hare = tee(iterable, 3) for i in range(lam): hare.next() # The distance between the hare and tortoise is now lam # Next, the hare and tortoise move at same speed until they agree while tortoise.val != hare.val: tortoise.next() hare.next() mu += 1 return mu, lam
[docs]def detect_cycle(iterable, limit=1e6): try: return brent(iterable, limit) except StopIteration: return (None, None) # no cycle found