JAX, Static-Shape Programming and Polyhedron — Part 2: Deriving the Closed Form

The table indexer works. But we promised you a sqrt — so let’s earn it.

This is Part 2. In Part 1 we taught JAX to index arbitrary polyhedral domains under static shapes, landing on HedraX’s table indexer. Along the way we also hand-wrote a slick, table-free closed form for triangles — a sqrt that maps a flat index k straight to (i, j). The cliffhanger: HedraX can derive that closed form for you, automatically. Let’s deliver.

Recall the hand-rolled version from Part 1:

i = jnp.floor((jnp.sqrt(8.0 * k + 1.0) - 1.0) / 2.0).astype(jnp.int32)  # where did this come from?
j = (k - i * (i + 1) // 2).astype(jnp.int32)

Slick and table-free — but bespoke. Where does that sqrt come from, and can a machine cook one up for any nice domain? Yes. The trick is to notice that indexing is just counting.

Indexing is Counting

Walk the domain in order and hand out indices 0, 1, 2, …. Then a point’s index is simply how many points come before it.

For the triangle, “before $(i, j)$” splits into two piles:

  • everything in an earlier row $i’ < i$ — that’s $T_i = i(i+1)/2$ points,
  • the points to my left in my own row — that’s $j$.

So $k = T_i + j$. Same formula as Part 1, but now it means something: it’s a rank.

The red point $(4,2)$ gets index $k=12$: the green earlier rows contribute $T_4 = 10$, the orange points to its left add $j = 2$.


Counting is a Solved Problem

The number of lattice points in a polytope — even a parametric one whose bounds depend on N, i, … — is exactly what Barvinok’s algorithm computes (Barvinok, 1994; Verdoolaege et al., 2007), and isl hands it to us as .card().

Ask it “how many triangle points have row $< X$?” and it answers with a polynomial:

{ [i,j] : 0 <= j <= i }    ⟶    prefix_count(X) = X(X + 1) / 2

No surprise — that’s $T_X$. HedraX does this once per dimension to get a prefix-count polynomial $P_t(\text{earlier coords}, X)$. No cleverness required on your part; the counting is the cleverness.


Unranking = Inverting the Count

Counting gives k from $(i, j)$. Inside the scan we need the reverse: $(i, j)$ from k. So we invert the prefix count, one dimension at a time:

  1. find the largest $i$ with $P_0(i) \le k$. Since $P_0(X) = X(X+1)/2$ is quadratic, just use the quadratic formula → $i = \lfloor (\sqrt{8k+1} - 1)/2 \rfloor$;
  2. subtract the rows we consumed ($k \leftarrow k - T_i$) and repeat for $j$. Its count is linear, so the inverse is a plain divide — here, the leftover $k$ is exactly $j$.

That’s the Part 1 code, line for line — rederived by HedraX, not by you. The inverse of a quadratic is one sqrt and a floor. But why stop at quadratic?


How Far Does the Closed Form Reach?

The triangle handed us a quadratic to invert. What if the count is a cubic — or worse?

First, the quantity that has to stay small isn’t the dimension, it’s the degree of the prefix-count polynomial — which is the depth of the deepest nested chain. A product domain stays low-degree no matter how many axes you stack:

  • triangle $\times$ line $\times$ line $\times \dots$ → still degree 2 (exactly why compile_closed_form_indexer handles the 3-D triangle × line without blinking);
  • a fully-nested $d$-simplex ($0 \le x_{d-1} \le \dots \le x_0 < N$) → the outer prefix is degree $d$.

And inverting a polynomial in radicals follows the textbook ladder — right up until it doesn’t:

nesting depth prefix count invert with
1 linear a divide
2 quadratic sqrt
3 cubic Cardano
4 quartic Ferrari
≥ 5 quintic+ — nothing

Degree 5 is the Abel–Ruffini wall: the generic quintic isn’t solvable in radicals (its Galois group is $S_5$, which isn’t solvable), and the Ehrhart count-polynomials are generic enough to inherit that. So a closed-form radical inverse exists exactly up to degree 4 — simplices up to 4 dimensions, with the 5-simplex the first that simply has no formula. (HedraX today wires up only the sqrt rung; Cardano and Ferrari are “just” more algebra.)

Past the wall, invert it anyway

The wall blocks the pretty formula, not table-free unranking. The prefix count is monotone, so you can always invert it numerically: seed with the leading term $X_0 = (d!\,k)^{1/d}$ and run Newton. It converges quadratically, so $\sim!\log\log N$ steps land you within $\pm 1$; an integer nudge finishes. And because $N$ is a static shape, you unroll a fixed handful of steps ($\approx 6$ for 64-bit) straight into the scan — still O(1) memory, still no table.

Two honest caveats:

  • it is not O(1) time — it’s a tiny, compile-time-fixed $\log\log N$ steps. (Then again, sqrt is itself fixed-iteration refinement wearing a single-instruction costume, so the “closed form vs. iterate” line is blurrier than it looks.)
  • once the counts pass $2^{53}$, float Newton loses integer exactness, so the final $\pm 1$ becomes a small bracketed integer search. That bookkeeping — not the step count — is the real tax past quadratic.

So the whole horizon: degree $\le 4$ → one shot through ${\texttt{sqrt}, \texttt{cbrt}}$; degree $\ge 5$ → a few Newton steps and some care with rounding. Either way, the table never comes back.


The Closed-Form Indexer

import hedrax as hdx
from jax import lax

sol = hdx.compile_closed_form_indexer("[N] -> { [i, j] : 0 <= j <= i < N }", N=10)

sol.is_closed_form     # True
sol.addresses          # range(0, 55)  — a range, not an array: there is no table
sol.unravel(35)        # -> [i, j], decoded with the auto-derived sqrt

Drop it into the same scan as Part 1 — the only difference is that addresses is now a bare range 0 … K-1:

def body(a, k):
    i, j = sol.unravel(k)
    return a.at[i, j].set(f(i, j)), None

a, _ = lax.scan(body, a0, jnp.asarray(sol.addresses))

Versus the table route, we traded an $O(M)$ address table and a gather-per-step for a handful of flops and one sqrt.1


When HedraX Bails (and That’s Fine)

The closed form exists only when every dimension’s count is a single-piece, degree-$\le 2$ polynomial. HedraX quietly gives up — and hdx.compile_indexer falls back to the table — when the geometry fights back:

  • strided / modular domains like ${\, i : i = 2a \,}$ → counting yields a quasi-polynomial (periodic $\lfloor \cdot \rfloor$ terms), which has no clean radical inverse;
  • non-convex unions like our GPT unicorn → there’s no single contiguous ranking to invert;
  • deeply nested domains whose counts climb past degree 4.
sol = hdx.compile_indexer("[N] -> { [i] : exists a : i = 2a and 0 <= i < N }", N=10)
sol.is_closed_form     # False  → table indexer, chosen automatically

So you always get an indexer — and the table-free one whenever the domain is nice enough to deserve it.


Wrapping Up

The whole arc in one breath: you described a loop as a set, a counting algorithm turned that set into a polynomial, and inverting the polynomial turned it back into arithmetic — a static-shape index with no table in sight. Compilers have leaned on this polyhedral machinery for decades (Feautrier, 1992; Bastoul, 2004); HedraX just aims it at JAX.

Rectangles are fine. Weird shapes are fun. And the weirdness, it turns out, was only ever counting.

  1. Whether that is actually faster on your accelerator is a separate question — a gather is cheap, and a sqrt is not free. The honest pitch is “it saves $O(M)$ memory, and it’s cooler.” The table indexer is often all you need. 

References

  1. [1] Barvinok, Alexander I.. A Polynomial Time Algorithm for Counting Integral Points in Polyhedra When the Dimension is Fixed. Mathematics of Operations Research, 1994.
  2. [2] Verdoolaege, Sven and Seghir, Kazem and Beyls, Kristof and D’Hollander, Erik and Bruynooghe, Maurice. Counting Integer Points in Parametric Polytopes Using Barvinok’s Rational Functions. Algorithmica, 2007.
  3. [3] Feautrier, Paul. Some Efficient Solutions to the Affine Scheduling Problem, Part I. One-Dimensional Time. International Journal of Parallel Programming, 1992.
  4. [4] Bastoul, Cédric. Code Generation in the Polyhedral Model Is Easier Than You Think. Proceedings of the 13th International Conference on Parallel Architectures and Compilation Techniques (PACT), 2004.



    Enjoy Reading This Article?

    Here are some more articles you might like to read next:

  • JAX, Static-Shape Programming and Polyhedron
  • On The Computability of Parametric Inversion
  • Estimating Fluid Velocity and Diffusion from Temperature Measurements (Part 2, Simulation)