diff --git a/Cargo.lock b/Cargo.lock index ff63e51..fedb379 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -486,6 +486,16 @@ version = "0.2.183" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" +[[package]] +name = "libmimalloc-sys" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "667f4fec20f29dfc6bc7357c582d91796c169ad7e2fce709468aefeb2c099870" +dependencies = [ + "cc", + "libc", +] + [[package]] name = "litemap" version = "0.8.1" @@ -513,6 +523,15 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "mimalloc" +version = "0.1.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1ee66a4b64c74f4ef288bcbb9192ad9c3feaad75193129ac8509af543894fd8" +dependencies = [ + "libmimalloc-sys", +] + [[package]] name = "mio" version = "1.1.1" @@ -532,6 +551,7 @@ dependencies = [ "bytes", "futures-util", "log", + "mimalloc", "pyo3", "pyo3-async-runtimes", "pyo3-log", diff --git a/Cargo.toml b/Cargo.toml index 60dec9a..7395ece 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,8 @@ async-nats = "0.46" bytes = "1.11.1" futures-util = "0.3.32" log = "0.4.29" -pyo3 = { version = "0.28", features = ["abi3", "experimental-inspect"] } +mimalloc = "0.1.48" +pyo3 = { version = "0.28", features = ["experimental-inspect"] } pyo3-async-runtimes = { version = "0.28", features = ["tokio-runtime"] } pyo3-log = "0.13.3" serde = { version = "1.0.228", features = ["derive"] } diff --git a/src/exceptions/rust_err.rs b/src/exceptions/rust_err.rs index 99e6689..75b4dd6 100644 --- a/src/exceptions/rust_err.rs +++ b/src/exceptions/rust_err.rs @@ -8,6 +8,8 @@ pub type NatsrpyResult = Result; pub enum NatsrpyError { #[error(transparent)] StdIOError(#[from] std::io::Error), + #[error("The lock is poisoned")] + PoisonedLock, #[error(transparent)] StdParseIntError(#[from] std::num::ParseIntError), #[error(transparent)] diff --git a/src/js/consumers/pull/consumer.rs b/src/js/consumers/pull/consumer.rs index 7956883..3dffbec 100644 --- a/src/js/consumers/pull/consumer.rs +++ b/src/js/consumers/pull/consumer.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use futures_util::StreamExt; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::NatsrpyResult, @@ -19,7 +18,7 @@ pub struct PullConsumer { name: String, #[pyo3(get)] stream_name: String, - consumer: Arc>, + consumer: Arc, } impl PullConsumer { @@ -29,7 +28,7 @@ impl PullConsumer { Self { name: info.name.clone(), stream_name: info.stream_name.clone(), - consumer: Arc::new(RwLock::new(consumer)), + consumer: Arc::new(consumer), } } } @@ -60,13 +59,9 @@ impl PullConsumer { min_ack_pending: Option, timeout: Option, ) -> NatsrpyResult> { - let ctx = self.consumer.clone(); - - // Because we borrow cosnumer lock - // later for modifications of fetchbuilder. + let consumer = self.consumer.clone(); #[allow(clippy::significant_drop_tightening)] natsrpy_future_with_timeout(py, timeout, async move { - let consumer = ctx.read().await; let mut fetch_builder = consumer.fetch(); if let Some(max_messages) = max_messages { fetch_builder = fetch_builder.max_messages(max_messages); diff --git a/src/js/consumers/push/consumer.rs b/src/js/consumers/push/consumer.rs index e79a276..150fd6e 100644 --- a/src/js/consumers/push/consumer.rs +++ b/src/js/consumers/push/consumer.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use futures_util::StreamExt; use pyo3::{Bound, PyAny, PyRef, Python}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -20,7 +19,7 @@ pub struct PushConsumer { name: String, #[pyo3(get)] stream_name: String, - consumer: Arc>, + consumer: Arc, } impl PushConsumer { @@ -30,20 +29,20 @@ impl PushConsumer { Self { name: info.name.clone(), stream_name: info.stream_name.clone(), - consumer: Arc::new(RwLock::new(consumer)), + consumer: Arc::new(consumer), } } } #[pyo3::pyclass] pub struct MessagesIterator { - messages: Option>>, + messages: Option>>, } impl From for MessagesIterator { fn from(value: async_nats::jetstream::consumer::push::Messages) -> Self { Self { - messages: Some(Arc::new(RwLock::new(value))), + messages: Some(Arc::new(tokio::sync::Mutex::new(value))), } } } @@ -51,11 +50,9 @@ impl From for MessagesIterator #[pyo3::pymethods] impl PushConsumer { pub fn messages<'py>(&self, py: Python<'py>) -> NatsrpyResult> { - let consumer_guard = self.consumer.clone(); + let consumer = self.consumer.clone(); natsrpy_future(py, async move { - Ok(MessagesIterator::from( - consumer_guard.read().await.messages().await?, - )) + Ok(MessagesIterator::from(consumer.messages().await?)) }) } @@ -87,7 +84,7 @@ impl MessagesIterator { }; #[allow(clippy::significant_drop_tightening)] natsrpy_future_with_timeout(py, timeout, async move { - let mut messages = messages_guard.write().await; + let mut messages = messages_guard.lock().await; let Some(message) = messages.next().await else { return Err(NatsrpyError::AsyncStopIteration); }; diff --git a/src/js/counters.rs b/src/js/counters.rs index 7819cd0..a7918c5 100644 --- a/src/js/counters.rs +++ b/src/js/counters.rs @@ -2,7 +2,6 @@ use std::{collections::HashMap, sync::Arc, time::Duration}; use async_nats::{HeaderMap, jetstream::context::traits::Publisher}; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -301,17 +300,18 @@ impl CounterEntry { #[pyo3::pyclass] #[allow(dead_code)] pub struct Counters { - stream: Arc>>, - js: Arc>, + stream: Arc>, + js: Arc, } impl Counters { + #[must_use] pub fn new( stream: async_nats::jetstream::stream::Stream, - js: Arc>, + js: Arc, ) -> Self { Self { - stream: Arc::new(RwLock::new(stream)), + stream: Arc::new(stream), js, } } @@ -357,8 +357,6 @@ impl Counters { headers.insert(COUNTER_INCREMENT_HEADER, value.to_string()); natsrpy_future_with_timeout(py, timeout, async move { let resp = js - .read() - .await .publish_message(async_nats::jetstream::message::OutboundMessage { subject: key.into(), payload: bytes::Bytes::new(), @@ -404,11 +402,7 @@ impl Counters { ) -> NatsrpyResult> { let stream_guard = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let message = stream_guard - .read() - .await - .direct_get_last_for_subject(key) - .await?; + let message = stream_guard.direct_get_last_for_subject(key).await?; CounterEntry::try_from(message) }) } diff --git a/src/js/jetstream.rs b/src/js/jetstream.rs index 480da14..6457783 100644 --- a/src/js/jetstream.rs +++ b/src/js/jetstream.rs @@ -2,7 +2,6 @@ use std::sync::Arc; use async_nats::{Subject, connection::State, jetstream::context::traits::Publisher}; use pyo3::{Bound, PyAny, Python, types::PyDict}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -15,15 +14,13 @@ use crate::{ #[pyo3::pyclass] pub struct JetStream { - ctx: Arc>, + ctx: Arc, } impl JetStream { #[must_use] pub fn new(ctx: async_nats::jetstream::Context) -> Self { - Self { - ctx: Arc::new(RwLock::new(ctx)), - } + Self { ctx: Arc::new(ctx) } } } @@ -92,20 +89,16 @@ impl JetStream { err_on_disconnect: bool, wait: bool, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); let data = payload.into(); let headermap = headers .map(async_nats::HeaderMap::from_pydict) .transpose()?; + let client = self.ctx.clone(); natsrpy_future(py, async move { - if err_on_disconnect - && ctx.read().await.client().connection_state() == State::Disconnected - { + if err_on_disconnect && client.client().connection_state() == State::Disconnected { return Err(NatsrpyError::Disconnected); } - let publication = ctx - .read() - .await + let publication = client .publish_message(async_nats::jetstream::message::OutboundMessage { subject: Subject::from(subject), payload: data, diff --git a/src/js/kv.rs b/src/js/kv.rs index e295d70..07500ef 100644 --- a/src/js/kv.rs +++ b/src/js/kv.rs @@ -13,7 +13,7 @@ use pyo3::{ Bound, Py, PyAny, PyRef, Python, types::{PyBytes, PyDateTime}, }; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -229,20 +229,19 @@ pub struct KeyValue { put_prefix: Option, #[pyo3(get)] use_jetstream_prefix: bool, - store: Arc>, + store: Arc, } impl KeyValue { #[must_use] pub fn new(store: async_nats::jetstream::kv::Store) -> Self { - // store. Self { name: store.name.clone(), stream_name: store.stream_name.clone(), prefix: store.prefix.clone(), put_prefix: store.put_prefix.clone(), use_jetstream_prefix: store.use_jetstream_prefix, - store: Arc::new(RwLock::new(store)), + store: Arc::new(store), } } } @@ -253,8 +252,6 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { Ok(store - .read() - .await .get(key) .await? .map(|data| Python::attach(move |gil| PyBytes::new(gil, &data).unbind()))) @@ -273,13 +270,9 @@ impl KeyValue { let data = value.into(); natsrpy_future(py, async move { if let Some(ttl) = ttl { - Ok(store - .read() - .await - .create_with_ttl(key, data, ttl.into()) - .await?) + Ok(store.create_with_ttl(key, data, ttl.into()).await?) } else { - Ok(store.read().await.create(key, data).await?) + Ok(store.create(key, data).await?) } }) } @@ -299,15 +292,9 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { match (ttl, expect_revision) { - (None, _) => Ok(store - .read() - .await - .purge_expect_revision(key, expect_revision) - .await?), - (Some(ttl), None) => Ok(store.read().await.purge_with_ttl(key, ttl.into()).await?), + (None, _) => Ok(store.purge_expect_revision(key, expect_revision).await?), + (Some(ttl), None) => Ok(store.purge_with_ttl(key, ttl.into()).await?), (Some(ttl), Some(revision)) => Ok(store - .read() - .await .purge_expect_revision_with_ttl(key, revision, ttl.into()) .await?), } @@ -322,10 +309,7 @@ impl KeyValue { ) -> NatsrpyResult> { let store = self.store.clone(); let data = value.into(); - natsrpy_future( - py, - async move { Ok(store.read().await.put(key, data).await?) }, - ) + natsrpy_future(py, async move { Ok(store.put(key, data).await?) }) } #[pyo3(signature=( @@ -340,11 +324,7 @@ impl KeyValue { ) -> NatsrpyResult> { let store = self.store.clone(); natsrpy_future(py, async move { - Ok(store - .read() - .await - .delete_expect_revision(key, expect_revision) - .await?) + Ok(store.delete_expect_revision(key, expect_revision).await?) }) } @@ -357,11 +337,7 @@ impl KeyValue { ) -> NatsrpyResult> { let store = self.store.clone(); natsrpy_future(py, async move { - Ok(store - .read() - .await - .update(key, value.into(), revision) - .await?) + Ok(store.update(key, value.into(), revision).await?) }) } @@ -369,7 +345,7 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { Ok(KVEntryIterator::new(Streamer::new( - store.read().await.history(key).await?, + store.history(key).await?, ))) }) } @@ -383,9 +359,9 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { let watch = if let Some(rev) = from_revision { - store.read().await.watch_all_from_revision(rev).await? + store.watch_all_from_revision(rev).await? } else { - store.read().await.watch_all().await? + store.watch_all().await? }; Ok(KVEntryIterator::new(Streamer::new(watch))) }) @@ -401,9 +377,9 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { let watch = if let Some(rev) = from_revision { - store.read().await.watch_from_revision(key, rev).await? + store.watch_from_revision(key, rev).await? } else { - store.read().await.watch(key).await? + store.watch(key).await? }; Ok(KVEntryIterator::new(Streamer::new(watch))) }) @@ -417,7 +393,7 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { Ok(KVEntryIterator::new(Streamer::new( - store.read().await.watch_with_history(key).await?, + store.watch_with_history(key).await?, ))) }) } @@ -430,7 +406,7 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { Ok(KVEntryIterator::new(Streamer::new( - store.read().await.watch_many(keys).await?, + store.watch_many(keys).await?, ))) }) } @@ -443,7 +419,7 @@ impl KeyValue { let store = self.store.clone(); natsrpy_future(py, async move { Ok(KVEntryIterator::new(Streamer::new( - store.read().await.watch_many_with_history(keys).await?, + store.watch_many_with_history(keys).await?, ))) }) } @@ -462,20 +438,12 @@ impl KeyValue { natsrpy_future(py, async move { let entry = if let Some(rev) = revision { store - .read() - .await .entry_for_revision(key, rev) .await? .map(KVEntry::try_from) .transpose()? } else { - store - .read() - .await - .entry(key) - .await? - .map(KVEntry::try_from) - .transpose()? + store.entry(key).await?.map(KVEntry::try_from).transpose()? }; Ok(entry) }) @@ -483,17 +451,13 @@ impl KeyValue { pub fn status<'py>(&self, py: Python<'py>) -> NatsrpyResult> { let store = self.store.clone(); - natsrpy_future(py, async move { - KVStatus::try_from(store.read().await.status().await?) - }) + natsrpy_future(py, async move { KVStatus::try_from(store.status().await?) }) } pub fn keys<'py>(&self, py: Python<'py>) -> NatsrpyResult> { let store = self.store.clone(); natsrpy_future(py, async move { - Ok(KeysIterator::new(Streamer::new( - store.read().await.keys().await?, - ))) + Ok(KeysIterator::new(Streamer::new(store.keys().await?))) }) } } diff --git a/src/js/managers/consumers.rs b/src/js/managers/consumers.rs index 0500a40..334318a 100644 --- a/src/js/managers/consumers.rs +++ b/src/js/managers/consumers.rs @@ -2,7 +2,7 @@ use std::{sync::Arc, time::Duration}; use futures_util::StreamExt; use pyo3::{Bound, FromPyObject, IntoPyObjectExt, PyAny, PyRef, Python}; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::Mutex; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -25,7 +25,7 @@ pub struct ConsumersIterator { >, >, >, - stream: Arc>>, + stream: Arc>, } #[pyo3::pyclass] @@ -73,10 +73,9 @@ impl ConsumersNamesIterator { } impl ConsumersIterator { + #[must_use] pub fn new( - stream: Arc< - RwLock>, - >, + stream: Arc>, streamer: Streamer< Result< async_nats::jetstream::consumer::Info, @@ -119,12 +118,12 @@ impl ConsumersIterator { // That means that the consumer is PushBased. if info.config.deliver_subject.is_some() { let consumer = consumers::push::consumer::PushConsumer::new( - stream.read().await.get_consumer(&consumer_name).await?, + stream.get_consumer(&consumer_name).await?, ); Ok(Python::attach(|py| consumer.into_py_any(py))?) } else { let consumer = consumers::pull::consumer::PullConsumer::new( - stream.read().await.get_consumer(&consumer_name).await?, + stream.get_consumer(&consumer_name).await?, ); Ok(Python::attach(|py| consumer.into_py_any(py))?) } @@ -141,14 +140,13 @@ impl ConsumersIterator { #[pyo3::pyclass] pub struct ConsumersManager { - stream: Arc>>, + stream: Arc>, } impl ConsumersManager { + #[must_use] pub const fn new( - stream: Arc< - RwLock>, - >, + stream: Arc>, ) -> Self { Self { stream } } @@ -193,15 +191,13 @@ impl ConsumersManager { natsrpy_future(py, async move { match config { ConsumerConfigs::Pull(config) => { - let consumer = PullConsumer::new( - ctx.read().await.create_consumer(config.try_into()?).await?, - ); + let consumer = + PullConsumer::new(ctx.create_consumer(config.try_into()?).await?); Ok(Python::attach(|gil| consumer.into_py_any(gil))?) } ConsumerConfigs::Push(config) => { - let consumer = PushConsumer::new( - ctx.read().await.create_consumer(config.try_into()?).await?, - ); + let consumer = + PushConsumer::new(ctx.create_consumer(config.try_into()?).await?); Ok(Python::attach(|gil| consumer.into_py_any(gil))?) } } @@ -217,15 +213,13 @@ impl ConsumersManager { natsrpy_future(py, async move { match config { ConsumerConfigs::Pull(config) => { - let consumer = PullConsumer::new( - ctx.read().await.update_consumer(config.try_into()?).await?, - ); + let consumer = + PullConsumer::new(ctx.update_consumer(config.try_into()?).await?); Ok(Python::attach(|gil| consumer.into_py_any(gil))?) } ConsumerConfigs::Push(config) => { - let consumer = PushConsumer::new( - ctx.read().await.update_consumer(config.try_into()?).await?, - ); + let consumer = + PushConsumer::new(ctx.update_consumer(config.try_into()?).await?); Ok(Python::attach(|gil| consumer.into_py_any(gil))?) } } @@ -236,7 +230,7 @@ impl ConsumersManager { let ctx = self.stream.clone(); natsrpy_future(py, async move { Ok(consumers::pull::consumer::PullConsumer::new( - ctx.read().await.get_consumer(&name).await?, + ctx.get_consumer(&name).await?, )) }) } @@ -245,7 +239,7 @@ impl ConsumersManager { let ctx = self.stream.clone(); natsrpy_future(py, async move { Ok(consumers::push::consumer::PushConsumer::new( - ctx.read().await.get_consumer(&name).await?, + ctx.get_consumer(&name).await?, )) }) } @@ -259,28 +253,30 @@ impl ConsumersManager { let ctx = self.stream.clone(); let untill = time::OffsetDateTime::now_utc() + Duration::from(delay); natsrpy_future(py, async move { - Ok(ctx.read().await.pause_consumer(&name, untill).await?.paused) + Ok(ctx.pause_consumer(&name, untill).await?.paused) }) } pub fn resume<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { let ctx = self.stream.clone(); - natsrpy_future(py, async move { - Ok(ctx.read().await.resume_consumer(&name).await?.paused) - }) + natsrpy_future( + py, + async move { Ok(ctx.resume_consumer(&name).await?.paused) }, + ) } pub fn delete<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { let ctx = self.stream.clone(); - natsrpy_future(py, async move { - Ok(ctx.read().await.delete_consumer(&name).await?.success) - }) + natsrpy_future( + py, + async move { Ok(ctx.delete_consumer(&name).await?.success) }, + ) } pub fn list<'py>(&self, py: Python<'py>) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future(py, async move { - let consumers = ctx.read().await.consumers(); + let consumers = ctx.consumers(); Ok(ConsumersIterator::new( ctx.clone(), Streamer::new(consumers), @@ -291,7 +287,7 @@ impl ConsumersManager { pub fn list_names<'py>(&self, py: Python<'py>) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future(py, async move { - let consumers = ctx.read().await.consumer_names(); + let consumers = ctx.consumer_names(); Ok(ConsumersNamesIterator::new(Streamer::new(consumers))) }) } diff --git a/src/js/managers/counters.rs b/src/js/managers/counters.rs index c54fa83..d3f9a1c 100644 --- a/src/js/managers/counters.rs +++ b/src/js/managers/counters.rs @@ -5,17 +5,17 @@ use crate::{ js::counters::{Counters, CountersConfig}, }; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{exceptions::rust_err::NatsrpyResult, utils::natsrpy_future}; #[pyo3::pyclass] pub struct CountersManager { - ctx: Arc>, + ctx: Arc, } impl CountersManager { - pub const fn new(ctx: Arc>) -> Self { + #[must_use] + pub const fn new(ctx: Arc) -> Self { Self { ctx } } } @@ -27,13 +27,13 @@ impl CountersManager { py: Python<'py>, config: CountersConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - let js = ctx.read().await; Ok(Counters::new( - js.create_stream(async_nats::jetstream::stream::Config::try_from(config)?) + client + .create_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?, - ctx.clone(), + client, )) }) } @@ -43,24 +43,22 @@ impl CountersManager { py: Python<'py>, config: CountersConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - let info = ctx - .read() - .await + let info = client .create_or_update_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?; Ok(Counters::new( - ctx.read().await.get_stream(info.config.name).await?, - ctx.clone(), + client.get_stream(info.config.name).await?, + client, )) }) } pub fn get<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - let stream = ctx.read().await.get_stream(&name).await?; + let stream = client.get_stream(&name).await?; let config = stream.get_info().await?.config; if !config.allow_direct { return Err(NatsrpyError::SessionError(format!( @@ -72,16 +70,16 @@ impl CountersManager { "Stream {name} doesn't allow message counters.", ))); } - Ok(Counters::new(stream, ctx.clone())) + Ok(Counters::new(stream, client)) }) } pub fn delete<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { - let ctx = self.ctx.clone(); - natsrpy_future(py, async move { - let js = ctx.read().await; - Ok(js.delete_stream(name).await?.success) - }) + let client = self.ctx.clone(); + natsrpy_future( + py, + async move { Ok(client.delete_stream(name).await?.success) }, + ) } pub fn update<'py>( @@ -89,16 +87,14 @@ impl CountersManager { py: Python<'py>, config: CountersConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - let info = ctx - .read() - .await + let info = client .update_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?; Ok(Counters::new( - ctx.read().await.get_stream(info.config.name).await?, - ctx.clone(), + client.get_stream(info.config.name).await?, + client, )) }) } diff --git a/src/js/managers/kv.rs b/src/js/managers/kv.rs index f805a2f..cde59e8 100644 --- a/src/js/managers/kv.rs +++ b/src/js/managers/kv.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::NatsrpyResult, @@ -11,11 +10,12 @@ use crate::{ #[pyo3::pyclass] pub struct KVManager { - ctx: Arc>, + ctx: Arc, } impl KVManager { - pub const fn new(ctx: Arc>) -> Self { + #[must_use] + pub const fn new(ctx: Arc) -> Self { Self { ctx } } } @@ -27,13 +27,10 @@ impl KVManager { py: Python<'py>, config: KVConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { Ok(KeyValue::new( - ctx.read() - .await - .create_key_value(config.try_into()?) - .await?, + client.create_key_value(config.try_into()?).await?, )) }) } @@ -43,11 +40,10 @@ impl KVManager { py: Python<'py>, config: KVConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { Ok(KeyValue::new( - ctx.read() - .await + client .create_or_update_key_value(config.try_into()?) .await?, )) @@ -55,16 +51,16 @@ impl KVManager { } pub fn get<'py>(&self, py: Python<'py>, bucket: String) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - Ok(KeyValue::new(ctx.read().await.get_key_value(bucket).await?)) + Ok(KeyValue::new(client.get_key_value(bucket).await?)) }) } pub fn delete<'py>(&self, py: Python<'py>, bucket: String) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { - Ok(ctx.read().await.delete_key_value(bucket).await?.success) + Ok(client.delete_key_value(bucket).await?.success) }) } @@ -73,13 +69,10 @@ impl KVManager { py: Python<'py>, config: KVConfig, ) -> NatsrpyResult> { - let ctx = self.ctx.clone(); + let client = self.ctx.clone(); natsrpy_future(py, async move { Ok(KeyValue::new( - ctx.read() - .await - .update_key_value(config.try_into()?) - .await?, + client.update_key_value(config.try_into()?).await?, )) }) } diff --git a/src/js/managers/object_store.rs b/src/js/managers/object_store.rs index c4c9c51..1267d0f 100644 --- a/src/js/managers/object_store.rs +++ b/src/js/managers/object_store.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::NatsrpyResult, @@ -11,11 +10,12 @@ use crate::{ #[pyo3::pyclass] pub struct ObjectStoreManager { - ctx: Arc>, + ctx: Arc, } impl ObjectStoreManager { - pub const fn new(ctx: Arc>) -> Self { + #[must_use] + pub const fn new(ctx: Arc) -> Self { Self { ctx } } } @@ -25,9 +25,7 @@ impl ObjectStoreManager { pub fn get<'py>(&self, py: Python<'py>, bucket: String) -> NatsrpyResult> { let ctx_guard = self.ctx.clone(); natsrpy_future(py, async move { - Ok(ObjectStore::new( - ctx_guard.read().await.get_object_store(bucket).await?, - )) + Ok(ObjectStore::new(ctx_guard.get_object_store(bucket).await?)) }) } @@ -36,22 +34,18 @@ impl ObjectStoreManager { py: Python<'py>, config: ObjectStoreConfig, ) -> NatsrpyResult> { - let ctx_guard = self.ctx.clone(); + let ctx = self.ctx.clone(); natsrpy_future(py, async move { Ok(ObjectStore::new( - ctx_guard - .read() - .await - .create_object_store(config.into()) - .await?, + ctx.create_object_store(config.into()).await?, )) }) } pub fn delete<'py>(&self, py: Python<'py>, bucket: String) -> NatsrpyResult> { - let ctx_guard = self.ctx.clone(); + let ctx = self.ctx.clone(); natsrpy_future(py, async move { - ctx_guard.read().await.delete_object_store(bucket).await?; + ctx.delete_object_store(bucket).await?; Ok(()) }) } diff --git a/src/js/managers/streams.rs b/src/js/managers/streams.rs index 66700f2..4ad91a9 100644 --- a/src/js/managers/streams.rs +++ b/src/js/managers/streams.rs @@ -2,17 +2,17 @@ use std::sync::Arc; use crate::js::stream::Stream; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; use crate::{exceptions::rust_err::NatsrpyResult, js::stream::StreamConfig, utils::natsrpy_future}; #[pyo3::pyclass] pub struct StreamsManager { - ctx: Arc>, + ctx: Arc, } impl StreamsManager { - pub const fn new(ctx: Arc>) -> Self { + #[must_use] + pub const fn new(ctx: Arc) -> Self { Self { ctx } } } @@ -26,9 +26,8 @@ impl StreamsManager { ) -> NatsrpyResult> { let ctx = self.ctx.clone(); natsrpy_future(py, async move { - let js = ctx.read().await; Ok(Stream::new( - js.create_stream(async_nats::jetstream::stream::Config::try_from(config)?) + ctx.create_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?, )) }) @@ -42,29 +41,26 @@ impl StreamsManager { let ctx = self.ctx.clone(); natsrpy_future(py, async move { let info = ctx - .read() - .await .create_or_update_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?; - Ok(Stream::new( - ctx.read().await.get_stream(info.config.name).await?, - )) + Ok(Stream::new(ctx.get_stream(info.config.name).await?)) }) } pub fn get<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { let ctx = self.ctx.clone(); - natsrpy_future(py, async move { - Ok(Stream::new(ctx.read().await.get_stream(name).await?)) - }) + natsrpy_future( + py, + async move { Ok(Stream::new(ctx.get_stream(name).await?)) }, + ) } pub fn delete<'py>(&self, py: Python<'py>, name: String) -> NatsrpyResult> { let ctx = self.ctx.clone(); - natsrpy_future(py, async move { - let js = ctx.read().await; - Ok(js.delete_stream(name).await?.success) - }) + natsrpy_future( + py, + async move { Ok(ctx.delete_stream(name).await?.success) }, + ) } pub fn update<'py>( @@ -75,13 +71,9 @@ impl StreamsManager { let ctx = self.ctx.clone(); natsrpy_future(py, async move { let info = ctx - .read() - .await .update_stream(async_nats::jetstream::stream::Config::try_from(config)?) .await?; - Ok(Stream::new( - ctx.read().await.get_stream(info.config.name).await?, - )) + Ok(Stream::new(ctx.get_stream(info.config.name).await?)) }) } } diff --git a/src/js/message.rs b/src/js/message.rs index d8045dd..deb4fc6 100644 --- a/src/js/message.rs +++ b/src/js/message.rs @@ -3,7 +3,6 @@ use pyo3::{ types::{PyBytes, PyDateTime, PyDict}, }; use std::sync::Arc; -use tokio::sync::RwLock; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -48,7 +47,7 @@ impl From> for JSInfo { pub struct JetStreamMessage { message: crate::message::Message, info: JSInfo, - acker: Arc>, + acker: Arc, } impl TryFrom for JetStreamMessage { @@ -60,7 +59,7 @@ impl TryFrom for JetStreamMessage { Ok(Self { message: message.try_into()?, info: js_info, - acker: Arc::new(RwLock::new(acker)), + acker: Arc::new(acker), }) } } @@ -72,12 +71,12 @@ impl JetStreamMessage { kind: async_nats::jetstream::message::AckKind, double: bool, ) -> NatsrpyResult> { - let acker_guard = self.acker.clone(); + let acker = self.acker.clone(); natsrpy_future(py, async move { if double { - acker_guard.read().await.double_ack_with(kind).await?; + acker.double_ack_with(kind).await?; } else { - acker_guard.read().await.ack_with(kind).await?; + acker.ack_with(kind).await?; } Ok(()) }) diff --git a/src/js/stream.rs b/src/js/stream.rs index 4892fcc..66941db 100644 --- a/src/js/stream.rs +++ b/src/js/stream.rs @@ -14,7 +14,6 @@ use crate::{ }, }; use pyo3::{Bound, PyAny, Python}; -use tokio::sync::RwLock; #[pyo3::pyclass(from_py_object)] #[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] @@ -944,7 +943,7 @@ impl From for PurgeResponse { pub struct Stream { #[pyo3(get)] name: String, - stream: Arc>>, + stream: Arc>, } impl Stream { #[must_use] @@ -954,7 +953,7 @@ impl Stream { let info = stream.cached_info(); Self { name: info.config.name.clone(), - stream: Arc::new(RwLock::new(stream)), + stream: Arc::new(stream), } } } @@ -976,7 +975,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let message = ctx.read().await.direct_get(sequence).await?; + let message = ctx.direct_get(sequence).await?; let result = Python::attach(move |gil| StreamMessage::from_nats_message(gil, &message))?; Ok(result) @@ -993,11 +992,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let message = ctx - .read() - .await - .direct_get_next_for_subject(subject, sequence) - .await?; + let message = ctx.direct_get_next_for_subject(subject, sequence).await?; let result = Python::attach(move |gil| StreamMessage::from_nats_message(gil, &message))?; Ok(result) @@ -1013,11 +1008,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let message = ctx - .read() - .await - .direct_get_first_for_subject(subject) - .await?; + let message = ctx.direct_get_first_for_subject(subject).await?; let result = Python::attach(move |gil| StreamMessage::from_nats_message(gil, &message))?; Ok(result) @@ -1033,11 +1024,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let message = ctx - .read() - .await - .direct_get_last_for_subject(subject) - .await?; + let message = ctx.direct_get_last_for_subject(subject).await?; let result = Python::attach(move |gil| StreamMessage::from_nats_message(gil, &message))?; Ok(result) @@ -1052,7 +1039,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - StreamInfo::try_from(ctx.read().await.get_info().await?) + StreamInfo::try_from(ctx.get_info().await?) }) } @@ -1072,7 +1059,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let mut purge_request = ctx.read().await.purge(); + let mut purge_request = ctx.purge(); if let Some(filter) = filter { purge_request = purge_request.filter(filter); } @@ -1105,7 +1092,7 @@ impl Stream { ) -> NatsrpyResult> { let ctx = self.stream.clone(); natsrpy_future_with_timeout(py, timeout, async move { - ctx.read().await.delete_message(sequence).await?; + ctx.delete_message(sequence).await?; Ok(()) }) } diff --git a/src/lib.rs b/src/lib.rs index 7d1174d..f967594 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,10 @@ // to have many args with defaults. clippy::too_many_arguments )] + +#[global_allocator] +static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; + pub mod exceptions; pub mod js; pub mod message; diff --git a/src/message.rs b/src/message.rs index 082e0fd..0a0e534 100644 --- a/src/message.rs +++ b/src/message.rs @@ -17,36 +17,59 @@ pub struct Message { pub length: usize, } +impl Message { + /// Convert from an `async_nats::Message` using an already-held GIL token. + /// This avoids a redundant `Python::attach` call when the caller already has the GIL. + pub fn from_nats_message( + gil: Python<'_>, + value: &async_nats::Message, + ) -> Result { + let headers = match &value.headers { + Some(headermap) => headermap.to_pydict(gil)?.unbind(), + None => PyDict::new(gil).unbind(), + }; + Ok(Self { + subject: value.subject.to_string(), + reply: value.reply.as_deref().map(ToString::to_string), + payload: PyBytes::new(gil, &value.payload).unbind(), + headers, + status: value.status.map(Into::::into), + description: value.description.clone(), + length: value.length, + }) + } +} + impl TryFrom<&async_nats::Message> for Message { type Error = NatsrpyError; fn try_from(value: &async_nats::Message) -> Result { + Python::attach(move |gil| Self::from_nats_message(gil, value)) + } +} + +impl TryFrom for Message { + type Error = NatsrpyError; + + fn try_from(value: async_nats::Message) -> Result { Python::attach(move |gil| { let headers = match &value.headers { Some(headermap) => headermap.to_pydict(gil)?.unbind(), None => PyDict::new(gil).unbind(), }; Ok(Self { - subject: value.subject.to_string(), - reply: value.reply.as_deref().map(ToString::to_string), + subject: value.subject.into_string(), + reply: value.reply.map(async_nats::Subject::into_string), payload: PyBytes::new(gil, &value.payload).unbind(), headers, status: value.status.map(Into::::into), - description: value.description.clone(), + description: value.description, length: value.length, }) }) } } -impl TryFrom for Message { - type Error = NatsrpyError; - - fn try_from(value: async_nats::Message) -> Result { - Self::try_from(&value) - } -} - #[pyo3::pymethods] impl Message { #[must_use] diff --git a/src/nats_cls.rs b/src/nats_cls.rs index 00fd8f1..07f1665 100644 --- a/src/nats_cls.rs +++ b/src/nats_cls.rs @@ -1,7 +1,6 @@ use async_nats::{Subject, client::traits::Publisher, message::OutboundMessage}; use pyo3::{Bound, IntoPyObjectExt, Py, PyAny, Python, types::PyDict}; -use std::sync::Arc; -use tokio::sync::RwLock; +use std::sync::{Arc, RwLock}; use crate::{ exceptions::rust_err::{NatsrpyError, NatsrpyResult}, @@ -16,7 +15,7 @@ use crate::{ #[pyo3::pyclass(name = "Nats")] pub struct NatsCls { - nats_session: Arc>>, + nats_session: Arc>>, addr: Vec, user_and_pass: Option<(String, String)>, nkey: Option, @@ -29,6 +28,15 @@ pub struct NatsCls { request_timeout: Option, } +/// Helper to read the client from the `RwLock`. Returns a clone of the Client if present. +fn get_client(session: &RwLock>) -> NatsrpyResult { + session + .read() + .map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))? + .clone() + .ok_or(NatsrpyError::NotInitialized) +} + #[pyo3::pymethods] impl NatsCls { #[new] @@ -98,15 +106,22 @@ impl NatsCls { let address = self.addr.clone(); let timeout = self.connection_timeout; natsrpy_future_with_timeout(py, Some(timeout), async move { - if session.read().await.is_some() { - return Err(NatsrpyError::SessionError( - "NATS session already exists".to_string(), - )); + { + let guard = session + .read() + .map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))?; + if guard.is_some() { + return Err(NatsrpyError::SessionError( + "NATS session already exists".to_string(), + )); + } } - // Scoping for early-dropping of a guard. + let client = conn_opts.connect(address).await?; { - let mut sesion_guard = session.write().await; - *sesion_guard = Some(conn_opts.connect(address).await?); + let mut guard = session + .write() + .map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))?; + *guard = Some(client); } Ok(()) }) @@ -122,7 +137,7 @@ impl NatsCls { reply: Option, err_on_disconnect: bool, ) -> NatsrpyResult> { - let session = self.nats_session.clone(); + let client = get_client(&self.nats_session)?; let data = bytes::Bytes::from(payload); let headermap = headers .map(async_nats::HeaderMap::from_pydict) @@ -133,24 +148,20 @@ impl NatsCls { data.len() ); natsrpy_future(py, async move { - if let Some(session) = session.read().await.as_ref() { - if err_on_disconnect - && session.connection_state() == async_nats::connection::State::Disconnected - { - return Err(NatsrpyError::Disconnected); - } - session - .publish_message(OutboundMessage { - subject: Subject::from(subject), - payload: data, - headers: headermap, - reply: reply.map(Subject::from), - }) - .await?; - Ok(()) - } else { - Err(NatsrpyError::NotInitialized) + if err_on_disconnect + && client.connection_state() == async_nats::connection::State::Disconnected + { + return Err(NatsrpyError::Disconnected); } + client + .publish_message(OutboundMessage { + subject: Subject::from(subject), + payload: data, + headers: headermap, + reply: reply.map(Subject::from), + }) + .await?; + Ok(()) }) } @@ -164,7 +175,7 @@ impl NatsCls { inbox: Option, timeout: Option, ) -> NatsrpyResult> { - let session = self.nats_session.clone(); + let client = get_client(&self.nats_session)?; let data = payload.map(bytes::Bytes::from); let headermap = headers .map(async_nats::HeaderMap::from_pydict) @@ -175,30 +186,22 @@ impl NatsCls { data.as_ref().map_or(0, bytes::Bytes::len) ); natsrpy_future(py, async move { - if let Some(session) = session.read().await.as_ref() { - let request = async_nats::Request { - payload: data, - headers: headermap, - inbox, - timeout: timeout.map(Into::into).map(Some), - }; - crate::message::Message::try_from(session.send_request(subject, request).await?) - } else { - Err(NatsrpyError::NotInitialized) - } + let request = async_nats::Request { + payload: data, + headers: headermap, + inbox, + timeout: timeout.map(Into::into).map(Some), + }; + crate::message::Message::try_from(client.send_request(subject, request).await?) }) } pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Draining NATS session"); - let session = self.nats_session.clone(); + let client = get_client(&self.nats_session)?; natsrpy_future(py, async move { - if let Some(session) = session.write().await.as_ref() { - session.drain().await?; - Ok(()) - } else { - Err(NatsrpyError::NotInitialized) - } + client.drain().await?; + Ok(()) }) } @@ -211,23 +214,19 @@ impl NatsCls { queue: Option, ) -> NatsrpyResult> { log::debug!("Subscribing to '{subject}'"); - let session = self.nats_session.clone(); + let client = get_client(&self.nats_session)?; natsrpy_future(py, async move { - if let Some(session) = session.read().await.as_ref() { - let subscriber = if let Some(queue) = queue { - session.queue_subscribe(subject, queue).await? - } else { - session.subscribe(subject).await? - }; - if let Some(cb) = callback { - let sub = CallbackSubscription::new(subscriber, cb)?; - Ok(Python::attach(|gil| sub.into_py_any(gil))?) - } else { - let sub = IteratorSubscription::new(subscriber); - Ok(Python::attach(|gil| sub.into_py_any(gil))?) - } + let subscriber = if let Some(queue) = queue { + client.queue_subscribe(subject, queue).await? + } else { + client.subscribe(subject).await? + }; + if let Some(cb) = callback { + let sub = CallbackSubscription::new(subscriber, cb)?; + Ok(Python::attach(|gil| sub.into_py_any(gil))?) } else { - Err(NatsrpyError::NotInitialized) + let sub = IteratorSubscription::new(subscriber); + Ok(Python::attach(|gil| sub.into_py_any(gil))?) } }) } @@ -254,7 +253,12 @@ impl NatsCls { backpressure_on_inflight: Option, ) -> NatsrpyResult> { log::debug!("Creating JetStream context"); - let session = self.nats_session.clone(); + if domain.is_some() && api_prefix.is_some() { + return Err(NatsrpyError::InvalidArgument(String::from( + "Either domain or api_prefix should be specified, not both.", + ))); + } + let client = get_client(&self.nats_session)?; natsrpy_future(py, async move { let mut builder = async_nats::jetstream::ContextBuilder::new().concurrency_limit(concurrency_limit); @@ -270,67 +274,59 @@ impl NatsCls { if let Some(backpressure_on_inflight) = backpressure_on_inflight { builder = builder.backpressure_on_inflight(backpressure_on_inflight); } - if domain.is_some() && api_prefix.is_some() { - return Err(NatsrpyError::InvalidArgument(String::from( - "Either domain or api_prefix should be specified, not both.", - ))); - } - session.read().await.as_ref().map_or_else( - || Err(NatsrpyError::NotInitialized), - |session| { - let js = if let Some(api_prefix) = api_prefix { - builder.api_prefix(api_prefix).build(session.clone()) - } else if let Some(domain) = domain { - builder.domain(domain).build(session.clone()) - } else { - builder.build(session.clone()) - }; - Ok(crate::js::jetstream::JetStream::new(js)) - }, - ) + let js = if let Some(api_prefix) = api_prefix { + builder.api_prefix(api_prefix).build(client) + } else if let Some(domain) = domain { + builder.domain(domain).build(client) + } else { + builder.build(client) + }; + Ok(crate::js::jetstream::JetStream::new(js)) }) } pub fn shutdown<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Closing nats session"); let session = self.nats_session.clone(); + let client = get_client(&session)?; + // Set session to None immediately so no new operations can start. + { + let mut guard = session + .write() + .map_err(|_| NatsrpyError::SessionError("Lock poisoned".to_string()))?; + *guard = None; + } natsrpy_future(py, async move { - let mut write_guard = session.write().await; - let Some(session) = write_guard.as_ref() else { - return Err(NatsrpyError::NotInitialized); - }; - session.drain().await?; - *write_guard = None; - drop(write_guard); + client.drain().await?; Ok(()) }) } pub fn flush<'py>(&self, py: Python<'py>) -> NatsrpyResult> { log::debug!("Flushing streams"); - let session = self.nats_session.clone(); + let client = get_client(&self.nats_session)?; natsrpy_future(py, async move { - if let Some(session) = session.write().await.as_ref() { - session.flush().await?; - Ok(()) - } else { - Err(NatsrpyError::NotInitialized) - } + client.flush().await?; + Ok(()) }) } } impl Drop for NatsCls { fn drop(&mut self) { - pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - let mut write_guard = self.nats_session.write().await; - if let Some(session) = write_guard.as_ref() { - log::warn!( - "NATS session was not closed before dropping. Draining session in drop. Please call `.shutdown()` function before dropping the session to avoid this warning." - ); - session.drain().await.ok(); - } - *write_guard = None; - }); + let client = { + let Ok(mut guard) = self.nats_session.write() else { + return; + }; + guard.take() + }; + if let Some(client) = client { + log::warn!( + "NATS session was not closed before dropping. Draining session in drop. Please call `.shutdown()` function before dropping the session to avoid this warning." + ); + pyo3_async_runtimes::tokio::get_runtime().block_on(async move { + client.drain().await.ok(); + }); + } } } diff --git a/src/subscriptions/callback.rs b/src/subscriptions/callback.rs index f83d0bf..84af5b3 100644 --- a/src/subscriptions/callback.rs +++ b/src/subscriptions/callback.rs @@ -1,22 +1,27 @@ -use std::{sync::Arc, time::Duration}; +use std::sync::Arc; use futures_util::StreamExt; use pyo3::{Bound, Py, PyAny, Python}; -use tokio::sync::Mutex; use crate::{exceptions::rust_err::NatsrpyResult, utils::natsrpy_future}; +enum UnsubscribeCommand { + Unsubscribe, + UnsubscribeAfter(u64), + Drain, +} + #[pyo3::pyclass] pub struct CallbackSubscription { - inner: Option>>, + unsub_sender: Option>, reading_task: tokio::task::AbortHandle, } -async fn process_message(message: async_nats::message::Message, py_callback: Py) { +async fn process_message(message: async_nats::message::Message, py_callback: Arc>) { let task = async || -> NatsrpyResult<()> { log::debug!("Received message: {:?}. Processing ...", &message); - let message = crate::message::Message::try_from(&message)?; let awaitable = Python::attach(|gil| -> NatsrpyResult<_> { + let message = crate::message::Message::from_nats_message(gil, &message)?; let res = py_callback.call1(gil, (message,))?; let rust_task = pyo3_async_runtimes::tokio::into_future(res.into_bound(gil))?; Ok(rust_task) @@ -31,50 +36,58 @@ async fn process_message(message: async_nats::message::Message, py_callback: Py< } async fn start_py_sub( - sub: Arc>, - py_callback: Py, + mut sub: async_nats::Subscriber, + py_callback: Arc>, locals: pyo3_async_runtimes::TaskLocals, + mut unsub_receiver: tokio::sync::mpsc::Receiver, ) { loop { - let message = { - let mut sub_guard = sub.lock().await; - // We wait up to 0.2 second for new messages. - // If this thing doesn't resolve in this period, - // we just release the lock. Otherwise it would be impossible to - // unsubscribe. - match tokio::time::timeout(Duration::from_millis(200), sub_guard.next()).await { - Ok(Some(message)) => message, - Ok(None) => break, - _ => continue, + tokio::select! { + msg = sub.next() => { + match msg { + Some(message) => { + let py_cb = py_callback.clone(); + tokio::spawn(pyo3_async_runtimes::tokio::scope( + locals.clone(), + process_message(message, py_cb), + )); + } + None => break, + } } - }; - let py_cb = Python::attach(|py| py_callback.clone_ref(py)); - tokio::spawn(pyo3_async_runtimes::tokio::scope( - locals.clone(), - process_message(message, py_cb), - )); + cmd = unsub_receiver.recv() => { + match cmd { + Some(UnsubscribeCommand::Unsubscribe) => { + sub.unsubscribe().await.ok(); + break; + } + Some(UnsubscribeCommand::UnsubscribeAfter(limit)) => { + sub.unsubscribe_after(limit).await.ok(); + // Don't break — continue receiving up to `limit` messages. + } + Some(UnsubscribeCommand::Drain) => { + sub.drain().await.ok(); + break; + } + None => break, + } + } + } } - // while let Some(message) = sub.lock().await.next().await { - // let py_cb = Python::attach(|py| py_callback.clone_ref(py)); - // tokio::spawn(pyo3_async_runtimes::tokio::scope( - // locals.clone(), - // process_message(message, py_cb), - // )); - // } } impl CallbackSubscription { pub fn new(sub: async_nats::Subscriber, callback: Py) -> NatsrpyResult { - let sub = Arc::new(Mutex::new(sub)); - let cb_sub = sub.clone(); + let (unsub_tx, unsub_rx) = tokio::sync::mpsc::channel(1); let task_locals = Python::attach(pyo3_async_runtimes::tokio::get_current_locals)?; + let callback = Arc::new(callback); let task_handle = tokio::task::spawn(pyo3_async_runtimes::tokio::scope( task_locals.clone(), - start_py_sub(cb_sub, callback, task_locals), + start_py_sub(sub, callback, task_locals, unsub_rx), )) .abort_handle(); Ok(Self { - inner: Some(sub), + unsub_sender: Some(unsub_tx), reading_task: task_handle, }) } @@ -88,46 +101,43 @@ impl CallbackSubscription { py: Python<'py>, limit: Option, ) -> NatsrpyResult> { - let Some(inner) = self.inner.clone() else { + let Some(sender) = self.unsub_sender.clone() else { unreachable!("Subscription used after del") }; natsrpy_future(py, async move { - if let Some(limit) = limit { - inner.lock().await.unsubscribe_after(limit).await?; - } else { - inner.lock().await.unsubscribe().await?; - } + let cmd = limit.map_or(UnsubscribeCommand::Unsubscribe, |n| { + UnsubscribeCommand::UnsubscribeAfter(n) + }); + sender.send(cmd).await.map_err(|_| { + crate::exceptions::rust_err::NatsrpyError::SessionError( + "Subscription already closed".to_string(), + ) + })?; Ok(()) }) } pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { - let Some(inner) = self.inner.clone() else { + let Some(sender) = self.unsub_sender.clone() else { unreachable!("Subscription used after del") }; natsrpy_future(py, async move { - inner.lock().await.drain().await?; + sender.send(UnsubscribeCommand::Drain).await.map_err(|_| { + crate::exceptions::rust_err::NatsrpyError::SessionError( + "Subscription already closed".to_string(), + ) + })?; Ok(()) }) } } -/// This is required only because -/// in nats library they run async operation on Drop. -/// -/// Because of that we need to execute drop in async -/// runtime's context. -/// -/// And because we want to perform a drop, -/// we need somehow drop the inner variable, -/// but leave self intouch. That is exactly why we have -/// Option>. So we can just assign it to None -/// and it will perform a drop. impl Drop for CallbackSubscription { fn drop(&mut self) { - pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - self.inner = None; - self.reading_task.abort(); - }); + // Drop the sender to signal the reading task to stop, + // then abort the task. Both operations are synchronous + // and don't require an async runtime context. + self.unsub_sender = None; + self.reading_task.abort(); } } diff --git a/src/subscriptions/iterator.rs b/src/subscriptions/iterator.rs index b38955f..101fc71 100644 --- a/src/subscriptions/iterator.rs +++ b/src/subscriptions/iterator.rs @@ -2,23 +2,76 @@ use std::sync::Arc; use futures_util::StreamExt; use pyo3::{Bound, PyAny, PyRef, Python}; -use tokio::sync::Mutex; use crate::exceptions::rust_err::{NatsrpyError, NatsrpyResult}; use crate::utils::futures::natsrpy_future_with_timeout; use crate::utils::natsrpy_future; use crate::utils::py_types::TimeValue; +enum UnsubscribeCommand { + Unsubscribe, + UnsubscribeAfter(u64), + Drain, +} + +/// Background task that owns the subscriber and forwards messages +/// to an mpsc channel. Listens for unsubscribe commands on a separate channel. +async fn sub_forwarder( + mut sub: async_nats::Subscriber, + msg_tx: tokio::sync::mpsc::Sender, + mut unsub_rx: tokio::sync::mpsc::Receiver, +) { + loop { + tokio::select! { + msg = sub.next() => { + match msg { + Some(message) => { + if msg_tx.send(message).await.is_err() { + // Receiver dropped, stop forwarding. + break; + } + } + None => break, + } + } + cmd = unsub_rx.recv() => { + match cmd { + Some(UnsubscribeCommand::Unsubscribe) => { + sub.unsubscribe().await.ok(); + break; + } + Some(UnsubscribeCommand::UnsubscribeAfter(limit)) => { + sub.unsubscribe_after(limit).await.ok(); + // Don't break — continue receiving up to `limit` messages. + } + Some(UnsubscribeCommand::Drain) => { + sub.drain().await.ok(); + break; + } + None => break, + } + } + } + } +} + #[pyo3::pyclass] pub struct IteratorSubscription { - inner: Option>>, + msg_rx: Arc>>, + unsub_tx: Option>, + task_handle: tokio::task::AbortHandle, } impl IteratorSubscription { #[must_use] pub fn new(sub: async_nats::Subscriber) -> Self { + let (msg_tx, msg_rx) = tokio::sync::mpsc::channel(128); + let (unsub_tx, unsub_rx) = tokio::sync::mpsc::channel(1); + let task_handle = tokio::task::spawn(sub_forwarder(sub, msg_tx, unsub_rx)).abort_handle(); Self { - inner: Some(Arc::new(Mutex::new(sub))), + msg_rx: Arc::new(tokio::sync::Mutex::new(msg_rx)), + unsub_tx: Some(unsub_tx), + task_handle, } } } @@ -36,14 +89,13 @@ impl IteratorSubscription { py: Python<'py>, timeout: Option, ) -> NatsrpyResult> { - let Some(inner) = self.inner.clone() else { - unreachable!("Subscription used after del") - }; + let msg_rx = self.msg_rx.clone(); natsrpy_future_with_timeout(py, timeout, async move { - let Some(message) = inner.lock().await.next().await else { - return Err(NatsrpyError::AsyncStopIteration); - }; - crate::message::Message::try_from(message) + let mut rx = msg_rx.lock().await; + rx.recv().await.map_or_else( + || Err(NatsrpyError::AsyncStopIteration), + crate::message::Message::try_from, + ) }) } @@ -57,45 +109,39 @@ impl IteratorSubscription { py: Python<'py>, limit: Option, ) -> NatsrpyResult> { - let Some(inner) = self.inner.clone() else { + let Some(sender) = self.unsub_tx.clone() else { unreachable!("Subscription used after del") }; natsrpy_future(py, async move { - if let Some(limit) = limit { - inner.lock().await.unsubscribe_after(limit).await?; - } else { - inner.lock().await.unsubscribe().await?; - } + let cmd = limit.map_or(UnsubscribeCommand::Unsubscribe, |n| { + UnsubscribeCommand::UnsubscribeAfter(n) + }); + sender.send(cmd).await.map_err(|_| { + NatsrpyError::SessionError("Subscription already closed".to_string()) + })?; Ok(()) }) } pub fn drain<'py>(&self, py: Python<'py>) -> NatsrpyResult> { - let Some(inner) = self.inner.clone() else { + let Some(sender) = self.unsub_tx.clone() else { unreachable!("Subscription used after del") }; natsrpy_future(py, async move { - inner.lock().await.drain().await?; + sender.send(UnsubscribeCommand::Drain).await.map_err(|_| { + NatsrpyError::SessionError("Subscription already closed".to_string()) + })?; Ok(()) }) } } -/// This is required only because -/// in nats library they run async operation on Drop. -/// -/// Because of that we need to execute drop in async -/// runtime's context. -/// -/// And because we want to perform a drop, -/// we need somehow drop the inner variable, -/// but leave self intouch. That is exactly why we have -/// Option>. So we can just assign it to None -/// and it will perform a drop. impl Drop for IteratorSubscription { fn drop(&mut self) { - pyo3_async_runtimes::tokio::get_runtime().block_on(async move { - self.inner = None; - }); + // Drop the sender to signal the forwarder task to stop, + // then abort the task. Both operations are synchronous + // and don't require an async runtime context. + self.unsub_tx = None; + self.task_handle.abort(); } } diff --git a/src/utils/streamer.rs b/src/utils/streamer.rs index 9bd7cd8..82ce3b7 100644 --- a/src/utils/streamer.rs +++ b/src/utils/streamer.rs @@ -24,7 +24,7 @@ impl Streamer { pub fn new( stream: impl futures_util::Stream + std::marker::Unpin + Send + 'static, ) -> Self { - let (tx, rx) = tokio::sync::mpsc::channel(1); + let (tx, rx) = tokio::sync::mpsc::channel(128); let task = tokio::task::spawn(task_pooler(stream, tx)); Self { messages: Arc::new(Mutex::new(rx)),