From 1c87a3cf96db9d11762ff9cb8ab662493b63cf21 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown Date: Fri, 27 Feb 2026 15:53:08 +0000 Subject: [PATCH 01/18] Added ENABLE_SUBCOMM build option --- CMakeLists.txt | 28 ++++++++++++++++++++++++---- quest/include/config.h.in | 4 ++++ 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5d308795..df82112d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,6 +145,12 @@ option( ) message(STATUS "Distribution is turned ${ENABLE_DISTRIBUTION}. Set ENABLE_DISTRIBUTION to modify.") +option( + ENABLE_SUBCOMM + "Whether QuEST will be built with support for restricting it to a user-defined MPI communicator. Turned OFF by default." + OFF +) +message(STATUS "Custom communicator support is turned ${ENABLE_SUBCOMM}. Set ENABLE_SUBCOMM to modify.") # GPU Acceleration option( @@ -206,6 +212,9 @@ if (ENABLE_CUQUANTUM AND NOT ENABLE_CUDA) message(FATAL_ERROR "Use of cuQuantum requires CUDA.") endif() +if (ENABLE_SUBCOMM AND NOT ENABLE_DISTRIBUTION) + message(FATAL_ERROR "Distribution must be enabled to make use of a user-defined communicator for QuEST.") +endif() if(WIN32) @@ -381,12 +390,22 @@ endif() # MPI if (ENABLE_DISTRIBUTION) find_package(MPI REQUIRED + # Component CXX is the C api usable from C++ + # NOT the deprecated C++ API COMPONENTS CXX ) - target_link_libraries(QuEST - PRIVATE - MPI::MPI_CXX - ) + + if(ENABLE_SUBCOMM) + target_link_libraries(QuEST + PUBLIC + MPI::MPI_CXX + ) + else() + target_link_libraries(QuEST + PRIVATE + MPI::MPI_CXX + ) + endif() endif() @@ -446,6 +465,7 @@ endif() # set vars which will be written to config.h.in (auto-converted to 0 or 1) set(COMPILE_OPENMP ${ENABLE_MULTITHREADING}) set(COMPILE_MPI ${ENABLE_DISTRIBUTION}) +set(COMPILE_SUBCOMM ${ENABLE_SUBCOMM}) set(COMPILE_CUQUANTUM ${ENABLE_CUQUANTUM}) set(INCLUDE_DEPRECATED_FUNCTIONS ${ENABLE_DEPRECATED_API}) diff --git a/quest/include/config.h.in b/quest/include/config.h.in index 2cb12fa9..070ecf29 100644 --- a/quest/include/config.h.in +++ b/quest/include/config.h.in @@ -37,6 +37,7 @@ #if defined(FLOAT_PRECISION) || \ defined(COMPILE_OPENMP) || \ defined(COMPILE_MPI) || \ + defined(COMPILE_SUBCOMM) || \ defined(COMPILE_CUDA) || \ defined(COMPILE_HIP) || \ defined(COMPILE_CUQUANTUM) || \ @@ -79,6 +80,7 @@ // crucial to QuEST source (informs external library usage) #cmakedefine01 COMPILE_OPENMP #cmakedefine01 COMPILE_MPI +#cmakedefine01 COMPILE_SUBCOMM #cmakedefine01 COMPILE_CUDA #cmakedefine01 COMPILE_CUQUANTUM @@ -118,6 +120,7 @@ #if ! defined(FLOAT_PRECISION) || \ ! defined(COMPILE_OPENMP) || \ ! defined(COMPILE_MPI) || \ + ! defined(COMPILE_SUBCOMM) || \ ! defined(COMPILE_CUDA) || \ ! defined(COMPILE_HIP) || \ ! defined(COMPILE_CUQUANTUM) || \ @@ -144,6 +147,7 @@ #if ! (COMPILE_OPENMP == 0 || COMPILE_OPENMP == 1) || \ ! (COMPILE_MPI == 0 || COMPILE_MPI == 1) || \ + ! (COMPILE_SUBCOMM == 0 || COMPILE_SUBCOMM == 1) || \ ! (COMPILE_CUDA == 0 || COMPILE_CUDA == 1) || \ ! (COMPILE_HIP == 0 || COMPILE_HIP == 1) || \ ! (COMPILE_CUQUANTUM == 0 || COMPILE_CUQUANTUM == 1) || \ From 19ac32042a9352610ca34752940f63738ae39769 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown Date: Fri, 27 Feb 2026 15:53:49 +0000 Subject: [PATCH 02/18] Moved from MPI_COMM_WORLD to mpiQuestComm --- quest/src/comm/comm_config.cpp | 21 ++++++++--- quest/src/comm/comm_config.hpp | 6 ++++ quest/src/comm/comm_routines.cpp | 60 +++++++++++++++++++++++--------- 3 files changed, 66 insertions(+), 21 deletions(-) diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 854a12bd..d5e448f1 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -20,6 +20,8 @@ #if COMPILE_MPI #include + + static MPI_Comm mpiCommQuest; #endif @@ -60,6 +62,9 @@ bool comm_isMpiCompiled() { return (bool) COMPILE_MPI; } +bool comm_isMpiSubCommunicatorCompiled() { + return (bool) COMPILE_SUBCOMM; +} bool comm_isMpiGpuAware() { @@ -106,6 +111,7 @@ void comm_init() { error_commAlreadyInit(); MPI_Init(NULL, NULL); + MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); #endif } @@ -118,7 +124,8 @@ void comm_end() { if (!comm_isInit()) return; - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(mpiCommQuest); + MPI_Comm_free(&mpiCommQuest); MPI_Finalize(); #endif @@ -135,7 +142,7 @@ int comm_getRank() { return ROOT_RANK; int rank; - MPI_Comm_rank(MPI_COMM_WORLD, &rank); + MPI_Comm_rank(mpiCommQuest, &rank); return rank; #else @@ -164,7 +171,7 @@ int comm_getNumNodes() { return 1; int numNodes; - MPI_Comm_size(MPI_COMM_WORLD, &numNodes); + MPI_Comm_size(mpiCommQuest, &numNodes); return numNodes; #else @@ -182,6 +189,12 @@ void comm_sync() { if (!comm_isInit()) return; - MPI_Barrier(MPI_COMM_WORLD); + MPI_Barrier(mpiCommQuest); #endif } + +#if COMPILE_MPI + MPI_Comm * getMpiComm() { + return &mpiCommQuest; + } +#endif diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index 444d1dbf..980bffba 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -10,6 +10,9 @@ #ifndef COMM_CONFIG_HPP #define COMM_CONFIG_HPP +#if COMPILE_MPI + #include +#endif constexpr int ROOT_RANK = 0; @@ -27,5 +30,8 @@ bool comm_isInit(); bool comm_isRootNode(); bool comm_isRootNode(int rank); +#if COMPILE_MPI + MPI_Comm * getMpiComm(); +#endif #endif // COMM_CONFIG_HPP \ No newline at end of file diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index 19ebcb9f..90e93808 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -149,7 +149,8 @@ int getMaxNumMessages() { // messages. Beware the max is obtained via a void pointer and might be unset... void* tagUpperBoundPtr; int isAttribSet; - MPI_Comm_get_attr(MPI_COMM_WORLD, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); + MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm_get_attr(*mpiCommQuest, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); // if something went wrong with obtaining the tag bound, return the safe minimum if (!isAttribSet) @@ -216,6 +217,8 @@ std::array dividePayloadIntoMessages(qindex numAmps) { void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // each message is asynchronously dispatched with a final wait, as per arxiv.org/abs/2308.07402 // we will send payload in multiple asynch messages (create two requests per msg for subsequent synch) @@ -226,8 +229,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m]); - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[2*m+1]); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[2*m]); + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[2*m+1]); } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -248,6 +251,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // we will not track nor wait for the asynch send; instead, the caller will later comm_sync() MPI_Request nullReq = MPI_REQUEST_NULL; @@ -257,7 +262,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { // asynchronously send the uniquely-tagged messages for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &nullReq); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &nullReq); } #else @@ -269,6 +274,8 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { void receiveArray(qcomp* dest, qindex numElems, int pairRank) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // expect the data in multiple messages auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems); @@ -278,7 +285,7 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) { // listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, MPI_COMM_WORLD, &requests[m]); + MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[m]); } // receivers wait for all messages to be received (while sender asynch proceeds) @@ -303,6 +310,8 @@ void globallyCombineNonUniformSubArrays( ) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + int myRank = comm_getRank(); int numNodes = comm_getNumNodes(); @@ -336,14 +345,14 @@ void globallyCombineNonUniformSubArrays( for (int m=0; m 0) { qindex recvInd = globalRecvIndPerRank[sendRank] + (numBigMsgs * bigMsgSize); requests.push_back(MPI_REQUEST_NULL); - MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, MPI_COMM_WORLD, &requests.back()); + MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, *mpiCommQuest, &requests.back()); } } @@ -639,7 +648,9 @@ void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank) { void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { #if COMPILE_MPI - MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, MPI_COMM_WORLD); + MPI_Comm * mpiCommQuest = getMpiComm(); + + MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -650,6 +661,8 @@ void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // only the sender and root nodes need to continue int recvRank = ROOT_RANK; int myRank = comm_getRank(); @@ -665,8 +678,8 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) for (qindex m=0; m(m); (myRank == sendRank)? - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, MPI_COMM_WORLD, &requests[m]): // sender - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, MPI_COMM_WORLD, &requests[m]); // root + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, *mpiCommQuest, &requests[m]): // sender + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, *mpiCommQuest, &requests[m]); // root } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -680,9 +693,10 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) void comm_broadcastIntsFromRoot(int* arr, qindex length) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_INT, sendRank, MPI_COMM_WORLD); + MPI_Bcast(arr, length, MPI_INT, sendRank, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -693,8 +707,10 @@ void comm_broadcastIntsFromRoot(int* arr, qindex length) { void comm_broadcastUnsignedsFromRoot(unsigned* arr, qindex length) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, MPI_COMM_WORLD); + MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -721,7 +737,9 @@ void comm_combineSubArrays(qcomp* recv, vector recvInds, vector void comm_reduceAmp(qcomp* localAmp) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm * mpiCommQuest = getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -732,7 +750,9 @@ void comm_reduceAmp(qcomp* localAmp) { void comm_reduceReal(qreal* localReal) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm * mpiCommQuest = getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -743,7 +763,9 @@ void comm_reduceReal(qreal* localReal) { void comm_reduceReals(qreal* localReals, qindex numLocalReals) { #if COMPILE_MPI - MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, MPI_COMM_WORLD); + MPI_Comm * mpiCommQuest = getMpiComm(); + + MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, *mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -754,10 +776,12 @@ void comm_reduceReals(qreal* localReals, qindex numLocalReals) { bool comm_isTrueOnAllNodes(bool val) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // perform global AND and broadcast result back to all nodes int local = (int) val; int global; - MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, MPI_COMM_WORLD); + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, *mpiCommQuest); return (bool) global; #else @@ -793,6 +817,8 @@ bool comm_isTrueOnRootNode(bool val) { vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) { #if COMPILE_MPI + MPI_Comm * mpiCommQuest = getMpiComm(); + // no need to validate array sizes and memory alloc successes; // these are trivial O(#nodes)-size arrays containing <20 chars int numNodes = comm_getNumNodes(); @@ -803,7 +829,7 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) // all nodes send root all their local chars int recvRank = ROOT_RANK; MPI_Gather(localChars, maxNumLocalChars, MPI_CHAR, allChars.data(), - maxNumLocalChars, MPI_CHAR, recvRank, MPI_COMM_WORLD); + maxNumLocalChars, MPI_CHAR, recvRank, *mpiCommQuest); // divide allChars into stings, delimited by each node's terminal char vector out(numNodes); From 95415b7cc3a1631b1096750a564dd2b17b263210 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Fri, 27 Feb 2026 17:36:21 +0000 Subject: [PATCH 03/18] Decided passing *MPI_Comm was probably overly cautious, and updated function name to comm_getMpiComm --- quest/src/comm/comm_config.cpp | 4 +-- quest/src/comm/comm_config.hpp | 2 +- quest/src/comm/comm_routines.cpp | 62 ++++++++++++++++---------------- 3 files changed, 34 insertions(+), 34 deletions(-) diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index d5e448f1..7ec95b0b 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -194,7 +194,7 @@ void comm_sync() { } #if COMPILE_MPI - MPI_Comm * getMpiComm() { - return &mpiCommQuest; + MPI_Comm comm_getMpiComm() { + return mpiCommQuest; } #endif diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index 980bffba..f17de9d2 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -31,7 +31,7 @@ bool comm_isRootNode(); bool comm_isRootNode(int rank); #if COMPILE_MPI - MPI_Comm * getMpiComm(); + MPI_Comm comm_getMpiComm(); #endif #endif // COMM_CONFIG_HPP \ No newline at end of file diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index 90e93808..335dceab 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -149,8 +149,8 @@ int getMaxNumMessages() { // messages. Beware the max is obtained via a void pointer and might be unset... void* tagUpperBoundPtr; int isAttribSet; - MPI_Comm * mpiCommQuest = getMpiComm(); - MPI_Comm_get_attr(*mpiCommQuest, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); + MPI_Comm mpiCommQuest = comm_getMpiComm(); + MPI_Comm_get_attr(mpiCommQuest, MPI_TAG_UB, &tagUpperBoundPtr, &isAttribSet); // if something went wrong with obtaining the tag bound, return the safe minimum if (!isAttribSet) @@ -217,7 +217,7 @@ std::array dividePayloadIntoMessages(qindex numAmps) { void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // each message is asynchronously dispatched with a final wait, as per arxiv.org/abs/2308.07402 @@ -229,8 +229,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[2*m]); - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[2*m+1]); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m]); + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m+1]); } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -251,7 +251,7 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // we will not track nor wait for the asynch send; instead, the caller will later comm_sync() MPI_Request nullReq = MPI_REQUEST_NULL; @@ -262,7 +262,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { // asynchronously send the uniquely-tagged messages for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &nullReq); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &nullReq); } #else @@ -274,7 +274,7 @@ void asynchSendArray(qcomp* send, qindex numElems, int pairRank) { void receiveArray(qcomp* dest, qindex numElems, int pairRank) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // expect the data in multiple messages auto [messageSize, numMessages] = dividePow2PayloadIntoMessages(numElems); @@ -285,7 +285,7 @@ void receiveArray(qcomp* dest, qindex numElems, int pairRank) { // listen to receive each uniquely-tagged message asynchronously (as per arxiv.org/abs/2308.07402) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, *mpiCommQuest, &requests[m]); + MPI_Irecv(&dest[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[m]); } // receivers wait for all messages to be received (while sender asynch proceeds) @@ -310,7 +310,7 @@ void globallyCombineNonUniformSubArrays( ) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); int myRank = comm_getRank(); int numNodes = comm_getNumNodes(); @@ -345,14 +345,14 @@ void globallyCombineNonUniformSubArrays( for (int m=0; m 0) { qindex recvInd = globalRecvIndPerRank[sendRank] + (numBigMsgs * bigMsgSize); requests.push_back(MPI_REQUEST_NULL); - MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, *mpiCommQuest, &requests.back()); + MPI_Ibcast(&recv[recvInd], remMsgSize, MPI_QCOMP, sendRank, mpiCommQuest, &requests.back()); } } @@ -648,9 +648,9 @@ void comm_exchangeAmpsToBuffers(Qureg qureg, int pairRank) { void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); - MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, *mpiCommQuest); + MPI_Bcast(sendAmp, 1, MPI_QCOMP, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -661,7 +661,7 @@ void comm_broadcastAmp(int sendRank, qcomp* sendAmp) { void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // only the sender and root nodes need to continue int recvRank = ROOT_RANK; @@ -678,8 +678,8 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) for (qindex m=0; m(m); (myRank == sendRank)? - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, *mpiCommQuest, &requests[m]): // sender - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, *mpiCommQuest, &requests[m]); // root + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, recvRank, tag, mpiCommQuest, &requests[m]): // sender + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, sendRank, tag, mpiCommQuest, &requests[m]); // root } // wait for all exchanges to complete (MPI will automatically free the request memory) @@ -693,10 +693,10 @@ void comm_sendAmpsToRoot(int sendRank, qcomp* send, qcomp* recv, qindex numAmps) void comm_broadcastIntsFromRoot(int* arr, qindex length) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_INT, sendRank, *mpiCommQuest); + MPI_Bcast(arr, length, MPI_INT, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -707,10 +707,10 @@ void comm_broadcastIntsFromRoot(int* arr, qindex length) { void comm_broadcastUnsignedsFromRoot(unsigned* arr, qindex length) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); int sendRank = ROOT_RANK; - MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, *mpiCommQuest); + MPI_Bcast(arr, length, MPI_UNSIGNED, sendRank, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -737,9 +737,9 @@ void comm_combineSubArrays(qcomp* recv, vector recvInds, vector void comm_reduceAmp(qcomp* localAmp) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); - MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, *mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localAmp, 1, MPI_QCOMP, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -750,9 +750,9 @@ void comm_reduceAmp(qcomp* localAmp) { void comm_reduceReal(qreal* localReal) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); - MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, *mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localReal, 1, MPI_QREAL, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -763,9 +763,9 @@ void comm_reduceReal(qreal* localReal) { void comm_reduceReals(qreal* localReals, qindex numLocalReals) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); - MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, *mpiCommQuest); + MPI_Allreduce(MPI_IN_PLACE, localReals, numLocalReals, MPI_QREAL, MPI_SUM, mpiCommQuest); #else error_commButEnvNotDistributed(); @@ -776,12 +776,12 @@ void comm_reduceReals(qreal* localReals, qindex numLocalReals) { bool comm_isTrueOnAllNodes(bool val) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // perform global AND and broadcast result back to all nodes int local = (int) val; int global; - MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, *mpiCommQuest); + MPI_Allreduce(&local, &global, 1, MPI_INT, MPI_LAND, mpiCommQuest); return (bool) global; #else @@ -817,7 +817,7 @@ bool comm_isTrueOnRootNode(bool val) { vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) { #if COMPILE_MPI - MPI_Comm * mpiCommQuest = getMpiComm(); + MPI_Comm mpiCommQuest = comm_getMpiComm(); // no need to validate array sizes and memory alloc successes; // these are trivial O(#nodes)-size arrays containing <20 chars @@ -829,7 +829,7 @@ vector comm_gatherStringsToRoot(char* localChars, int maxNumLocalChars) // all nodes send root all their local chars int recvRank = ROOT_RANK; MPI_Gather(localChars, maxNumLocalChars, MPI_CHAR, allChars.data(), - maxNumLocalChars, MPI_CHAR, recvRank, *mpiCommQuest); + maxNumLocalChars, MPI_CHAR, recvRank, mpiCommQuest); // divide allChars into stings, delimited by each node's terminal char vector out(numNodes); From 0699e58b66ac6e068a26292b6550aa9b263b9ad1 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:55:39 +0000 Subject: [PATCH 04/18] environment.cpp: added methods to reset rank and numNodes, and reporting for subcomm compiled --- quest/src/api/environment.cpp | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/quest/src/api/environment.cpp b/quest/src/api/environment.cpp index 54149189..543ffd2f 100644 --- a/quest/src/api/environment.cpp +++ b/quest/src/api/environment.cpp @@ -153,6 +153,11 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti globalEnvPtr->numNodes = (useDistrib)? comm_getNumNodes() : 1; } +void updateQuESTEnvDistInfo() { + globalEnvPtr->rank = (globalEnvPtr->isDistributed)? comm_getRank() : 0; + globalEnvPtr->numNodes = (globalEnvPtr->isDistributed)? comm_getNumNodes() : 1; + return; +} /* @@ -187,10 +192,11 @@ void printCompilationInfo() { print_table( "compilation", { - {"isMpiCompiled", comm_isMpiCompiled()}, - {"isGpuCompiled", gpu_isGpuCompiled()}, - {"isOmpCompiled", cpu_isOpenmpCompiled()}, - {"isCuQuantumCompiled", gpu_isCuQuantumCompiled()}, + {"isMpiCompiled", comm_isMpiCompiled()}, + {"isMpiSubCommunicatorCompiled", comm_isMpiSubCommunicatorCompiled()}, + {"isGpuCompiled", gpu_isGpuCompiled()}, + {"isOmpCompiled", cpu_isOpenmpCompiled()}, + {"isCuQuantumCompiled", gpu_isCuQuantumCompiled()}, }); } @@ -454,8 +460,12 @@ void syncQuESTEnv() { if (globalEnvPtr->isGpuAccelerated) gpu_sync(); - if (globalEnvPtr->isDistributed) + if (globalEnvPtr->isDistributed) { comm_sync(); + #if COMPILE_SUBCOMM + updateQuESTEnvDistInfo(); + #endif + } } From bfec27935b51d1d051752f10a5b167bea27e6b46 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:56:20 +0000 Subject: [PATCH 05/18] comm_config.hpp/cpp: added comm_setMpiComm --- quest/src/comm/comm_config.cpp | 13 +++++++++++++ quest/src/comm/comm_config.hpp | 4 ++++ 2 files changed, 17 insertions(+) diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 7ec95b0b..90221cd7 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -197,4 +197,17 @@ void comm_sync() { MPI_Comm comm_getMpiComm() { return mpiCommQuest; } + + #if COMPILE_SUBCOMM + void comm_setMpiComm(MPI_Comm newComm) { + if (mpiCommQuest != MPI_COMM_NULL) { + MPI_Barrier(mpiCommQuest); + MPI_Comm_free(&mpiCommQuest); + } + + MPI_Comm_dup(newComm, &mpiCommQuest); + + return; + } + #endif #endif diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index f17de9d2..c772c4d6 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -17,6 +17,7 @@ constexpr int ROOT_RANK = 0; bool comm_isMpiCompiled(); +bool comm_isMpiSubCommunicatorCompiled(); bool comm_isMpiGpuAware(); void comm_init(); @@ -32,6 +33,9 @@ bool comm_isRootNode(int rank); #if COMPILE_MPI MPI_Comm comm_getMpiComm(); + #if COMPILE_SUBCOMM + void comm_setMpiComm(MPI_Comm newComm); + #endif #endif #endif // COMM_CONFIG_HPP \ No newline at end of file From 3a3bdffdb10c2f5576f660b00caeeaad76a86801 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 2 Mar 2026 16:45:36 +0000 Subject: [PATCH 06/18] CMakeLists.txt: PUBLIC MPI::MPI_CXX turned out to be unhelpful, even for SubComm, because of course it enforces CXX --- CMakeLists.txt | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index df82112d..70ec1165 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -395,17 +395,10 @@ if (ENABLE_DISTRIBUTION) COMPONENTS CXX ) - if(ENABLE_SUBCOMM) - target_link_libraries(QuEST - PUBLIC - MPI::MPI_CXX - ) - else() - target_link_libraries(QuEST - PRIVATE - MPI::MPI_CXX - ) - endif() + target_link_libraries(QuEST + PRIVATE + MPI::MPI_CXX + ) endif() From 33585817bb2b48c9468dd3cc7e5a23cc7fd34a4a Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:08:29 +0000 Subject: [PATCH 07/18] Added new custom QuESTEnv initialiser which allow user to positively declare that they take ownership of MPI --- quest/include/environment.h | 3 +++ quest/src/api/environment.cpp | 21 +++++++++++++++------ quest/src/comm/comm_config.cpp | 19 ++++++++++++++----- quest/src/comm/comm_config.hpp | 4 ++-- 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/quest/include/environment.h b/quest/include/environment.h index 04f24bfe..dba9cd06 100644 --- a/quest/include/environment.h +++ b/quest/include/environment.h @@ -36,6 +36,7 @@ typedef struct { int isMultithreaded; int isGpuAccelerated; int isDistributed; + int userOwnsMpi; // deployment modes which cannot be directly changed after compilation int isCuQuantumEnabled; @@ -61,6 +62,8 @@ void initQuESTEnv(); */ void initCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread); +void initCustomMpiQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread); + /// @notyetdoced void finalizeQuESTEnv(); diff --git a/quest/src/api/environment.cpp b/quest/src/api/environment.cpp index 543ffd2f..50e95c61 100644 --- a/quest/src/api/environment.cpp +++ b/quest/src/api/environment.cpp @@ -71,7 +71,7 @@ static bool hasEnvBeenFinalized = false; */ -void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread, const char* caller) { +void validateAndInitCustomQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread, const char* caller) { // ensure that we are never re-initialising QuEST (even after finalize) because // this leads to undefined behaviour in distributed mode, as per the MPI @@ -94,7 +94,7 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti // perform that specifically upon the MPI-process-bound GPU(s). Further, // we can make sure validation errors are reported only by the root node. if (useDistrib) - comm_init(); + comm_init(userOwnsMpi); validate_newEnvDistributedBetweenPower2Nodes(caller); @@ -145,6 +145,7 @@ void validateAndInitCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMulti globalEnvPtr->isMultithreaded = useMultithread; globalEnvPtr->isGpuAccelerated = useGpuAccel; globalEnvPtr->isDistributed = useDistrib; + globalEnvPtr->userOwnsMpi = userOwnsMpi; globalEnvPtr->isCuQuantumEnabled = useCuQuantum; globalEnvPtr->isGpuSharingEnabled = permitGpuSharing; @@ -206,6 +207,7 @@ void printDeploymentInfo() { print_table( "deployment", { {"isMpiEnabled", globalEnvPtr->isDistributed}, + {"doesUserOwnMpi", globalEnvPtr->userOwnsMpi}, {"isGpuEnabled", globalEnvPtr->isGpuAccelerated}, {"isOmpEnabled", globalEnvPtr->isMultithreaded}, {"isCuQuantumEnabled", globalEnvPtr->isCuQuantumEnabled}, @@ -403,13 +405,19 @@ extern "C" { void initCustomQuESTEnv(int useDistrib, int useGpuAccel, int useMultithread) { - validateAndInitCustomQuESTEnv(useDistrib, useGpuAccel, useMultithread, __func__); + const int USER_OWNS_MPI = 0; + validateAndInitCustomQuESTEnv(useDistrib, USER_OWNS_MPI, useGpuAccel, useMultithread, __func__); } +void initCustomMpiQuESTEnv(int useDistrib, int userOwnsMpi, int useGpuAccel, int useMultithread) { + validateAndInitCustomQuESTEnv(useDistrib, userOwnsMpi, useGpuAccel, useMultithread, __func__); +} + void initQuESTEnv() { - validateAndInitCustomQuESTEnv(modeflag::USE_AUTO, modeflag::USE_AUTO, modeflag::USE_AUTO, __func__); + const int USER_OWNS_MPI = 0; + validateAndInitCustomQuESTEnv(modeflag::USE_AUTO, USER_OWNS_MPI, modeflag::USE_AUTO, modeflag::USE_AUTO, __func__); } @@ -442,7 +450,7 @@ void finalizeQuESTEnv() { if (globalEnvPtr->isDistributed) { comm_sync(); - comm_end(); + comm_end(globalEnvPtr->userOwnsMpi); } // free global env's heap memory and flag it as unallocated @@ -508,10 +516,11 @@ void getEnvironmentString(char str[200]) { int cuQuantum = env.isGpuAccelerated && gpu_isCuQuantumCompiled(); int gpuDirect = env.isGpuAccelerated && gpu_isDirectGpuCommPossible(); - snprintf(str, 200, "CUDA=%d OpenMP=%d MPI=%d threads=%d ranks=%d cuQuantum=%d gpuDirect=%d", + snprintf(str, 200, "CUDA=%d OpenMP=%d MPI=%d userOwnsMPI=%d threads=%d ranks=%d cuQuantum=%d gpuDirect=%d", env.isGpuAccelerated, env.isMultithreaded, env.isDistributed, + env.userOwnsMpi, numThreads, env.numNodes, cuQuantum, diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 90221cd7..8b559495 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -103,21 +103,27 @@ bool comm_isInit() { } -void comm_init() { +void comm_init(int userOwnsMpi) { #if COMPILE_MPI // error if attempting re-initialisation - if (comm_isInit()) + if (!userOwnsMpi && comm_isInit()) error_commAlreadyInit(); + + // TODO: error if user has not initialised + if (userOwnsMpi && !comm_isInit()); + + // QuEST must initialise MPI if the user does not own it + if (!userOwnsMpi) + MPI_Init(NULL, NULL); - MPI_Init(NULL, NULL); MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); #endif } -void comm_end() { +void comm_end(int userOwnsMpi) { #if COMPILE_MPI // gracefully permit comm_end() before comm_init(), as input validation can trigger @@ -126,7 +132,10 @@ void comm_end() { MPI_Barrier(mpiCommQuest); MPI_Comm_free(&mpiCommQuest); - MPI_Finalize(); + + // QuEST must finalise MPI if the user does not own it + if (!userOwnsMpi) + MPI_Finalize(); #endif } diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index c772c4d6..ea88f2bd 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -20,8 +20,8 @@ bool comm_isMpiCompiled(); bool comm_isMpiSubCommunicatorCompiled(); bool comm_isMpiGpuAware(); -void comm_init(); -void comm_end(); +void comm_init(int userOwnsMpi); +void comm_end(int userOwnsMpi); void comm_sync(); int comm_getRank(); From 6a093d0caca31d7808a0b7190c1db69f7f505c7b Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 2 Mar 2026 18:09:00 +0000 Subject: [PATCH 08/18] validation.cpp: updated comm_end call --- quest/src/core/validation.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quest/src/core/validation.cpp b/quest/src/core/validation.cpp index 3ac48505..85b58db9 100644 --- a/quest/src/core/validation.cpp +++ b/quest/src/core/validation.cpp @@ -1149,8 +1149,9 @@ void default_inputErrorHandler(const char* func, const char* msg) { comm_sync(); // finalise MPI before error-exit to avoid scaring user with giant MPI error message + // we always "take ownership" of MPI here since we're about to kill the whole program if (comm_isInit()) - comm_end(); + comm_end(0); // simply exit, interrupting any other process (potentially leaking) exit(EXIT_FAILURE); From 588b636b04f1a9157aec614f4779467718ce298a Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown Date: Fri, 6 Mar 2026 14:42:36 +0000 Subject: [PATCH 09/18] comm_config.hpp: added config.h include so COMPILE_MPI is actually defined --- quest/src/comm/comm_config.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index ea88f2bd..e598ee78 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -10,6 +10,8 @@ #ifndef COMM_CONFIG_HPP #define COMM_CONFIG_HPP +#include "quest/include/config.h" + #if COMPILE_MPI #include #endif @@ -38,4 +40,4 @@ bool comm_isRootNode(int rank); #endif #endif -#endif // COMM_CONFIG_HPP \ No newline at end of file +#endif // COMM_CONFIG_HPP From 7c6caac721ccf3de567683eec5e7b70d6a36acf8 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:38:07 +0000 Subject: [PATCH 10/18] subcommunicator.h/cpp: implemented QuESTEnv initialiser with custom MPI_Comm --- quest/include/subcommunicator.h | 22 ++++++++++++++++++++++ quest/src/api/subcommunicator.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+) create mode 100644 quest/include/subcommunicator.h create mode 100644 quest/src/api/subcommunicator.cpp diff --git a/quest/include/subcommunicator.h b/quest/include/subcommunicator.h new file mode 100644 index 00000000..16653c23 --- /dev/null +++ b/quest/include/subcommunicator.h @@ -0,0 +1,22 @@ +#ifndef SUBCOMMUNICATOR_H +#define SUBCOMMUNICATOR_H + +#include "quest/include/config.h" + +#if COMPILE_MPI && COMPILE_SUBCOMM + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void initCustomMpiCommQuESTEnv(MPI_Comm questComm, int useGpuAccel, int useMultithread); + +#ifdef __cplusplus +} +#endif + +#endif + +#endif diff --git a/quest/src/api/subcommunicator.cpp b/quest/src/api/subcommunicator.cpp new file mode 100644 index 00000000..aa61b8ed --- /dev/null +++ b/quest/src/api/subcommunicator.cpp @@ -0,0 +1,25 @@ +#include "quest/include/config.h" +#include "quest/include/environment.h" +#include "quest/include/subcommunicator.h" + +#include "quest/src/comm/comm_config.hpp" + +#if COMPILE_MPI && COMPILE_SUBCOMM + +#include + +void initCustomMpiCommQuESTEnv(MPI_Comm userQuestComm, int useGpuAccel, int useMultithread) { + // useDistrib and userOwnsMpi are implied by the user of this initialiser + const int USE_DISTRIB = 1; + const int USER_OWNS_MPI = 1; + + // set mpiCommQuest to user provided communicator + comm_setMpiComm(userQuestComm); + + // initialise QuEST around that communicator + initCustomMpiQuESTEnv(USE_DISTRIB, USER_OWNS_MPI, useGpuAccel, useMultithread); + + return; +} + +#endif From e6628e4e989bb6cf23ddcf6941c92ef2729fd8d2 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:38:40 +0000 Subject: [PATCH 11/18] CMake: added subcommunicator.cpp --- quest/src/api/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quest/src/api/CMakeLists.txt b/quest/src/api/CMakeLists.txt index 0979f2f6..43b61df7 100644 --- a/quest/src/api/CMakeLists.txt +++ b/quest/src/api/CMakeLists.txt @@ -12,6 +12,7 @@ target_sources(QuEST operations.cpp paulis.cpp qureg.cpp + subcommunicator.cpp trotterisation.cpp types.cpp -) \ No newline at end of file +) From bb2d5f24ef06155fbb77e7ab36c44e1faea9a4f7 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:39:59 +0000 Subject: [PATCH 12/18] comm_config.hpp: added missing config.h include... --- quest/src/comm/comm_config.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/quest/src/comm/comm_config.hpp b/quest/src/comm/comm_config.hpp index ea88f2bd..e598ee78 100644 --- a/quest/src/comm/comm_config.hpp +++ b/quest/src/comm/comm_config.hpp @@ -10,6 +10,8 @@ #ifndef COMM_CONFIG_HPP #define COMM_CONFIG_HPP +#include "quest/include/config.h" + #if COMPILE_MPI #include #endif @@ -38,4 +40,4 @@ bool comm_isRootNode(int rank); #endif #endif -#endif // COMM_CONFIG_HPP \ No newline at end of file +#endif // COMM_CONFIG_HPP From 7c30b8cbf0207e2fcacf7fb5beecffc5beb1987c Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:40:56 +0000 Subject: [PATCH 13/18] comm_config.cpp: explicitly initialise mpiCommQuest to MPI_COMM_NULL, updated setComm for init only workflow --- quest/src/comm/comm_config.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 8b559495..57c1b7e5 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -21,7 +21,7 @@ #if COMPILE_MPI #include - static MPI_Comm mpiCommQuest; + static MPI_Comm mpiCommQuest = MPI_COMM_NULL; #endif @@ -117,7 +117,12 @@ void comm_init(int userOwnsMpi) { if (!userOwnsMpi) MPI_Init(NULL, NULL); - MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); + // If user is setting their own comm, mpiCommQuest will be NOT MPI_COMM_NULL, + // and we should not touch it. + // If user is NOT setting their own comm, mpiCommQuest will be MPI_COMM_NULL, + // and we should set it to MPI_COMM_WORLD. + if (mpiCommQuest == MPI_COMM_NULL) + MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); #endif } @@ -209,9 +214,8 @@ void comm_sync() { #if COMPILE_SUBCOMM void comm_setMpiComm(MPI_Comm newComm) { + // TODO:error if mpiCommQuEST is already set! if (mpiCommQuest != MPI_COMM_NULL) { - MPI_Barrier(mpiCommQuest); - MPI_Comm_free(&mpiCommQuest); } MPI_Comm_dup(newComm, &mpiCommQuest); From cb72846b873c4e61a51c66747a8b1e68a41714fd Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 14:53:19 +0000 Subject: [PATCH 14/18] quest.h: added subcommunicator header --- quest/include/quest.h | 1 + 1 file changed, 1 insertion(+) diff --git a/quest/include/quest.h b/quest/include/quest.h index 409253ff..8771b712 100644 --- a/quest/include/quest.h +++ b/quest/include/quest.h @@ -45,6 +45,7 @@ #include "quest/include/operations.h" #include "quest/include/paulis.h" #include "quest/include/qureg.h" +#include "quest/include/subcommunicator.h" #include "quest/include/matrices.h" #include "quest/include/wrappers.h" From 7ff2a2e8b5a24431ec3bdd5f4f70fcec0fe9ac7d Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 9 Mar 2026 14:54:43 +0000 Subject: [PATCH 15/18] CMake: added MPI to application binaries when SUBCOMM is enabled --- CMakeLists.txt | 6 +++++- examples/CMakeLists.txt | 4 ++++ tests/CMakeLists.txt | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 70ec1165..43d4607e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -578,6 +578,10 @@ add_executable(min_example ) target_link_libraries(min_example PRIVATE QuEST::QuEST) +if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(min_example PRIVATE MPI::MPI_CXX) +endif() + if (INSTALL_BINARIES) install(TARGETS min_example RUNTIME @@ -754,4 +758,4 @@ install( if(PROJECT_IS_TOP_LEVEL) include(CPack) -endif () \ No newline at end of file +endif () diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index afc8f85d..d21dc5ae 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,6 +20,10 @@ function(add_example direc in_fn) add_executable(${target} ${in_fn}) target_link_libraries(${target} PUBLIC QuEST) + if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(${target} PRIVATE MPI::MPI_CXX) + endif() + if (INSTALL_BINARIES) install( TARGETS ${target} diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 31b0ea75..c3262235 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,6 +6,10 @@ add_executable(tests target_link_libraries(tests PRIVATE QuEST::QuEST Catch2::Catch2) target_compile_features(tests PUBLIC cxx_std_20) +if (ENABLE_DISTRIBUTION AND ENABLE_SUBCOMM) + target_link_libraries(tests PRIVATE MPI::MPI_CXX) +endif() + add_subdirectory(unit) add_subdirectory(utils) add_subdirectory(integration) From 68e435b4ee8d7a0b1fa0986c63c771fc2a55bbe6 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:55:46 +0100 Subject: [PATCH 16/18] comm_routines.cpp: post Irecv before Isend which probably won't do anything but it makes MPI library implementers less nervous --- quest/src/comm/comm_routines.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quest/src/comm/comm_routines.cpp b/quest/src/comm/comm_routines.cpp index 335dceab..14f3e58f 100644 --- a/quest/src/comm/comm_routines.cpp +++ b/quest/src/comm/comm_routines.cpp @@ -229,8 +229,8 @@ void exchangeArrays(qcomp* send, qcomp* recv, qindex numElems, int pairRank) { // so that messages are permitted to arrive out-of-order (supporting UCX adaptive-routing) for (qindex m=0; m(m); // gauranteed int, but m*messageSize needs qindex - MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m]); - MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m+1]); + MPI_Irecv(&recv[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m]); + MPI_Isend(&send[m*messageSize], messageSize, MPI_QCOMP, pairRank, tag, mpiCommQuest, &requests[2*m+1]); } // wait for all exchanges to complete (MPI will automatically free the request memory) From f15def4595ceeec8fb9d42efc93347092e052f4c Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 13 Apr 2026 15:57:14 +0100 Subject: [PATCH 17/18] tests: added new env test for initCustomMpiQuESTEnv --- tests/unit/environment.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/unit/environment.cpp b/tests/unit/environment.cpp index 6d4efb80..344ac586 100644 --- a/tests/unit/environment.cpp +++ b/tests/unit/environment.cpp @@ -83,6 +83,24 @@ TEST_CASE( "initCustomQuESTEnv", TEST_CATEGORY ) { } +TEST_CASE( "initCustomMpiQuESTEnv", TEST_CATEGORY ) { + + SECTION( LABEL_CORRECTNESS ) { + + // cannot be meaningfully tested since env already active + SUCCEED( ); + } + + SECTION( LABEL_VALIDATION ) { + + REQUIRE_THROWS_WITH( initCustomMpiQuESTEnv(0,0,0,0), ContainsSubstring( "already been initialised") ); + + // cannot check arguments since env-already-initialised + // validation is performed first + } +} + + TEST_CASE( "finalizeQuESTEnv", TEST_CATEGORY ) { SECTION( LABEL_CORRECTNESS ) { @@ -143,6 +161,7 @@ TEST_CASE( "getQuESTEnv", TEST_CATEGORY ) { REQUIRE( (env.isMultithreaded == 0 || env.isMultithreaded == 1) ); REQUIRE( (env.isGpuAccelerated == 0 || env.isGpuAccelerated == 1) ); REQUIRE( (env.isDistributed == 0 || env.isDistributed == 1) ); + REQUIRE( (env.userOwnsMpi == 0 || env.userOwnsMpi == 1) ); REQUIRE( (env.isCuQuantumEnabled == 0 || env.isCuQuantumEnabled == 1) ); REQUIRE( (env.isGpuSharingEnabled == 0 || env.isGpuSharingEnabled == 1) ); From dbe2b1edb267a720dd9dda1d6c985ab2d44fc5c6 Mon Sep 17 00:00:00 2001 From: Oliver Thomson Brown <8394906+otbrown@users.noreply.github.com> Date: Mon, 13 Apr 2026 17:22:30 +0100 Subject: [PATCH 18/18] Added error throws to comm_config to cover new scenarios of badness with user owned MPI --- quest/src/comm/comm_config.cpp | 11 +++++++---- quest/src/core/errors.cpp | 5 +++++ quest/src/core/errors.hpp | 4 +++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/quest/src/comm/comm_config.cpp b/quest/src/comm/comm_config.cpp index 57c1b7e5..15e9f98d 100644 --- a/quest/src/comm/comm_config.cpp +++ b/quest/src/comm/comm_config.cpp @@ -110,8 +110,10 @@ void comm_init(int userOwnsMpi) { if (!userOwnsMpi && comm_isInit()) error_commAlreadyInit(); - // TODO: error if user has not initialised - if (userOwnsMpi && !comm_isInit()); + // error if user owns MPI but has not initialised + if (userOwnsMpi && !comm_isInit()) { + error_commNotInit(); + } // QuEST must initialise MPI if the user does not own it if (!userOwnsMpi) @@ -122,7 +124,7 @@ void comm_init(int userOwnsMpi) { // If user is NOT setting their own comm, mpiCommQuest will be MPI_COMM_NULL, // and we should set it to MPI_COMM_WORLD. if (mpiCommQuest == MPI_COMM_NULL) - MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); + MPI_Comm_dup(MPI_COMM_WORLD, &mpiCommQuest); #endif } @@ -214,8 +216,9 @@ void comm_sync() { #if COMPILE_SUBCOMM void comm_setMpiComm(MPI_Comm newComm) { - // TODO:error if mpiCommQuEST is already set! + // error if mpiCommQuEST is already set! if (mpiCommQuest != MPI_COMM_NULL) { + error_commDoubleSetMpiComm(); } MPI_Comm_dup(newComm, &mpiCommQuest); diff --git a/quest/src/core/errors.cpp b/quest/src/core/errors.cpp index 9e72b1e0..1791d7e9 100644 --- a/quest/src/core/errors.cpp +++ b/quest/src/core/errors.cpp @@ -181,6 +181,11 @@ void error_commNumMessagesExceedTagMax() { raiseInternalError("A function attempted to communicate via more messages than permitted (since there would be more uniquely-tagged messages than the tag upperbound)."); } +void error_commDoubleSetMpiComm() { + + raiseInternalError("An attempt was made to set mpiCommQuest after it had already been set, as indicated by mpiCommQuest != MPI_COMM_NULL."); +} + void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps) { bool valid = ( diff --git a/quest/src/core/errors.hpp b/quest/src/core/errors.hpp index 950ac17e..57feedf3 100644 --- a/quest/src/core/errors.hpp +++ b/quest/src/core/errors.hpp @@ -85,6 +85,8 @@ void error_commGivenInconsistentNumSubArraysANodes(); void error_commNumMessagesExceedTagMax(); +void error_commDoubleSetMpiComm(); + void assert_commBoundsAreValid(Qureg qureg, qindex sendInd, qindex recvInd, qindex numAmps); void assert_commPayloadIsPowerOf2(qindex numAmps); @@ -383,4 +385,4 @@ void error_unexpectedNumLindbladSuperpropTerms(); -#endif // ERRORS_HPP \ No newline at end of file +#endif // ERRORS_HPP