Skip to content
11 changes: 8 additions & 3 deletions diffrax/misc/sde_kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@


def _kl(drift1, drift2, diffusion):
inv_diffusion = jnp.linalg.pinv(diffusion)
scale = inv_diffusion @ (drift1 - drift2)
if diffusion.ndim == 1:
scale = (drift1 - drift2) / diffusion
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So my original code here in sde_kl_divergence was pretty hacky and not library-ready, and I think it'll still need some more work to get ready.

In particular I think it would make most sense to operate the level of terms. This would allow for abstracting over the kind of diffusion used -- e.g. ControlTerm versus WeaklyDiagonalControlTerm etc. -- rather than the current vector-field-based approach.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need to bump the version number. Here, I have changed sde_kl_divergence API from taking drift functions, a diffusion function ... into taking two MultiTerm. Although there is a duplication in control terms as they share the same, this sounds more natural as we compare two SDEs.

else:
inv_diffusion = jnp.linalg.pinv(diffusion)
scale = inv_diffusion @ (drift1 - drift2)
return 0.5 * jnp.sum(scale**2)


Expand All @@ -23,7 +26,7 @@ class _AugDrift(eqx.Module):
def __call__(self, t, y, args):
y, _ = y
context = self.context(t)
aug_y = jnp.concatenate([y, context], axis=-1)
aug_y = jnp.concatenate([y, context], axis=-1) if context is not None else y
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: flipping the if and else branches allows for switching if context is not None down toj ust if context is None.

drift1 = self.drift1(t, aug_y, args)
drift2 = self.drift2(t, y, args)
diffusion = self.diffusion(t, y, args)
Expand Down Expand Up @@ -66,6 +69,8 @@ def sde_kl_divergence(
bm: AbstractBrownianPath,
):
aug_y0 = (y0, 0.0)
if context is None:
context = lambda t: None
return (
_AugDrift(drift1, drift2, diffusion, context),
_AugDiffusion(diffusion),
Expand Down
13 changes: 8 additions & 5 deletions examples/neural_sde.ipynb → examples/neural_sde_gan.ipynb

Large diffs are not rendered by default.

Loading