Skip to content
Snippets Groups Projects
DecompCheck.py 6.39 KiB
Newer Older
Adrian Pope's avatar
Adrian Pope committed
import numpy as N
import sympy.ntheory as SN
import itertools

####################
# LOW LEVEL ROUTINES
####################

# independent of dimension

# take out trivial divisors 1 and n
def divisors(n):
    return SN.divisors(n)[1:-1]

# non-trivial decompositions of integer "n" into "dim" integer factors
# returns an arbitrary length list of lists of length "dim" 
# possibly with repeats
def decompList(n, dim):
    dl = divisors(n)
    if dim == 2:
        ret = []
        for d in dl:
            tmp = [d, n/d]
            tmp.sort()
            ret.append(tmp)
        ret.sort()
    else:
        ret = []
        for d in dl:
            tmp = decompList(n/d, dim-1)
            for i in tmp:
                i.append(d)
                i.sort()
                ret.append(i)
    return ret

# returns a list of a unique set of rank-"dim" tuples
def decompSet(n, dim, reverse=False):
    ret = decompList(n,dim)
    ret = map(tuple, ret)
    ret = set(ret)
    ret = list(ret)
    ret.sort(reverse=reverse)
    return ret

# decompositions sorted by increasing aspect ratio of largest/smallest factor
# can also limit by maximum aspect ratio
def decompAspectRatio(n, dim, maxAspectRatio=0.0):
    ret = decompSet(n, dim, False)
    ret = map(lambda x: [1.0*x[-1]/x[0],x], ret)
    ret.sort()
    if maxAspectRatio > 0.0:
        ret = filter(lambda x: x[0] <= maxAspectRatio, ret)
    return ret

# decompositions based on ranks, independent of grid dimensions

# check whether a 2D decomp is commensurate with subset of 3D decomp
# checks both orders of 2D decomp and returns True if either works
# a = 3D subset, 2-tuple
# b = 2D 2-tuple
def check2d2d(a, b):
    return (b[0]%a[0]==0 and b[1]%a[1]==0) or (b[1]%a[0]==0 and b[0]%a[1]==0)

# filter list of 2D decomps for those commensurate with subset of 3D decomp
# a = 3D subset 2-tuple
# b = list of 2D 2-tuples
def filter2d2d(a, blist):
    return filter(lambda x: check2d2d(a,x), blist)

# return xpencils, ypencils, and zpencils commensurate with a 3D decomp
# pencil lists could be emtpy if there are no co-commensurate decomps
# a = 3D 3-tuple
def filter3d(a):
    nranks = a[0]*a[1]*a[2]
    blist = decompSet(nranks, 2)
    blist.sort(reverse=True)
    xpencils = filter2d2d((a[1],a[2]), blist)
    ypencils = filter2d2d((a[2],a[0]), blist)
    zpencils = filter2d2d((a[0],a[1]), blist)
    return [xpencils, ypencils, zpencils]



# sum the remainders of ng%tuple
# works for arbitrary-rank tuples
def sumRemainders(ng, tuple):
    return N.array(map(lambda x: ng%x, tuple)).sum()

# find pencil decompositions co-commensurate with 3D decomp and ng
def filterNg3d2d(ng, tuple3d, filter3doutput):
    if sumRemainders(ng, tuple3d) == 0:
        xpencils = filter(lambda x: sumRemainders(ng, x)==0, filter3doutput[0])
        ypencils = filter(lambda x: sumRemainders(ng, x)==0, filter3doutput[1])
        zpencils = filter(lambda x: sumRemainders(ng, x)==0, filter3doutput[2])
        if len(xpencils) > 0 and len(ypencils) > 0 and len(zpencils) > 0:
            return [xpencils, ypencils, zpencils]
    return []

######################
# CONVENIENCE ROUTINES
######################

def filterNg3d(ng, tuple3d):
    filter3doutput = filter3d(tuple3d)
    return filterNg3d2d(ng, tuple3d, filter3doutput)

# find ng in range [ng0,ng1] that have pencil decomps for a given 3D decomp
def ngCandidates3d(tuple3d, ng0, ng1):
    ret = []
    filter3doutput = filter3d(tuple3d)
    for i in xrange(ng0, ng1+1):
        if len(filterNg3d2d(i, tuple3d, filter3doutput)) > 0:
            ret.append(i)
    return ret

# cannot remember what permutations does, seems unused
def permutations(tuple):
    ret = list(set(list(itertools.permutations(tuple))))
    ret.sort()
    return ret

# cannot remember what doublecheck does, seems unused
def doublecheck(nranks, ng, t3d, t2dx, t2dy, t2dz):
    # decomps have right number of ranks
    if N.array(t3d).prod() != nranks:
        return False
    if N.array(t2dx).prod() != nranks:
        return False
    if N.array(t2dy).prod() != nranks:
        return False
    if N.array(t2dz).prod() != nranks:
        return False
    # 2D decomps are commensurate with 3D decomps
    if t2dx[1] % t3d[1] != 0 or t2dx[2] % t3d[2] != 0:
        return False
    if t2dy[0] % t3d[0] != 0 or t2dy[2] % t3d[2] != 0:
        return False
    if t2dz[0] % t3d[0] != 0 or t2dz[1] % t3d[1] != 0:
        return False
    # decomps are divisors of ng
    if sumRemainders(ng, t3d) != 0:
        return False
    if sumRemainders(ng, t2dx) != 0:
        return False
    if sumRemainders(ng, t2dy) != 0:
        return False
    if sumRemainders(ng, t2dz) != 0:
        return False
    # ok looks fine
    return True

#####################
# HIGH LEVEL ROUTINES
#####################

# find ng in range [ng0,ng1] that have pencil decomps for any valid 3D decomp
# optionally specify maximum 3D decomp aspect ratio to check
def ngCandidatesNranks(nranks, ng0, ng1, maxAspectRatio=0.0):
    tuple3dlist = decompAspectRatio(nranks, 3, maxAspectRatio)
    ret = []
    for i in tuple3dlist:
        tuple3d = i[1]
        ret += ngCandidates3d(tuple3d, ng0, ng1)
    ret = set(ret)
    ret = list(ret)
    ret.sort()
    return ret

# find 3D decomps that have co-commensurate pencil decomps for ng and nranks
# optionally specify maximum 3D decomp aspect ratio to check
def filterNgNranks(ng, nranks, maxAspectRatio=0.0):
    list3d = decompAspectRatio(nranks, 3, maxAspectRatio)
    return filter(lambda x: len(filterNg3d(ng, x[1])) > 0, list3d)

# expand pencils into 3D tuples in all orders that work
def enumeratePencilsNg3d(ng, tuple3d):
    pencils = filterNg3d(ng, tuple3d)
    xret = []
    for pencil in pencils[0]:
        if pencil[0]%tuple3d[1]==0 and pencil[1]%tuple3d[2]==0:
            xret.append( (1, pencil[0], pencil[1]) )
        if pencil[1]%tuple3d[1]==0 and pencil[0]%tuple3d[2]==0:
            xret.append( (1, pencil[1], pencil[0]) )
    yret = []
    for pencil in pencils[1]:
        if pencil[0]%tuple3d[0]==0 and pencil[1]%tuple3d[2]==0:
            yret.append( (pencil[0], 1, pencil[1]) )
        if pencil[1]%tuple3d[0]==0 and pencil[0]%tuple3d[2]==0:
            yret.append( (pencil[1], 1, pencil[0]) )
    zret = []
    for pencil in pencils[2]:
        if pencil[0]%tuple3d[0]==0 and pencil[1]%tuple3d[1]==0:
            zret.append( (pencil[0], pencil[1], 1) )
        if pencil[1]%tuple3d[0]==0 and pencil[0]%tuple3d[1]==0:
            zret.append( (pencil[1], pencil[0], 1) )
    return [xret, yret, zret]