Skip to content

Commit 0513e89

Browse files
ChrisRackauckas-ClaudeChrisRackauckasclaudedevmotion
authored
Fix rrule for broadcasted over empty Tuple{} (#834)
* Fix rrule for broadcasted over empty Tuple{} When broadcasting over an empty tuple, `Broadcast.combine_eltypes` returns `Union{}`. Since `Union{} <: Number` is true, the code entered `may_bc_derivatives` which tried to construct `Tuple{Union{}, ...}` and errored. Fix by treating `Union{}` as a trivial non-differentiable case alongside `Bool`. Fixes #830 Co-Authored-By: Chris Rackauckas <[email protected]> Co-Authored-By: Claude Opus 4.6 <[email protected]> * Update test/rulesets/Base/broadcast.jl Co-authored-by: David Müller-Widmann <[email protected]> * Update test/rulesets/Base/broadcast.jl Co-authored-by: David Müller-Widmann <[email protected]> * Add multi-argument test for broadcast over empty tuples Test atan.((), ()) to cover the multi-arg path through the trivial broadcast rrule when eltype is Union{}. Co-Authored-By: Chris Rackauckas <[email protected]> Co-Authored-By: Claude Opus 4.6 <[email protected]> --------- Co-authored-by: ChrisRackauckas-Claude <[email protected]> Co-authored-by: Claude Opus 4.6 <[email protected]> Co-authored-by: David Müller-Widmann <[email protected]>
1 parent 9dbb830 commit 0513e89

2 files changed

Lines changed: 15 additions & 1 deletion

File tree

src/rulesets/Base/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ end
2828

2929
function rrule(cfg::RCR, ::typeof(broadcasted), ::BroadcastStyle, f::F, args::Vararg{Any,N}) where {F,N}
3030
T = Broadcast.combine_eltypes(f, args)
31-
if T === Bool # TODO use nondifftype here
31+
if T === Bool || T === Union{} # TODO use nondifftype here
3232
# 1: Trivial case: non-differentiable output, e.g. `x .> 0`
3333
@debug("split broadcasting trivial", f, T)
3434
bc_trivial_back(_) = (TRI_NO..., ntuple(Returns(ZeroTangent()), length(args))...)

test/rulesets/Base/broadcast.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,20 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
177177
end
178178

179179
@testset "bugs" begin
180+
@testset "broadcast over empty tuple" begin # https://github.com/JuliaDiff/ChainRules.jl/issues/830
181+
y, bk = rrule(CFG, copybroadcasted, BT1, isone, ())
182+
@test y == ()
183+
@test bk(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent())
184+
185+
y2, bk2 = rrule(CFG, copybroadcasted, BT1, sin, ())
186+
@test y2 == ()
187+
@test bk2(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent())
188+
189+
# Multi-argument case
190+
y3, bk3 = rrule(CFG, copybroadcasted, BT1, atan, (), ())
191+
@test y3 == ()
192+
@test bk3(Tangent{Tuple{}}()) == (NoTangent(), NoTangent(), NoTangent(), ZeroTangent(), ZeroTangent())
193+
end
180194
@testset "unbroadcast with NTuple" begin # https://github.com/JuliaDiff/ChainRules.jl/pull/661
181195
@test ChainRules.unbroadcast((1, 2, [3]), [4, 5, [6]]) isa Tangent # earlier, NTuple demanded same type
182196
@test ChainRules.unbroadcast(broadcasted(-, (1, 2), 3), (4, 5)) == (4, 5) # earlier, called ndims(::Tuple)

0 commit comments

Comments
 (0)