nameserv: check infohints in MPI_Lookup_name

Check info hints "port_name_size" in MPI_Lookup_name and pass the buffer
length to PMI utilities.

This provides a mechanism for users to safely allocate a portname buffer
and accommodate larger portname size than MPI_MAX_PORT_NAME.
Esse commit está contido em:
Hui Zhou
2023-12-06 10:11:41 -06:00
commit 4a66315cc8
7 arquivos alterados com 45 adições e 15 exclusões
+1 -1
Ver Arquivo
@@ -90,7 +90,7 @@ int MPIR_pmi_bcast_local(char *val, int val_size);
/* name service functions */
int MPIR_pmi_publish(const char name[], const char port[]);
int MPIR_pmi_lookup(const char name[], char port[]);
int MPIR_pmi_lookup(const char name[], char port[], int portlen);
int MPIR_pmi_unpublish(const char name[]);
/* Other misc functions */
+2
Ver Arquivo
@@ -167,6 +167,8 @@ be in the range 0 to %d
**namepubnotfound %s:Lookup failed for service name %s
**namepubnotunpub:Failed to unpublish service name
**namepubnotunpub %s:Failed to unpublish service name %s
**namepubtrunc:Lookup returned port_name is truncated
**namepubtrunc %s:Lookup for service name %s is truncated
**sendbuf_inplace:sendbuf cannot be MPI_IN_PLACE
**recvbuf_inplace:recvbuf cannot be MPI_IN_PLACE
**buf_inplace:buffer cannot be MPI_IN_PLACE
+8 -2
Ver Arquivo
@@ -41,10 +41,16 @@ int MPID_NS_Publish(MPID_NS_Handle handle, const MPIR_Info * info_ptr,
int MPID_NS_Lookup(MPID_NS_Handle handle, const MPIR_Info * info_ptr,
const char service_name[], char port[])
{
MPL_UNREFERENCED_ARG(info_ptr);
MPL_UNREFERENCED_ARG(handle);
return MPIR_pmi_lookup(service_name, port);
int port_name_size = MPI_MAX_PORT_NAME;
if (info_ptr) {
const char *val = MPIR_Info_lookup(info_ptr, "port_name_size");
if (val) {
port_name_size = atoi(val);
}
}
return MPIR_pmi_lookup(service_name, port, port_name_size);
}
int MPID_NS_Unpublish(MPID_NS_Handle handle, const MPIR_Info * info_ptr, const char service_name[])
+4 -3
Ver Arquivo
@@ -740,11 +740,12 @@ int MPIR_pmi_publish(const char name[], const char port[])
return mpi_errno;
}
int MPIR_pmi_lookup(const char name[], char port[])
int MPIR_pmi_lookup(const char name[], char port[], int port_len)
{
int mpi_errno = MPI_SUCCESS;
SWITCH_PMI(mpi_errno = pmi1_lookup(name, port),
mpi_errno = pmi2_lookup(name, port), mpi_errno = pmix_lookup(name, port));
SWITCH_PMI(mpi_errno = pmi1_lookup(name, port, port_len),
mpi_errno = pmi2_lookup(name, port, port_len),
mpi_errno = pmix_lookup(name, port, port_len));
return mpi_errno;
}
+24 -3
Ver Arquivo
@@ -213,15 +213,36 @@ static int pmi1_publish(const char name[], const char port[])
goto fn_exit;
}
static int pmi1_lookup(const char name[], char port[])
static int pmi1_lookup(const char name[], char port[], int port_len)
{
int mpi_errno = MPI_SUCCESS;
int pmi_errno;
pmi_errno = PMI_Lookup_name(name, port);
#ifdef PMI_MAX_PORT_NAME
int maxlen = PMI_MAX_PORT_NAME;
#else
int maxlen = MPI_MAX_PORT_NAME;
#endif
char *tmpbuf = NULL;
if (port_len >= maxlen) {
pmi_errno = PMI_Lookup_name(name, port);
} else {
/* allocate a temporary buffer for safety */
tmpbuf = MPL_malloc(maxlen, MPL_MEM_OTHER);
pmi_errno = PMI_Lookup_name(name, tmpbuf);
if (pmi_errno == PMI_SUCCESS) {
int mpl_err = MPL_strncpy(port, tmpbuf, port_len);
MPIR_ERR_CHKANDJUMP1(mpl_err, mpi_errno, MPI_ERR_NAME, "**namepubtrunc",
"**namepubtrunc %s", name);
}
}
MPIR_ERR_CHKANDJUMP1(pmi_errno, mpi_errno, MPI_ERR_NAME, "**namepubnotfound",
"**namepubnotfound %s", name);
fn_exit:
if (tmpbuf) {
MPL_free(tmpbuf);
}
return mpi_errno;
fn_fail:
goto fn_exit;
@@ -320,7 +341,7 @@ static int pmi1_publish(const char name[], const char port[])
return MPI_ERR_INTERN;
}
static int pmi1_lookup(const char name[], char port[])
static int pmi1_lookup(const char name[], char port[], int port_len)
{
return MPI_ERR_INTERN;
}
+3 -3
Ver Arquivo
@@ -261,13 +261,13 @@ static int pmi2_publish(const char name[], const char port[])
goto fn_exit;
}
static int pmi2_lookup(const char name[], char port[])
static int pmi2_lookup(const char name[], char port[], int port_len)
{
int mpi_errno = MPI_SUCCESS;
int pmi_errno;
/* release the global CS for PMI calls */
MPID_THREAD_CS_EXIT(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
pmi_errno = PMI2_Nameserv_lookup(name, NULL, port, MPI_MAX_PORT_NAME);
pmi_errno = PMI2_Nameserv_lookup(name, NULL, port, port_len);
MPID_THREAD_CS_ENTER(GLOBAL, MPIR_THREAD_GLOBAL_ALLFUNC_MUTEX);
MPIR_ERR_CHKANDJUMP1(pmi_errno, mpi_errno, MPI_ERR_NAME, "**namepubnotfound",
"**namepubnotfound %s", name);
@@ -374,7 +374,7 @@ static int pmi2_publish(const char name[], const char port[])
return MPI_ERR_INTERN;
}
static int pmi2_lookup(const char name[], char port[])
static int pmi2_lookup(const char name[], char port[], int port_len)
{
return MPI_ERR_INTERN;
}
+3 -3
Ver Arquivo
@@ -481,7 +481,7 @@ static int pmix_publish(const char name[], const char port[])
goto fn_exit;
}
static int pmix_lookup(const char name[], char port[])
static int pmix_lookup(const char name[], char port[], int port_len)
{
int mpi_errno = MPI_SUCCESS;
int pmi_errno;
@@ -490,7 +490,7 @@ static int pmix_lookup(const char name[], char port[])
MPL_strncpy(pdata[0].key, name, PMIX_MAX_KEYLEN);
pmi_errno = PMIx_Lookup(pdata, 1, NULL, 0);
if (pmi_errno == PMIX_SUCCESS) {
MPL_strncpy(port, pdata[0].value.data.string, MPI_MAX_PORT_NAME);
MPL_strncpy(port, pdata[0].value.data.string, port_len);
}
PMIX_PDATA_FREE(pdata, 1);
MPIR_ERR_CHKANDJUMP1(pmi_errno, mpi_errno, MPI_ERR_NAME, "**namepubnotfound",
@@ -970,7 +970,7 @@ static int pmix_publish(const char name[], const char port[])
return MPI_ERR_INTERN;
}
static int pmix_lookup(const char name[], char port[])
static int pmix_lookup(const char name[], char port[], int port_len)
{
return MPI_ERR_INTERN;
}