stock
© 2025 stock
GitHub
X
Select a display theme
  1. blog

  2. conditional flo... ode solvers

Table of Contents

  • 1Background
  • 2Getting Started

Conditional Flow Matching with ODE Solvers

Jason Stock|01.05.2025|3 minutes

Conditional Flow Matching (CFM): a simulation-free training objective for continuous normalizing flows.​ We explore a few different flow matching variants and ODE solvers on a simple dataset. This blog is a brief overview of the mlx-cfm GitHub repo, for simple conditional flow matching in MLX with ODE solvers.

img

Background

Training: consider a smooth time-varying vector field u : [0,1]×Rd→Rdu\,:\,[0, 1] \times \mathbb{R}^d \to \mathbb{R}^du:[0,1]×Rd→Rd that governs the dynamics of an ordinary differential equation (ODE), dx=ut(x) dtdx = u_t(x)\,dtdx=ut​(x)dt. The probability path pt(x)p_t(x)pt​(x) can be generated by transporting mass along the vector field ut(x)u_t(x)ut​(x) between distributions over time, following the continuity equation

∂p∂t=−∇⋅(ptut).\frac{\partial p}{\partial t} = -\nabla \cdot (p_t u_t).∂t∂p​=−∇⋅(pt​ut​).

However, the target distributions pt(x)p_t(x)pt​(x) and the vector field ut(x)u_t(x)ut​(x) are intractable in practice. Therefore, we assume the probability path can be expressed as a marginal over latent variables:

pt(x)=∫pt(x∣z)q(z) dz,p_t(x) = \int p_t(x | z) q(z)\, dz,pt​(x)=∫pt​(x∣z)q(z)dz,

where pt(x∣z)=N(x∣μt(z),σt2I)p_t(x | z) = \mathcal{N}\left(x | \mu_t(z), \sigma_t^2 I\right)pt​(x∣z)=N(x∣μt​(z),σt2​I) is the conditional probability path, with a latent zzz sampled from a prior distribution q(z)q(z)q(z). The dynamics of the conditional probability path are now governed by a conditional vector field ut(x∣z)u_t(x | z)ut​(x∣z). We approximate this using a neural network, parameterizing the time-dependent vector field vθ : [0,1]×Rd→Rdv_\theta\,:\,[0,1] \times \mathbb{R}^d \to \mathbb{R}^dvθ​:[0,1]×Rd→Rd. We train the network by regressing the conditional flow matching loss:

LCFM(θ)=Et,q(z),pt(x∣z)∥vθ(t,x)−ut(x∣z)∥2,L_{\text{CFM}}(\theta) = \mathrm{E}_{t, q(z), p_t(x | z)} \lVert v_\theta(t, x) - u_t(x | z) \rVert^2,LCFM​(θ)=Et,q(z),pt​(x∣z)​∥vθ​(t,x)−ut​(x∣z)∥2,

such that t∼U(0,1),  z∼q(z),  and  xt∼pt(x∣z)t \sim U(0,1), \; z \sim q(z), \; \text{and} \; x_t \sim p_t(x|z)t∼U(0,1),z∼q(z),andxt​∼pt​(x∣z). But, how do we compute ut(x∣z)u_t(x|z)ut​(x∣z)? Well, assuming a Gaussian probability path, we have a unique vector field (Theorem 3; Lipman et al. 2023) given by,

ut(x∣z)=σ˙t(z)σt(z) (x−μt(z))+μ˙t(z),u_t(x | z) = \frac{\dot{\sigma}_t (z)}{\sigma_t (z)}\,\left(x - \mu_t(z)\right) + \dot{\mu}_t(z),ut​(x∣z)=σt​(z)σ˙t​(z)​(x−μt​(z))+μ˙​t​(z),

where μ˙\dot{\mu}μ˙​ and σ˙\dot{\sigma}σ˙ are the time derivatives of the mean and standard deviation. If we consider z≡(x0,x1)\mathbf{z} \equiv (\mathbf{x}_0, \mathbf{x}_1)z≡(x0​,x1​) and q(z)=q0(x0)q1(x1)q(z) = q_0(x_0)q_1(x_1)q(z)=q0​(x0​)q1​(x1​) with

μt(z)=tx1+(1−t)x0,σt(z)=σ>0,\begin{align} \mu_t(z) &= tx_1 + (1 - t) x_0, \\ \sigma_t(z) &= \sigma_{> 0}, \end{align}μt​(z)σt​(z)​=tx1​+(1−t)x0​,=σ>0​,​​

then we have independent conditional flow matching (Tong et al. 2023) with the resulting conditional probability path and vector field

pt(x∣z)=N(x∣tx1+(1−t)x0,σ2),ut(x∣z)=x1−x0.\begin{align} p_t(x | z) &= \mathcal{N}\left(x | tx_1 + (1 - t) x_0, \sigma^2\right), \\ u_t(x | z) &= x_1 - x_0. \end{align}pt​(x∣z)ut​(x∣z)​=N(x∣tx1​+(1−t)x0​,σ2),=x1​−x0​.​​

Alternatively, the variance-preserving stochastic interpolant (Albergo & Vanden-Eijnden 2023) has the form

μt(z)=cos⁡(πt/2)x0+sin⁡(πt/2)x1andσt(z)=0,ut(x∣z)=π2(cos⁡(πt/2)x1−sin⁡(πt/2)x0).\begin{align} \mu_t(z) = \cos \left(\pi t / 2\right)x_0 + \sin \left(\pi t / 2 \right)x_1 \quad\text{and}\quad \sigma_t(z) = 0,\\ u_t(x | z) = \frac{\pi}{2} \left( \cos\left(\pi t / 2\right) x_1 - \sin\left(\pi t / 2\right) x_0 \right). \end{align}μt​(z)=cos(πt/2)x0​+sin(πt/2)x1​andσt​(z)=0,ut​(x∣z)=2π​(cos(πt/2)x1​−sin(πt/2)x0​).​​

Sampling: now that we have our vector field, we can sample from our prior x∼q0(x)\mathbf{x} \sim q_0(\mathbf{x})x∼q0​(x), and run a forward ODE solver (e.g., fixed Euler or higher-order, adaptive Dormand–Prince) generally defined by

xt+Δ=xt+vθ(t,xt)Δ,\mathbf{x}_{t+\Delta} = \mathbf{x}_{t} + v_\theta (t, \mathbf{x}_t) \Delta,xt+Δ​=xt​+vθ​(t,xt​)Δ,

for ttt steps between 000 and 111.

Getting Started

Check out the repo on GitHub to recreate the figure at the top of this post! 😄