Arquivos
Hui Zhou 37cda22eab coll: add MPIR_CVAR_YAKSA_REDUCTION_THRESHOLD
We call MPIR_Typerep_reduce_is_supported to determine whether we do
collective host buffer swap in reduce and allreduce. We may want to make
better decision based on message size, thus we are adding the count to
the parameters.

Add a cvar to disable yaksa reduction for large messages.
2024-07-25 14:30:26 -05:00

787 linhas
30 KiB
Python

##
## Copyright (C) by Argonne National Laboratory
## See COPYRIGHT in top-level directory
##
from local_python import MPI_API_Global as G
from local_python import RE
from local_python.mpi_api import *
from local_python.binding_common import *
def main():
binding_dir = G.get_srcdir_path("src/binding")
c_dir = "src/binding/c"
func_list = load_C_func_list(binding_dir, silent=True)
G.algos = load_coll_algos("src/mpi/coll/coll_algorithms.txt")
coll_names = ["barrier", "bcast", "gather", "gatherv", "scatter", "scatterv", "allgather", "allgatherv", "alltoall", "alltoallv", "alltoallw", "reduce", "allreduce", "reduce_scatter", "reduce_scatter_block", "scan", "exscan", "neighbor_allgather", "neighbor_allgatherv", "neighbor_alltoall", "neighbor_alltoallv", "neighbor_alltoallw"]
G.out = []
G.prototypes_hash = {}
G.prototypes = []
G.out.append("#include \"mpiimpl.h\"")
G.out.append("#include \"iallgatherv/iallgatherv.h\"")
for a in coll_names:
dump_coll(a, "blocking")
dump_coll(a, "nonblocking")
dump_coll(a, "persistent")
dump_c_file("src/mpi/coll/mpir_coll.c", G.out)
dump_prototypes("src/mpi/coll/include/coll_algos.h", G.prototypes)
def add_prototype(l):
if RE.match(r'int\s+(\w+)\(', l):
func_name = RE.m.group(1)
if func_name not in G.prototypes_hash:
G.prototypes_hash[func_name] = 1
G.prototypes.append(l)
else:
pass
def load_coll_algos(algo_txt):
All = {}
with open(algo_txt) as In:
(func_commkind, algo_list, algo) = (None, None, None)
for line in In:
if RE.match(r'(\w+-(intra|inter)):', line):
func_commkind = RE.m.group(1)
algo_list = []
All[func_commkind] = algo_list
elif RE.match(r'\s+(\w+)\s*$', line):
algo = {"name": RE.m.group(1), "func-commkind": func_commkind}
algo_list.append(algo)
elif RE.match(r'\s+(\w+):\s*(.+)', line):
(key, value) = RE.m.group(1,2)
algo[key] = value
return All
def dump_coll(name, blocking_type):
if blocking_type == "blocking":
dump_allcomm_auto_blocking(name)
dump_mpir_impl_blocking(name)
elif blocking_type == "nonblocking":
dump_allcomm_sched_auto(name)
dump_sched_impl(name)
dump_mpir_impl_nonblocking(name)
elif blocking_type == "persistent":
dump_mpir_impl_persistent(name)
else:
raise Exception("Wrong blocking_type")
dump_mpir(name, blocking_type)
def dump_allcomm_auto_blocking(name):
""" MPIR_Xxx_allcomm_auto - use Csel selections """
blocking_type = "blocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "blocking")
func_args = get_func_args(args, name, "blocking")
# e.g. ibcast, Ibcast, IBCAST
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
G.out.append("")
G.out.append("/* ---- %s ---- */" % func_name)
G.out.append("")
add_prototype("int MPIR_%s_allcomm_auto(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_allcomm_auto(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
# -- Csel_search
dump_open("MPIR_Csel_coll_sig_s coll_sig = {")
G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME)
G.out.append(".comm_ptr = comm_ptr,")
for p in func['parameters']:
if not re.match(r'comm$', p['name']):
G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name']))
dump_close("};")
G.out.append("")
G.out.append("MPII_Csel_container_s *cnt = MPIR_Csel_search(comm_ptr->csel_comm, coll_sig);")
G.out.append("MPIR_Assert(cnt);")
G.out.append("")
# -- switch
def dump_cnt_algo_blocking(algo, commkind):
if "allcomm" in algo:
commkind = "allcomm"
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "csel")
algo_params = get_algo_params(params, algo)
add_prototype("int MPIR_%s_%s_%s(%s)" % (Name, commkind, algo_name, algo_params))
dump_split(3, "mpi_errno = MPIR_%s_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
dump_open("switch (cnt->id) {")
for commkind in ("intra", "inter"):
if commkind == "inter" and re.match(r'(scan|exscan|neighbor_)', name):
continue
for algo in G.algos[func_name + "-" + commkind]:
if "allcomm" in algo:
if commkind == "intra":
commkind = "allcomm"
else:
# skip inter since it is covered already
continue
G.out.append("case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_%s_%s_%s:" % (Name, commkind, algo['name']))
G.out.append("INDENT")
dump_cnt_algo_blocking(algo, commkind)
G.out.append("break;");
G.out.append("DEDENT")
G.out.append("")
G.out.append("case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_%s_allcomm_nb:" % Name)
add_prototype("int MPIR_%s_allcomm_nb(%s);" % (Name, func_params))
dump_split(2, " mpi_errno = MPIR_%s_allcomm_nb(%s);" % (Name, func_args))
G.out.append(" break;");
G.out.append("")
G.out.append("default:")
G.out.append(" MPIR_Assert(0);")
dump_close("}")
# -- return
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
dump_fn_exit()
dump_close("}")
def dump_allcomm_sched_auto(name):
""" MPIR_Xxx_allcomm_sched_auto - use Csel selections """
blocking_type = "nonblocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "allcomm_sched_auto")
# e.g. ibcast, Ibcast, IBCAST
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
G.out.append("")
G.out.append("/* ---- %s ---- */" % func_name)
G.out.append("")
add_prototype("int MPIR_%s_allcomm_sched_auto(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_allcomm_sched_auto(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
# -- Csel_search
dump_open("MPIR_Csel_coll_sig_s coll_sig = {")
G.out.append(".coll_type = MPIR_CSEL_COLL_TYPE__%s," % NAME)
G.out.append(".comm_ptr = comm_ptr,")
for p in func['parameters']:
if not re.match(r'comm$', p['name']):
G.out.append(".u.%s.%s = %s," % (func_name, p['name'], p['name']))
dump_close("};")
G.out.append("")
G.out.append("MPII_Csel_container_s *cnt = MPIR_Csel_search(comm_ptr->csel_comm, coll_sig);")
G.out.append("MPIR_Assert(cnt);")
G.out.append("")
# -- add shced_auto prototypes
sched_auto_params = get_func_params(params, name, "sched_auto")
add_prototype("int MPIR_%s_intra_sched_auto(%s)" % (Name, sched_auto_params))
if not re.match(r'(scan|exscan|neighbor_)', name):
add_prototype("int MPIR_%s_inter_sched_auto(%s)" % (Name, sched_auto_params))
# -- switch
def dump_cnt_algo_tsp(algo, commkind):
G.out.append("MPII_GENTRAN_CREATE_SCHED_P();")
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "csel")
algo_params = get_algo_params(params, algo)
add_prototype("int MPIR_TSP_%s_sched_%s_%s(%s)" % (Name, commkind, algo_name, algo_params))
dump_split(3, "mpi_errno = MPIR_TSP_%s_sched_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
def dump_cnt_algo_sched(algo, commkind):
G.out.append("MPII_SCHED_CREATE_SCHED_P();")
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "csel")
algo_params = get_algo_params(params, algo)
add_prototype("int MPIR_%s_%s_%s(%s)" % (Name, commkind, algo_name, algo_params))
dump_split(3, "mpi_errno = MPIR_%s_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
dump_open("switch (cnt->id) {")
for commkind in ("intra", "inter"):
if commkind == "inter" and re.match(r'(scan|exscan|neighbor_)', name):
continue
for algo in G.algos[func_name + "-" + commkind]:
use_commkind = commkind
if "allcomm" in algo:
if commkind == "intra":
use_commkind = "allcomm"
else:
# skip inter since it is covered already
continue
G.out.append("case MPII_CSEL_CONTAINER_TYPE__ALGORITHM__MPIR_%s_%s_%s:" % (Name, use_commkind, algo['name']))
G.out.append("INDENT")
if algo['name'].startswith('tsp_'):
dump_cnt_algo_tsp(algo, use_commkind)
else:
dump_cnt_algo_sched(algo, use_commkind)
G.out.append("break;");
G.out.append("DEDENT")
G.out.append("")
G.out.append("default:")
G.out.append(" MPIR_Assert(0);")
dump_close("}")
# -- return
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
dump_fn_exit()
dump_close("}")
def dump_mpir_impl_blocking(name):
""" MPIR_Xxx_impl - """
blocking_type = "blocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "blocking")
func_args = get_func_args(args, name, "blocking")
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
need_fallback = False
def dump_algo(algo, commkind):
if "allcomm" in algo:
commkind = "allcomm"
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "cvar")
dump_split(3, "mpi_errno = MPIR_%s_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
def dump_cases(commkind):
nonlocal need_fallback
CVAR_PREFIX = "MPIR_CVAR_%s_%s_ALGORITHM" % (NAME, commkind.upper())
for algo in G.algos[func_name + '-' + commkind]:
if algo['name'] != "auto" and algo['name'] != "nb":
G.out.append("case %s_%s:" % (CVAR_PREFIX, algo['name']))
G.out.append("INDENT")
if 'restrictions' in algo:
dump_fallback(algo)
need_fallback = True
dump_algo(algo, commkind)
G.out.append("break;");
G.out.append("DEDENT")
G.out.append("case %s_nb:" % CVAR_PREFIX)
dump_split(3, " mpi_errno = MPIR_%s_allcomm_nb(%s);" % (Name, func_args))
G.out.append(" break;");
G.out.append("case %s_auto:" % CVAR_PREFIX)
dump_split(3, " mpi_errno = MPIR_%s_allcomm_auto(%s);" % (Name, func_args))
G.out.append(" break;");
G.out.append("default:")
G.out.append(" MPIR_Assert(0);")
# ----------------
G.out.append("")
add_prototype("int MPIR_%s_impl(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
dump_open("if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {")
dump_open("switch (MPIR_CVAR_%s_INTRA_ALGORITHM) {" % NAME)
dump_cases("intra")
dump_close("}")
dump_else()
if re.match(r'(scan|exscan|neighbor_)', name):
G.out.append("MPIR_Assert_error(\"Only intra-communicator allowed\");")
else:
dump_open("switch (MPIR_CVAR_%s_INTER_ALGORITHM) {" % NAME)
dump_cases("inter")
dump_close("}")
dump_close("}")
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
if need_fallback:
G.out.append("goto fn_exit;")
G.out.append("")
G.out.append("fallback:")
dump_split(1, "mpi_errno = MPIR_%s_allcomm_auto(%s);" % (Name, func_args))
G.out.append("")
dump_fn_exit()
dump_close("}")
def dump_sched_impl(name):
""" MPIR_Xxx_impl - """
blocking_type = "nonblocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "sched_impl")
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
need_fallback = False
def dump_algo_tsp(algo, commkind):
G.out.append("MPII_GENTRAN_CREATE_SCHED_P();")
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "cvar")
if "allcomm" in algo:
commkind = "allcomm"
dump_split(3, "mpi_errno = MPIR_TSP_%s_sched_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
def dump_algo_sched(algo, commkind):
G.out.append("MPII_SCHED_CREATE_SCHED_P();")
algo_name = get_algo_name(algo)
algo_args = get_algo_args(args, algo, "cvar")
if "allcomm" in algo:
commkind = "allcomm"
dump_split(3, "mpi_errno = MPIR_%s_%s_%s(%s);" % (Name, commkind, algo_name, algo_args))
def dump_cases(commkind):
nonlocal need_fallback
CVAR_PREFIX = "MPIR_CVAR_%s_%s_ALGORITHM" % (NAME, commkind.upper())
for algo in G.algos[func_name + '-' + commkind]:
if algo['name'] != "auto" and algo['name'] != "nb":
G.out.append("case %s_%s:" % (CVAR_PREFIX, algo['name']))
G.out.append("INDENT")
if 'restrictions' in algo:
dump_fallback(algo)
need_fallback = True
if algo['name'].startswith('tsp_'):
dump_algo_tsp(algo, commkind)
else:
dump_algo_sched(algo, commkind)
G.out.append("break;");
G.out.append("DEDENT")
G.out.append("case %s_auto:" % CVAR_PREFIX)
func_args = get_func_args(args, name, "allcomm_sched_auto")
dump_split(3, " mpi_errno = MPIR_%s_allcomm_sched_auto(%s);" % (Name, func_args))
G.out.append(" break;");
G.out.append("default:")
G.out.append(" MPIR_Assert(0);")
# ----------------
G.out.append("")
add_prototype("int MPIR_%s_sched_impl(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_sched_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
dump_open("if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {")
dump_open("switch (MPIR_CVAR_%s_INTRA_ALGORITHM) {" % NAME)
dump_cases("intra")
dump_close("}")
dump_else()
if re.match(r'(scan|exscan|neighbor_)', name):
G.out.append("MPIR_Assert_error(\"Only intra-communicator allowed\");")
else:
dump_open("switch (MPIR_CVAR_%s_INTER_ALGORITHM) {" % NAME)
dump_cases("inter")
dump_close("}")
dump_close("}")
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
if need_fallback:
G.out.append("goto fn_exit;")
G.out.append("")
G.out.append("fallback:")
func_args = get_func_args(args, name, "allcomm_sched_auto")
dump_split(1, "mpi_errno = MPIR_%s_allcomm_sched_auto(%s);" % (Name, func_args))
G.out.append("")
dump_fn_exit()
dump_close("}")
def dump_mpir_impl_nonblocking(name):
blocking_type = "nonblocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "nonblocking")
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
G.out.append("")
add_prototype("int MPIR_%s_impl(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("enum MPIR_sched_type sched_type;")
G.out.append("void *sched;")
G.out.append("")
G.out.append("*request = NULL;")
func_args = get_func_args(args, name, "mpir_impl_nonblocking")
dump_split(1, "mpi_errno = MPIR_%s_sched_impl(%s);" % (Name, func_args))
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
G.out.append("MPII_SCHED_START(sched_type, sched, comm_ptr, request);")
G.out.append("")
G.out.append("fn_exit:")
G.out.append("return mpi_errno;")
G.out.append("fn_fail:")
G.out.append("goto fn_exit;")
dump_close("}")
def dump_mpir_impl_persistent(name):
blocking_type = "persistent"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "persistent")
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
G.out.append("")
add_prototype("int MPIR_%s_impl(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_impl(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
G.out.append("MPIR_Request *req = MPIR_Request_create(MPIR_REQUEST_KIND__PREQUEST_COLL);")
G.out.append("MPIR_ERR_CHKANDJUMP(!req, mpi_errno, MPI_ERR_OTHER, \"**nomem\");")
G.out.append("MPIR_Comm_add_ref(comm_ptr);")
G.out.append("req->comm = comm_ptr;")
G.out.append("MPIR_Comm_save_inactive_request(comm_ptr, req);")
G.out.append("req->u.persist_coll.sched_type = MPIR_SCHED_INVALID;")
G.out.append("req->u.persist_coll.real_request = NULL;")
func_args = get_func_args(args, name, "mpir_impl_persistent")
dump_split(1, "mpi_errno = MPIR_I%s_sched_impl(%s);" % (name, func_args))
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
G.out.append("")
G.out.append("*request = req;")
G.out.append("")
G.out.append("fn_exit:")
G.out.append("return mpi_errno;")
G.out.append("fn_fail:")
G.out.append("goto fn_exit;")
dump_close("}")
def dump_mpir(name, blocking_type):
""" MPIR_Xxx - """
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, blocking_type)
func_args = get_func_args(args, name, blocking_type)
func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()
def dump_buffer_swap_pre():
G.out.append("void *in_recvbuf = recvbuf;")
G.out.append("void *host_sendbuf = NULL;")
G.out.append("void *host_recvbuf = NULL;")
G.out.append("")
if name == "reduce_scatter":
G.out.append("MPI_Aint count = 0;")
G.out.append("for (int i = 0; i < MPIR_Comm_size(comm_ptr); i++) {")
G.out.append(" count += recvcounts[i];")
G.out.append("}")
G.out.append("")
elif name == "reduce_scatter_block":
G.out.append("MPI_Aint count = MPIR_Comm_size(comm_ptr) * recvcount;")
if name == "reduce":
use_recvbuf = "(comm_ptr->rank == root || root == MPI_ROOT) ? recvbuf : NULL"
else:
use_recvbuf = "recvbuf"
G.out.append("if(!MPIR_Typerep_reduce_is_supported(op, count, datatype))")
G.out.append(" MPIR_Coll_host_buffer_alloc(sendbuf, %s, count, datatype, &host_sendbuf, &host_recvbuf);" % use_recvbuf)
G.out.append("")
for buf in ("sendbuf", "recvbuf"):
G.out.append("if (host_%s) {" % buf);
G.out.append(" %s = host_%s;" % (buf, buf));
G.out.append("}")
G.out.append("")
def dump_buffer_swap_post():
count = "count"
if name == "reduce_scatter":
count = "recvcounts[comm_ptr->rank]"
elif name == "reduce_scatter_block":
count = "recvcount"
if blocking_type == "blocking":
G.out.append("if (host_recvbuf) {")
G.out.append(" recvbuf = in_recvbuf;")
G.out.append(" MPIR_Localcopy(host_recvbuf, count, datatype, recvbuf, count, datatype);")
G.out.append("}")
G.out.append("MPIR_Coll_host_buffer_free(host_sendbuf, host_recvbuf);")
elif blocking_type == "nonblocking":
G.out.append("MPIR_Coll_host_buffer_swap_back(host_sendbuf, host_recvbuf, in_recvbuf, %s, datatype, *request);" % count)
elif blocking_type == "persistent":
G.out.append("MPIR_Coll_host_buffer_persist_set(host_sendbuf, host_recvbuf, in_recvbuf, %s, datatype, *request);" % count)
G.out.append("")
add_prototype("int MPIR_%s(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("")
need_buffer_swap = False
if re.match(r'(reduce|allreduce|scan|exscan|reduce_scatter)', name):
need_buffer_swap = True
if need_buffer_swap:
dump_buffer_swap_pre()
cond1 = "MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_all"
cond2 = "MPIR_CVAR_DEVICE_COLLECTIVES == MPIR_CVAR_DEVICE_COLLECTIVES_percoll"
cond3 = "MPIR_CVAR_%s_DEVICE_COLLECTIVE" % NAME
G.out.append("if ((%s) ||" % cond1)
G.out.append(" ((%s) &&" % cond2)
G.out.append(" %s)) {" % cond3)
G.out.append("INDENT")
dump_split(2, "mpi_errno = MPID_%s(%s);" % (Name, func_args))
dump_else()
dump_split(2, "mpi_errno = MPIR_%s_impl(%s);" % (Name, func_args))
dump_close("}")
if need_buffer_swap:
dump_buffer_swap_post()
G.out.append("")
G.out.append("return mpi_errno;")
dump_close("}")
# ----
def dump_fallback(algo):
cond_list = []
for a in algo['restrictions'].replace(" ","").split(','):
if a == "inplace":
cond_list.append("sendbuf == MPI_IN_PLACE")
elif a == "noinplace":
cond_list.append("sendbuf != MPI_IN_PLACE")
elif a == "power-of-two":
cond_list.append("comm_ptr->local_size == comm_ptr->coll.pof2")
elif a == "size-ge-pof2":
cond_list.append("count >= comm_ptr->coll.pof2")
elif a == "commutative":
cond_list.append("MPIR_Op_is_commutative(op)")
elif a== "builtin-op":
cond_list.append("HANDLE_IS_BUILTIN(op)")
elif a == "parent-comm":
cond_list.append("MPIR_Comm_is_parent_comm(comm_ptr)")
elif a == "node-consecutive":
cond_list.append("MPII_Comm_is_node_consecutive(comm_ptr)")
elif a == "displs-ordered":
# assume it's allgatherv
cond_list.append("MPII_Iallgatherv_is_displs_ordered(comm_ptr->local_size, recvcounts, displs)")
else:
raise Exception("Unsupported restrictions - %s" % a)
(func_name, commkind) = algo['func-commkind'].split('-')
G.out.append("MPII_COLLECTIVE_FALLBACK_CHECK(comm_ptr->rank, %s, mpi_errno," % ' && '.join(cond_list))
G.out.append(" \"%s %s cannot be applied.\\n\");" % (func_name.capitalize(), algo['name']))
# ----
def get_func_name(name, blocking_type):
if blocking_type == "blocking":
return name
elif blocking_type == "nonblocking":
return 'i' + name
elif blocking_type == "persistent":
return name + "_init"
def get_params_and_args(func):
mapping = G.MAPS['SMALL_C_KIND_MAP']
params = []
args = []
for p in func['parameters']:
if p['name'] == 'comm':
params.append("MPIR_Comm * comm_ptr")
args.append("comm_ptr")
else:
s = get_C_param(p, func, mapping)
if p['kind'].startswith('POLY'):
s = re.sub(r'\bint ', 'MPI_Aint ', s)
params.append(s)
args.append(p['name'])
return (', '.join(params), ', '.join(args))
def get_algo_extra_args(algo, kind):
(func_name, commkind) = algo['func-commkind'].split('-')
extra_params = algo['extra_params'].replace(' ', '').split(',')
cvar_params = algo['cvar_params'].replace(' ', '').split(',')
if len(extra_params) != len(cvar_params):
raise Exception("algorithm %s-%s-%s: extra_params and cvar_params sizes mismatch!" % (func_name, commkind, algo['name']))
out_list = []
for i in range(len(extra_params)):
if RE.match(r'\w+=(.+)', extra_params[i]):
# constant parameter
out_list.append(RE.m.group(1))
else:
if kind == "csel":
prefix = "cnt->u.%s.%s_%s." % (func_name, commkind, algo['name'])
out_list.append(prefix + extra_params[i])
elif kind == "cvar":
prefix = "MPIR_CVAR_%s_" % func_name.upper()
tmp = prefix + cvar_params[i]
if re.match(r"%sTREE_TYPE" % prefix, tmp):
newname = "MPIR_%s_tree_type" % func_name.capitalize()
tmp = re.sub(r"%sTREE_TYPE" % prefix, newname, tmp)
elif re.match(r"%sTHROTTLE" % prefix, tmp):
newname = "MPIR_CVAR_ALLTOALL_THROTTLE"
tmp = re.sub(r"%sTHROTTLE" % prefix, newname, tmp)
out_list.append(tmp)
else:
raise Exception("Wrong kind!")
return ', '.join(out_list)
def get_algo_extra_params(algo):
extra_params = algo['extra_params'].replace(' ', '').split(',')
out_list = []
for a in extra_params:
if RE.match(r'(\w+)=.+', a):
# constant parameter
out_list.append("int " + RE.m.group(1))
else:
out_list.append("int " + a)
return ', '.join(out_list)
# additional wrappers
def get_algo_args(args, algo, kind):
algo_args = args
if 'extra_params' in algo:
algo_args += ", " + get_algo_extra_args(algo, kind)
if algo['name'].startswith('tsp_'):
algo_args += ", *sched_p"
elif algo['func-commkind'].startswith('i'):
algo_args += ", *sched_p"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_args += ", errflag"
return algo_args
def get_algo_params(params, algo):
algo_params = params
if 'extra_params' in algo:
algo_params += ", " + get_algo_extra_params(algo)
if algo['name'].startswith('tsp_'):
algo_params += ", MPIR_TSP_sched_t sched"
elif algo['func-commkind'].startswith('i'):
algo_params += ", MPIR_Sched_t s"
elif not algo['func-commkind'].startswith('neighbor_'):
algo_params += ", MPIR_Errflag_t errflag"
return algo_params
def get_algo_name(algo):
# the name used in algo function name
if "func_name" in algo:
return algo['func_name']
elif algo['name'].startswith('tsp_'):
return algo['name'][4:]
else:
return algo['name']
def get_func_params(params, name, kind):
func_params = params
if kind == "blocking":
if not name.startswith('neighbor_'):
func_params += ", MPIR_Errflag_t errflag"
elif kind == "nonblocking":
func_params += ", MPIR_Request ** request"
elif kind == "persistent":
func_params += ", MPIR_Info * info_ptr, MPIR_Request ** request"
elif kind == "sched_auto":
func_params += ", MPIR_Sched_t s"
elif kind == "allcomm_sched_auto":
func_params += ", bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
elif kind == "sched_impl":
func_params += ", bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
else:
raise Exception("get_func_params - unexpected kind = %s" % kind)
return func_params
def get_func_args(args, name, kind):
func_args = args
if kind == "blocking":
if not name.startswith('neighbor_'):
func_args += ", errflag"
elif kind == "nonblocking":
func_args += ", request"
elif kind == "persistent":
func_args += ", info_ptr, request"
elif kind == "allcomm_sched_auto":
func_args += ", is_persistent, sched_p, sched_type_p"
elif kind == "mpir_impl_nonblocking":
func_args += ", false, &sched, &sched_type"
elif kind == "mpir_impl_persistent":
func_args += ", true, &req->u.persist_coll.sched, &req->u.persist_coll.sched_type"
else:
raise Exception("get_func_args - unexpected kind = %s" % kind)
return func_args
# ----------------------
def dump_c_file(f, lines):
print(" --> [%s]" % f)
with open(f, "w") as Out:
indent = 0
for l in G.copyright_c:
print(l, file=Out)
for l in lines:
if RE.match(r'(INDENT|DEDENT)', l):
# indentations
a = RE.m.group(1)
if a == "INDENT":
indent += 1
else:
indent -= 1
elif RE.match(r'\s*(fn_exit|fn_fail|fallback):', l):
# labels
print(" %s:" % RE.m.group(1), file=Out)
else:
# print the line with correct indentations
if indent > 0 and not RE.match(r'#(if|endif)', l):
print(" " * indent, end='', file=Out)
print(l, file=Out)
def dump_prototypes(f, prototypes):
print(" --> [%s]" % f)
with open(f, "w") as Out:
for l in G.copyright_c:
print(l, file=Out)
print("#ifndef COLL_ALGOS_H_INCLUDED", file=Out)
print("#define COLL_ALGOS_H_INCLUDED", file=Out)
print("", file=Out)
for l in prototypes:
lines = split_line_with_break(l + ';', '', 80)
for l2 in lines:
print(l2, file=Out)
print("#endif /* COLL_ALGOS_H_INCLUDED */", file=Out)
def dump_open(line):
G.out.append(line)
G.out.append("INDENT")
def dump_close(line):
G.out.append("DEDENT")
G.out.append(line)
def dump_else():
G.out.append("DEDENT")
G.out.append("} else {")
G.out.append("INDENT")
def dump_fn_exit():
G.out.append("")
G.out.append("fn_exit:")
G.out.append("return mpi_errno;")
G.out.append("fn_fail:")
G.out.append("goto fn_exit;")
def dump_split(indent, l):
tlist = split_line_with_break(l, "", 100 - indent * 4)
G.out.extend(tlist)
# ---------------------------------------------------------
if __name__ == "__main__":
main()