Skip to content

Commit 84924a5

Browse files
committed
Chore: Add type hints to manager.py
1 parent c8b9ece commit 84924a5

7 files changed

Lines changed: 96 additions & 72 deletions

File tree

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ clean-dist: clean
2121
rm -rf dist/
2222

2323
lint: venv
24-
$(VENV_ACTIVATE); python -m ruff check .
24+
$(VENV_ACTIVATE); python -m ruff check . && python -m mypy
2525

2626
format: venv
2727
$(VENV_ACTIVATE); python -m ruff format . && python -m ruff check . --fix

mypy.ini

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[mypy]
2+
explicit_package_bases = true
3+
files=plux/runtime/manager.py,tests/test_manager.py
4+
ignore_missing_imports = False
5+
follow_imports = silent
6+
ignore_errors = False
7+
disallow_untyped_defs = True
8+
disallow_untyped_calls = True
9+
disallow_any_generics = True
10+
disallow_subclassing_any = True
11+
warn_unused_ignores = True

plux/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,5 @@
3333
"PluginSpecResolver",
3434
"PluginType",
3535
"plugin",
36-
"__version__"
36+
"__version__",
3737
]

plux/core/plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class PluginDisabled(PluginException):
2727

2828
reason: str
2929

30-
def __init__(self, namespace: str, name: str, reason: str = None):
30+
def __init__(self, namespace: str, name: str, reason: str | None = None):
3131
message = f"plugin {namespace}:{name} is disabled"
3232
if reason:
3333
message = f"{message}, reason: {reason}"

plux/runtime/manager.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
import threading
33
import typing as t
4+
from collections.abc import Iterable
5+
from importlib.metadata import EntryPoint
46

57
from plux.core.plugin import (
68
Plugin,
@@ -18,9 +20,10 @@
1820
LOG = logging.getLogger(__name__)
1921

2022
P = t.TypeVar("P", bound=Plugin)
23+
PS = t.ParamSpec("PS")
2124

2225

23-
def _call_safe(func: t.Callable, args: tuple, exception_message: str):
26+
def _call_safe(func: t.Callable[PS, None], args: t.Any, exception_message: str) -> None:
2427
"""
2528
Call the given function with the given arguments, and if it fails, log the given exception_message. If
2629
logging.DEBUG is set for the logger, then we also log the traceback. An exception is made for any
@@ -32,7 +35,7 @@ def _call_safe(func: t.Callable, args: tuple, exception_message: str):
3235
:return: whatever the func returns
3336
"""
3437
try:
35-
return func(*args)
38+
func(*args, **{})
3639
except PluginException:
3740
# re-raise PluginExceptions, since they should be handled by the caller
3841
raise
@@ -54,23 +57,25 @@ class PluginLifecycleNotifierMixin:
5457

5558
listeners: list[PluginLifecycleListener]
5659

57-
def _fire_on_resolve_after(self, plugin_spec):
60+
def _fire_on_resolve_after(self, plugin_spec: PluginSpec) -> None:
5861
for listener in self.listeners:
5962
_call_safe(
6063
listener.on_resolve_after,
6164
(plugin_spec,), #
6265
"error while calling on_resolve_after",
6366
)
6467

65-
def _fire_on_resolve_exception(self, namespace, entrypoint, exception):
68+
def _fire_on_resolve_exception(
69+
self, namespace: str, entrypoint: EntryPoint, exception: Exception
70+
) -> None:
6671
for listener in self.listeners:
6772
_call_safe(
6873
listener.on_resolve_exception,
6974
(namespace, entrypoint, exception),
7075
"error while calling on_resolve_exception",
7176
)
7277

73-
def _fire_on_init_after(self, plugin_spec, plugin):
78+
def _fire_on_init_after(self, plugin_spec: PluginSpec, plugin: P) -> None:
7479
for listener in self.listeners:
7580
_call_safe(
7681
listener.on_init_after,
@@ -81,31 +86,35 @@ def _fire_on_init_after(self, plugin_spec, plugin):
8186
"error while calling on_init_after",
8287
)
8388

84-
def _fire_on_init_exception(self, plugin_spec, exception):
89+
def _fire_on_init_exception(self, plugin_spec: PluginSpec, exception: Exception) -> None:
8590
for listener in self.listeners:
8691
_call_safe(
8792
listener.on_init_exception,
8893
(plugin_spec, exception),
8994
"error while calling on_init_exception",
9095
)
9196

92-
def _fire_on_load_before(self, plugin_spec, plugin, load_args, load_kwargs):
97+
def _fire_on_load_before(
98+
self, plugin_spec: PluginSpec, plugin: P, load_args: t.Any, load_kwargs: t.Any
99+
) -> None:
93100
for listener in self.listeners:
94101
_call_safe(
95102
listener.on_load_before,
96103
(plugin_spec, plugin, load_args, load_kwargs),
97104
"error while calling on_load_before",
98105
)
99106

100-
def _fire_on_load_after(self, plugin_spec, plugin, result):
107+
def _fire_on_load_after(self, plugin_spec: PluginSpec, plugin: P | None, result: t.Any) -> None:
101108
for listener in self.listeners:
102109
_call_safe(
103110
listener.on_load_after,
104111
(plugin_spec, plugin, result),
105112
"error while calling on_load_after",
106113
)
107114

108-
def _fire_on_load_exception(self, plugin_spec, plugin, exception):
115+
def _fire_on_load_exception(
116+
self, plugin_spec: PluginSpec, plugin: P | None, exception: Exception
117+
) -> None:
109118
for listener in self.listeners:
110119
_call_safe(
111120
listener.on_load_exception,
@@ -123,20 +132,20 @@ class PluginContainer(t.Generic[P]):
123132
lock: threading.RLock
124133

125134
plugin_spec: PluginSpec
126-
plugin: P = None
127-
load_value: t.Any | None = None
135+
plugin: P | None = None
136+
load_value: t.Any = None
128137

129138
is_init: bool = False
130139
is_loaded: bool = False
131140

132-
init_error: Exception = None
133-
load_error: Exception = None
141+
init_error: Exception | None = None
142+
load_error: Exception | None = None
134143

135144
is_disabled: bool = False
136-
disabled_reason = str = None
145+
disabled_reason: str | None = None
137146

138147
@property
139-
def distribution(self) -> Distribution:
148+
def distribution(self) -> Distribution | None:
140149
"""
141150
Uses metadata from importlib to resolve the distribution information for this plugin.
142151
@@ -160,19 +169,19 @@ class PluginManager(PluginLifecycleNotifierMixin, t.Generic[P]):
160169

161170
namespace: str
162171

163-
load_args: list | tuple
172+
load_args: list[t.Any] | tuple[t.Any, ...]
164173
load_kwargs: dict[str, t.Any]
165174
listeners: list[PluginLifecycleListener]
166175
filters: list[PluginFilter]
167176

168177
def __init__(
169178
self,
170179
namespace: str,
171-
load_args: list | tuple = None,
172-
load_kwargs: dict = None,
173-
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] = None,
174-
finder: PluginFinder = None,
175-
filters: list[PluginFilter] = None,
180+
load_args: list[t.Any] | tuple[t.Any, ...] | None = None,
181+
load_kwargs: dict[str, t.Any] | None = None,
182+
listener: PluginLifecycleListener | t.Iterable[PluginLifecycleListener] | None = None,
183+
finder: PluginFinder | None = None,
184+
filters: list[PluginFilter] | None = None,
176185
):
177186
"""
178187
Create a new ``PluginManager`` that can be used to load plugins. The simplest ``PluginManager`` only needs
@@ -231,7 +240,7 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t
231240
self.load_kwargs = load_kwargs or dict()
232241

233242
if listener:
234-
if isinstance(listener, (list, set, tuple)):
243+
if isinstance(listener, Iterable):
235244
self.listeners = list(listener)
236245
else:
237246
self.listeners = [listener]
@@ -243,10 +252,10 @@ def on_load_before(self, plugin_spec: PluginSpec, plugin: Plugin, load_result: t
243252

244253
self.finder = finder or MetadataPluginFinder(self.namespace, self._fire_on_resolve_exception)
245254

246-
self._plugin_index = None
255+
self._plugin_index: dict[str, PluginContainer[P]] | None = None
247256
self._init_mutex = threading.RLock()
248257

249-
def add_listener(self, listener: PluginLifecycleListener):
258+
def add_listener(self, listener: PluginLifecycleListener) -> None:
250259
"""
251260
Adds a lifecycle listener to the plugin manager. The listener will be notified of plugin lifecycle events.
252261
@@ -326,12 +335,12 @@ def load(self, name: str) -> P:
326335
if container.load_error:
327336
raise container.load_error
328337

329-
if not container.is_loaded:
338+
if container.plugin is None or not container.is_loaded:
330339
raise PluginException("plugin did not load correctly", namespace=self.namespace, name=name)
331340

332341
return container.plugin
333342

334-
def load_all(self, propagate_exceptions=False) -> list[P]:
343+
def load_all(self, propagate_exceptions: bool = False) -> list[P]:
335344
"""
336345
Attempts to load all plugins found in the namespace and returns those that were loaded successfully.
337346
@@ -364,10 +373,10 @@ def load_all(self, propagate_exceptions=False) -> list[P]:
364373
:param propagate_exceptions: If True, re-raises any exceptions encountered during loading
365374
:return: A list of successfully loaded plugin instances
366375
"""
367-
plugins = list()
376+
plugins: list[P] = list()
368377

369378
for name, container in self._plugins.items():
370-
if container.is_loaded:
379+
if container.plugin is not None and container.is_loaded:
371380
plugins.append(container.plugin)
372381
continue
373382

@@ -552,7 +561,7 @@ def _require_plugin(self, name: str) -> PluginContainer[P]:
552561

553562
return self._plugins[name]
554563

555-
def _load_plugin(self, container: PluginContainer) -> None:
564+
def _load_plugin(self, container: PluginContainer[P]) -> None:
556565
"""
557566
Implements the core algorithm to load a plugin from a ``PluginSpec`` (contained in the ``PluginContainer``),
558567
and stores all relevant results, such as the Plugin instance, load result, or any errors into the passed
@@ -602,6 +611,7 @@ def _load_plugin(self, container: PluginContainer) -> None:
602611
return
603612

604613
plugin = container.plugin
614+
assert plugin # Make MyPy happy - plugin should exist at this point
605615

606616
if not plugin.should_load():
607617
raise PluginDisabled(
@@ -643,9 +653,9 @@ def _plugin_from_spec(self, plugin_spec: PluginSpec) -> P:
643653
if spec:
644654
factory = spec.factory
645655

646-
return factory()
656+
return factory() # type: ignore[return-value]
647657

648-
def _init_plugin_index(self) -> dict[str, PluginContainer]:
658+
def _init_plugin_index(self) -> dict[str, PluginContainer[P]]:
649659
"""
650660
Initializes the plugin index, which maps plugin names to plugin containers. This method will *resolve* plugins,
651661
meaning it loads the entry point object reference, thereby importing all its code.
@@ -654,7 +664,7 @@ def _init_plugin_index(self) -> dict[str, PluginContainer]:
654664
"""
655665
return {plugin.name: plugin for plugin in self._import_plugins() if plugin}
656666

657-
def _import_plugins(self) -> t.Iterable[PluginContainer]:
667+
def _import_plugins(self) -> t.Iterable[PluginContainer[P]]:
658668
"""
659669
Finds all ``PluginSpace`` instances in the namespace, creates a container for each spec, and yields them one
660670
by one. The plugin finder will typically load the entry point which involves importing the module it lives in.
@@ -671,14 +681,14 @@ def _import_plugins(self) -> t.Iterable[PluginContainer]:
671681

672682
yield self._create_container(spec)
673683

674-
def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer:
684+
def _create_container(self, plugin_spec: PluginSpec) -> PluginContainer[P]:
675685
"""
676686
Factory method to create a ``PluginContainer`` for the given ``PluginSpec``.
677687
678688
:param plugin_spec: The ``PluginSpec`` to create a container for.
679689
:return: A new ``PluginContainer`` with the basic information of the plugin spec.
680690
"""
681-
container = PluginContainer()
691+
container = PluginContainer[P]()
682692
container.lock = threading.RLock()
683693
container.name = plugin_spec.name
684694
container.plugin_spec = plugin_spec

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ dev = [
3333
"setuptools",
3434
"pytest==8.4.1",
3535
"ruff==0.9.1",
36+
"mypy",
3637
]
3738

3839
[tool.hatch.build.hooks.vcs]

0 commit comments

Comments
 (0)