fix(networks): replace Tensor | None with Optional[Tensor] for TorchScript compatibility#8879
Conversation
…r] for TorchScript compatibility The `|` union type syntax (e.g. `torch.Tensor | None`) was introduced in Python 3.10. While `from __future__ import annotations` defers evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that contain these forward method signatures. Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the `forward` methods of: - `monai/networks/blocks/crossattention.py` (CrossAttentionBlock) - `monai/networks/blocks/selfattention.py` (SABlock) - `monai/networks/blocks/transformerblock.py` (TransformerBlock) These three blocks are used in the ViT/UNETR scripting path, causing `RuntimeError: Can't redefine method: forward on class` when `torch.jit.script()` is called on a UNETR model. Closes Project-MONAI#7939 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
|
Hey @holgerroth @wyli @ericspod. Could you, please, have a look at this? |
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThis PR replaces Python 3.10 union type annotations ( ChangesUnion to Optional type annotation updates
Estimated code review effort🎯 2 (Simple) | ⏱️ ~8 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Hi @AlexanderSanin I don't see files changed anymore since pre-commit (using ruff I think) autofixed your changes back to what they were. You may have to create a variable for your type with However, we don't support Python 3.9 anymore so we expect the |
… compat TorchScript's annotation parser does not support the PEP 604 `X | None` union syntax. Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the `forward` methods of CrossAttentionBlock, SABlock, and TransformerBlock. Add `# noqa: UP045` on each affected line so ruff (pyupgrade) does not auto-revert the annotations back to the `X | None` form. Closes Project-MONAI#7939 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…ations' into fix/torchscript-union-type-annotations
|
Thanks for the feedback @ericspod! I updated the PR to use I confirmed locally that Fully understand if the preference is to phase out TorchScript altogether — happy to close this in favour of that effort if that is the direction. |
…chScript Two follow-up fixes so that scripting UNETR (test_unetr::test_script) fully passes and the static-checks (codeformat) job is green: 1. Move ``from typing import Optional`` into the standard-library import group (before the third-party ``torch`` imports) in crossattention.py, selfattention.py and transformerblock.py. isort (profile=black) requires this ordering; without it the codeformat check failed. 2. Add ``__constants__ = ["with_cross_attention"]`` to TransformerBlock. PR Project-MONAI#8848 made ``cross_attn`` an ``nn.Identity`` when ``with_cross_attention`` is False. TorchScript statically compiles every branch, so the ``self.cross_attn(..., context=context)`` call was checked against ``nn.Identity.forward`` (which has no ``context`` argument) and scripting failed. Marking the flag as a TorchScript constant lets the compiler prune the dead cross-attention branch when it is False, while still keeping ``cross_attn`` as ``nn.Identity`` (no registered params), preserving the behaviour and tests added in Project-MONAI#8848. The typing.Final annotation form cannot be used here because ``from __future__ import annotations`` stringizes it and TorchScript cannot resolve ``'Final[bool]'``; the ``__constants__`` list avoids that. Closes Project-MONAI#7939 Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
|
Pushed two follow-up fixes after the latest `dev` merge:
Verified locally:
|
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/blocks/transformerblock.py (1)
111-116:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winAdd or point to regression tests for the updated
forwardsignature.Please ensure coverage explicitly exercises scripting and both default-
Noneoptional params on this modified definition.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/transformerblock.py` around lines 111 - 116, The updated forward signature in transformerblock.py (method forward) added optional params context and attn_mask; add unit tests that (1) instantiate the TransformerBlock and call forward with both optional params omitted (defaults None) and with explicit tensors for context and attn_mask to exercise each branch, and (2) ensure the module is TorchScript-able by scripting/tracing the instance and running the scripted model with the same two call patterns; name tests to reflect default-None coverage and scripting so CI will catch regressions.
🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)
111-116: ⚡ Quick winAdd a Google-style docstring to
forward.This modified definition still lacks a method docstring documenting args/returns/raises.
As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@monai/networks/blocks/transformerblock.py` around lines 111 - 116, Add a Google-style docstring to the forward method for transformerblock.forward that documents the arguments (x: torch.Tensor - input tensor and expected shape/ dtype/ device; context: Optional[torch.Tensor] - optional conditioning tensor and its expected shape/meaning; attn_mask: Optional[torch.Tensor] - optional attention mask and its shape/semantics), the return value (torch.Tensor - shape and dtype of the output), and any exceptions the method may raise (e.g., ValueError for shape mismatches, RuntimeError for device/dtype incompatibility). Keep the docstring concise, include type hints already present, and place it immediately below the forward signature.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 111-116: The updated forward signature in transformerblock.py
(method forward) added optional params context and attn_mask; add unit tests
that (1) instantiate the TransformerBlock and call forward with both optional
params omitted (defaults None) and with explicit tensors for context and
attn_mask to exercise each branch, and (2) ensure the module is TorchScript-able
by scripting/tracing the instance and running the scripted model with the same
two call patterns; name tests to reflect default-None coverage and scripting so
CI will catch regressions.
---
Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 111-116: Add a Google-style docstring to the forward method for
transformerblock.forward that documents the arguments (x: torch.Tensor - input
tensor and expected shape/ dtype/ device; context: Optional[torch.Tensor] -
optional conditioning tensor and its expected shape/meaning; attn_mask:
Optional[torch.Tensor] - optional attention mask and its shape/semantics), the
return value (torch.Tensor - shape and dtype of the output), and any exceptions
the method may raise (e.g., ValueError for shape mismatches, RuntimeError for
device/dtype incompatibility). Keep the docstring concise, include type hints
already present, and place it immediately below the forward signature.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e0d04d4e-816d-48f2-a02a-42b967c41d2e
📒 Files selected for processing (3)
monai/networks/blocks/crossattention.pymonai/networks/blocks/selfattention.pymonai/networks/blocks/transformerblock.py
🚧 Files skipped from review as they are similar to previous changes (2)
- monai/networks/blocks/selfattention.py
- monai/networks/blocks/crossattention.py
Summary
Fixes #7939
The
|union type syntax (e.g.torch.Tensor | None) was introduced in Python 3.10. Whilefrom __future__ import annotationsdefers annotation evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that include these forward method signatures, producing:This PR replaces
torch.Tensor | NonewithOptional[torch.Tensor](fromtyping) in theforwardmethods of the three blocks that form the ViT/UNETR scripting path:monai/networks/blocks/crossattention.py—CrossAttentionBlock.forwardmonai/networks/blocks/selfattention.py—SABlock.forwardmonai/networks/blocks/transformerblock.py—TransformerBlock.forwardTest plan
torch.jit.script(CrossAttentionBlock(...))succeeds after fixtorch.jit.script(SABlock(...))succeeds after fixtorch.jit.script(TransformerBlock(...))succeeds after fixpython -m pytest tests/networks/nets/test_unetr.py -k test_scriptpython -m pytest tests/networks/blocks/test_crossattention.pypython -m pytest tests/networks/blocks/test_selfattention.py