Source code for Goulib.math2

#!/usr/bin/env python
# coding: utf8
more math than :mod:`math` standard library, without numpy

from __future__ import division #"true division" everywhere

__author__ = "Philippe Guglielmetti"
__copyright__ = "Copyright 2012, Philippe Guglielmetti"
__credits__ = [
__license__ = "LGPL"

import six, logging
from six.moves import map, reduce, filter, zip_longest, range

import math, cmath, operator, itertools, fractions, numbers, random

from Goulib import itertools2, decorators

inf=float('Inf') #infinity
eps = 2.2204460492503131e-16 # numpy.finfo(np.float64).eps

    nan=float('nan') #Not a Number

# define some math functions that are not available in all supported versions of python

except AttributeError:
[docs] def cmp(x,y): '''Compare the two objects x and y and return an integer according to the outcome. The return value is negative if x < y, zero if x == y and strictly positive if x > y. ''' return sign(x-y)
try: log2=math.log2 except AttributeError: def log2(x):return math.log(x,2) try: isclose=math.isclose except AttributeError: def isclose(a, b, rel_tol=1e-09, abs_tol=0.0): '''approximately equal. Use this instead of a==b in floating point ops implements :param a,b: the two values to be tested to relative closeness :param rel_tol: relative tolerance it is the amount of error allowed, relative to the larger absolute value of a or b. For example, to set a tolerance of 5%, pass tol=0.05. The default tolerance is 1e-9, which assures that the two values are the same within about 9 decimal digits. rel_tol must be greater than 0.0 :param abs_tol: minimum absolute tolerance level -- useful for comparisons near zero. ''' # if a == b: # short-circuit exact equality return True if rel_tol < 0.0 or abs_tol < 0.0: raise ValueError('error tolerances must be non-negative') # use cmath so it will work with complex ot float if math.isinf(abs(a)) or math.isinf(abs(b)): # This includes the case of two infinities of opposite sign, or # one infinity and one finite number. Two infinities of opposite sign # would otherwise have an infinite relative tolerance. return False diff = abs(b - a) return (((diff <= abs(rel_tol * b)) or (diff <= abs(rel_tol * a))) or (diff <= abs_tol))
[docs]def allclose(a, b, rel_tol=1e-09, abs_tol=0.0): ''':return: True if two arrays are element-wise equal within a tolerance.''' # for x,y in zip_longest(a,b): if x is None or y is None: return False if not isclose(x ,y , rel_tol=rel_tol, abs_tol=abs_tol): return False return True
# basic useful functions
[docs]def is_number(x): ''':return: True if x is a number of any type, including Complex''' # return isinstance(x, numbers.Number)
[docs]def is_complex(x): return isinstance(x,complex)
[docs]def is_real(x): return is_number(x) and not is_complex(x)
[docs]def sign(number): ''':return: 1 if number is positive, -1 if negative, 0 if ==0''' if number<0: return -1 if number>0: return 1 return 0
# rounding
[docs]def rint(v): ''' :return: int value nearest to float v ''' return int(round(v))
[docs]def is_integer(x, rel_tol=0, abs_tol=0): ''' :return: True if float x is an integer within tolerances ''' if isinstance(x, six.integer_types): return True try: if rel_tol+abs_tol==0: return x==rint(x) return isclose(x,round(x),rel_tol=rel_tol,abs_tol=abs_tol) except TypeError: # for complex return False
[docs]def int_or_float(x, rel_tol=0, abs_tol=0): ''' :param x: int or float :return: int if x is (almost) an integer, otherwise float ''' return rint(x) if is_integer(x, rel_tol, abs_tol) else x
[docs]def format(x, decimals=3): ''' formats a float with given number of decimals, but not an int :return: string repr of x with decimals if not int ''' if is_integer(x): decimals = 0 return '{0:.{1}f}'.format(x, decimals)
# improved versions of math functions
[docs]def gcd(*args): '''greatest common divisor of an arbitrary number of args''' # L = list(args) #in case args are generated b=L[0] # will be returned if only one arg while len(L) > 1: a = L[-2] b = L[-1] L = L[:-2] while a: a, b = b%a, a L.append(b) return abs(b)
[docs]def lcm(*args): '''least common multiple of any number of integers''' if len(args)<=2: return mul(args) // gcd(*args) # TODO : better res=lcm(args[0],args[1]) for n in args[2:]: res=lcm(res,n) return res
[docs]def xgcd(a,b): '''Extended GCD :return: (gcd, x, y) where gcd is the greatest common divisor of a and b with the sign of b if b is nonzero, and with the sign of a if b is 0. The numbers x,y are such that gcd = ax+by.''' #taken from prevx, x = 1, 0; prevy, y = 0, 1 while b: q, r = divmod(a,b) x, prevx = prevx - q*x, x y, prevy = prevy - q*y, y a, b = b, r return a, prevx, prevy
[docs]def coprime(*args): ''':return: True if args are coprime to each other''' return gcd(*args)==1
[docs]def coprimes_gen(limit): '''generates coprime pairs using Farey sequence ''' # pend = [] n,d = 0,1 # n, d is the start fraction n/d (0,1) initially N = D = 1 # N, D is the stop fraction N/D (1,1) initially while True: mediant_d = d + D if mediant_d <= limit: mediant_n = n + N pend.append((mediant_n, mediant_d, N, D)) N = mediant_n D = mediant_d else: yield n, d #numerator / denominator if pend: n, d, N, D = pend.pop() else: break
[docs]def carmichael(n): ''' Carmichael function :return : int smallest positive integer m such that a^m mod n = 1 for every integer a between 1 and n that is coprime to n. :param n: int :see: :see: also known as the reduced totient function or the least universal exponent function. ''' coprimes = [x for x in range(1, n) if gcd(x, n) == 1] k = 1 while not all(pow(x, k, n) == 1 for x in coprimes): k += 1 return k
# #code decomposed from
[docs]def is_primitive_root(x,m,s={}): '''returns True if x is a primitive root of m :param s: set of coprimes to m, if already known ''' if not s: s={n for n in range(1, m) if coprime(n, m) } return {pow(x, p, m) for p in range(1, m)}==s
[docs]def primitive_root_gen(m): '''generate primitive roots modulo m''' required_set = {num for num in range(1, m) if coprime(num, m) } for n in range(1, m): if is_primitive_root(n,m,required_set): yield n
[docs]def primitive_roots(modulo): return list(primitive_root_gen(modulo))
[docs]def quad(a, b, c, allow_complex=False): ''' solves quadratic equations aX^2+bX+c=0 :param a,b,c: floats :param allow_complex: function returns complex roots if True :return: x1,x2 real or complex solutions ''' discriminant = b*b - 4 *a*c if allow_complex: d=cmath.sqrt(discriminant) else: d=math.sqrt(discriminant) return (-b + d) / (2*a), (-b - d) / (2*a)
[docs]def ceildiv(a, b): return -(-a // b) #simple and clever
[docs]def ipow(x,y,z=0): ''' :param x: number (int or float) :param y: int power :param z: int optional modulus :return: (x**y) % z as integer if possible ''' if y<0 : if z: raise NotImplementedError('no modulus allowed for negative power') else: return 1/ipow(x,-y) a,b=1,x while y>0: if y%2 == 1: a=(a*b)%z if z else a*b b = (b*b)%z if z else b*b y=y//2 return a
[docs]def pow(x,y,z=0): ''' :return: (x**y) % z as integer ''' if not isinstance(y,six.integer_types): if z==0: return six.builtins.pow(x,y) #switches to floats in Py3... else: return six.builtins.pow(x,y,z) #switches to floats in Py3... else: return ipow(x,y,z)
[docs]def sqrt(n): '''square root :return: int, float or complex depending on n ''' if type(n) is int: s=isqrt(n) if s*s==n: return s if n<0: return cmath.sqrt(n) return math.sqrt(n)
[docs]def isqrt(n): '''integer square root :return: largest int x for which x * x <= n ''' # # n=int(n) x = n y = (x + 1) // 2 while y < x: x = y y = (x + n // x) // 2 return x
[docs]def icbrt(n): '''integer cubic root :return: largest int x for which x * x * x <= n ''' # if n <= 0: return 0 x = int(n ** (1. / 3.) * (1 + 1e-12)) while True: y = (2 * x + n // (x * x)) // 3 if y >= x: return x x = y
[docs]def is_square(n): s=isqrt(n) return s*s==n
[docs]def introot(n, r=2): ''' integer r-th root :return: int, greatest integer less than or equal to the r-th root of n. For negative n, returns the least integer greater than or equal to the r-th root of n, or None if r is even. ''' # copied from if n < 0: return None if r%2 == 0 else -introot(-n, r) if n < 2: return n if r == 2: return isqrt(n) lower, upper = 0, n while lower != upper - 1: mid = (lower + upper) // 2 m = mid**r if m == n: return mid elif m < n: lower = mid elif m > n: upper = mid return lower
[docs]def is_power(n): ''' :return: integer that, when squared/cubed/etc, yields n, or 0 if no such integer exists. Note that the power to which this number is raised will be prime.''' # copied from for p in primes_gen(): r = introot(n, p) if r is None: continue if r ** p == n: return r if r == 1: return 0
[docs]def multiply(x, y): ''' Karatsuba fast multiplication algorithm Copyright (c) 2014 Project Nayuki ''' _CUTOFF = 1536 #_CUTOFF >= 64, or else there will be infinite recursion. if x.bit_length() <= _CUTOFF or y.bit_length() <= _CUTOFF: # Base case return x * y else: n = max(x.bit_length(), y.bit_length()) half = (n + 32) // 64 * 32 mask = (1 << half) - 1 xlow = x & mask ylow = y & mask xhigh = x >> half yhigh = y >> half a = multiply(xhigh, yhigh) b = multiply(xlow + xhigh, ylow + yhigh) c = multiply(xlow, ylow) d = b - a - c return (((a << half) + d) << half) + c
#vector operations
[docs]def accsum(it): '''Yield accumulated sums of iterable: accsum(count(1)) -> 1,3,6,10,...''' return itertools2.drop(1, itertools2.ireduce(operator.add, it, 0))
cumsum=accsum #numpy alias
[docs]def mul(nums,init=1): ''' :return: Product of nums ''' return reduce(operator.mul, nums, init)
[docs]def dot_vv(a,b,default=0): '''dot product for vectors :param a: vector (iterable) :param b: vector (iterable) :param default: default value of the multiplication operator ''' return sum(map( operator.mul, a, b),default)
[docs]def dot_mv(a,b,default=0): '''dot product for vectors :param a: matrix (iterable or iterables) :param b: vector (iterable) :param default: default value of the multiplication operator ''' return [dot_vv(line,b,default) for line in a]
[docs]def dot_mm(a,b,default=0): '''dot product for matrices :param a: matrix (iterable or iterables) :param b: matrix (iterable or iterables) :param default: default value of the multiplication operator ''' return transpose([dot_mv(a,col) for col in zip(*b)])
[docs]def dot(a,b,default=0): '''dot product general but slow : use dot_vv, dot_mv or dot_mm if you know a and b's dimensions ''' if itertools2.ndim(a)==2: # matrix if itertools2.ndim(b)==2: # matrix*matrix return dot_mm(a,b,default) else: # matrix*vector return dot_mv(a,b,default) else: #vector*vector return dot_vv(a,b,default)
# some basic matrix ops
[docs]def zeros(shape): ''' :see: ''' return ([0]*shape[1])*shape[0]
[docs]def diag(v): ''' Create a two-dimensional array with the flattened input as a diagonal. :param v: If v is a 2-D array, return a copy of its diagonal. If v is a 1-D array, return a 2-D array with v on the diagonal :see: ''' s=len(v) if itertools2.ndim(v)==2: return [v[i][i] for i in range(s)] res=[] for i,x in enumerate(v): line=[x]+[0]*(s-1) line=line[-i:]+line[:-i] res.append(line) return res
[docs]def identity(n): return diag([1]*n)
eye=identity # alias for now
[docs]def transpose(m): ''' :return: matrix m transposed ''' # ensures the result is a list of lists return list(map(list,list(zip(*m))))
[docs]def maximum(m): ''' Compare N arrays and returns a new array containing the element-wise maxima :param m: list of arrays (matrix) :return: list of maximal values found in each column of m :see: ''' return [max(c) for c in transpose(m)]
[docs]def minimum(m): ''' Compare N arrays and returns a new array containing the element-wise minima :param m: list of arrays (matrix) :return: list of minimal values found in each column of m :see: ''' return [min(c) for c in transpose(m)]
[docs]def vecadd(a,b,fillvalue=0): '''addition of vectors of inequal lengths''' return [l[0]+l[1] for l in zip_longest(a,b,fillvalue=fillvalue)]
[docs]def vecsub(a,b,fillvalue=0): '''substraction of vectors of inequal lengths''' return [l[0]-l[1] for l in zip_longest(a,b,fillvalue=fillvalue)]
[docs]def vecneg(a): '''unary negation''' return list(map(operator.neg,a))
[docs]def vecmul(a,b): '''product of vectors of inequal lengths''' if isinstance(a,(int,float)): return [x*a for x in b] if isinstance(b,(int,float)): return [x*b for x in a] return [reduce(operator.mul,l) for l in zip(a,b)]
[docs]def vecdiv(a,b): '''quotient of vectors of inequal lengths''' if isinstance(b,(int,float)): return [float(x)/b for x in a] return [reduce(operator.truediv,l) for l in zip(a,b)]
[docs]def veccompare(a,b): '''compare values in 2 lists. returns triple number of pairs where [a<b, a==b, a>b]''' res=[0,0,0] for ai,bi in zip(a,b): if ai<bi: res[0]+=1 elif ai==bi: res[1]+=1 else: res[2]+=1 return res
[docs]def sat(x,low=0,high=None): ''' saturates x between low and high ''' if isinstance(x,(int,float)): if low is not None: x=max(x,low) if high is not None: x=min(x,high) return x return [sat(_,low,high) for _ in x]
#norms and distances
[docs]def norm_2(v): ''' :return: "normal" euclidian norm of vector v ''' return sqrt(sum(x*x for x in v))
[docs]def norm_1(v): ''' :return: "manhattan" norm of vector v ''' return sum(abs(x) for x in v)
[docs]def norm_inf(v): ''' :return: infinite norm of vector v ''' return max(abs(x) for x in v)
[docs]def norm(v,order=2): ''' :see: ''' return sum(abs(x)**order for x in v)**(1./order)
[docs]def dist(a,b,norm=norm_2): return norm(vecsub(a,b))
[docs]def vecunit(v,norm=norm_2): ''' :return: vector normalized ''' return vecdiv(v,norm(v))
[docs]def hamming(s1, s2): '''Calculate the Hamming distance between two iterables''' return sum(c1 != c2 for c1, c2 in zip(s1, s2))
[docs]def sets_dist(a,b): ''' :see: ''' c = a.intersection(b) return sqrt(len(a-c)*2 + len(b-c)*2)
[docs]def sets_levenshtein(a,b): '''levenshtein distance on sets :see: ''' c = a.intersection(b) return len(a-c)+len(b-c)
[docs]def levenshtein(seq1, seq2): '''levenshtein distance :return: distance between 2 iterables :see: ''' # oneago = None thisrow = list(range(1, len(seq2) + 1)) + [0] for x in range(len(seq1)): _, oneago, thisrow = oneago, thisrow, [0] * len(seq2) + [x + 1] for y in range(len(seq2)): delcost = oneago[y] + 1 addcost = thisrow[y - 1] + 1 subcost = oneago[y - 1] + (seq1[x] != seq2[y]) thisrow[y] = min(delcost, addcost, subcost) return thisrow[len(seq2) - 1]
#stats #moved to # numbers functions # mostly from
[docs]def recurrence(signature, values, cst=0, max=None, mod=0): '''general generator for recurrences :param signature: factors defining the recurrence :param values: list of initial values ''' values=list(values) # to allow tuples or iterators factors=list(reversed(signature)) for n in values: if mod: n=n%mod yield n values=values[-len(signature):] while True: n=dot_vv(factors,values) if max and n>max: break n=n+cst if mod: n=n%mod yield n values=values[1:] values.append(n)
[docs]def fibonacci_gen(max=None,mod=0): '''Generate fibonacci serie''' return recurrence([1,1],[0,1],max=max,mod=mod)
[docs]def fibonacci(n,mod=0): ''' fibonacci series n-th element :param n: int can be extremely high, like 1e19 ! :param mod: int optional modulo ''' if n < 0: raise ValueError("Negative arguments not implemented") # #uses return mod_matpow([[1,1],[1,0]],n,mod)[0][1]
[docs]def is_fibonacci(n): '''returns True if n is in Fibonacci series''' # return is_square(5*n*n + 4) or is_square(5*n*n - 4)
[docs]def pisano_cycle(mod): if mod<2: return [0] seq=[0,1] l=len(seq) s=[] for i,n in enumerate(fibonacci_gen(mod=mod)): s.append(n) if i>l and s[-l:]==seq: return s[:-l]
[docs]def pisano_period(mod): if mod<2: return 1 flag=False # 0 was found for i,n in enumerate(fibonacci_gen(mod=mod)): if not flag: flag=n==0 elif i>3: if n==1: return i-1 flag=False
[docs]def collatz(n): if n % 2 == 0: return n // 2 else: return 3 * n + 1
[docs]def collatz_gen(n=0): yield n while True: n=collatz(n) yield n
[docs]@decorators.memoize def collatz_period(n): if n==1: return 1 return 1+collatz_period(collatz(n))
[docs]def pascal_gen(): '''Pascal's triangle read by rows: C(n,k) = binomial(n,k) = n!/(k!*(n-k)!), 0<=k<=n. ''' __author__ = 'Nick Hobson <>' # code from with additional related functions for row in itertools.count(): x = 1 yield x for m in range(row): x = (x * (row - m)) // (m + 1) yield x
[docs]def catalan(n): '''Catalan numbers: C(n) = binomial(2n,n)/(n+1) = (2n)!/(n!(n+1)!). ''' return binomial(2*n,n)//(n+1) #result is always int
[docs]def catalan_gen(): '''Generate Catalan numbers: C(n) = binomial(2n,n)/(n+1) = (2n)!/(n!(n+1)!). Also called Segner numbers. ''' yield 1 last=1 yield last for n in itertools.count(1): last=last*(4*n+2)//(n+2) yield last
[docs]def is_pythagorean_triple(a,b,c): return a*a+b*b == c*c
[docs]def primitive_triples(): ''' generates primitive Pythagorean triplets x<y<z sorted by hypotenuse z, then longest side y through Berggren's matrices and breadth first traversal of ternary tree :see: ''' key=lambda x:(x[2],x[1]) from sortedcontainers import SortedListWithKey triples=SortedListWithKey(key=key) triples.add([3,4,5]) A = [[ 1,-2, 2], [ 2,-1, 2], [ 2,-2, 3]] B = [[ 1, 2, 2], [ 2, 1, 2], [ 2, 2, 3]] C = [[-1, 2, 2], [-2, 1, 2], [-2, 2, 3]] while triples: (a,b,c) = triples.pop(0) yield (a,b,c) # expand this triple to 3 new triples using Berggren's matrices for X in [A,B,C]: triple=[sum(x*y for (x,y) in zip([a,b,c],X[i])) for i in range(3)] if triple[0]>triple[1]: # ensure x<y<z triple[0],triple[1]=triple[1],triple[0] triples.add(triple)
[docs]def triples(): ''' generates all Pythagorean triplets triplets x<y<z sorted by hypotenuse z, then longest side y ''' prim=[] #list of primitive triples up to now key=lambda x:(x[2],x[1]) from sortedcontainers import SortedListWithKey samez=SortedListWithKey(key=key) # temp triplets with same z buffer=SortedListWithKey(key=key) # temp for triplets with smaller z for pt in primitive_triples(): z=pt[2] if samez and z!=samez[0][2]: #flush samez while samez: yield samez.pop(0) samez.add(pt) #build buffer of smaller multiples of the primitives already found for i,pm in enumerate(prim): p,m=pm[0:2] while True: mz=m*p[2] if mz < z: buffer.add(tuple(m*x for x in p)) elif mz == z: # we need another buffer because next pt might have # the same z as the previous one, but a smaller y than # a multiple of a previous pt ... samez.add(tuple(m*x for x in p)) else: break m+=1 prim[i][1]=m #update multiplier for next loops while buffer: #flush buffer yield buffer.pop(0) prim.append([pt,2]) #add primitive to the list
[docs]def divisors(n): ''' :param n: int :return: all divisors of n: divisors(12) -> 1,2,3,6,12 including 1 and n, except for 1 which returns a single 1 to avoid messing with sum of divisors... ''' if n==1: yield 1 else: all_factors = [[f**p for p in itertools2.irange(0,fp)] for (f, fp) in factorize(n)] # do not use itertools2.product here as long as the order of the result differs for ns in itertools.product(*all_factors): yield mul(ns)
[docs]def proper_divisors(n): ''':return: all divisors of n except n itself.''' return (divisor for divisor in divisors(n) if divisor != n)
_sieve=list() # array of bool indicating primality
[docs]def sieve(n, oneisprime=False): '''prime numbers from 2 to a prime < n Very fast (n<10,000,000) in 0.4 sec. Example: >>>prime_sieve(25) [2, 3, 5, 7, 11, 13, 17, 19, 23] Algorithm & Python source: Robert William Hanks ''' n=int(n) # to tolerate n=1E9, which is float if n<2: return [] if n==2: return [1] if oneisprime else [] global _sieve if n>len(_sieve): #recompute the sieve #TODO: enlarge the sieve... # _sieve = [False,False,True,True]+[False,True] * ((n-4)//2) #assert(len(_sieve)==n) for i in range(3,int(n**0.5)+1,2): if _sieve[i]: _sieve[i*i::2*i]=[False]*int((n-i*i-1)/(2*i)+1) return ([1,2] if oneisprime else [2]) + [i for i in range(3,n,2) if _sieve[i]]
_primes=sieve(1000) # primes up to 1000 _primes_set = set(_primes) # to speed us primality tests below
[docs]def primes(n): '''memoized list of n first primes :warning: do not call with large n, use prime_gen instead ''' m=n-len(_primes) if m>0: more=list(itertools2.take(m,primes_gen(_primes[-1]+1))) _primes.extend(more) _primes_set.union(set(more)) return _primes[:n]
[docs]def is_prime_euler(n,eb=(2,)): '''Euler's primality test :param n: int number to test :param eb: test basis :return: False if not prime, True if prime, but also for many pseudoprimes... :see: ''' # for b in eb: if b >= n: continue if not pow(b, n-1, n) == 1: return False r = n - 1 while r%2 == 0: r //= 2 c = pow(b, r, n) if c == 1: continue while c != 1 and c != n-1: c = pow(c, 2, n) if c == 1: return False return True # according to Euler, but there are
[docs]def is_prime(n, oneisprime=False, tb=(3,5,7,11), eb=(2,), mrb=None): '''main primality test. :param n: int number to test :param oneisprime: bool True if 1 should be considered prime (it was, a long time ago) :param tb: trial division basis :param eb: Euler's test basis :param mrb: Miller-Rabin basis, automatic if None :see: It’s an implementation of the BPSW test (Baillie-Pomerance-Selfridge-Wagstaff) with some prefiltes for speed and is deterministic for all numbers less than 2^64 Iin fact, while infinitely many false positives are conjectured to exist, no false positives are currently known. The prefilters consist of trial division against 2 and the elements of the tuple tb, checking whether n is square, and Euler’s primality test to the bases in the tuple eb. If the number is less than 3825123056546413051, we use the Miller-Rabin test on a set of bases for which the test is known to be deterministic over this range. ''' # if n <= 0: return False if n == 1: return oneisprime if n<len(_sieve): return _sieve[n] if n in _primes_set: return True if any(n%p == 0 for p in tb): return False if is_square(n): return False # it's quick ... if not is_prime_euler(n): return False # Euler's test s, d = pfactor(n) if not sprp(n, 2, s, d): return False if n < 2047: return True # BPSW has two phases: SPRP with base 2 and SLPRP. # We just did the SPRP; now we do the SLPRP if n >= 3825123056546413051: d = 5 while True: if gcd(d, n) > 1: p, q = 0, 0 break if jacobi(d, n) == -1: p, q = 1, (1 - d) // 4 break d = -d - 2*d//abs(d) if p == 0: return n == d s, t = pfactor(n + 2) u, v, u2, v2, m = 1, p, 1, p, t//2 k = q while m > 0: u2, v2, q = (u2*v2)%n, (v2*v2-2*q)%n, (q*q)%n if m%2 == 1: u, v = u2*v+u*v2, v2*v+u2*u*d if u%2 == 1: u += n if v%2 == 1: v += n u, v, k = (u//2)%n, (v//2)%n, (q*k)%n m //= 2 if (u == 0) or (v == 0): return True for i in range(1, s): v, k = (v*v-2*k)%n, (k*k)%n if v == 0: return True return False # Miller-Rabin if not mrb: if n < 1373653: mrb = [3] elif n < 25326001: mrb = [3,5] elif n < 3215031751: mrb = [3,5,7] elif n < 2152302898747: mrb = [3,5,7,11] elif n < 3474749660383: mrb = [3,5,6,11,13] elif n < 341550071728321: mrb = [3,5,7,11,13,17] # This number is also a false positive for primes(19+1). elif n < 3825123056546413051: mrb = [3,5,7,11,13,17,19,23] # Also a false positive for primes(31+1). return all(sprp(n, b, s, d) for b in mrb)
[docs]def nextprime(n): '''Determines, with some semblance of efficiency, the least prime number strictly greater than n.''' # from if n < 2: return 2 if n == 2: return 3 n = (n + 1) | 1 # first odd larger than n m = n % 6 if m == 3: if is_prime(n+2): return n+2 n += 4 elif m == 5: if is_prime(n): return n n += 2 for m in itertools.count(n, 6): if is_prime(m ): return m if is_prime(m+4): return m+4
[docs]def prevprime(n): '''Determines, very inefficiently, the largest prime number strictly smaller than n.''' if n<3: return None if n==3: return 2 n = n | 1 # n if it is odd, or n+1 if it is even while True: n-=2 if is_prime(n): return n
[docs]def primes_gen(start=2,stop=None): '''generate prime numbers from start''' if start==1: yield 1 #if we asked for it explicitly start=2 if stop is None or stop>start: n=start-1 # to include start if it is prime while True: n=nextprime(n) if (stop is None) or n<=stop: yield n else: break else: # backwards n=start+1 while True: n=prevprime(n) if n and n>=stop: yield n else: break
[docs]def random_prime(bits): '''returns a random number of the specified bit length''' import random while True: n = random.getrandbits(bits-1)+2**(bits-1); n=nextprime(n-1) if n<2**bits: return n
[docs]def euclid_gen(): '''generates Euclid numbers: 1 + product of the first n primes''' n = 1 for p in primes_gen(1): n = n * p yield n+1
[docs]def prime_factors(num, start=2): '''generates all prime factors (ordered) of num''' for p in primes_gen(start): if num==1: break if is_prime(num): #because it's fast yield num break if p>num: break while num % p==0: yield p num=num//p
[docs]def lpf(n): '''greatest prime factor''' if n<4 : return n return itertools2.first(prime_factors(n))
[docs]def gpf(n): '''greatest prime factor''' if n<4 : return n return itertools2.last(prime_factors(n))
[docs]def prime_divisors(num, start=2): '''generates unique prime divisors (ordered) of num''' return itertools2.unique(prime_factors(num,start))
[docs]def is_multiple(n,factors) : '''return True if n has ONLY factors as prime factors''' return not set(prime_divisors(n))-set(factors)
[docs]def factorize(n): '''find the prime factors of n along with their frequencies. Example: >>> factor(786456) [(2,3), (3,3), (11,1), (331,1)] ''' if n==1: #allows to make many things quite simpler... return [(1,1)] return itertools2.compress(prime_factors(n))
[docs]def factors(n): for (p,e) in factorize(n): yield p**e
[docs]def number_of_divisors(n): # res=1 if n>1: for (p,e) in factorize(n): res=res*(e+1) return res
[docs]def omega(n): '''Number of distinct primes dividing n''' return itertools2.count_unique(prime_factors(n))
[docs]def bigomega(n): '''Number of prime divisors of n counted with multiplicity''' return itertools2.ilen(prime_factors(n))
[docs]def moebius(n): '''Möbius (or Moebius) function mu(n). mu(1) = 1; mu(n) = (-1)^k if n is the product of k different primes; otherwise mu(n) = 0. ''' if n==1: return 1 res=1 for p,q in factorize(n): if q>1: return 0 res=-res return res
[docs]def euler_phi(n): '''Euler totient function :see: ''' if n<=1: return n return int(mul((1 - 1.0 / p for p, _ in factorize(n)),n))
totient=euler_phi #alias. totient is available in sympy
[docs]def kempner(n): '''"Kempner function, also called Smarandache function :return: int smallest positive integer m such that n divides m!. :param n: int :see: :see: ''' if n==1: return 1 if is_prime(n) : return n @decorators.memoize def _np(n,p): #n^p . use k=0 while p > n: k += n p -= n + 1 t=k while t%n!=0: t=t//n p-=1 p=max(0,p) return (k + p) * n; return max(_np(f,p) for f,p in factorize(n))
[docs]def prime_ktuple(constellation): ''' generates tuples of primes with specified differences :param constellation: iterable of int differences betwwen primes to return :note: negative int means the difference must NOT be prime :see: (0, 2) twin primes (0, 4) cousin primes (0, 6) sexy primes (0, 2, 6), (0, 4, 6) prime triplets (0, 6, 12, -18) sexy prime triplets (0, 2, 6, 8) prime quadruplets (0, 6, 12, 18) sexy prime quadruplets (0, 2, 6, 8, 12), (0, 4, 6, 10, 12) quintuplet primes (0, 4, 6, 10, 12, 16) sextuplet primes ''' diffs=constellation[1:] for p in primes_gen(): res=[p] for d in diffs: if is_prime(p+abs(d)) == (d<0): res=None break res.append(p+d) if res: yield tuple(res)
[docs]def twin_primes(): return prime_ktuple((0, 2))
[docs]def cousin_primes(): return prime_ktuple((0, 4))
[docs]def sexy_primes(): return prime_ktuple((0, 6))
[docs]def sexy_prime_triplets(): return prime_ktuple((0, 6, 12, -18)) #exclude quatruplets
[docs]def sexy_prime_quadruplets(): return prime_ktuple((0, 6, 12, 18))
[docs]def lucas_lehmer (p): '''Lucas Lehmer primality test for Mersenne exponent p :param p: int :return: True if 2^p-1 is prime ''' # if p == 2: return True elif not is_prime(p): return False else: m_p = (1 << p) - 1 # 2^p-1 s = 4 for i in range(3, p + 1): s = (s*s - 2) % m_p return s == 0
# digits manipulation
[docs]def digits_gen(num, base=10): '''generates int digits of num in base BACKWARDS''' if num == 0: yield 0 while num: num,rem=divmod(num,base) yield rem
[docs]def digits(num, base=10, rev=False): ''' :return: list of digits of num expressed in base, optionally reversed ''' res=list(digits_gen(num,base)) if not rev: res.reverse() return res
[docs]def digsum(num, f=None, base=10): '''sum of digits :param num: number :param f: int power or function applied to each digit :param base: optional base :return: sum of f(digits) of num digsum(num) -> sum of digits digsum(num,base=2) -> number of 1 bits in binary represenation of num digsum(num,2) -> sum of the squares of digits digsum(num,f=lambda x:x**x) -> sum of the digits elevaed to their own power ''' d=digits_gen(num,base) if f is None: return sum(d) if is_number(f): p=f f=lambda x:pow(x,p) try: return sum(map(f,d)) except AttributeError: pass d=[f(x,i) for i,x in enumerate(d)] return sum(d)
[docs]def integer_exponent(a,b=10): ''' :returns: int highest power of b that divides a. :see: ''' res=0 for d in digits_gen(a, b): if d>0 : break res+=1 return res
trailing_zeros= integer_exponent
[docs]def power_tower(v): ''' :return: v[0]**v[1]**v[2] ... :see: ''' return reduce(lambda x,y:y**x, reversed(v))
[docs]def carries(a,b,base=10,pos=0): ''' :return: int number of carries required to add a+b in base ''' carry, answer = 0, 0 # we have no carry terms so far, and we haven't carried anything yet for one,two in zip_longest(digits_gen(a,base), digits_gen(b,base), fillvalue=0): carry = (one+two+carry)//base answer += carry>0 # increment the number of carry terms, if we will carry again return answer
[docs]def powertrain(n): ''' :return: v[0]**v[1]*v[2]**v[3] ...**(v[-1] or 0) :author: # Chai Wah Wu, Jun 16 2017 :see: ''' s = str(n) l = len(s) m = int(s[-1]) if l % 2 else 1 for i in range(0, l-1, 2): m *= int(s[i])**int(s[i+1]) return m
[docs]def str_base(num, base=10, numerals = '0123456789abcdefghijklmnopqrstuvwxyz'): ''' :return: string representation of num in base :param num: int number (decimal) :param base: int base, 10 by default :param numerals: string with all chars representing numbers in base base. chars after the base-th are ignored ''' if base==10 and numerals[:10]=='0123456789': return str(num) if base==2 and numerals[:2]=='01': return "{0:b}".format(int(num)) if base < 2 or base > len(numerals): raise ValueError("str_base: base must be between 2 and %d" % len(numerals)) if num < 0: sign = '-' num = -num else: sign = '' result = '' for d in digits_gen(num,base): result = numerals[d] + result return sign + result
[docs]def int_base(num, base): ''' :return: int representation of num in base :param num: int number (decimal) :param base: int base, <= 10 ''' return int(str_base(num,base))
[docs]def num_from_digits(digits, base=10): ''' :param digits: string or list of digits representing a number in given base :param base: int base, 10 by default :return: int number ''' if isinstance(digits,six.string_types): return int(digits,base) res,f=0,1 for x in reversed(list(digits)): res=res+int(x*f) f=f*base return res
[docs]def reverse(i): return int(str(i)[::-1])
[docs]def is_palindromic(num, base=10): '''Check if 'num' in base 'base' is a palindrome, that's it, if it can be read equally from left to right and right to left.''' if base==10: return num==reverse(num) digitslst = list(digits_gen(num, base)) return digitslst == list(reversed(digitslst))
[docs]def is_anagram(num1, num2, base=10): '''Check if 'num1' and 'num2' have the same digits in base''' digits1=sorted(str_base(num1,base)) digits2=sorted(str_base(num2,base)) return digits1==digits2
[docs]def is_pandigital(num, base=10): ''' :return: True if num contains all digits in specified base ''' n=str_base(num,base) return len(n)>=base and not '123456789abcdefghijklmnopqrstuvwxyz'[:base-1].strip(n)
# return set(sorted(digits_from_num(num,base))) == set(range(base)) #slow
[docs]def bouncy(n,up=False,down=False): ''' :param n: int number to test :param up: bool :param down: bool bouncy(x) returns True for Bouncy numbers (digits form a strictly non-monotonic sequence) (A152054) bouncy(x,True,None) returns True for Numbers with digits in nondecreasing order (OEIS A009994) bouncy(x,None,True) returns True for Numbers with digits in nonincreasing order (OEIS A009996) ''' s=str(n) s1=''.join(sorted(s)) res=True if up is not None: res = res and up==(s==s1) if down is not None: res = res and down==(s==s1[::-1]) return res
[docs]def repunit_gen(base=10, digit=1): '''generate repunits''' n=digit yield 0 # to be coherent with definition while True: yield n n=n*base+digit
[docs]def repunit(n, base=10, digit=1): ''' :return: nth repunit ''' if n==0: return 0 if digit==1: return (base**n - 1)//(base-1) return int(str(digit)*n,base)
# repeating decimals #
[docs]def rational_form(numerator, denominator): '''information about the decimal representation of a rational number. :return: 5 integer : integer, decimal, shift, repeat, cycle * shift is the len of decimal with leading zeroes if any * cycle is the len of repeat with leading zeroes if any ''' def first_divisible_repunit(x): #finds the first number in the sequence (9, 99, 999, 9999, ...) that is divisible by x. assert x%2 != 0 and x%5 != 0 for r in itertools2.drop(1,repunit_gen(digit=9)): if r % x == 0: return r shift,p = 0,1 for x in (10,2,5): while denominator % x == 0: denominator //= x numerator = 10*numerator//x shift += 1 p *= 10 base,numerator = divmod(numerator,denominator) integer,decimal = divmod(base,p) repunit = first_divisible_repunit(denominator) repeat = numerator * (repunit // denominator) cycle = number_of_digits(repunit) if repeat else 0 return integer, decimal, shift, repeat, cycle
[docs]def rational_str(n,d): integer, decimal, shift, repeat, cycle = rational_form(n,d) s = str(integer) if not (decimal or repeat): return s s = s + "." if decimal or shift: s = s + "{:0{}}".format(decimal, shift) if repeat: s = s + "({:0{}})".format(repeat, cycle) return s
[docs]def rational_cycle(num,den): '''periodic part of the decimal expansion of num/den. Any initial 0's are placed at end of cycle. :see: ''' _, _, _, digits, cycle = rational_form(num,den) lz=cycle-number_of_digits(digits) return digits*ipow(10,lz)
# polygonal numbers
[docs]def tetrahedral(n): ''' :return: int n-th tetrahedral number :see: ''' return n*(n+1)*(n+2)//6
[docs]def sum_of_squares(n): ''' :return: 1^2 + 2^2 + 3^2 + ... + n^2 :see: ''' return n*(n+1)*(2*n+1)//6
pyramidal = sum_of_squares
[docs]def sum_of_cubes(n): ''' :return: 1^3 + 2^3 + 3^3 + ... + n^3 :see: ''' a=triangular(n) return a*a # by Nicomachus's theorem
[docs]def bernouilli_gen(init=1): '''generator of Bernouilli numbers :param init: int -1 or +1. * -1 for "first Bernoulli numbers" with B1=-1/2 * +1 for "second Bernoulli numbers" with B1=+1/2 ''' B, m = [], 0 while True: B.append(fractions.Fraction(1, m+1)) for j in range(m, 0, -1): B[j-1] = j*(B[j-1] - B[j]) yield init*B[0] if m==1 else B[0]# (which is Bm) m += 1
[docs]def bernouilli(n,init=1): return itertools2.takenth(n,bernouilli_gen(init))
[docs]def faulhaber(n,p): ''' sum of the p-th powers of the first n positive integers :return: 1^p + 2^p + 3^p + ... + n^p :see: ''' s=0 for j,a in enumerate(bernouilli_gen()): if j>p : break s=s+binomial(p+1,j)*a*n**(p+1-j) return s//(p+1)
[docs]def is_happy(n): # while n > 1 and n != 89 and n != 4: n = digsum(n,2) #sum of squares of digits return n==1
[docs]def lychrel_seq(n): while True: r = reverse(n) yield n,r if n==r : break n += r
[docs]def lychrel_count(n, limit=96): '''number of lychrel iterations before n becomes palindromic :param n: int number to test :param limit: int max number of loops. default 96 corresponds to the known most retarded non lychrel number :warning: there are palindrom lychrel numbers such as 4994 ''' for i in itertools.count(): r=reverse(n) if r == n or i==limit: return i n=n+r
[docs]def is_lychrel(n,limit=96): ''' :warning: there are palindrom lychrel numbers such as 4994 ''' r=lychrel_count(n, limit) if r>=limit: return True if r==0: #palindromic number... return lychrel_count(n+reverse(n),limit)+1>=limit #... can be lychrel ! return False
[docs]def polygonal(s, n): # return ((s-2)*n*n-(s-4)*n)//2
[docs]def triangle(n): ''' :return: nth triangle number, defined as the sum of [1,n] values. :see: ''' return polygonal(3,n) # (n*(n+1))/2
[docs]def is_triangle(x): ''' :return: True if x is a triangle number ''' return is_square(1 + 8*x)
[docs]def square(n): return polygonal(4,n) # n*n
[docs]def pentagonal(n): ''' :return: nth pentagonal number :see: ''' return polygonal(5,n) # n*(3*n - 1)/2
[docs]def is_pentagonal(n): ''' :return: True if x is a pentagonal number ''' if n<1: return False n=1+24*n s=isqrt(n) if s*s != n: return False return is_integer((1+s)/6.0)
[docs]def hexagonal(n): ''' :return: nth hexagonal number :see: ''' return polygonal(6,n) # n*(2*n - 1)
[docs]def is_hexagonal(n): return (1 + sqrt(1 + (8 * n))) % 4 == 0
[docs]def heptagonal(n): return polygonal(7,n) # (n * (5 * n - 3)) / 2
[docs]def is_heptagonal(n): return (3 + sqrt(9 + (40 * n))) % 10 == 0
[docs]def octagonal(n): return polygonal(8,n) # (n * (3 * n - 2))
[docs]def is_octagonal(n): return (2 + sqrt(4 + (12 * n))) % 6 == 0
[docs]@decorators.memoize def partition(n): '''The partition function p(n) gives the number of partitions of a nonnegative integer n into positive integers. (There is one partition of zero into positive integers, i.e. the empty partition, since the empty sum is defined as 0.) :see: ''' #TODO : def non_zero_integers(n): for k in range(1, n): yield k yield -k if n == 0: return 1 elif n < 0: return 0 result = 0 for k in non_zero_integers(n + 1): sign = 1 if (k - 1) % 2==0 else -1 result += sign * partition(n - pentagonal(k)) return result
[docs]@decorators.memoize def partitionsQ(n,d=0): # # # # if n==0: return 1 return sum(partitionsQ(n-k,n-2*k+1) for k in range(1,n-d+1))
[docs]def get_cardinal_name(num): '''Get cardinal name for number (0 to 1 million)''' numbers = { 0: "zero", 1: "one", 2: "two", 3: "three", 4: "four", 5: "five", 6: "six", 7: "seven", 8: "eight", 9: "nine", 10: "ten", 11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 15: "fifteen", 16: "sixteen", 17: "seventeen", 18: "eighteen", 19: "nineteen", 20: "twenty", 30: "thirty", 40: "forty", 50: "fifty", 60: "sixty", 70: "seventy", 80: "eighty", 90: "ninety", } def _get_tens(n): a, b = divmod(n, 10) return (numbers[n] if (n in numbers) else "%s-%s" % (numbers[10*a], numbers[b])) def _get_hundreds(n): tens = n % 100 hundreds = (n // 100) % 10 return filter(bool,[ hundreds > 0 and numbers[hundreds], hundreds > 0 and "hundred", hundreds > 0 and tens and "and", (not hundreds or tens > 0) and _get_tens(tens), ]) blocks=digits(num,1000,rev=True) #group by 1000 res='' for hdu,word in zip(blocks,['',' thousand ',' million ',' billion ']): if hdu==0: continue #skip try: res=' '.join(_get_hundreds(hdu))+word+res except: pass return res
[docs]def abundance(n): return sum(divisors(n))-2*n
[docs]def is_perfect(n): ''' :return: -1 if n is deficient, 0 if perfect, 1 if abundant :see:,, ''' # return sign(abundance(n)) #simple, but might be slow for large n for s in itertools2.accumulate(divisors(n)): if s>2*n: return 1 return 0 if s==2*n else -1
[docs]def number_of_digits(num, base=10): '''Return number of digits of num (expressed in base 'base')''' # math.log(num,base) is imprecise, and len(str(num,base)) is slow and ugly num=abs(num) for i in itertools.count(): if num<base: return i+1 num=num//base
[docs]def chakravala(n): '''solves x^2 - n*y^2 = 1 for x,y integers ''' # m = 1 k = 1 a = 1 b = 0 while k != 1 or b == 0: m = k * (m//k+1) - m m = m - int((m - sqrt(n))//k) * k tempA = (a*m + n*b) // abs(k) b = (a + b*m) // abs(k) k = (m*m - n) // k a = tempA return a,b
#combinatorics factorial=math.factorial #didn't knew it was there...
[docs]def factorialk(n,k): '''Multifactorial of n of order k, n(!!...!). This is the multifactorial of n skipping k values. For example, factorialk(17, 4) = 17!!!! = 17 * 13 * 9 * 5 * 1 In particular, for any integer ``n``, we have factorialk(n, 1) = factorial(n) factorialk(n, 2) = factorial2(n) :param n: int Calculate multifactorial. If `n` < 0, the return value is 0. :param k : int Order of multifactorial. :return: int Multifactorial of `n`. ''' # code from scipy, with extact=true if n < -1: return 0 if n <= 0: return 1 val=mul(range(n, 0, -k)) return val
[docs]def factorial2(n): return factorialk(n,2)
[docs]def factorial_gen(f=lambda x:x): '''Generator of factorial :param f: optional function to apply at each step ''' last=1 yield last for n in itertools.count(1): last=f(last*n) yield last
[docs]def binomial(n,k): '''binomial coefficient "n choose k" :param: n, k int :return: int, number of ways to chose n items in k, unordered :see: ''' #return factorial(n) // (factorial(k) * factorial(n - k)) # is very slow # code from if k < 0 or k > n: return 0 if k == 0 or k == n: return 1 k = min(k, n - k) # take advantage of symmetry if k>1e8: raise OverflowError('k=%d too large'%k) c = 1 for i in range(k): c = c * (n - i) // (i + 1) return int(c)
choose=binomial #alias ncombinations=binomial #alias
[docs]def binomial_exponent(n,k,p): ''' :return: int largest power of p that divides binomial(n,k) ''' if is_prime(p): return carries(k,n-k,p) # return min(binomial_exponent(n,k,a)//b for a,b in factorize(p))
[docs]def log_factorial(n): ''' :return: float approximation of ln(n!) by Ramanujan formula ''' return n*math.log(n) - n + (math.log(n*(1+4*n*(1+2*n))))/6 + math.log(math.pi)/2
[docs]def log_binomial(n,k): ''' :return: float approximation of ln(binomial(n,k)) ''' return log_factorial(n) - log_factorial(k) - log_factorial(n - k)
[docs]def ilog(a,b,upper_bound=False): '''discrete logarithm x such that b^x=a :parameter a,b: integer :parameter upper_bound: bool. if True, returns smallest x such that b^x>=a :return: x integer such that b^x=a, or upper_bound, or None ''' # TODO: implement using baby_step_giant_step or or similar #for now it's brute force... l = 0 while a >= b: a //= b l += 1 return l p=1 for x in itertools.count(): if p==a: return x if p>a: return x if upper_bound else None p=b*p
#from "the right way to calculate stuff" :
[docs]def angle(u,v,unit=True): ''' :param u,v: iterable vectors :param unit: bool True if vectors are unit vectors. False increases computations :returns: float angle n radians between u and v unit vectors i ''' if not unit: u=vecunit(u) v=vecunit(v) if dot_vv(u,v) >=0: return 2.*math.asin(dist(v,u)/2) else: return math.pi - 2.*math.asin(dist(vecneg(v),u)/2)
[docs]def sin_over_x(x): '''numerically safe sin(x)/x''' if 1. + x*x == 1.: return 1. else: return math.sin(x)/x
[docs]def slerp(u,v,t): '''spherical linear interpolation :param u,v: 3D unit vectors :param t: float in [0,1] interval :return: vector interpolated between u and v ''' a=angle(u,v) fu=(1-t)*sin_over_x((1-t)*a)/sin_over_x(a) fv=t*sin_over_x(t*a)/sin_over_x(a) return vecadd([fu*x for x in u],[fv*x for x in v])
[docs]def proportional(nseats,votes): '''assign n seats proportionaly to votes using the method :param nseats: int number of seats to assign :param votes: iterable of int or float weighting each party :result: list of ints seats allocated to each party ''' quota=sum(votes)/(1.+nseats) #force float frac=[vote/quota for vote in votes] res=[int(f) for f in frac] n=nseats-sum(res) #number of seats remaining to allocate if n==0: return res #done if n<0: return [min(x,nseats) for x in res] #to handle case where votes=[0,0,..,1,0,...,0] #give the remaining seats to the n parties with the largest remainder remainders=vecsub(frac,res) limit=sorted(remainders,reverse=True)[n-1] for i,r in enumerate(remainders): if r>=limit: res[i]+=1 n-=1 # attempt to handle perfect equality if n==0: return res #done
[docs]def triangular_repartition(x,n): ''' divide 1 into n fractions such that: - their sum is 1 - they follow a triangular linear repartition (sorry, no better name for now) where x/1 is the maximum ''' def _integral(x1,x2): '''return integral under triangle between x1 and x2''' if x2<=x: return (x1 + x2) * float(x2 - x1) / x elif x1>=x: return (2-x1-x2) * float(x1-x2) / (x-1) else: #top of the triangle: return _integral(x1,x)+_integral(x,x2) w=1./n #width of a slice return [_integral(i*w,(i+1)*w) for i in range(n)]
[docs]def rectangular_repartition(x,n,h): ''' divide 1 into n fractions such that: - their sum is 1 - they follow a repartition along a pulse of height h<1 ''' w=1./n #width of a slice and of the pulse x=max(x,w/2.) x=min(x,1-w/2.) xa,xb=x-w/2,x+w/2. #start,end of the pulse o=(1.-h)/(n-1) #base level def _integral(x1,x2): '''return integral between x1 and x2''' if x2<=xa or x1>=xb: return o elif x1<xa: return float(o*(xa-x1)+h*(w-(xa-x1)))/w else: # x1<=xb return float(h*(xb-x1)+o*(w-(xb-x1)))/w return [_integral(i*w,(i+1)*w) for i in range(n)]
[docs]def de_bruijn(k, n): ''' De Bruijn sequence for alphabet k and subsequences of length n. ''' try: # let's see if k can be cast to an integer; # if so, make our alphabet a list _ = int(k) alphabet = list(map(str, range(k))) except (ValueError, TypeError): alphabet = k k = len(k) a = [0] * k * n sequence = [] def db(t, p): if t > n: if n % p == 0: sequence.extend(a[1:p + 1]) else: a[t] = a[t - p] db(t + 1, p) for j in range(a[t - p] + 1, k): a[t] = j db(t + 1, t) db(1, 1) return "".join(alphabet[i] for i in sequence)
'''modular arithmetic initial motivation: code translated from see also mathematica code from ''' #see for a MOD class
[docs]def mod_inv(a,b): # if is_prime(b): #Use Euler's Theorem return ipow(a,b-2,b) b0 = b x0, x1 = 0, 1 if b == 1: return 1 while a > 1: q = a // b a, b = b, a%b x0, x1 = x1 - q * x0, x0 if x1 < 0: x1 += b0 return x1
[docs]def mod_div(a,b,m): ''' :return: x such that (b*x) mod m = a mod m ''' return a*mod_inv(b,m)
[docs]def mod_fact(n,m): ''' :return: n! mod m ''' res = 1 while n > 0: for i in range(2,n%m+1): res = (res * i) % m n=n//m if n%2 > 0 : res = m - res return res%m
[docs]def chinese_remainder(m, a): ''' :param m: list of int moduli :param a: list of int remainders :return: smallest int x such that x mod ni=ai ''' # res = 0 prod=mul(m) for m_i, a_i in zip(m, a): p = prod // m_i res += a_i * mod_inv(p, m_i) * p return res % prod
def _count(n, p): ''' :return: power of p in n ''' k=0; while n>=p: k+=n//p n=n//p return k;
[docs]def mod_binomial(n,k,m,q=None): '''calculates C(n,k) mod m for large n,k,m :param n: int total number of elements :param k: int number of elements to pick :param m: int modulo (or iterable of (m,p) tuples used internally) :param q: optional int power of m for prime m, used internally ''' # the function implements 3 cases which are called recursively: # 1 : m is factorized in powers of primes pi^qi # 2 : Chinese remainder theorem is used to combine all C(n,k) mod pi^qi # 3 : Lucas or Andrew Granville's theorem is used to calculate C(n,k) mod pi^qi if type(m) is int: if q is None: return mod_binomial(n,k,factorize(m)) #3 elif q==1: #use'_theorem ni=digits_gen(n,m) ki=digits_gen(k,m) res=1 for a,b in zip(ni,ki): res*=binomial(a,b) if res==0: break return res #see ''' elif q==3: return mod_binomial(n*m,k*m,m) ''' #no choice for the moment.... return binomial(n,k)%(m**q) else: #2 #use r,f=[],[] for p,q in m: r.append(mod_binomial(n,k,p,q)) f.append(p**q) return chinese_remainder(f,r)
[docs]def baby_step_giant_step(y, a, n): ''' solves Discrete Logarithm Problem (DLP) y = a**x mod n ''' # s = int(math.ceil(math.sqrt(n))) A = [y * pow(a, r, n) % n for r in range(s)] for t in range(1,s+1): value = pow(a, t*s, n) if value in A: return (t * s - A.index(value)) % n
# inspired from
[docs]def mod_matmul(A,B, mod=0): if not mod: return dot_mm(A,B) return [[sum(a*b for a,b in zip(A_row,B_col))%mod for B_col in zip(*B)] for A_row in A]
[docs]def mod_matpow(M, power, mod=0): if power < 0: raise NotImplementedError result = identity(2) for power in digits(power,2,True): if power: result = mod_matmul(result, M, mod) M = mod_matmul(M, M, mod) return result
# in fact numpy.linalg.matrix_power has a bug for large powers # # so our function here above is better :-) matrix_power=mod_matpow
[docs]def mod_sqrt(n, p): '''modular sqrt(n) mod p ''' assert is_prime(p),'p must be prime' a = n%p if p%4 == 3: return pow(a, (p+1) >> 2, p) elif p%8 == 5: v = pow(a << 1, (p-5) >> 3, p) i = ((a*v*v << 1) % p) - 1 return (a*v*i)%p elif p%8 == 1: # Shank's method q, e = p-1, 0 while q&1 == 0: e += 1 q >>= 1 n = 2 while legendre(n, p) != -1: n += 1 w, x, y, r = pow(a, q, p), pow(a, (q+1) >> 1, p), pow(n, q, p), e while True: if w == 1: return x v, k = w, 0 while v != 1 and k+1 < r: v = (v*v)%p k += 1 if k == 0: return x d = pow(y, 1 << (r-k-1), p) x, y = (x*d)%p, (d*d)%p w, r = (w*y)%p, k else: return a # p == 2
[docs]def mod_fac(n, mod, mod_is_prime=False): '''modular factorial : return n! % modulo if module is prime, use Wilson's theorem ''' if n >= mod : # then mod is a factor of n! return 0 if n<20: # for small n the naive algorithm can be faster return factorial(n)%mod if mod_is_prime or is_prime(mod): # use Wilson's Theorem : (n-1)! == -1 (mod modulo) result = mod - 1; # avoid negative numbers: -1 == modulo - 1 (mod modulo) for i in range(mod - 1,n,-1): result *= mod_inv(i, mod) result %= mod return result else: return itertools2.nth(n,factorial_gen(lambda x:x%mod))
[docs]def pi_digits_gen(): ''' generates pi digits as a sequence of INTEGERS ! using Jeremy Gibbons spigot generator :see : ''' # code from q, r, t, j = 1, 180, 60, 2 while True: u, y = 3*(3*j+1)*(3*j+2), (q*(27*j-12)+5*r)//(5*t) yield y q, r, t, j = 10*q*j*(2*j-1), 10*u*(q*(5*j-2)+r-y*t), t*u, j+1
#------------------------------------------------------------------------------ # factorization code taken from # # # #------------------------------------------------------------------------------
[docs]def pfactor(n): '''Helper function for sprp. Returns the tuple (x,y) where n - 1 == (2 ** x) * y and y is odd. We have this bit separated out so that we don’t waste time recomputing s and d for each base when we want to check n against multiple bases. ''' # s, d, q = 0, n-1, 2 while not d & q - 1: s, q = s+1, q*2 return s, d // (q // 2)
[docs]def sprp(n, a, s=None, d=None): '''Checks n for primality using the Strong Probable Primality Test to base a. If present, s and d should be the first and second items, respectively, of the tuple returned by the function pfactor(n)''' # if n%2 == 0: return False if (s is None) or (d is None): s, d = pfactor(n) x = pow(a, d, n) if x == 1: return True for i in range(s): if x == n - 1: return True x = pow(x, 2, n) return False
[docs]def jacobi(a, p): '''Computes the Jacobi symbol (a|p), where p is a positive odd number. :see: ''' # if (p%2 == 0) or (p < 0): return None # p must be a positive odd number if (a == 0) or (a == 1): return a a, t = a%p, 1 while a != 0: while not a & 1: a //= 2 if p & 7 in (3, 5): t *= -1 a, p = p, a if (a & 3 == 3) and (p & 3) == 3: t *= -1 a %= p return t if p == 1 else 0
[docs]def pollardRho_brent(n): '''Brent’s improvement on Pollard’s rho algorithm. :return: int n if n is prime otherwise, we keep chugging until we find a factor of n strictly between 1 and n. :see: ''' if is_prime(n): return n g = n while g == n: y, c, m, g, r, q = randrange(1, n), randrange(1, n), randrange(1, n), 1, 1, 1 while g==1: x, k = y, 0 for i in range(r): y = (y**2 + c) % n while k < r and g == 1: ys = y for i in range(min(m, r-k)): y = (y**2 + c) % n q = q * abs(x-y) % n g, k = gcd(q, n), k+m r *= 2 if g==n: while True: ys = (ys**2+c)%n g = gcd(abs(x-ys), n) if g > 1: break return g
[docs]def pollard_pm1(n, B1=100, B2=1000): # TODO: What are the best default bounds and way to increment them? '''Pollard’s p+1 algorithm, two-phase version. :return: n if n is prime; otherwise, we keep chugging until we find a factor of n strictly between 1 and n. ''' if is_prime(n): return n m = is_power(n) if m: return m while True: pg = primes_gen() q = 2 # TODO: what about other initial values of q? p = while p <= B1: q, p = pow(q, p**ilog(B1, p), n), g = gcd(q-1, n) if 1 < g < n: return g while p <= B2: q, p = pow(q, p, n), g = gcd(q-1, n) if 1 < g < n: return g # These bounds failed. Increase and try again. B1 *= 10 B2 *= 10
[docs]def mlucas(v, a, n): ''' Helper function for williams_pp1(). Multiplies along a Lucas sequence modulo n. ''' v1, v2 = v, (v**2 - 2) % n for bit in bin(a)[3:]: if bit == "0": v1, v2 = ((v1**2 - 2) % n, (v1*v2 - v) % n) else: v1, v2 = ((v1*v2 - v) % n, (v2**2 - 2) % n) return v1
[docs]def williams_pp1(n): '''Williams’ p+1 algorithm. :return: n if n is prime otherwise, we keep chugging until we find a factor of n strictly between 1 and n. ''' if is_prime(n): return n m = is_power(n) if m: return m for v in itertools.count(1): for p in primes_gen(): e = ilog(isqrt(n), p) if e == 0: break for _ in range(e): v = mlucas(v, p, n) g = gcd(v - 2, n) if 1 < g < n: return g if g == n: break
[docs]def ecadd(p1, p2, p0, n): # Add two points p1 and p2 given point P0 = P1-P2 modulo n x1,z1 = p1; x2,z2 = p2; x0,z0 = p0 t1, t2 = (x1-z1)*(x2+z2), (x1+z1)*(x2-z2) return (z0*pow(t1+t2,2,n) % n, x0*pow(t1-t2,2,n) % n)
[docs]def ecdub(p, A, n): # double point p on A modulo n x, z = p; An, Ad = A t1, t2 = pow(x+z,2,n), pow(x-z,2,n) t = t1 - t2 return (t1*t2*4*Ad % n, (4*Ad*t2 + t*An)*t % n)
[docs]def ecmul(m, p, A, n): # multiply point p by m on curve A modulo n if m == 0: return (0, 0) elif m == 1: return p else: q = ecdub(p, A, n) if m == 2: return q b = 1 while b < m: b *= 2 b //= 4 r = p while b: if m&b: q, r = ecdub(q, A, n), ecadd(q, r, p, n) else: q, r = ecadd(r, q, p, n), ecdub(r, A, n) b //= 2 return r
[docs]def factor_ecm(n, B1=10, B2=20): ''' Factors n using the elliptic curve method, using Montgomery curves and an algorithm analogous to the two-phase variant of Pollard’s p-1 method. :return: n if n is prime otherwise, we keep chugging until we find a factor of n strictly between 1 and n ''' # TODO: Determine the best defaults for B1 and B2 and the best way to increment them and iters # TODO: We currently compute the prime lists from the sieve as we need them, but this means that we recompute them at every # iteration. While it would not be particularly efficient memory-wise, we might be able to increase time-efficiency # by computing the primes we need ahead of time (say once at the beginning and then once each time we increase the # bounds) and saving them in lists, and then iterate the inner while loops over those lists. if is_prime(n): return n m = is_power(n) if m: return m iters = 1 while True: for _ in range(iters): # TODO: multiprocessing? # TODO: Do we really want to call the randomizer? Why not have seed be a function of B1, B2, and iters? # TODO: Are some seeds better than others? seed = random.randrange(6, n) u, v = (seed**2 - 5) % n, 4*seed % n p = pow(u, 3, n) Q, C = (pow(v-u,3,n)*(3*u+v) % n, 4*p*v % n), (p, pow(v,3,n)) pg = primes_gen() p = while p <= B1: Q, p = ecmul(p**ilog(B1, p), Q, C, n), g = gcd(Q[1], n) if 1 < g < n: return g while p <= B2: # "There is a simple coding trick that can speed up the second stage. Instead of multiplying each prime times Q, # we iterate over i from B1 + 1 to B2, adding 2Q at each step; when i is prime, the current Q can be accumulated # into the running solution. Again, we defer the calculation of the greatest common divisor until the end of the # iteration. # TODO: Implement this trick and compare performance. Q = ecmul(p, Q, C, n) g *= Q[1] g %= n p = g = gcd(g, n) if 1 < g < n: return g # This seed failed. Try again with a new one. # These bounds failed. Increase and try again. B1 *= 3 B2 *= 3 iters *= 2
# legendre symbol (a|m) # TODO: which is faster?
[docs]def legendre(a, p): '''Functions to comptue the Legendre symbol (a|p). The return value isn’t meaningful if p is composite :see: ''' return ((pow(a, (p-1) >> 1, p) + 1) % p) - 1
[docs]def legendre2(a, p): # TODO: pretty sure this computes the Jacobi symbol '''Functions to comptue the Legendre symbol (a|p). The return value isn’t meaningful if p is composite :see: ''' if a == 0: return 0 x, y, L = a, p, 1 while 1: if x > (y >> 1): x = y - x if y & 3 == 3: L = -L while x & 3 == 0: x >>= 2 if x & 1 == 0: x >>= 1 if y & 7 == 3 or y & 7 == 5: L = -L if x == 1: return ((L+1) % p) - 1 if x & 3 == 3 and y & 3 == 3: L = -L x, y = y % x, x