diff --git a/lib/ClientConnection.cc b/lib/ClientConnection.cc index cc7e1f67..a135ade7 100644 --- a/lib/ClientConnection.cc +++ b/lib/ClientConnection.cc @@ -33,6 +33,7 @@ #include "ConnectionPool.h" #include "ConsumerImpl.h" #include "ExecutorService.h" +#include "Future.h" #include "LogUtils.h" #include "MockServer.h" #include "OpSendMsg.h" @@ -1029,8 +1030,6 @@ void ClientConnection::newPartitionedMetadataLookup(const std::string& topicName void ClientConnection::newLookup(const SharedBuffer& cmd, uint64_t requestId, const char* requestType, const LookupDataResultPromisePtr& promise) { Lock lock(mutex_); - std::shared_ptr lookupDataResult; - lookupDataResult = std::make_shared(); if (isClosed()) { lock.unlock(); promise->setFailed(ResultNotConnected); @@ -1040,18 +1039,30 @@ void ClientConnection::newLookup(const SharedBuffer& cmd, uint64_t requestId, co promise->setFailed(ResultTooManyLookupRequestException); return; } - LookupRequestData requestData; - requestData.promise = promise; - requestData.timer = executor_->createDeadlineTimer(); - requestData.timer->expires_after(operationsTimeout_); - requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { - handleLookupTimeout(ec, requestData); + + auto request = insertRequest( + pendingLookupRequests_, requestId, [weakSelf{weak_from_this()}, requestId, requestType]() { + if (auto self = weakSelf.lock()) { + LOG_WARN(self->cnxString() + << requestType << " request timeout to broker, req_id: " << requestId); + self->numOfPendingLookupRequest_--; + } + }); + request->getFuture().addListener([promise](Result result, const LookupDataResultPtr& lookupDataResult) { + if (result == ResultOk) { + promise->setValue(lookupDataResult); + } else { + promise->setFailed(result); + } }); - pendingLookupRequests_.insert(std::make_pair(requestId, requestData)); numOfPendingLookupRequest_++; lock.unlock(); LOG_DEBUG(cnxString() << "Inserted lookup request " << requestType << " (req_id: " << requestId << ")"); + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest(requestType, requestId)) { + return; + } sendCommand(cmd); } @@ -1159,21 +1170,19 @@ Future ClientConnection::sendRequestWithId(const SharedBuf if (isClosed()) { lock.unlock(); - Promise promise; LOG_DEBUG(cnxString() << "Fail " << requestType << "(req_id: " << requestId << ") to a closed connection"); + Promise promise; promise.setFailed(ResultNotConnected); return promise.getFuture(); } - PendingRequestData requestData; - requestData.timer = executor_->createDeadlineTimer(); - requestData.timer->expires_after(operationsTimeout_); - requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { - handleRequestTimeout(ec, requestData); - }); - - pendingRequests_.insert(std::make_pair(requestId, requestData)); + auto request = insertRequest( + pendingRequests_, requestId, + [cnxString{cnxString()}, physicalAddress{physicalAddress_}, requestId, requestType]() { + LOG_WARN(cnxString << "Network request timeout to broker, remote: " << physicalAddress + << ", req_id: " << requestId << ", request: " << requestType); + }); lock.unlock(); LOG_DEBUG(cnxString() << "Inserted request " << requestType << " (req_id: " << requestId << ")"); @@ -1187,31 +1196,7 @@ Future ClientConnection::sendRequestWithId(const SharedBuf } else { sendCommand(cmd); } - return requestData.promise.getFuture(); -} - -void ClientConnection::handleRequestTimeout(const ASIO_ERROR& ec, - const PendingRequestData& pendingRequestData) { - if (!ec && !pendingRequestData.hasGotResponse->load()) { - LOG_WARN(cnxString() << "Network request timeout to broker, remote: " << physicalAddress_); - pendingRequestData.promise.setFailed(ResultTimeout); - } -} - -void ClientConnection::handleLookupTimeout(const ASIO_ERROR& ec, - const LookupRequestData& pendingRequestData) { - if (!ec) { - LOG_WARN(cnxString() << "Lookup request timeout to broker, remote: " << physicalAddress_); - pendingRequestData.promise->setFailed(ResultTimeout); - } -} - -void ClientConnection::handleGetLastMessageIdTimeout(const ASIO_ERROR& ec, - const ClientConnection::LastMessageIdRequestData& data) { - if (!ec) { - LOG_WARN(cnxString() << "GetLastMessageId request timeout to broker, remote: " << physicalAddress_); - data.promise->setFailed(ResultTimeout); - } + return request->getFuture(); } void ClientConnection::handleKeepAliveTimeout(const ASIO_ERROR& ec) { @@ -1294,7 +1279,7 @@ const std::future& ClientConnection::close(Result result, bool switchClust cancelTimer(*connectTimer_); lock.unlock(); int refCount = weak_from_this().use_count(); - if (!isResultRetryable(result)) { + if (result != ResultAlreadyClosed /* closed by the pool */ && !isResultRetryable(result)) { LOG_ERROR(cnxString() << "Connection closed with " << result << " (refCnt: " << refCount << ")"); } else { LOG_INFO(cnxString() << "Connection disconnected (refCnt: " << refCount << ")"); @@ -1344,25 +1329,25 @@ const std::future& ClientConnection::close(Result result, bool switchClust connectPromise_.setFailed(result); - // Fail all pending requests, all these type are map whose value type contains the Promise object + // Fail all pending requests after releasing the lock. for (auto& kv : pendingRequests) { - kv.second.fail(result); + kv.second->fail(result); } for (auto& kv : pendingLookupRequests) { - kv.second.fail(result); + kv.second->fail(result); } for (auto& kv : pendingConsumerStatsMap) { LOG_ERROR(cnxString() << " Closing Client Connection, please try again later"); kv.second.setFailed(result); } for (auto& kv : pendingGetLastMessageIdRequests) { - kv.second.fail(result); + kv.second->fail(result); } for (auto& kv : pendingGetNamespaceTopicsRequests) { - kv.second.setFailed(result); + kv.second->fail(result); } for (auto& kv : pendingGetSchemaRequests) { - kv.second.fail(result); + kv.second->fail(result); } return *closeFuture_; } @@ -1406,77 +1391,76 @@ Commands::ChecksumType ClientConnection::getChecksumType() const { Future ClientConnection::newGetLastMessageId(uint64_t consumerId, uint64_t requestId) { Lock lock(mutex_); - auto promise = std::make_shared(); if (isClosed()) { lock.unlock(); LOG_ERROR(cnxString() << " Client is not connected to the broker"); + auto promise = std::make_shared(); promise->setFailed(ResultNotConnected); return promise->getFuture(); } - LastMessageIdRequestData requestData; - requestData.promise = promise; - requestData.timer = executor_->createDeadlineTimer(); - requestData.timer->expires_after(operationsTimeout_); - requestData.timer->async_wait([this, self{shared_from_this()}, requestData](const ASIO_ERROR& ec) { - handleGetLastMessageIdTimeout(ec, requestData); - }); - pendingGetLastMessageIdRequests_.insert(std::make_pair(requestId, requestData)); + auto request = + insertRequest(pendingGetLastMessageIdRequests_, requestId, [cnxString = cnxString(), requestId]() { + LOG_WARN(cnxString << "GetLastMessageId request timeout to broker, req_id: " << requestId); + }); lock.unlock(); + + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest("GET_LAST_MESSAGE_ID", requestId)) { + return request->getFuture(); + } sendCommand(Commands::newGetLastMessageId(consumerId, requestId)); - return promise->getFuture(); + return request->getFuture(); } Future ClientConnection::newGetTopicsOfNamespace( const std::string& nsName, CommandGetTopicsOfNamespace_Mode mode, uint64_t requestId) { Lock lock(mutex_); - Promise promise; if (isClosed()) { lock.unlock(); LOG_ERROR(cnxString() << "Client is not connected to the broker"); + Promise promise; promise.setFailed(ResultNotConnected); return promise.getFuture(); } - pendingGetNamespaceTopicsRequests_.insert(std::make_pair(requestId, promise)); + auto request = + insertRequest(pendingGetNamespaceTopicsRequests_, requestId, [cnxString = cnxString(), requestId]() { + LOG_WARN(cnxString << "GetTopicsOfNamespace request timeout to broker, req_id: " << requestId); + }); lock.unlock(); + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest("GET_TOPICS_OF_NAMESPACE", requestId)) { + return request->getFuture(); + } sendCommand(Commands::newGetTopicsOfNamespace(nsName, mode, requestId)); - return promise.getFuture(); + return request->getFuture(); } Future ClientConnection::newGetSchema(const std::string& topicName, const std::string& version, uint64_t requestId) { Lock lock(mutex_); - Promise promise; if (isClosed()) { lock.unlock(); LOG_ERROR(cnxString() << "Client is not connected to the broker"); + Promise promise; promise.setFailed(ResultNotConnected); return promise.getFuture(); } - auto timer = executor_->createDeadlineTimer(); - pendingGetSchemaRequests_.emplace(requestId, GetSchemaRequest{promise, timer}); + auto request = + insertRequest(pendingGetSchemaRequests_, requestId, [cnxString = cnxString(), requestId]() { + LOG_WARN(cnxString << "GetSchema request timeout to broker, req_id: " << requestId); + }); lock.unlock(); - timer->expires_after(operationsTimeout_); - timer->async_wait([this, self{shared_from_this()}, requestId](const ASIO_ERROR& ec) { - if (ec) { - return; - } - Lock lock(mutex_); - auto it = pendingGetSchemaRequests_.find(requestId); - if (it != pendingGetSchemaRequests_.end()) { - auto promise = std::move(it->second.promise); - pendingGetSchemaRequests_.erase(it); - lock.unlock(); - promise.setFailed(ResultTimeout); - } - }); - + if (mockingRequests_.load(std::memory_order_acquire) && mockServer_ != nullptr && + mockServer_->sendRequest("GET_SCHEMA", requestId)) { + return request->getFuture(); + } sendCommand(Commands::newGetSchema(topicName, version, requestId)); - return promise.getFuture(); + return request->getFuture(); } void ClientConnection::checkServerError(ServerError error, const std::string& message) { @@ -1541,12 +1525,11 @@ void ClientConnection::handleSuccess(const proto::CommandSuccess& success) { Lock lock(mutex_); auto it = pendingRequests_.find(success.request_id()); if (it != pendingRequests_.end()) { - PendingRequestData requestData = it->second; + auto request = std::move(it->second); pendingRequests_.erase(it); lock.unlock(); - requestData.promise.setValue({}); - cancelTimer(*requestData.timer); + request->complete({}); } } @@ -1558,9 +1541,7 @@ void ClientConnection::handlePartitionedMetadataResponse( Lock lock(mutex_); auto it = pendingLookupRequests_.find(partitionMetadataResponse.request_id()); if (it != pendingLookupRequests_.end()) { - cancelTimer(*it->second.timer); - - LookupDataResultPromisePtr lookupDataPromise = it->second.promise; + auto request = std::move(it->second); pendingLookupRequests_.erase(it); numOfPendingLookupRequest_--; lock.unlock(); @@ -1574,17 +1555,17 @@ void ClientConnection::handlePartitionedMetadataResponse( << " error: " << partitionMetadataResponse.error() << " msg: " << partitionMetadataResponse.message()); checkServerError(partitionMetadataResponse.error(), partitionMetadataResponse.message()); - lookupDataPromise->setFailed( + request->fail( getResult(partitionMetadataResponse.error(), partitionMetadataResponse.message())); } else { LOG_ERROR(cnxString() << "Failed partition-metadata lookup req_id: " << partitionMetadataResponse.request_id() << " with empty response: "); - lookupDataPromise->setFailed(ResultConnectError); + request->fail(ResultConnectError); } } else { LookupDataResultPtr lookupResultPtr = std::make_shared(); lookupResultPtr->setPartitions(partitionMetadataResponse.partitions()); - lookupDataPromise->setValue(lookupResultPtr); + request->complete(lookupResultPtr); } } else { @@ -1600,7 +1581,7 @@ void ClientConnection::handleConsumerStatsResponse( Lock lock(mutex_); auto it = pendingConsumerStatsMap_.find(consumerStatsResponse.request_id()); if (it != pendingConsumerStatsMap_.end()) { - Promise consumerStatsPromise = it->second; + auto request = std::move(it->second); pendingConsumerStatsMap_.erase(it); lock.unlock(); @@ -1609,7 +1590,7 @@ void ClientConnection::handleConsumerStatsResponse( LOG_ERROR(cnxString() << " Failed to get consumer stats - " << consumerStatsResponse.error_message()); } - consumerStatsPromise.setFailed( + request.setFailed( getResult(consumerStatsResponse.error_code(), consumerStatsResponse.error_message())); } else { LOG_DEBUG(cnxString() << "ConsumerStatsResponse command - Received consumer stats " @@ -1622,7 +1603,7 @@ void ClientConnection::handleConsumerStatsResponse( consumerStatsResponse.blockedconsumeronunackedmsgs(), consumerStatsResponse.address(), consumerStatsResponse.connectedsince(), consumerStatsResponse.type(), consumerStatsResponse.msgrateexpired(), consumerStatsResponse.msgbacklog()); - consumerStatsPromise.setValue(brokerStats); + request.setValue(brokerStats); } } else { LOG_WARN("ConsumerStatsResponse command - Received unknown request id from server: " @@ -1635,8 +1616,7 @@ void ClientConnection::handleLookupTopicRespose( Lock lock(mutex_); auto it = pendingLookupRequests_.find(lookupTopicResponse.request_id()); if (it != pendingLookupRequests_.end()) { - cancelTimer(*it->second.timer); - LookupDataResultPromisePtr lookupDataPromise = it->second.promise; + auto request = std::move(it->second); pendingLookupRequests_.erase(it); numOfPendingLookupRequest_--; lock.unlock(); @@ -1648,12 +1628,11 @@ void ClientConnection::handleLookupTopicRespose( << " error: " << lookupTopicResponse.error() << " msg: " << lookupTopicResponse.message()); checkServerError(lookupTopicResponse.error(), lookupTopicResponse.message()); - lookupDataPromise->setFailed( - getResult(lookupTopicResponse.error(), lookupTopicResponse.message())); + request->fail(getResult(lookupTopicResponse.error(), lookupTopicResponse.message())); } else { LOG_ERROR(cnxString() << "Failed lookup req_id: " << lookupTopicResponse.request_id() << " with empty response: "); - lookupDataPromise->setFailed(ResultConnectError); + request->fail(ResultConnectError); } } else { LOG_DEBUG(cnxString() << "Received lookup response from server. req_id: " @@ -1676,7 +1655,7 @@ void ClientConnection::handleLookupTopicRespose( lookupResultPtr->setRedirect(lookupTopicResponse.response() == proto::CommandLookupTopicResponse::Redirect); lookupResultPtr->setShouldProxyThroughServiceUrl(lookupTopicResponse.proxy_through_service_url()); - lookupDataPromise->setValue(lookupResultPtr); + request->complete(lookupResultPtr); } } else { @@ -1692,12 +1671,12 @@ void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess Lock lock(mutex_); auto it = pendingRequests_.find(producerSuccess.request_id()); if (it != pendingRequests_.end()) { - PendingRequestData requestData = it->second; + auto request = it->second; if (!producerSuccess.producer_ready()) { LOG_INFO(cnxString() << " Producer " << producerSuccess.producer_name() << " has been queued up at broker. req_id: " << producerSuccess.request_id()); - requestData.hasGotResponse->store(true); + request->disableTimeout(); lock.unlock(); } else { pendingRequests_.erase(it); @@ -1713,8 +1692,7 @@ void ClientConnection::handleProducerSuccess(const proto::CommandProducerSuccess } else { data.topicEpoch = std::nullopt; } - requestData.promise.setValue(data); - cancelTimer(*requestData.timer); + request->complete(data); } } } @@ -1729,30 +1707,27 @@ void ClientConnection::handleError(const proto::CommandError& error) { auto it = pendingRequests_.find(error.request_id()); if (it != pendingRequests_.end()) { - PendingRequestData requestData = it->second; + auto request = std::move(it->second); pendingRequests_.erase(it); lock.unlock(); - requestData.promise.setFailed(result); - cancelTimer(*requestData.timer); + request->fail(result); } else { - PendingGetLastMessageIdRequestsMap::iterator it = - pendingGetLastMessageIdRequests_.find(error.request_id()); + auto it = pendingGetLastMessageIdRequests_.find(error.request_id()); if (it != pendingGetLastMessageIdRequests_.end()) { - auto getLastMessageIdPromise = it->second.promise; + auto request = std::move(it->second); pendingGetLastMessageIdRequests_.erase(it); lock.unlock(); - getLastMessageIdPromise->setFailed(result); + request->fail(result); } else { - PendingGetNamespaceTopicsMap::iterator it = - pendingGetNamespaceTopicsRequests_.find(error.request_id()); + auto it = pendingGetNamespaceTopicsRequests_.find(error.request_id()); if (it != pendingGetNamespaceTopicsRequests_.end()) { - Promise getNamespaceTopicsPromise = it->second; + auto request = std::move(it->second); pendingGetNamespaceTopicsRequests_.erase(it); lock.unlock(); - getNamespaceTopicsPromise.setFailed(result); + request->fail(result); } else { lock.unlock(); } @@ -1904,16 +1879,15 @@ void ClientConnection::handleGetLastMessageIdResponse( auto it = pendingGetLastMessageIdRequests_.find(getLastMessageIdResponse.request_id()); if (it != pendingGetLastMessageIdRequests_.end()) { - auto getLastMessageIdPromise = it->second.promise; + auto request = std::move(it->second); pendingGetLastMessageIdRequests_.erase(it); lock.unlock(); if (getLastMessageIdResponse.has_consumer_mark_delete_position()) { - getLastMessageIdPromise->setValue( - {toMessageId(getLastMessageIdResponse.last_message_id()), - toMessageId(getLastMessageIdResponse.consumer_mark_delete_position())}); + request->complete({toMessageId(getLastMessageIdResponse.last_message_id()), + toMessageId(getLastMessageIdResponse.consumer_mark_delete_position())}); } else { - getLastMessageIdPromise->setValue({toMessageId(getLastMessageIdResponse.last_message_id())}); + request->complete({toMessageId(getLastMessageIdResponse.last_message_id())}); } } else { lock.unlock(); @@ -1931,7 +1905,7 @@ void ClientConnection::handleGetTopicOfNamespaceResponse( auto it = pendingGetNamespaceTopicsRequests_.find(response.request_id()); if (it != pendingGetNamespaceTopicsRequests_.end()) { - Promise getTopicsPromise = it->second; + auto request = std::move(it->second); pendingGetNamespaceTopicsRequests_.erase(it); lock.unlock(); @@ -1953,7 +1927,7 @@ void ClientConnection::handleGetTopicOfNamespaceResponse( NamespaceTopicsPtr topicsPtr = std::make_shared>(topicSet.begin(), topicSet.end()); - getTopicsPromise.setValue(topicsPtr); + request->complete(topicsPtr); } else { lock.unlock(); LOG_WARN( @@ -1968,7 +1942,7 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp Lock lock(mutex_); auto it = pendingGetSchemaRequests_.find(response.request_id()); if (it != pendingGetSchemaRequests_.end()) { - Promise getSchemaPromise = it->second.promise; + auto request = std::move(it->second); pendingGetSchemaRequests_.erase(it); lock.unlock(); @@ -1981,7 +1955,7 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp : "") << " -- req_id: " << response.request_id()); } - getSchemaPromise.setFailed(result); + request->fail(result); return; } @@ -1992,7 +1966,7 @@ void ClientConnection::handleGetSchemaResponse(const proto::CommandGetSchemaResp properties[kv->key()] = kv->value(); } SchemaInfo schemaInfo(static_cast(schema.type()), "", schema.schema_data(), properties); - getSchemaPromise.setValue(schemaInfo); + request->complete(schemaInfo); } else { lock.unlock(); LOG_WARN( @@ -2013,24 +1987,23 @@ void ClientConnection::handleAckResponse(const proto::CommandAckResponse& respon return; } - auto promise = it->second.promise; + auto request = std::move(it->second); pendingRequests_.erase(it); lock.unlock(); if (response.has_error()) { - promise.setFailed(getResult(response.error(), "")); + request->fail(getResult(response.error(), "")); } else { - promise.setValue({}); + request->complete({}); } } void ClientConnection::unsafeRemovePendingRequest(long requestId) { auto it = pendingRequests_.find(requestId); if (it != pendingRequests_.end()) { - it->second.promise.setFailed(ResultDisconnected); - cancelTimer(*it->second.timer); - + auto request = std::move(it->second); pendingRequests_.erase(it); + request->fail(ResultDisconnected); } } diff --git a/lib/ClientConnection.h b/lib/ClientConnection.h index 75e4bca8..c8cd86fe 100644 --- a/lib/ClientConnection.h +++ b/lib/ClientConnection.h @@ -28,6 +28,7 @@ #include #include #include + #ifdef USE_ASIO #include #include @@ -52,8 +53,10 @@ #include "AsioTimer.h" #include "Commands.h" +#include "ExecutorService.h" #include "GetLastMessageIdResponse.h" #include "LookupDataResult.h" +#include "PendingRequest.h" #include "SharedBuffer.h" #include "TimeUtils.h" #include "UtilAllocator.h" @@ -66,9 +69,6 @@ class PulsarFriend; using TcpResolverPtr = std::shared_ptr; -class ExecutorService; -using ExecutorServicePtr = std::shared_ptr; - class ConnectionPool; class ClientConnection; typedef std::shared_ptr ClientConnectionPtr; @@ -225,47 +225,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this promise; - DeadlineTimerPtr timer; - std::shared_ptr hasGotResponse{std::make_shared(false)}; - - void fail(Result result) { - cancelTimer(*timer); - promise.setFailed(result); - } - }; - - struct LookupRequestData { - LookupDataResultPromisePtr promise; - DeadlineTimerPtr timer; - - void fail(Result result) { - cancelTimer(*timer); - promise->setFailed(result); - } - }; - - struct LastMessageIdRequestData { - GetLastMessageIdResponsePromisePtr promise; - DeadlineTimerPtr timer; - - void fail(Result result) { - cancelTimer(*timer); - promise->setFailed(result); - } - }; - - struct GetSchemaRequest { - Promise promise; - DeadlineTimerPtr timer; - - void fail(Result result) { - cancelTimer(*timer); - promise.setFailed(result); - } - }; - /* * handler for connectAsync * creates a ConnectionPtr which has a valid ClientConnection object @@ -303,12 +262,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this inline AllocHandler customAllocReadHandler(Handler h) { return AllocHandler(readHandlerAllocator_, h); @@ -385,33 +338,49 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this PendingRequestsMap; - PendingRequestsMap pendingRequests_; + template + using RequestMap = std::unordered_map>; - typedef std::map PendingLookupRequestsMap; - PendingLookupRequestsMap pendingLookupRequests_; + RequestMap pendingRequests_; + RequestMap pendingLookupRequests_; + RequestMap pendingGetLastMessageIdRequests_; + RequestMap pendingGetNamespaceTopicsRequests_; + RequestMap pendingGetSchemaRequests_; - typedef std::map ProducersMap; + typedef std::unordered_map ProducersMap; ProducersMap producers_; - typedef std::map ConsumersMap; + typedef std::unordered_map ConsumersMap; ConsumersMap consumers_; typedef std::map> PendingConsumerStatsMap; PendingConsumerStatsMap pendingConsumerStatsMap_; - typedef std::map PendingGetLastMessageIdRequestsMap; - PendingGetLastMessageIdRequestsMap pendingGetLastMessageIdRequests_; - - typedef std::map> PendingGetNamespaceTopicsMap; - PendingGetNamespaceTopicsMap pendingGetNamespaceTopicsRequests_; - - typedef std::unordered_map PendingGetSchemaMap; - PendingGetSchemaMap pendingGetSchemaRequests_; - mutable std::mutex mutex_; typedef std::unique_lock Lock; + // Note: this method must be called when holding `mutex_` + template + auto insertRequest(RequestMap& pendingRequests, uint64_t requestId, OnTimeout onTimeout) { + auto request = std::make_shared>( + executor_->createTimer(operationsTimeout_), + [this, self{shared_from_this()}, requestId, onTimeout{std::move(onTimeout)}, + &pendingRequests]() mutable { + { + std::lock_guard lock{mutex_}; + if (auto it = pendingRequests.find(requestId); it != pendingRequests.end()) { + pendingRequests.erase(it); + } + } + onTimeout(); + }); + auto [iterator, inserted] = pendingRequests.emplace(requestId, request); + if (inserted) { + request->initialize(); + } // else: the request id is duplicated + return iterator->second; + } + // Pending buffers to write on the socket std::deque pendingWriteBuffers_; int pendingWriteOperations_ = 0; @@ -435,7 +404,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this consumerStatsRequests); uint32_t maxPendingLookupRequest_; - uint32_t numOfPendingLookupRequest_ = 0; + std::atomic_uint32_t numOfPendingLookupRequest_{0}; bool isTlsAllowInsecureConnection_ = false; diff --git a/lib/ExecutorService.h b/lib/ExecutorService.h index 80659d4b..4a36396c 100644 --- a/lib/ExecutorService.h +++ b/lib/ExecutorService.h @@ -28,12 +28,14 @@ #include #include #include +#include #else #include #include #include #include #include +#include #endif #include #include @@ -68,6 +70,13 @@ class PULSAR_PUBLIC ExecutorService : public std::enable_shared_from_this + ASIO::steady_timer createTimer(const Duration &duration) { + auto timer = ASIO::steady_timer(io_context_); + timer.expires_after(duration); + return timer; + } + // Execute the task in the event loop thread asynchronously, i.e. the task will be put in the event loop // queue and executed later. template diff --git a/lib/MockServer.h b/lib/MockServer.h index 2d830fc7..6f8d1390 100644 --- a/lib/MockServer.h +++ b/lib/MockServer.h @@ -81,6 +81,26 @@ class MockServer : public std::enable_shared_from_this { proto::CommandConsumerStatsResponse response; response.set_request_id(requestId); connection->handleConsumerStatsResponse(response); + } else if (request == "LOOKUP") { + proto::CommandLookupTopicResponse response; + response.set_request_id(requestId); + response.set_response(proto::CommandLookupTopicResponse_LookupType_Connect); + response.set_brokerserviceurl("pulsar://localhost:6650"); + connection->handleLookupTopicRespose(response); + } else if (request == "GET_LAST_MESSAGE_ID") { + proto::CommandGetLastMessageIdResponse response; + response.set_request_id(requestId); + response.mutable_last_message_id(); + connection->handleGetLastMessageIdResponse(response); + } else if (request == "GET_TOPICS_OF_NAMESPACE") { + proto::CommandGetTopicsOfNamespaceResponse response; + response.set_request_id(requestId); + connection->handleGetTopicOfNamespaceResponse(response); + } else if (request == "GET_SCHEMA") { + proto::CommandGetSchemaResponse response; + response.set_request_id(requestId); + response.mutable_schema()->set_type(proto::Schema_Type_String); + connection->handleGetSchemaResponse(response); } else { proto::CommandSuccess success; success.set_request_id(requestId); diff --git a/lib/PendingRequest.h b/lib/PendingRequest.h new file mode 100644 index 00000000..465073f6 --- /dev/null +++ b/lib/PendingRequest.h @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#pragma once + +#include + +#include +#include +#include + +#include "AsioDefines.h" +#include "AsioTimer.h" +#include "Future.h" + +namespace pulsar { + +template +class PendingRequest : public std::enable_shared_from_this> { + public: + PendingRequest(ASIO::steady_timer timer, std::function timeoutCallback) + : timer_(std::move(timer)), timeoutCallback_(std::move(timeoutCallback)) {} + + void initialize() { + timer_.async_wait([this, weakSelf{this->weak_from_this()}](const auto& error) { + auto self = weakSelf.lock(); + if (!self || error || timeoutDisabled_.load(std::memory_order_acquire)) { + return; + } + timeoutCallback_(); + promise_.setFailed(ResultTimeout); + }); + } + + void complete(const T& value) { + promise_.setValue(value); + cancelTimer(timer_); + } + + void fail(Result result) { + promise_.setFailed(result); + cancelTimer(timer_); + } + + void disableTimeout() { timeoutDisabled_.store(true, std::memory_order_release); } + + auto getFuture() const { return promise_.getFuture(); } + + ~PendingRequest() { cancelTimer(timer_); } + + private: + ASIO::steady_timer timer_; + Promise promise_; + std::function timeoutCallback_; + std::atomic_bool timeoutDisabled_{false}; +}; + +template +using PendingRequestPtr = std::shared_ptr>; + +} // namespace pulsar diff --git a/tests/ClientTest.cc b/tests/ClientTest.cc index 6bd6cc8a..63ac9d16 100644 --- a/tests/ClientTest.cc +++ b/tests/ClientTest.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include @@ -33,8 +34,13 @@ #include "PulsarFriend.h" #include "WaitUtils.h" #include "lib/AsioDefines.h" +#include "lib/AtomicSharedPtr.h" #include "lib/ClientConnection.h" +#include "lib/ConnectionPool.h" +#include "lib/ExecutorService.h" #include "lib/LogUtils.h" +#include "lib/MockServer.h" +#include "lib/TimeUtils.h" #include "lib/checksum/ChecksumProvider.h" #include "lib/stats/ProducerStatsImpl.h" @@ -231,6 +237,97 @@ TEST(ClientTest, testConnectTimeoutAfterTcpConnected) { server->stop(); } +TEST(ClientTest, testConnectionNotReferredAfterClose) { + Client client(lookupUrl); + auto topic = "test-connection-not-referred-after-close-" + std::to_string(time(nullptr)); + Producer producer; + ASSERT_EQ(ResultOk, client.createProducer(topic, producer)); + + Reader reader; + ASSERT_EQ(ResultOk, client.createReader(topic, MessageId::earliest(), {}, reader)); + + bool available; + ASSERT_EQ(ResultOk, reader.hasMessageAvailable(available)); + ASSERT_FALSE(available); + + ASSERT_EQ(ResultOk, producer.send(MessageBuilder().setContent("test").build())); + ASSERT_EQ(ResultOk, reader.hasMessageAvailable(available)); + ASSERT_TRUE(available); + + Message msg; + ASSERT_EQ(ResultOk, reader.readNext(msg)); + ASSERT_EQ("test", msg.getDataAsString()); + + auto start = TimeUtils::currentTimeMillis(); + ASSERT_EQ(ResultOk, client.close()); + auto closeTimeMs = TimeUtils::currentTimeMillis() - start; + ASSERT_LT(closeTimeMs, 3000) << "close time: " << closeTimeMs << " ms"; +} + +TEST(ClientTest, testTimedOutPendingRequestsAreErasedFromConnectionMaps) { + const auto suffix = std::to_string(std::chrono::steady_clock::now().time_since_epoch().count()); + ClientConfiguration conf; + conf.setOperationTimeoutSeconds(1); + + auto executorProvider = std::make_shared(1); + AtomicSharedPtr serviceInfo; + serviceInfo.store(std::make_shared(lookupUrl)); + ConnectionPool pool(serviceInfo, conf, executorProvider, ""); + auto connection = std::make_shared(lookupUrl, lookupUrl, *serviceInfo.load(), + executorProvider->get(), conf, "", pool, 0); + PulsarFriend::setServerProtocolVersion(*connection, 8); + + long requestIdGenerator = 0; + auto mockServer = std::make_shared(connection); + connection->attachMockServer(mockServer); + mockServer->setRequestDelay({{"TEST_PENDING_REQUEST", 1200}, + {"LOOKUP", 1200}, + {"GET_LAST_MESSAGE_ID", 1200}, + {"GET_TOPICS_OF_NAMESPACE", 1200}, + {"GET_SCHEMA", 1200}}); + + auto pingFuture = + connection->sendRequestWithId(Commands::newPing(), requestIdGenerator++, "TEST_PENDING_REQUEST"); + + auto lookupPromise = std::make_shared(); + auto lookupFuture = lookupPromise->getFuture(); + connection->newTopicLookup("persistent://public/default/testTimedOutPendingRequests-" + suffix, false, "", + requestIdGenerator++, lookupPromise); + + auto lastMessageIdFuture = connection->newGetLastMessageId(0, requestIdGenerator++); + + auto getTopicsOfNamespaceFuture = connection->newGetTopicsOfNamespace( + "public/default", CommandGetTopicsOfNamespace_Mode_PERSISTENT, requestIdGenerator++); + + auto getSchemaFuture = connection->newGetSchema( + "persistent://public/default/testTimedOutPendingRequests-" + suffix, "", requestIdGenerator++); + + ResponseData responseData; + ASSERT_EQ(ResultTimeout, pingFuture.get(responseData)); + ASSERT_EQ(0u, PulsarFriend::getPendingRequests(*connection)); + + LookupDataResultPtr lookupData; + ASSERT_EQ(ResultTimeout, lookupFuture.get(lookupData)); + ASSERT_EQ(0u, PulsarFriend::getPendingLookupRequests(*connection)); + ASSERT_EQ(0u, PulsarFriend::getNumOfPendingLookupRequests(*connection)); + + GetLastMessageIdResponse lastMessageIdResponse; + ASSERT_EQ(ResultTimeout, lastMessageIdFuture.get(lastMessageIdResponse)); + ASSERT_EQ(0u, PulsarFriend::getPendingGetLastMessageIdRequests(*connection)); + + NamespaceTopicsPtr topics; + ASSERT_EQ(ResultTimeout, getTopicsOfNamespaceFuture.get(topics)); + ASSERT_EQ(0u, PulsarFriend::getPendingGetTopicsOfNamespaceRequests(*connection)); + + SchemaInfo schemaInfo; + ASSERT_EQ(ResultTimeout, getSchemaFuture.get(schemaInfo)); + ASSERT_EQ(0u, PulsarFriend::getPendingGetSchemaRequests(*connection)); + + mockServer->close(); + connection->close(ResultDisconnected).wait(); + executorProvider->close(); +} + TEST(ClientTest, testGetNumberOfReferences) { Client client("pulsar://localhost:6650"); diff --git a/tests/PulsarFriend.h b/tests/PulsarFriend.h index 1f351d16..3296953b 100644 --- a/tests/PulsarFriend.h +++ b/tests/PulsarFriend.h @@ -167,6 +167,41 @@ class PulsarFriend { return cnx.pendingConsumerStatsMap_.size(); } + static size_t getPendingRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingRequests_.size(); + } + + static size_t getPendingLookupRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingLookupRequests_.size(); + } + + static size_t getNumOfPendingLookupRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.numOfPendingLookupRequest_; + } + + static size_t getPendingGetLastMessageIdRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingGetLastMessageIdRequests_.size(); + } + + static size_t getPendingGetTopicsOfNamespaceRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingGetNamespaceTopicsRequests_.size(); + } + + static size_t getPendingGetSchemaRequests(const ClientConnection& cnx) { + std::lock_guard lock(cnx.mutex_); + return cnx.pendingGetSchemaRequests_.size(); + } + + static void setServerProtocolVersion(ClientConnection& cnx, int serverProtocolVersion) { + std::lock_guard lock(cnx.mutex_); + cnx.serverProtocolVersion_ = serverProtocolVersion; + } + static void setNegativeAckEnabled(Consumer consumer, bool enabled) { consumer.impl_->setNegativeAcknowledgeEnabledForTesting(enabled); }