Skip to content

fix(networks): replace Tensor | None with Optional[Tensor] for TorchScript compatibility#8879

Open
AlexanderSanin wants to merge 6 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/torchscript-union-type-annotations
Open

fix(networks): replace Tensor | None with Optional[Tensor] for TorchScript compatibility#8879
AlexanderSanin wants to merge 6 commits into
Project-MONAI:devfrom
AlexanderSanin:fix/torchscript-union-type-annotations

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

@AlexanderSanin AlexanderSanin commented May 27, 2026

Summary

Fixes #7939

The | union type syntax (e.g. torch.Tensor | None) was introduced in Python 3.10. While from __future__ import annotations defers annotation evaluation at runtime, TorchScript's annotation parser does not support this syntax and fails when scripting models that include these forward method signatures, producing:

RuntimeError: Can't redefine method: forward on class

This PR replaces torch.Tensor | None with Optional[torch.Tensor] (from typing) in the forward methods of the three blocks that form the ViT/UNETR scripting path:

  • monai/networks/blocks/crossattention.pyCrossAttentionBlock.forward
  • monai/networks/blocks/selfattention.pySABlock.forward
  • monai/networks/blocks/transformerblock.pyTransformerBlock.forward

Test plan

  • torch.jit.script(CrossAttentionBlock(...)) succeeds after fix
  • torch.jit.script(SABlock(...)) succeeds after fix
  • torch.jit.script(TransformerBlock(...)) succeeds after fix
  • python -m pytest tests/networks/nets/test_unetr.py -k test_script
  • python -m pytest tests/networks/blocks/test_crossattention.py
  • python -m pytest tests/networks/blocks/test_selfattention.py

…r] for TorchScript compatibility

The `|` union type syntax (e.g. `torch.Tensor | None`) was introduced in
Python 3.10. While `from __future__ import annotations` defers evaluation
at runtime, TorchScript's annotation parser does not support this syntax
and fails when scripting models that contain these forward method signatures.

Replace `torch.Tensor | None` with `Optional[torch.Tensor]` in the
`forward` methods of:
- `monai/networks/blocks/crossattention.py` (CrossAttentionBlock)
- `monai/networks/blocks/selfattention.py` (SABlock)
- `monai/networks/blocks/transformerblock.py` (TransformerBlock)

These three blocks are used in the ViT/UNETR scripting path, causing
`RuntimeError: Can't redefine method: forward on class` when
`torch.jit.script()` is called on a UNETR model.

Closes Project-MONAI#7939

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Hey @holgerroth @wyli @ericspod. Could you, please, have a look at this?

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 27, 2026

📝 Walkthrough

Walkthrough

This PR replaces Python 3.10 union type annotations (torch.Tensor | None) with Optional[torch.Tensor] in three attention block modules, adds the corresponding Optional imports, and declares TransformerBlock.__constants__ = ["with_cross_attention"]. No runtime logic was changed.

Changes

Union to Optional type annotation updates

Layer / File(s) Summary
CrossAttention forward signature and import
monai/networks/blocks/crossattention.py
Adds Optional import and changes CrossAttentionBlock.forward context parameter annotation from `torch.Tensor
SelfAttention forward signature and import
monai/networks/blocks/selfattention.py
Adds Optional import and changes SABlock.forward attn_mask parameter annotation from `torch.Tensor
Transformer forward signature, import, and constant
monai/networks/blocks/transformerblock.py
Adds Optional import, sets TransformerBlock.__constants__ = ["with_cross_attention"], and changes TransformerBlock.forward context and attn_mask parameter annotations from `torch.Tensor

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~8 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely describes the main change: replacing PEP 604 union syntax with Optional for TorchScript compatibility.
Description check ✅ Passed Description covers issue reference, problem context, solution details, and test plans. Sections match template structure despite incomplete checkbox items.
Linked Issues check ✅ Passed PR directly addresses #7939 by replacing torch.Tensor | None with Optional[torch.Tensor] in forward signatures of the three blocks causing TorchScript failures.
Out of Scope Changes check ✅ Passed All changes are confined to the three target files and directly support the TorchScript compatibility goal. No unrelated modifications introduced.
Docstring Coverage ✅ Passed Docstring coverage is 83.33% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@ericspod
Copy link
Copy Markdown
Member

ericspod commented May 28, 2026

Hi @AlexanderSanin I don't see files changed anymore since pre-commit (using ruff I think) autofixed your changes back to what they were. You may have to create a variable for your type with OptTensor: TypeAlias = Optional[torch.Tensor] at the tope of the source files with an annotation to make ruff ignore it. This can then be used where you want.

However, we don't support Python 3.9 anymore so we expect the X | None syntax to always work, and Torchscript itself is deprecated and will be phased out of MONAI soon. If you desperately need Torchscript we can look at this fix but I'd suggest migrating off of it instead.

… compat

TorchScript's annotation parser does not support the PEP 604 `X | None`
union syntax. Replace `torch.Tensor | None` with `Optional[torch.Tensor]`
in the `forward` methods of CrossAttentionBlock, SABlock, and
TransformerBlock.

Add `# noqa: UP045` on each affected line so ruff (pyupgrade) does not
auto-revert the annotations back to the `X | None` form.

Closes Project-MONAI#7939

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…ations' into fix/torchscript-union-type-annotations
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Thanks for the feedback @ericspod!

I updated the PR to use Optional[torch.Tensor] # noqa: UP045 directly on the affected forward signatures (rather than a TypeAlias). This keeps ruff happy while preserving the annotation form that TorchScript can parse.

I confirmed locally that torch.jit.script() succeeds on all three blocks after the change.

Fully understand if the preference is to phase out TorchScript altogether — happy to close this in favour of that effort if that is the direction.

…chScript

Two follow-up fixes so that scripting UNETR (test_unetr::test_script) fully
passes and the static-checks (codeformat) job is green:

1. Move ``from typing import Optional`` into the standard-library import
   group (before the third-party ``torch`` imports) in crossattention.py,
   selfattention.py and transformerblock.py. isort (profile=black) requires
   this ordering; without it the codeformat check failed.

2. Add ``__constants__ = ["with_cross_attention"]`` to TransformerBlock.
   PR Project-MONAI#8848 made ``cross_attn`` an ``nn.Identity`` when
   ``with_cross_attention`` is False. TorchScript statically compiles every
   branch, so the ``self.cross_attn(..., context=context)`` call was checked
   against ``nn.Identity.forward`` (which has no ``context`` argument) and
   scripting failed. Marking the flag as a TorchScript constant lets the
   compiler prune the dead cross-attention branch when it is False, while
   still keeping ``cross_attn`` as ``nn.Identity`` (no registered params),
   preserving the behaviour and tests added in Project-MONAI#8848.

The typing.Final annotation form cannot be used here because
``from __future__ import annotations`` stringizes it and TorchScript cannot
resolve ``'Final[bool]'``; the ``__constants__`` list avoids that.

Closes Project-MONAI#7939

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@AlexanderSanin
Copy link
Copy Markdown
Contributor Author

Pushed two follow-up fixes after the latest `dev` merge:

  1. isort — moved `from typing import Optional` into the stdlib import group (before `torch`). This was the cause of the failing `static-checks (codeformat)` job.

  2. TorchScript regression from fix: only instantiate CrossAttentionBlock when with_cross_attention=True #8848 — that PR made `cross_attn` an `nn.Identity` when `with_cross_attention=False`. Since TorchScript statically compiles every branch, the `self.cross_attn(..., context=context)` call was being checked against `nn.Identity.forward` (no `context` arg) and scripting failed. I added `constants = ["with_cross_attention"]` to `TransformerBlock` so the compiler prunes the dead cross-attention branch when the flag is False, while keeping `cross_attn` as a param-less `nn.Identity` (so the tests added in fix: only instantiate CrossAttentionBlock when with_cross_attention=True #8848 still hold). I could not use `typing.Final` here because `from future import annotations` stringizes it and TorchScript cannot resolve `Final[bool]`.

Verified locally:

  • `python -m unittest tests.networks.nets.test_unetr` → OK (all 9 tests, including the 4 `test_script` cases from this issue)
  • `tests/networks/blocks/test_transformerblock.py` → 197 passed (fix: only instantiate CrossAttentionBlock when with_cross_attention=True #8848 invariants intact)
  • `test_crossattention` + `test_selfattention` + `test_vit` → 1592 passed
  • isort / ruff / ruff-format / black all clean on the three files

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
monai/networks/blocks/transformerblock.py (1)

111-116: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add or point to regression tests for the updated forward signature.

Please ensure coverage explicitly exercises scripting and both default-None optional params on this modified definition.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/transformerblock.py` around lines 111 - 116, The
updated forward signature in transformerblock.py (method forward) added optional
params context and attn_mask; add unit tests that (1) instantiate the
TransformerBlock and call forward with both optional params omitted (defaults
None) and with explicit tensors for context and attn_mask to exercise each
branch, and (2) ensure the module is TorchScript-able by scripting/tracing the
instance and running the scripted model with the same two call patterns; name
tests to reflect default-None coverage and scripting so CI will catch
regressions.
🧹 Nitpick comments (1)
monai/networks/blocks/transformerblock.py (1)

111-116: ⚡ Quick win

Add a Google-style docstring to forward.

This modified definition still lacks a method docstring documenting args/returns/raises.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/networks/blocks/transformerblock.py` around lines 111 - 116, Add a
Google-style docstring to the forward method for transformerblock.forward that
documents the arguments (x: torch.Tensor - input tensor and expected shape/
dtype/ device; context: Optional[torch.Tensor] - optional conditioning tensor
and its expected shape/meaning; attn_mask: Optional[torch.Tensor] - optional
attention mask and its shape/semantics), the return value (torch.Tensor - shape
and dtype of the output), and any exceptions the method may raise (e.g.,
ValueError for shape mismatches, RuntimeError for device/dtype incompatibility).
Keep the docstring concise, include type hints already present, and place it
immediately below the forward signature.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 111-116: The updated forward signature in transformerblock.py
(method forward) added optional params context and attn_mask; add unit tests
that (1) instantiate the TransformerBlock and call forward with both optional
params omitted (defaults None) and with explicit tensors for context and
attn_mask to exercise each branch, and (2) ensure the module is TorchScript-able
by scripting/tracing the instance and running the scripted model with the same
two call patterns; name tests to reflect default-None coverage and scripting so
CI will catch regressions.

---

Nitpick comments:
In `@monai/networks/blocks/transformerblock.py`:
- Around line 111-116: Add a Google-style docstring to the forward method for
transformerblock.forward that documents the arguments (x: torch.Tensor - input
tensor and expected shape/ dtype/ device; context: Optional[torch.Tensor] -
optional conditioning tensor and its expected shape/meaning; attn_mask:
Optional[torch.Tensor] - optional attention mask and its shape/semantics), the
return value (torch.Tensor - shape and dtype of the output), and any exceptions
the method may raise (e.g., ValueError for shape mismatches, RuntimeError for
device/dtype incompatibility). Keep the docstring concise, include type hints
already present, and place it immediately below the forward signature.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: e0d04d4e-816d-48f2-a02a-42b967c41d2e

📥 Commits

Reviewing files that changed from the base of the PR and between ad84460 and 176357d.

📒 Files selected for processing (3)
  • monai/networks/blocks/crossattention.py
  • monai/networks/blocks/selfattention.py
  • monai/networks/blocks/transformerblock.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • monai/networks/blocks/selfattention.py
  • monai/networks/blocks/crossattention.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: Can't redefine method: forward on class in test_unetr

2 participants