Skip to content

Commit e80f65e

Browse files
Fix crash due to asio object lifetime and thread safety issue (#551)
1 parent 23b60d1 commit e80f65e

11 files changed

Lines changed: 494 additions & 327 deletions

.github/workflows/ci-pr-validation.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,15 @@ jobs:
260260
Pop-Location
261261
}
262262
263+
- name: Ensure vcpkg has full history(windows)
264+
if: runner.os == 'Windows'
265+
shell: pwsh
266+
run: |
267+
$isShallow = (git -C "${{ env.VCPKG_ROOT }}" rev-parse --is-shallow-repository).Trim()
268+
if ($isShallow -eq "true") {
269+
git -C "${{ env.VCPKG_ROOT }}" fetch --unshallow
270+
}
271+
263272
- name: remove system vcpkg(windows)
264273
if: runner.os == 'Windows'
265274
run: rm -rf "$VCPKG_INSTALLATION_ROOT"

lib/ClientConnection.cc

Lines changed: 247 additions & 288 deletions
Large diffs are not rendered by default.

lib/ClientConnection.h

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include <any>
2626
#include <atomic>
2727
#include <cstdint>
28+
#include <future>
29+
#include <optional>
2830
#ifdef USE_ASIO
2931
#include <asio/bind_executor.hpp>
3032
#include <asio/io_context.hpp>
@@ -41,8 +43,10 @@
4143
#include <deque>
4244
#include <functional>
4345
#include <memory>
46+
#include <mutex>
4447
#include <string>
4548
#include <unordered_map>
49+
#include <utility>
4650
#include <vector>
4751

4852
#include "AsioTimer.h"
@@ -156,11 +160,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
156160
* Close the connection.
157161
*
158162
* @param result all pending futures will complete with this result
159-
* @param detach remove it from the pool if it's true
160-
*
161-
* `detach` should only be false when the connection pool is closed.
162163
*/
163-
void close(Result result = ResultConnectError, bool detach = true);
164+
const std::future<void>& close(Result result = ResultConnectError);
164165

165166
bool isClosed() const;
166167

@@ -193,7 +194,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
193194

194195
const std::string& brokerAddress() const;
195196

196-
const std::string& cnxString() const;
197+
auto cnxString() const { return *std::atomic_load(&cnxStringPtr_); }
197198

198199
int getServerProtocolVersion() const;
199200

@@ -219,28 +220,48 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
219220
mockingRequests_.store(true, std::memory_order_release);
220221
}
221222

222-
void handleKeepAliveTimeout();
223+
void handleKeepAliveTimeout(const ASIO_ERROR& ec);
223224

224225
private:
225226
struct PendingRequestData {
226227
Promise<Result, ResponseData> promise;
227228
DeadlineTimerPtr timer;
228229
std::shared_ptr<std::atomic_bool> hasGotResponse{std::make_shared<std::atomic_bool>(false)};
230+
231+
void fail(Result result) {
232+
cancelTimer(*timer);
233+
promise.setFailed(result);
234+
}
229235
};
230236

231237
struct LookupRequestData {
232238
LookupDataResultPromisePtr promise;
233239
DeadlineTimerPtr timer;
240+
241+
void fail(Result result) {
242+
cancelTimer(*timer);
243+
promise->setFailed(result);
244+
}
234245
};
235246

236247
struct LastMessageIdRequestData {
237248
GetLastMessageIdResponsePromisePtr promise;
238249
DeadlineTimerPtr timer;
250+
251+
void fail(Result result) {
252+
cancelTimer(*timer);
253+
promise->setFailed(result);
254+
}
239255
};
240256

241257
struct GetSchemaRequest {
242258
Promise<Result, SchemaInfo> promise;
243259
DeadlineTimerPtr timer;
260+
261+
void fail(Result result) {
262+
cancelTimer(*timer);
263+
promise.setFailed(result);
264+
}
244265
};
245266

246267
/*
@@ -297,26 +318,26 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
297318
}
298319

299320
template <typename ConstBufferSequence, typename WriteHandler>
300-
inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler handler) {
321+
inline void asyncWrite(const ConstBufferSequence& buffers, WriteHandler&& handler) {
301322
if (isClosed()) {
302323
return;
303324
}
304325
if (tlsSocket_) {
305-
ASIO::async_write(*tlsSocket_, buffers, ASIO::bind_executor(strand_, handler));
326+
ASIO::async_write(*tlsSocket_, buffers, std::forward<WriteHandler>(handler));
306327
} else {
307-
ASIO::async_write(*socket_, buffers, handler);
328+
ASIO::async_write(*socket_, buffers, std::forward<WriteHandler>(handler));
308329
}
309330
}
310331

311332
template <typename MutableBufferSequence, typename ReadHandler>
312-
inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler handler) {
333+
inline void asyncReceive(const MutableBufferSequence& buffers, ReadHandler&& handler) {
313334
if (isClosed()) {
314335
return;
315336
}
316337
if (tlsSocket_) {
317-
tlsSocket_->async_read_some(buffers, ASIO::bind_executor(strand_, handler));
338+
tlsSocket_->async_read_some(buffers, std::forward<ReadHandler>(handler));
318339
} else {
319-
socket_->async_receive(buffers, handler);
340+
socket_->async_receive(buffers, std::forward<ReadHandler>(handler));
320341
}
321342
}
322343

@@ -337,7 +358,6 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
337358
*/
338359
SocketPtr socket_;
339360
TlsSocketPtr tlsSocket_;
340-
ASIO::strand<ASIO::io_context::executor_type> strand_;
341361

342362
const std::string logicalAddress_;
343363
/*
@@ -350,7 +370,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
350370
ClientConfiguration::ProxyProtocol proxyProtocol_;
351371

352372
// Represent both endpoint of the tcp connection. eg: [client:1234 -> server:6650]
353-
std::string cnxString_;
373+
std::shared_ptr<std::string> cnxStringPtr_;
354374

355375
/*
356376
* indicates if async connection establishment failed
@@ -360,7 +380,8 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
360380
SharedBuffer incomingBuffer_;
361381

362382
Promise<Result, ClientConnectionWeakPtr> connectPromise_;
363-
std::shared_ptr<PeriodicTask> connectTimeoutTask_;
383+
const std::chrono::milliseconds connectTimeout_;
384+
const DeadlineTimerPtr connectTimer_;
364385

365386
typedef std::map<long, PendingRequestData> PendingRequestsMap;
366387
PendingRequestsMap pendingRequests_;
@@ -419,6 +440,7 @@ class PULSAR_PUBLIC ClientConnection : public std::enable_shared_from_this<Clien
419440
const std::string clientVersion_;
420441
ConnectionPool& pool_;
421442
const size_t poolIndex_;
443+
std::optional<std::future<void>> closeFuture_;
422444

423445
friend class PulsarFriend;
424446
friend class ConsumerTest;

lib/ConnectionPool.cc

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,43 @@ bool ConnectionPool::close() {
5454
return false;
5555
}
5656

57+
std::vector<ClientConnectionPtr> connectionsToClose;
58+
// ClientConnection::close() will remove the connection from the pool, which is not allowed when iterating
59+
// over a map, so we store the connections to close in a vector first and don't iterate the pool when
60+
// closing the connections.
5761
std::unique_lock<std::recursive_mutex> lock(mutex_);
62+
connectionsToClose.reserve(pool_.size());
63+
for (auto&& kv : pool_) {
64+
connectionsToClose.emplace_back(kv.second);
65+
}
66+
pool_.clear();
67+
lock.unlock();
5868

59-
for (auto cnxIt = pool_.begin(); cnxIt != pool_.end(); cnxIt++) {
60-
auto& cnx = cnxIt->second;
69+
for (auto&& cnx : connectionsToClose) {
6170
if (cnx) {
62-
// The 2nd argument is false because removing a value during the iteration will cause segfault
63-
cnx->close(ResultDisconnected, false);
71+
// Close with a fatal error to not let client retry
72+
auto& future = cnx->close(ResultAlreadyClosed);
73+
using namespace std::chrono_literals;
74+
if (auto status = future.wait_for(5s); status != std::future_status::ready) {
75+
LOG_WARN("Connection close timed out for " << cnx.get()->cnxString());
76+
}
77+
if (cnx.use_count() > 1) {
78+
// There are some asynchronous operations that hold the reference on the connection, we should
79+
// wait until them to finish. Otherwise, `io_context::stop()` will be called in
80+
// `ClientImpl::shutdown()` when closing the `ExecutorServiceProvider`. Then
81+
// `io_context::run()` will return and the `io_context` object will be destroyed. In this
82+
// case, if there is any pending handler, it will crash.
83+
for (int i = 0; i < 500 && cnx.use_count() > 1; i++) {
84+
std::this_thread::sleep_for(10ms);
85+
}
86+
if (cnx.use_count() > 1) {
87+
LOG_WARN("Connection still has " << (cnx.use_count() - 1)
88+
<< " references after waiting for 5 seconds for "
89+
<< cnx.get()->cnxString());
90+
}
91+
}
6492
}
6593
}
66-
pool_.clear();
6794
return true;
6895
}
6996

lib/ExecutorService.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ void ExecutorService::close(long timeoutMs) {
125125
}
126126
}
127127

128-
void ExecutorService::postWork(std::function<void(void)> task) { ASIO::post(io_context_, std::move(task)); }
129-
130128
/////////////////////
131129

132130
ExecutorServiceProvider::ExecutorServiceProvider(int nthreads)

lib/ExecutorService.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,16 @@
2323

2424
#include <atomic>
2525
#ifdef USE_ASIO
26+
#include <asio/dispatch.hpp>
2627
#include <asio/io_context.hpp>
2728
#include <asio/ip/tcp.hpp>
29+
#include <asio/post.hpp>
2830
#include <asio/ssl.hpp>
2931
#else
32+
#include <boost/asio/dispatch.hpp>
3033
#include <boost/asio/io_context.hpp>
3134
#include <boost/asio/ip/tcp.hpp>
35+
#include <boost/asio/post.hpp>
3236
#include <boost/asio/ssl.hpp>
3337
#endif
3438
#include <chrono>
@@ -37,6 +41,7 @@
3741
#include <memory>
3842
#include <mutex>
3943
#include <thread>
44+
#include <utility>
4045

4146
#include "AsioTimer.h"
4247

@@ -62,7 +67,19 @@ class PULSAR_PUBLIC ExecutorService : public std::enable_shared_from_this<Execut
6267
TcpResolverPtr createTcpResolver();
6368
// throws std::runtime_error if failed
6469
DeadlineTimerPtr createDeadlineTimer();
65-
void postWork(std::function<void(void)> task);
70+
71+
// Execute the task in the event loop thread asynchronously, i.e. the task will be put in the event loop
72+
// queue and executed later.
73+
template <typename T>
74+
void postWork(T &&task) {
75+
ASIO::post(io_context_, std::forward<T>(task));
76+
}
77+
78+
// Different from `postWork`, if it's already in the event loop, execute the task immediately
79+
template <typename T>
80+
void dispatch(T &&task) {
81+
ASIO::dispatch(io_context_, std::forward<T>(task));
82+
}
6683

6784
// See TimeoutProcessor for the semantics of the parameter.
6885
void close(long timeoutMs = 3000);

lib/PeriodicTask.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class PeriodicTask : public std::enable_shared_from_this<PeriodicTask> {
5353

5454
void stop() noexcept;
5555

56-
void setCallback(CallbackType callback) noexcept { callback_ = callback; }
56+
void setCallback(CallbackType&& callback) noexcept { callback_ = std::move(callback); }
5757

5858
State getState() const noexcept { return state_; }
5959
int getPeriodMs() const noexcept { return periodMs_; }

tests/BasicEndToEndTest.cc

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3188,7 +3188,17 @@ static void expectTimeoutOnRecv(Consumer &consumer) {
31883188
ASSERT_EQ(ResultTimeout, res);
31893189
}
31903190

3191-
void testNegativeAcks(const std::string &topic, bool batchingEnabled) {
3191+
static std::vector<std::string> expectedNegativeAckMessages(size_t numMessages) {
3192+
std::vector<std::string> expected;
3193+
expected.reserve(numMessages);
3194+
for (size_t i = 0; i < numMessages; i++) {
3195+
expected.emplace_back("test-" + std::to_string(i));
3196+
}
3197+
return expected;
3198+
}
3199+
3200+
void testNegativeAcks(const std::string &topic, bool batchingEnabled, bool expectOrdered = true) {
3201+
constexpr size_t numMessages = 10;
31923202
Client client(lookupUrl);
31933203
Consumer consumer;
31943204
ConsumerConfiguration conf;
@@ -3202,22 +3212,32 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) {
32023212
result = client.createProducer(topic, producerConf, producer);
32033213
ASSERT_EQ(ResultOk, result);
32043214

3205-
for (int i = 0; i < 10; i++) {
3215+
for (size_t i = 0; i < numMessages; i++) {
32063216
Message msg = MessageBuilder().setContent("test-" + std::to_string(i)).build();
32073217
producer.sendAsync(msg, nullptr);
32083218
}
32093219

32103220
producer.flush();
32113221

3222+
std::vector<std::string> receivedMessages;
3223+
receivedMessages.reserve(numMessages);
32123224
std::vector<MessageId> toNeg;
3213-
for (int i = 0; i < 10; i++) {
3225+
for (size_t i = 0; i < numMessages; i++) {
32143226
Message msg;
32153227
consumer.receive(msg);
32163228

32173229
LOG_INFO("Received message " << msg.getDataAsString());
3218-
ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i));
3230+
if (expectOrdered) {
3231+
ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i));
3232+
}
3233+
receivedMessages.emplace_back(msg.getDataAsString());
32193234
toNeg.push_back(msg.getMessageId());
32203235
}
3236+
if (!expectOrdered) {
3237+
auto expectedMessages = expectedNegativeAckMessages(numMessages);
3238+
std::sort(receivedMessages.begin(), receivedMessages.end());
3239+
ASSERT_EQ(expectedMessages, receivedMessages);
3240+
}
32213241
// No more messages expected
32223242
expectTimeoutOnRecv(consumer);
32233243

@@ -3228,15 +3248,25 @@ void testNegativeAcks(const std::string &topic, bool batchingEnabled) {
32283248
}
32293249
PulsarFriend::setNegativeAckEnabled(consumer, true);
32303250

3231-
for (int i = 0; i < 10; i++) {
3251+
std::vector<std::string> redeliveredMessages;
3252+
redeliveredMessages.reserve(numMessages);
3253+
for (size_t i = 0; i < numMessages; i++) {
32323254
Message msg;
32333255
consumer.receive(msg);
32343256
LOG_INFO("-- Redelivery -- Received message " << msg.getDataAsString());
32353257

3236-
ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i));
3258+
if (expectOrdered) {
3259+
ASSERT_EQ(msg.getDataAsString(), "test-" + std::to_string(i));
3260+
}
3261+
redeliveredMessages.emplace_back(msg.getDataAsString());
32373262

32383263
consumer.acknowledge(msg);
32393264
}
3265+
if (!expectOrdered) {
3266+
auto expectedMessages = expectedNegativeAckMessages(numMessages);
3267+
std::sort(redeliveredMessages.begin(), redeliveredMessages.end());
3268+
ASSERT_EQ(expectedMessages, redeliveredMessages);
3269+
}
32403270

32413271
// No more messages expected
32423272
expectTimeoutOnRecv(consumer);
@@ -3262,7 +3292,7 @@ TEST(BasicEndToEndTest, testNegativeAcksWithPartitions) {
32623292
LOG_INFO("res = " << res);
32633293
ASSERT_FALSE(res != 204 && res != 409);
32643294

3265-
testNegativeAcks(topicName, true);
3295+
testNegativeAcks(topicName, true, false);
32663296
}
32673297

32683298
void testNegativeAckPrecisionBitCnt(const std::string &topic, int precisionBitCnt) {

0 commit comments

Comments
 (0)