5151# #### `sum(f, x)`
5252# ####
5353
54- # Can't map over Adjoint/Transpose Vector
55- function rrule (
56- config:: RuleConfig{>:HasReverseMode} ,
57- :: typeof (sum),
58- f,
59- xs:: Union{Adjoint{<:Number,<:AbstractVector},Transpose{<:Number,<:AbstractVector}} ;
60- kwargs...
61- )
62- op = xs isa Adjoint ? adjoint : transpose
63- # since summing a vector we don't need to worry about dims which simplifies adjointing
64- vector = parent (xs)
65- y, vector_sum_pb = rrule (config, sum, f, vector; kwargs... )
66- function covector_sum_pb (ȳ)
67- s̄um, f̄, v̄ = vector_sum_pb (ȳ)
68- return s̄um, f̄, op (v̄)
69- end
70-
71- return y, covector_sum_pb
72- end
73-
7454function rrule (
7555 config:: RuleConfig{>:HasReverseMode} ,
7656 :: typeof (sum),
@@ -96,7 +76,8 @@ function rrule(
9676 # see `f.(xs)` but we don't need the pullbacks. Not implemented at present.)
9777
9878 # In the general case, we need to save all the pullbacks:
99- fx_and_pullbacks = map (xᵢ -> rrule_via_ad (config, f, xᵢ), xs)
79+ # (Here `map` or `broadcast` would fail for adjoint vectors.)
80+ fx_and_pullbacks = [rrule_via_ad (config, f, xᵢ) for xᵢ in xs]
10081 y = sum (first, fx_and_pullbacks; dims)
10182
10283 function sum_pullback_f2 (dy)
0 commit comments