"""Utilities for enumeration of finite and countably infinite sets. """ from __future__ import absolute_import, division, print_function ### # Countable iteration # Simplifies some calculations class Aleph0(int): _singleton = None def __new__(type): if type._singleton is None: type._singleton = int.__new__(type) return type._singleton def __repr__(self): return "" def __str__(self): return "inf" def __cmp__(self, b): return 1 def __sub__(self, b): raise ValueError("Cannot subtract aleph0") __rsub__ = __sub__ def __add__(self, b): return self __radd__ = __add__ def __mul__(self, b): if b == 0: return b return self __rmul__ = __mul__ def __floordiv__(self, b): if b == 0: raise ZeroDivisionError return self __rfloordiv__ = __floordiv__ __truediv__ = __floordiv__ __rtuediv__ = __floordiv__ __div__ = __floordiv__ __rdiv__ = __floordiv__ def __pow__(self, b): if b == 0: return 1 return self aleph0 = Aleph0() def base(line): return line * (line + 1) // 2 def pairToN(pair): x, y = pair line, index = x + y, y return base(line) + index def getNthPairInfo(N): # Avoid various singularities if N == 0: return (0, 0) # Gallop to find bounds for line line = 1 next = 2 while base(next) <= N: line = next next = line << 1 # Binary search for starting line lo = line hi = line << 1 while lo + 1 != hi: # assert base(lo) <= N < base(hi) mid = (lo + hi) >> 1 if base(mid) <= N: lo = mid else: hi = mid line = lo return line, N - base(line) def getNthPair(N): line, index = getNthPairInfo(N) return (line - index, index) def getNthPairBounded(N, W=aleph0, H=aleph0, useDivmod=False): """getNthPairBounded(N, W, H) -> (x, y) Return the N-th pair such that 0 <= x < W and 0 <= y < H.""" if W <= 0 or H <= 0: raise ValueError("Invalid bounds") elif N >= W * H: raise ValueError("Invalid input (out of bounds)") # Simple case... if W is aleph0 and H is aleph0: return getNthPair(N) # Otherwise simplify by assuming W < H if H < W: x, y = getNthPairBounded(N, H, W, useDivmod=useDivmod) return y, x if useDivmod: return N % W, N // W else: # Conceptually we want to slide a diagonal line across a # rectangle. This gives more interesting results for large # bounds than using divmod. # If in lower left, just return as usual cornerSize = base(W) if N < cornerSize: return getNthPair(N) # Otherwise if in upper right, subtract from corner if H is not aleph0: M = W * H - N - 1 if M < cornerSize: x, y = getNthPair(M) return (W - 1 - x, H - 1 - y) # Otherwise, compile line and index from number of times we # wrap. N = N - cornerSize index, offset = N % W, N // W # p = (W-1, 1+offset) + (-1,1)*index return (W - 1 - index, 1 + offset + index) def getNthPairBoundedChecked( N, W=aleph0, H=aleph0, useDivmod=False, GNP=getNthPairBounded ): x, y = GNP(N, W, H, useDivmod) assert 0 <= x < W and 0 <= y < H return x, y def getNthNTuple(N, W, H=aleph0, useLeftToRight=False): """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W) Return the N-th W-tuple, where for 0 <= x_i < H.""" if useLeftToRight: elts = [None] * W for i in range(W): elts[i], N = getNthPairBounded(N, H) return tuple(elts) else: if W == 0: return () elif W == 1: return (N,) elif W == 2: return getNthPairBounded(N, H, H) else: LW, RW = W // 2, W - (W // 2) L, R = getNthPairBounded(N, H**LW, H**RW) return getNthNTuple( L, LW, H=H, useLeftToRight=useLeftToRight ) + getNthNTuple(R, RW, H=H, useLeftToRight=useLeftToRight) def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple): t = GNT(N, W, H, useLeftToRight) assert len(t) == W for i in t: assert i < H return t def getNthTuple( N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False ): """getNthTuple(N, maxSize, maxElement) -> x Return the N-th tuple where len(x) < maxSize and for y in x, 0 <= y < maxElement.""" # All zero sized tuples are isomorphic, don't ya know. if N == 0: return () N -= 1 if maxElement is not aleph0: if maxSize is aleph0: raise NotImplementedError("Max element size without max size unhandled") bounds = [maxElement**i for i in range(1, maxSize + 1)] S, M = getNthPairVariableBounds(N, bounds) else: S, M = getNthPairBounded(N, maxSize, useDivmod=useDivmod) return getNthNTuple(M, S + 1, maxElement, useLeftToRight=useLeftToRight) def getNthTupleChecked( N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False, GNT=getNthTuple, ): # FIXME: maxsize is inclusive t = GNT(N, maxSize, maxElement, useDivmod, useLeftToRight) assert len(t) <= maxSize for i in t: assert i < maxElement return t def getNthPairVariableBounds(N, bounds): """getNthPairVariableBounds(N, bounds) -> (x, y) Given a finite list of bounds (which may be finite or aleph0), return the N-th pair such that 0 <= x < len(bounds) and 0 <= y < bounds[x].""" if not bounds: raise ValueError("Invalid bounds") if not (0 <= N < sum(bounds)): raise ValueError("Invalid input (out of bounds)") level = 0 active = list(range(len(bounds))) active.sort(key=lambda i: bounds[i]) prevLevel = 0 for i, index in enumerate(active): level = bounds[index] W = len(active) - i if level is aleph0: H = aleph0 else: H = level - prevLevel levelSize = W * H if N < levelSize: # Found the level idelta, delta = getNthPairBounded(N, W, H) return active[i + idelta], prevLevel + delta else: N -= levelSize prevLevel = level else: raise RuntimError("Unexpected loop completion") def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds): x, y = GNVP(N, bounds) assert 0 <= x < len(bounds) and 0 <= y < bounds[x] return (x, y) ### def testPairs(): W = 3 H = 6 a = [[" " for x in range(10)] for y in range(10)] b = [[" " for x in range(10)] for y in range(10)] for i in range(min(W * H, 40)): x, y = getNthPairBounded(i, W, H) x2, y2 = getNthPairBounded(i, W, H, useDivmod=True) print(i, (x, y), (x2, y2)) a[y][x] = "%2d" % i b[y2][x2] = "%2d" % i print("-- a --") for ln in a[::-1]: if "".join(ln).strip(): print(" ".join(ln)) print("-- b --") for ln in b[::-1]: if "".join(ln).strip(): print(" ".join(ln)) def testPairsVB(): bounds = [2, 2, 4, aleph0, 5, aleph0] a = [[" " for x in range(15)] for y in range(15)] b = [[" " for x in range(15)] for y in range(15)] for i in range(min(sum(bounds), 40)): x, y = getNthPairVariableBounds(i, bounds) print(i, (x, y)) a[y][x] = "%2d" % i print("-- a --") for ln in a[::-1]: if "".join(ln).strip(): print(" ".join(ln)) ### # Toggle to use checked versions of enumeration routines. if False: getNthPairVariableBounds = getNthPairVariableBoundsChecked getNthPairBounded = getNthPairBoundedChecked getNthNTuple = getNthNTupleChecked getNthTuple = getNthTupleChecked if __name__ == "__main__": testPairs() testPairsVB()