diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d308795..43d4607e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,6 +145,12 @@ option( ) message(STATUS "Distribution is turned ${ENABLE_DISTRIBUTION}. Set ENABLE_DISTRIBUTION to modify.") +option( + ENABLE_SUBCOMM + "Whether QuEST will be built with support for restricting it to a user-defined MPI communicator. Turned OFF by default." + OFF +) +message(STATUS "Custom communicator support is turned ${ENABLE_SUBCOMM}. Set ENABLE_SUBCOMM to modify.") # GPU Acceleration option( @@ -206,6 +212,9 @@ if (ENABLE_CUQUANTUM AND NOT ENABLE_CUDA) message(FATAL_ERROR "Use of cuQuantum requires CUDA.") endif() +if (ENABLE_SUBCOMM AND NOT ENABLE_DISTRIBUTION) + message(FATAL_ERROR "Distribution must be enabled to make use of a user-defined communicator for QuEST.") +endif() if(WIN32) @@ -381,8 +390,11 @@ endif() # MPI if (ENABLE_DISTRIBUTION) find_package(MPI REQUIRED + # Component CXX is the C api usable from C++ + # NOT the deprecated C++ API COMPONENTS CXX ) + target_link_libraries(QuEST PRIVATE MPI::MPI_CXX @@ -446,6 +458,7 @@ endif() # set vars which will be written to config.h.in (auto-converted to 0 or 1) set(COMPILE_OPENMP ${ENABLE_MULTITHREADING}) set(COMPILE_MPI ${ENABLE_DISTRIBUTION}) +set(COMPILE_SUBCOMM ${ENABLE_SUBCOMM}) set(COMPILE_CUQUANTUM ${ENABLE_CUQUANTUM}) set(INCLUDE_DEPRECATED_FUNCTIONS ${ENABLE_DEPRECATED_API}) @@ -565,6 +578,10 @@ add_executable(min_example ) target_link_libraries(min_example PRIVATE QuEST::QuEST) +if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(min_example PRIVATE MPI::MPI_CXX) +endif() + if (INSTALL_BINARIES) install(TARGETS min_example RUNTIME @@ -741,4 +758,4 @@ install( if(PROJECT_IS_TOP_LEVEL) include(CPack) -endif () \ No newline at end of file +endif () diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index afc8f85d..d21dc5ae 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,6 +20,10 @@ function(add_example direc in_fn) add_executable(${target} ${in_fn}) target_link_libraries(${target} PUBLIC QuEST) + if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(${target} PRIVATE MPI::MPI_CXX) + endif() + if (INSTALL_BINARIES) install( TARGETS ${target} diff --git a/quest/include/config.h.in b/quest/include/config.h.in index 2cb12fa9..070ecf29 100644 --- a/quest/include/config.h.in +++ b/quest/include/config.h.in @@ -37,6 +37,7 @@ #if defined(FLOAT_PRECISION) || \ defined(COMPILE_OPENMP) || \ defined(COMPILE_MPI) || \ + defined(COMPILE_SUBCOMM) || \ defined(COMPILE_CUDA) || \ defined(COMPILE_HIP) || \ defined(COMPILE_CUQUANTUM) || \ @@ -79,6 +80,7 @@ // crucial to QuEST source (informs external library usage) #cmakedefine01 COMPILE_OPENMP #cmakedefine01 COMPILE_MPI +#cmakedefine01 COMPILE_SUBCOMM #cmakedefine01 COMPILE_CUDA #cmakedefine01 COMPILE_CUQUANTUM @@ -118,6 +120,7 @@ #if ! defined(FLOAT_PRECISION) || \ ! defined(COMPILE_OPENMP) || \ ! defined(COMPILE_MPI) || \ + ! defined(COMPILE_SUBCOMM) || \ ! defined(COMPILE_CUDA) || \ ! defined(COMPILE_HIP) || \ ! defined(COMPILE_CUQUANTUM) || \ @@ -144,6 +147,7 @@ #if ! (COMPILE_OPENMP == 0 || COMPILE_OPENMP == 1) || \ ! (COMPILE_MPI == 0 || COMPILE_MPI == 1) || \ + ! (COMPILE_SUBCOMM == 0 || COMPILE_SUBCOMM == 1) || \ ! (COMPILE_CUDA == 0 || COMPILE_CUDA == 1) || \ ! (COMPILE_HIP == 0 || COMPILE_HIP == 1) || \ ! (COMPILE_CUQUANTUM == 0 || COMPILE_CUQUANTUM == 1) || \ diff --git a/quest/include/environment.h b/quest/include/environment.h index 04f24bfe..dba9cd06 100644 --- a/quest/include/environment.h +++ b/quest/include/environment.h @@ -36,6 +36,7 @@ typedef struct { int isMultithreaded; int isGpuAccelerated; int isDistributed; + int userOwnsMpi; // deployment modes which cannot be directly changed after compilation int isCuQuantumEnabled; @@ -61,6 +62,8 @@ void initQuESTEnv(); */ void initCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread); +void initCustomMpiQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread); + /// @notyetdoced void finalizeQuESTEnv(); diff --git a/quest/include/quest.h b/quest/include/quest.h index 409253ff..8771b712 100644 --- a/quest/include/quest.h +++ b/quest/include/quest.h @@ -45,6 +45,7 @@ #include "quest/include/operations.h" #include "quest/include/paulis.h" #include "quest/include/qureg.h" +#include "quest/include/subcommunicator.h" #include "quest/include/matrices.h" #include "quest/include/wrappers.h" diff --git a/quest/include/subcommunicator.h b/quest/include/subcommunicator.h new file mode 100644 index 00000000..16653c23 --- /dev/null +++ b/quest/include/subcommunicator.h @@ -0,0 +1,22 @@ +#ifndef SUBCOMMUNICATOR_H +#define SUBCOMMUNICATOR_H + +#include "quest/include/config.h" + +#if COMPILE_MPI && COMPILE_SUBCOMM + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void initCustomMpiCommQuESTEnv(MPI_Comm questComm, int useGpuAccel, int useMultithread); + +#ifdef __cplusplus +} +#endif + +#endif + +#endif diff --git a/quest/src/api/CMakeLists.txt b/quest/src/api/CMakeLists.txt index 0979f2f6..43b61df7 100644 --- a/quest/src/api/CMakeLists.txt +++ b/quest/src/api/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources(QuEST operations.cpp paulis.cpp qureg.cpp + subcommunicator.cpp trotterisation.cpp types.cpp -) \ No newline at end of file +) diff --git a/quest/src/api/environment.cpp b/quest/src/api/environment.cpp index 54149189..50e95c61 100644 --- a/quest/src/api/environment.cpp +++ b/quest/src/api/environment.cpp @@ -71,7 +71,7 @@ static bool hasEnvBeenFinalized = false; */ -void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread, const char* caller) { +void validateAndInitCustomQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread, const char* caller) { // ensure that we are never re-initialising QuEST (even after finalize) because // this leads to undefined behaviour in distributed mode, as per the MPI @@ -94,7 +94,7 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti // perform that specifically upon the MPI-process-bound GPU(s). Further, // we can make sure validation errors are reported only by the root node. if (useDistrib) - comm_init(); + comm_init(userOwnsMpi); validate_newEnvDistributedBetweenPower2Nodes(caller); @@ -145,6 +145,7 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti globalEnvPtr->isMultithreaded = useMultithread; globalEnvPtr->isGpuAccelerated = useGpuAccel; globalEnvPtr->isDistributed = useDistrib; + globalEnvPtr->userOwnsMpi = userOwnsMpi; globalEnvPtr->isCuQuantumEnabled = useCuQuantum; globalEnvPtr->isGpuSharingEnabled = permitGpuSharing; @@ -153,6 +154,11 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti globalEnvPtr->numNodes = (useDistrib)? comm_getNumNodes() : 1; } +void updateQuESTEnvDistInfo() { + globalEnvPtr->rank = (globalEnvPtr->isDistributed)? comm_getRank() : 0; + globalEnvPtr->numNodes = (globalEnvPtr->isDistributed)? comm_getNumNodes() : 1; + return; +} /* @@ -187,10 +193,11 @@ void printCompilationInfo() { print_table( "compilation", { - {"isMpiCompiled", comm_isMpiCompiled()}, - {"isGpuCompiled", gpu_isGpuCompiled()}, - {"isOmpCompiled", cpu_isOpenmpCompiled()}, - {"isCuQuantumCompiled", gpu_isCuQuantumCompiled()}, + {"isMpiCompiled", comm_isMpiCompiled()}, + {"isMpiSubCommunicatorCompiled", comm_isMpiSubCommunicatorCompiled()}, + {"isGpuCompiled", gpu_isGpuCompiled()}, + {"isOmpCompiled", cpu_isOpenmpCompiled()}, + {"isCuQuantumCompiled", gpu_isCuQuantumCompiled()}, }); } @@ -200,6 +207,7 @@ void printDeploymentInfo() { print_table( "deployment", { {"isMpiEnabled", globalEnvPtr->isDistributed}, + {"doesUserOwnMpi", globalEnvPtr->userOwnsMpi}, {"isGpuEnabled", globalEnvPtr->isGpuAccelerated}, {"isOmpEnabled", globalEnvPtr->isMultithreaded}, {"isCuQuantumEnabled", globalEnvPtr->isCuQuantumEnabled}, @@ -397,13 +405,19 @@ extern "C" { void initCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread) { - validateAndInitCustomQuESTEnv(useDistrib, useGpuAccel, useMultithread, __func__); + const int USER_OWNS_MPI = 0; + validateAndInitCustomQuESTEnv(useDistrib, USER_OWNS_MPI, useGpuAccel, useMultithread, __func__); } +void initCustomMpiQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread) { + validateAndInitCustomQuESTEnv(useDistrib, userOwnsMpi, useGpuAccel, useMultithread, __func__); +} + void initQuESTEnv() { - validateAndInitCustomQuESTEnv(modeflag::USE_AUTO, modeflag::USE_AUTO, modeflag::USE_AUTO, __func__); + const int USER_OWNS_MPI = 0; + validateAndInitCustomQuESTEnv(modeflag::USE_AUTO, USER_OWNS_MPI, modeflag::USE_AUTO, modeflag::USE_AUTO, __func__); } @@ -436,7 +450,7 @@ void finalizeQuESTEnv() { if (globalEnvPtr->isDistributed) { comm_sync(); - comm_end(); + comm_end(globalEnvPtr->userOwnsMpi); } // free global env's heap memory and flag it as unallocated @@ -454,8 +468,12 @@ void syncQuESTEnv() { if (globalEnvPtr->isGpuAccelerated) gpu_sync(); - if (globalEnvPtr->isDistributed) + if (globalEnvPtr->isDistributed) { comm_sync(); + #if COMPILE_SUBCOMM + updateQuESTEnvDistInfo(); + #endif + } } @@ -498,10 +516,11 @@ void getEnvironmentString(char str[200]) { int cuQuantum = env.isGpuAccelerated && gpu_isCuQuantumCompiled(); int gpuDirect = env.isGpuAccelerated && gpu_isDirectGpuCommPossible(); - snprintf(str, 200, "CUDA=%d OpenMP=%d MPI=%d threads=%d ranks=%d cuQuantum=%d gpuDirect=%d", + snprintf(str, 200, "CUDA=%d OpenMP=%d MPI=%d userOwnsMPI=%d threads=%d ranks=%d cuQuantum=%d gpuDirect=%d", env.isGpuAccelerated, env.isMultithreaded, env.isDistributed, + env.userOwnsMpi, numThreads, env.numNodes, cuQuantum, diff --git a/quest/src/api/subcommunicator.cpp b/quest/src/api/subcommunicator.cpp new file mode 100644 index 00000000..aa61b8ed --- /dev/null +++ b/quest/src/api/subcommunicator.cpp @@ -0,0 +1,25 @@ +#include "quest/include/config.h" +#include "quest/include/environment.h" +#include "quest/include/subcommunicator.h" + +#include "quest/src/comm/comm_config.hpp" + +#if COMPILE_MPI && COMPILE_SUBCOMM + +#include + +void initCustomMpiCommQuESTEnv(MPI_Comm userQuestComm, int useGpuAccel, int useMultithread) { + // useDistrib and userOwnsMpi are implied by the user of this initialiser + const int USE_DISTRIB = 1; + const int USER_OWNS_MPI = 1; + + // set mpiCommQuest to user provided communicator + comm_setMpiComm(userQuestComm); + + // initialise QuEST around that communicator + initCustomMpiQuESTEnv(USE_DISTRIB, USER_OWNS_MPI, useGpuAccel, useMultithread); + + return; +} + +#endif diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 854a12bd..15e9f98d 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -20,6 +20,8 @@ #if COMPILE_MPI #include + + static MPI_Comm mpiCommQuest = MPI_COMM_NULL; #endif @@ -60,6 +62,9 @@ bool comm_isMpiCompiled() { return (bool) COMPILE_MPI; } +bool comm_isMpiSubCommunicatorCompiled() { + return (bool) COMPILE_SUBCOMM; +} bool comm_isMpiGpuAware() { @@ -98,28 +103,46 @@ bool comm_isInit() { } -void comm_init() { +void comm_init(int userOwnsMpi) { #if COMPILE_MPI // error if attempting re-initialisation - if (comm_isInit()) + if (!userOwnsMpi && comm_isInit()) error_commAlreadyInit(); + + // error if user owns MPI but has not initialised + if (userOwnsMpi && !comm_isInit()) { + error_commNotInit(); + } + + // QuEST must initialise MPI if the user does not own it + if (!userOwnsMpi) + MPI_Init(NULL, NULL); - MPI_Init(NULL, NULL); + // If user is setting their own comm, mpiCommQuest will be NOT MPI_COMM_NULL, + // and we should not touch it. + // If user is NOT setting their own comm, mpiCommQuest will be MPI_COMM_NULL, + // and we should set it to MPI_COMM_WORLD. + if (mpiCommQuest == MPI_COMM_NULL) + MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); #endif } -void comm_end() { +void comm_end(int userOwnsMpi) { #if COMPILE_MPI // gracefully permit comm_end() before comm_init(), as input validation can trigger if (!comm_isInit()) return; - MPI_Barrier(MPI_COMM_WORLD); - MPI_Finalize(); + MPI_Barrier(mpiCommQuest); + MPI_Comm_free(&mpiCommQuest); + + // QuEST must finalise MPI if the user does not own it + if (!userOwnsMpi) + MPI_Finalize(); #endif } @@ -135,7 +158,7 @@ int comm_getRank() { return ROOT_RANK; int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_rank(mpiCommQuest, &rank); return rank; #else @@ -164,7 +187,7 @@ int comm_getNumNodes() { return 1; int numNodes; - MPI_Comm_size(MPI_COMM_WORLD, &numNodes); + MPI_Comm_size(mpiCommQuest, &numNodes); return numNodes; #else @@ -182,6 +205,25 @@ void comm_sync() { if (!comm_isInit()) return; - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(mpiCommQuest); #endif } + +#if COMPILE_MPI + MPI_Comm comm_getMpiComm() { + return mpiCommQuest; + } + + #if COMPILE_SUBCOMM + void comm_setMpiComm(MPI_Comm newComm) { + // error if mpiCommQuEST is already set! + if (mpiCommQuest != MPI_COMM_NULL) { + error_commDoubleSetMpiComm(); + } + + MPI_Comm_dup(newComm, &mpiCommQuest); + + return; + } + #endif +#endif diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index 444d1dbf..e598ee78 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -10,14 +10,20 @@ #ifndef COMM_CONFIG_HPP #define COMM_CONFIG_HPP +#include "quest/include/config.h" + +#if COMPILE_MPI + #include +#endif constexpr int ROOT_RANK = 0; bool comm_isMpiCompiled(); +bool comm_isMpiSubCommunicatorCompiled(); bool comm_isMpiGpuAware(); -void comm_init(); -void comm_end(); +void comm_init(int userOwnsMpi); +void comm_end(int userOwnsMpi); void comm_sync(); int comm_getRank(); @@ -27,5 +33,11 @@ bool comm_isInit(); bool comm_isRootNode(); bool comm_isRootNode(int rank); +#if COMPILE_MPI + MPI_Comm comm_getMpiComm(); + #if COMPILE_SUBCOMM + void comm_setMpiComm(MPI_Comm newComm); + #endif +#endif -#endif // COMM_CONFIG_HPP \ No newline at end of file +#endif // COMM_CONFIG_HPP diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index 19ebcb9f..14f3e58f 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -149,7 +149,8 @@ int getMaxNumMessages() { // messages. Beware the max is obtained via a void pointer and might be unset... void* tagUpperBoundPtr; int isAttribSet; - MPI_Comm_get_attr(MPI_COMM_WORLD, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm_get_attr(mpiCommQuest, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); // if something went wrong with obtaining the tag bound, return the safe minimum if (!isAttribSet) @@ -216,6 +217,8 @@ std::array dividePayloadIntoMessages(qindex numAmps) { void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // each message is asynchronously dispatched with a final wait, as per arxiv.org/abs/2308.07402 // we will send payload in multiple asynch messages (create two requests per msg for subsequent synch) @@ -226,8 +229,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m]); - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m+1]); + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m]); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m+1]); } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -248,6 +251,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // we will not track nor wait for the asynch send; instead, the caller will later comm_sync() MPI_Request nullReq = MPI_REQUEST_NULL; @@ -257,7 +262,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { // asynchronously send the uniquely-tagged messages for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &nullReq); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &nullReq); } #else @@ -269,6 +274,8 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { void receiveArray(qcomp* dest, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // expect the data in multiple messages auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems); @@ -278,7 +285,7 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) { // listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[m]); + MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[m]); } // receivers wait for all messages to be received (while sender asynch proceeds) @@ -303,6 +310,8 @@ void globallyCombineNonUniformSubArrays( ) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + int myRank = comm_getRank(); int numNodes = comm_getNumNodes(); @@ -336,14 +345,14 @@ void globallyCombineNonUniformSubArrays( for (int m=0; m 0) { qindex recvInd = globalRecvIndPerRank[sendRank] + (numBigMsgs * bigMsgSize); requests.push_back(MPI_REQUEST_NULL); - MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, MPI_COMM_WORLD, &requests.back()); + MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, mpiCommQuest, &requests.back()); } } @@ -639,7 +648,9 @@ void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank) { void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { #if COMPILE_MPI - MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, MPI_COMM_WORLD); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + + MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -650,6 +661,8 @@ void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // only the sender and root nodes need to continue int recvRank = ROOT_RANK; int myRank = comm_getRank(); @@ -665,8 +678,8 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) for (qindex m=0; m(m); (myRank == sendRank)? - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, MPI_COMM_WORLD, &requests[m]): // sender - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, MPI_COMM_WORLD, &requests[m]); // root + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, mpiCommQuest, &requests[m]): // sender + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, mpiCommQuest, &requests[m]); // root } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -680,9 +693,10 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) void comm_broadcastIntsFromRoot(int* arr, qindex length) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_INT, sendRank, MPI_COMM_WORLD); + MPI_Bcast(arr, length, MPI_INT, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -693,8 +707,10 @@ void comm_broadcastIntsFromRoot(int* arr, qindex length) { void comm_broadcastUnsignedsFromRoot(unsigned* arr, qindex length) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, MPI_COMM_WORLD); + MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -721,7 +737,9 @@ void comm_combineSubArrays(qcomp* recv, vector recvInds, vector void comm_reduceAmp(qcomp* localAmp) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -732,7 +750,9 @@ void comm_reduceAmp(qcomp* localAmp) { void comm_reduceReal(qreal* localReal) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -743,7 +763,9 @@ void comm_reduceReal(qreal* localReal) { void comm_reduceReals(qreal* localReals, qindex numLocalReals) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -754,10 +776,12 @@ void comm_reduceReals(qreal* localReals, qindex numLocalReals) { bool comm_isTrueOnAllNodes(bool val) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // perform global AND and broadcast result back to all nodes int local = (int) val; int global; - MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, MPI_COMM_WORLD); + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, mpiCommQuest); return (bool) global; #else @@ -793,6 +817,8 @@ bool comm_isTrueOnRootNode(bool val) { vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) { #if COMPILE_MPI + MPI_Comm mpiCommQuest = comm_getMpiComm(); + // no need to validate array sizes and memory alloc successes; // these are trivial O(#nodes)-size arrays containing <20 chars int numNodes = comm_getNumNodes(); @@ -803,7 +829,7 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) // all nodes send root all their local chars int recvRank = ROOT_RANK; MPI_Gather(localChars, maxNumLocalChars, MPI_CHAR, allChars.data(), - maxNumLocalChars, MPI_CHAR, recvRank, MPI_COMM_WORLD); + maxNumLocalChars, MPI_CHAR, recvRank, mpiCommQuest); // divide allChars into stings, delimited by each node's terminal char vector out(numNodes); diff --git a/quest/src/core/errors.cpp b/quest/src/core/errors.cpp index 9e72b1e0..1791d7e9 100644 --- a/quest/src/core/errors.cpp +++ b/quest/src/core/errors.cpp @@ -181,6 +181,11 @@ void error_commNumMessagesExceedTagMax() { raiseInternalError("A function attempted to communicate via more messages than permitted (since there would be more uniquely-tagged messages than the tag upperbound)."); } +void error_commDoubleSetMpiComm() { + + raiseInternalError("An attempt was made to set mpiCommQuest after it had already been set, as indicated by mpiCommQuest != MPI_COMM_NULL."); +} + void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps) { bool valid = ( diff --git a/quest/src/core/errors.hpp b/quest/src/core/errors.hpp index 950ac17e..57feedf3 100644 --- a/quest/src/core/errors.hpp +++ b/quest/src/core/errors.hpp @@ -85,6 +85,8 @@ void error_commGivenInconsistentNumSubArraysANodes(); void error_commNumMessagesExceedTagMax(); +void error_commDoubleSetMpiComm(); + void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps); void assert_commPayloadIsPowerOf2(qindex numAmps); @@ -383,4 +385,4 @@ void error_unexpectedNumLindbladSuperpropTerms(); -#endif // ERRORS_HPP \ No newline at end of file +#endif // ERRORS_HPP diff --git a/quest/src/core/validation.cpp b/quest/src/core/validation.cpp index 3ac48505..85b58db9 100644 --- a/quest/src/core/validation.cpp +++ b/quest/src/core/validation.cpp @@ -1149,8 +1149,9 @@ void default_inputErrorHandler(const char* func, const char* msg) { comm_sync(); // finalise MPI before error-exit to avoid scaring user with giant MPI error message + // we always "take ownership" of MPI here since we're about to kill the whole program if (comm_isInit()) - comm_end(); + comm_end(0); // simply exit, interrupting any other process (potentially leaking) exit(EXIT_FAILURE); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 31b0ea75..c3262235 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,6 +6,10 @@ add_executable(tests target_link_libraries(tests PRIVATE QuEST::QuEST Catch2::Catch2) target_compile_features(tests PUBLIC cxx_std_20) +if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(tests PRIVATE MPI::MPI_CXX) +endif() + add_subdirectory(unit) add_subdirectory(utils) add_subdirectory(integration) diff --git a/tests/unit/environment.cpp b/tests/unit/environment.cpp index 6d4efb80..344ac586 100644 --- a/tests/unit/environment.cpp +++ b/tests/unit/environment.cpp @@ -83,6 +83,24 @@ TEST_CASE( "initCustomQuESTEnv", TEST_CATEGORY ) { } +TEST_CASE( "initCustomMpiQuESTEnv", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + + // cannot be meaningfully tested since env already active + SUCCEED( ); + } + + SECTION( LABEL_VALIDATION ) { + + REQUIRE_THROWS_WITH( initCustomMpiQuESTEnv(0,0,0,0), ContainsSubstring( "already been initialised") ); + + // cannot check arguments since env-already-initialised + // validation is performed first + } +} + + TEST_CASE( "finalizeQuESTEnv", TEST_CATEGORY ) { SECTION( LABEL_CORRECTNESS ) { @@ -143,6 +161,7 @@ TEST_CASE( "getQuESTEnv", TEST_CATEGORY ) { REQUIRE( (env.isMultithreaded == 0 || env.isMultithreaded == 1) ); REQUIRE( (env.isGpuAccelerated == 0 || env.isGpuAccelerated == 1) ); REQUIRE( (env.isDistributed == 0 || env.isDistributed == 1) ); + REQUIRE( (env.userOwnsMpi == 0 || env.userOwnsMpi == 1) ); REQUIRE( (env.isCuQuantumEnabled == 0 || env.isCuQuantumEnabled == 1) ); REQUIRE( (env.isGpuSharingEnabled == 0 || env.isGpuSharingEnabled == 1) );