keras.ops.normalize with order=2 produces NaN gradients for zero input vectors on the torch backend. The epsilon argument does not prevent it — it guards only the forward value, not the derivative.
Notably, the same function's general-order branch is gradient-safe: order=1 / order=3 clamp the denominator before dividing. Only the order==2 "fast path" special-cases into raw rsqrt, which re-instantiates the singularity:
# keras/src/ops/nn.py :: _normalize (current master)
if 2 == order:
# A special case: L2 normalization with `x * rsqrt(...)`
# instead of `x / sqrt(...)`
square_sum = backend.numpy.sum(backend.numpy.square(x), axis=axis, keepdims=True)
inv_norm = backend.math.rsqrt(square_sum) # rsqrt(0) = inf enters the graph
inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon) # fixes the VALUE, not the GRADIENT
return x * inv_norm
norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)
denom = backend.numpy.maximum(norm, epsilon) # guard BEFORE divide — safe
return backend.numpy.divide(x, denom)
Backward must differentiate rsqrt at 0: torch computes grad_in = -0.5 * grad_out * y**3 with y = rsqrt(0) = inf. The minimum correctly routes grad_out = 0 to the clamped branch, but 0 * inf = NaN (IEEE), so RsqrtBackward0 emits NaN regardless of epsilon.
Standalone reproduction
import os; os.environ["KERAS_BACKEND"] = "torch"
import torch, keras
x = torch.zeros(2, 4, requires_grad=True)
keras.ops.normalize(x, axis=-1, order=2).sum().backward()
print(x.grad) # tensor of nan — with ANY epsilon value
# order=1 on the same input: finite gradients (the safe branch)
x2 = torch.zeros(2, 4, requires_grad=True)
keras.ops.normalize(x2, axis=-1, order=1).sum().backward()
print(x2.grad) # finite
# and under anomaly detection the op is pinpointed:
# torch.autograd.set_detect_anomaly(True)
# -> RuntimeError: Function 'RsqrtBackward0' returned nan values in its 0th output.
Verified on keras 3.12.0 and on current master (the _normalize code above is unchanged).
Prior art — both reference implementations guard this
- TensorFlow
tf.math.l2_normalize: x * rsqrt(maximum(square_sum, 1e-12)) — clamp before the singular op; finite gradients everywhere by construction.
- PyTorch hit this exact bug class in 2017 and fixed it:
- pytorch/pytorch#2421 "Gradient of zero norm is nan" → fixed in PR #2775 (zero-subgradient for
norm at 0)
- pytorch/pytorch#3264 "F.normalize NaN gradient"
F.normalize computes v / max(‖v‖, ε) — denominator clamped before the division.
Ironically, ops.normalize(order=2) on the torch backend bypasses torch's own already-guarded F.normalize/linalg.norm and reintroduces the 2017 singularity via raw rsqrt.
Suggested fix
Semantics-preserving (since min(inv_norm, 1/ε) ⟺ effectively norm ≥ ε): clamp before the rsqrt, mirroring the function's own general-order branch:
inv_norm = backend.math.rsqrt(
backend.numpy.maximum(square_sum, epsilon * epsilon))
Forward output is unchanged for all inputs (identical clamping threshold); gradients become finite and bounded everywhere, matching tf.math.l2_normalize exactly.
Happy to send a PR with the one-line fix plus a gradient regression test if useful.
keras.ops.normalizewithorder=2produces NaN gradients for zero input vectors on the torch backend. Theepsilonargument does not prevent it — it guards only the forward value, not the derivative.Notably, the same function's general-order branch is gradient-safe:
order=1/order=3clamp the denominator before dividing. Only theorder==2"fast path" special-cases into rawrsqrt, which re-instantiates the singularity:Backward must differentiate
rsqrtat 0: torch computesgrad_in = -0.5 * grad_out * y**3withy = rsqrt(0) = inf. Theminimumcorrectly routesgrad_out = 0to the clamped branch, but0 * inf = NaN(IEEE), soRsqrtBackward0emits NaN regardless ofepsilon.Standalone reproduction
Verified on keras 3.12.0 and on current
master(the_normalizecode above is unchanged).Prior art — both reference implementations guard this
tf.math.l2_normalize:x * rsqrt(maximum(square_sum, 1e-12))— clamp before the singular op; finite gradients everywhere by construction.normat 0)F.normalizecomputesv / max(‖v‖, ε)— denominator clamped before the division.Ironically,
ops.normalize(order=2)on the torch backend bypasses torch's own already-guardedF.normalize/linalg.normand reintroduces the 2017 singularity via rawrsqrt.Suggested fix
Semantics-preserving (since
min(inv_norm, 1/ε)⟺ effectivelynorm ≥ ε): clamp before the rsqrt, mirroring the function's own general-order branch:Forward output is unchanged for all inputs (identical clamping threshold); gradients become finite and bounded everywhere, matching
tf.math.l2_normalizeexactly.Happy to send a PR with the one-line fix plus a gradient regression test if useful.