Skip to content

Commit cd8e355

Browse files
committed
WIP refactor: replace sharding with single connection set (6)
1 parent 824e27b commit cd8e355

1 file changed

Lines changed: 145 additions & 76 deletions

File tree

sqlx-core/src/pool/connection_set.rs

Lines changed: 145 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use std::time::Duration;
1616
use tracing::Instrument;
1717

1818
pub struct ConnectionSet<C> {
19-
global: Arc<Global>,
19+
global: Arc<Global<C>>,
2020
slots: Box<[Arc<Slot<C>>]>,
2121
}
2222

@@ -31,9 +31,9 @@ enum AcquirePreference {
3131
Either,
3232
}
3333

34-
struct Global {
35-
unlock_event: Event<usize>,
36-
disconnect_event: Event<usize>,
34+
struct Global<C> {
35+
unlock_event: Event<ReleaseWithoutNotify<C>>,
36+
disconnect_event: Event<ReleaseWithoutNotify<C>>,
3737
locked_set: Box<[AtomicBool]>,
3838
num_connected: AtomicUsize,
3939
min_connections: usize,
@@ -46,10 +46,12 @@ struct SlotGuard<C> {
4646
locked: Option<AsyncMutexGuardArc<Option<C>>>,
4747
}
4848

49+
struct ReleaseWithoutNotify<C>(SlotGuard<C>);
50+
4951
struct Slot<C> {
5052
// By having each `Slot` hold its own reference to `Global`, we can avoid extra contended clones
5153
// which would sap performance
52-
global: Arc<Global>,
54+
global: Arc<Global<C>>,
5355
index: usize,
5456
// I'd love to eliminate this redundant `Arc` but it's likely not possible without `unsafe`
5557
connection: Arc<AsyncMutex<Option<C>>>,
@@ -129,40 +131,49 @@ impl<C> ConnectionSet<C> {
129131
);
130132

131133
if self.slots.len() == 1 {
132-
span.record("alternate_slot", 0usize);
134+
span.record("preferred_slot", 0usize);
133135
return self.slots[0].acquire(pref).instrument(span).await;
134136
}
135137

136-
// Always try to lock the connection associated with our thread ID first
137-
let preferred_slot = current_thread_id() % self.slots.len();
138+
let preferred_slot = self.choose_preferred_slot();
138139
span.record("preferred_slot", preferred_slot);
139140

140-
// The number of tasks currently interested in this slot. Always at least 1.
141-
let search_offset = Arc::strong_count(&self.slots[preferred_slot].connection);
141+
let acquire_preferred = self.slots[preferred_slot].acquire(pref);
142+
143+
let acquire_global = async {
144+
// Yielding actually improves performance here.
145+
rt::yield_now().await;
146+
147+
// Since we know `preferred_slot` is locked, we offset our search by the number
148+
// of tasks interested in this slot, which is always at least 1.
149+
let search_offset = Arc::strong_count(&self.slots[preferred_slot]);
142150

143-
let acquire_global = pin!(async {
144151
if let Some(locked) = self.try_acquire(pref, preferred_slot.wrapping_add(search_offset))
145152
{
153+
tracing::trace!(
154+
search_offset,
155+
slot = locked.slot.index,
156+
"acquired from try_acquire"
157+
);
146158
return locked;
147159
}
148160

149-
loop {
150-
let slot = self.global.listen(pref).await;
161+
// Since `acquire_global` is fair, we wait
162+
//rt::sleep(Duration::from_millis(50)).await;
151163

152-
if let Some(locked) = self.try_acquire(pref, slot) {
153-
return locked;
154-
}
155-
}
156-
});
164+
rt::yield_now().await;
165+
166+
self.global.listen(pref).await
167+
};
157168

158-
let res = race(self.slots[preferred_slot].acquire(pref), acquire_global)
169+
let res = race(acquire_preferred, acquire_global)
159170
.instrument(span.clone())
160171
.await;
161172

162173
let _span = span.enter();
163174
match res {
164175
Ok(preferred) => {
165-
tracing::trace!("acquired from preferred_slot");
176+
tracing::trace!(slot = preferred_slot, "acquired from acquire_preferred");
166177
preferred
167178
}
168179
Err(global) => {
@@ -186,6 +197,26 @@ impl<C> ConnectionSet<C> {
186197
)
187198
}
188199

200+
/// Find a non-leaked slot starting with the one associated with [`current_thread_id()`].
201+
fn choose_preferred_slot(&self) -> usize {
202+
// Always try to lock the connection associated with our thread ID first
203+
let starting_slot = current_thread_id() % self.slots.len();
204+
205+
let search_slots = (starting_slot..self.slots.len()).chain(0..starting_slot);
206+
207+
for slot in search_slots {
208+
if !self.slots[slot].is_leaked() {
209+
return slot;
210+
}
211+
}
212+
213+
tracing::warn!(
214+
num_slots = self.slots.len(),
215+
"all slots have been leaked! all acquires will time out"
216+
);
217+
starting_slot
218+
}
219+
189220
fn try_acquire(&self, pref: AcquirePreference, starting_slot: usize) -> Option<SlotGuard<C>> {
190221
let starting_slot = starting_slot % self.slots.len();
191222

@@ -255,6 +286,46 @@ impl<C> ConnectionSet<C> {
255286
}
256287
}
257288

289+
const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation";
290+
const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`";
291+
292+
impl<C> ConnectedSlot<C> {
293+
pub fn take(mut self) -> (C, DisconnectedSlot<C>) {
294+
let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED);
295+
(conn, self.0.assert_disconnected())
296+
}
297+
}
298+
299+
impl<C> Deref for ConnectedSlot<C> {
300+
type Target = C;
301+
302+
#[inline(always)]
303+
fn deref(&self) -> &Self::Target {
304+
self.0.get().as_ref().expect(EXPECT_CONNECTED)
305+
}
306+
}
307+
308+
impl<C> DerefMut for ConnectedSlot<C> {
309+
#[inline(always)]
310+
fn deref_mut(&mut self) -> &mut Self::Target {
311+
self.0.get_mut().as_mut().expect(EXPECT_CONNECTED)
312+
}
313+
}
314+
315+
impl<C> DisconnectedSlot<C> {
316+
pub fn put(mut self, conn: C) -> ConnectedSlot<C> {
317+
*self.0.get_mut() = Some(conn);
318+
ConnectedSlot(self.0)
319+
}
320+
321+
pub fn leak(mut self) {
322+
self.0.slot.connected.store(false, Ordering::Relaxed);
323+
self.0.slot.leaked.store(true, Ordering::Release);
324+
// Drop the guard without marking the connection as unlocked
325+
self.0.locked = None;
326+
}
327+
}
328+
258329
impl AcquirePreference {
259330
#[inline(always)]
260331
fn wants_connected(&self, is_connected: bool) -> bool {
@@ -411,71 +482,35 @@ impl<C> SlotGuard<C> {
411482
DisconnectedSlot(self)
412483
}
413484

414-
/// Updates `Slot::connected` without notifying the `ConnectionSet`.
415-
///
416-
/// Returns `Some(connected)` or `None` if this guard was already dropped.
417-
fn drop_without_notify(&mut self) -> Option<bool> {
485+
fn release_without_notify(&mut self) -> Option<ReleaseWithoutNotify<C>> {
418486
self.locked.take().map(|locked| {
419-
let connected = locked.is_some();
420-
self.slot.set_is_connected(connected);
421-
self.slot.locked.store(false, Ordering::Release);
422-
self.slot.global.locked_set[self.slot.index].store(false, Ordering::Relaxed);
423-
connected
487+
ReleaseWithoutNotify(SlotGuard {
488+
slot: self.slot.clone(),
489+
locked: Some(locked),
490+
})
424491
})
425492
}
426493
}
427494

428-
const EXPECT_LOCKED: &str = "BUG: `SlotGuard::locked` should not be `None` in normal operation";
429-
const EXPECT_CONNECTED: &str = "BUG: `ConnectedSlot` expects `Slot::connection` to be `Some`";
430-
431-
impl<C> ConnectedSlot<C> {
432-
pub fn take(mut self) -> (C, DisconnectedSlot<C>) {
433-
let conn = self.0.get_mut().take().expect(EXPECT_CONNECTED);
434-
(conn, self.0.assert_disconnected())
435-
}
436-
}
437-
438-
impl<C> Deref for ConnectedSlot<C> {
439-
type Target = C;
440-
441-
#[inline(always)]
442-
fn deref(&self) -> &Self::Target {
443-
self.0.get().as_ref().expect(EXPECT_CONNECTED)
444-
}
445-
}
446-
447-
impl<C> DerefMut for ConnectedSlot<C> {
448-
#[inline(always)]
449-
fn deref_mut(&mut self) -> &mut Self::Target {
450-
self.0.get_mut().as_mut().expect(EXPECT_CONNECTED)
451-
}
452-
}
453-
454-
impl<C> DisconnectedSlot<C> {
455-
pub fn put(mut self, conn: C) -> ConnectedSlot<C> {
456-
*self.0.get_mut() = Some(conn);
457-
ConnectedSlot(self.0)
458-
}
459-
460-
pub fn leak(mut self) {
461-
self.0.slot.leaked.store(true, Ordering::Release);
462-
self.0.drop_without_notify();
463-
}
464-
}
465-
466495
impl<C> Drop for SlotGuard<C> {
467496
fn drop(&mut self) {
468-
let Some(connected) = self.drop_without_notify() else {
497+
let Some(mut guard) = self.release_without_notify() else {
469498
return;
470499
};
471500

472-
let event = if connected {
501+
let connected = guard.is_connected();
502+
503+
let event = if guard.is_connected() {
473504
&self.slot.global.unlock_event
474505
} else {
475506
&self.slot.global.disconnect_event
476507
};
477508

478-
if event.notify(1.tag(self.slot.index).additional()) != 0 {
509+
if event.notify(
510+
1.tag_with(|| ReleaseWithoutNotify(guard.take()))
511+
.additional(),
512+
) != 0
513+
{
479514
return;
480515
}
481516

@@ -494,13 +529,47 @@ impl<C> Drop for SlotGuard<C> {
494529
}
495530
}
496531

497-
impl Global {
532+
impl<C> ReleaseWithoutNotify<C> {
533+
fn take(&mut self) -> SlotGuard<C> {
534+
SlotGuard {
535+
slot: self.0.slot.clone(),
536+
locked: Some(
537+
self.0
538+
.locked
539+
.take()
540+
.expect("BUG: `SlotGuard.locked` should not be `None` here"),
541+
),
542+
}
543+
}
544+
545+
fn is_connected(&self) -> bool {
546+
self.0
547+
.locked
548+
.as_ref()
549+
.expect("BUG: `SlotGuard.locked` should not be `None` here")
550+
.is_some()
551+
}
552+
}
553+
554+
impl<C> Drop for ReleaseWithoutNotify<C> {
555+
fn drop(&mut self) {
556+
let Some(locked) = self.0.locked.take() else {
557+
return;
558+
};
559+
560+
self.0.slot.set_is_connected(locked.is_some());
561+
self.0.slot.locked.store(false, Ordering::Release);
562+
self.0.slot.global.locked_set[self.0.slot.index].store(false, Ordering::Relaxed);
563+
}
564+
}
565+
566+
impl<C> Global<C> {
498567
#[inline(always)]
499568
fn num_connected(&self) -> usize {
500569
self.num_connected.load(Ordering::Relaxed)
501570
}
502571

503-
async fn listen(&self, pref: AcquirePreference) -> usize {
572+
async fn listen(&self, pref: AcquirePreference) -> SlotGuard<C> {
504573
match pref {
505574
AcquirePreference::Either => race(self.listen_unlocked(), self.listen_disconnected())
506575
.await
@@ -510,14 +579,14 @@ impl Global {
510579
}
511580
}
512581

513-
async fn listen_unlocked(&self) -> usize {
582+
async fn listen_unlocked(&self) -> SlotGuard<C> {
514583
listener!(self.unlock_event => listener);
515-
listener.await
584+
listener.await.take()
516585
}
517586

518-
async fn listen_disconnected(&self) -> usize {
587+
async fn listen_disconnected(&self) -> SlotGuard<C> {
519588
listener!(self.disconnect_event => listener);
520-
listener.await
589+
listener.await.take()
521590
}
522591
}
523592

0 commit comments

Comments
 (0)