Skip to content

ops.normalize(order=2) yields NaN gradients for zero vectors on the torch backend — the L2 fast path bypasses the function's own epsilon guard #23075

@nmandilaras

Description

@nmandilaras

keras.ops.normalize with order=2 produces NaN gradients for zero input vectors on the torch backend. The epsilon argument does not prevent it — it guards only the forward value, not the derivative.

Notably, the same function's general-order branch is gradient-safe: order=1 / order=3 clamp the denominator before dividing. Only the order==2 "fast path" special-cases into raw rsqrt, which re-instantiates the singularity:

# keras/src/ops/nn.py :: _normalize  (current master)
if 2 == order:
    # A special case: L2 normalization with `x * rsqrt(...)`
    # instead of `x / sqrt(...)`
    square_sum = backend.numpy.sum(backend.numpy.square(x), axis=axis, keepdims=True)
    inv_norm = backend.math.rsqrt(square_sum)                   # rsqrt(0) = inf enters the graph
    inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon)   # fixes the VALUE, not the GRADIENT
    return x * inv_norm

norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)
denom = backend.numpy.maximum(norm, epsilon)                    # guard BEFORE divide — safe
return backend.numpy.divide(x, denom)

Backward must differentiate rsqrt at 0: torch computes grad_in = -0.5 * grad_out * y**3 with y = rsqrt(0) = inf. The minimum correctly routes grad_out = 0 to the clamped branch, but 0 * inf = NaN (IEEE), so RsqrtBackward0 emits NaN regardless of epsilon.

Standalone reproduction

import os; os.environ["KERAS_BACKEND"] = "torch"
import torch, keras

x = torch.zeros(2, 4, requires_grad=True)
keras.ops.normalize(x, axis=-1, order=2).sum().backward()
print(x.grad)        # tensor of nan — with ANY epsilon value

# order=1 on the same input: finite gradients (the safe branch)
x2 = torch.zeros(2, 4, requires_grad=True)
keras.ops.normalize(x2, axis=-1, order=1).sum().backward()
print(x2.grad)       # finite

# and under anomaly detection the op is pinpointed:
# torch.autograd.set_detect_anomaly(True)
# -> RuntimeError: Function 'RsqrtBackward0' returned nan values in its 0th output.

Verified on keras 3.12.0 and on current master (the _normalize code above is unchanged).

Prior art — both reference implementations guard this

  • TensorFlow tf.math.l2_normalize: x * rsqrt(maximum(square_sum, 1e-12)) — clamp before the singular op; finite gradients everywhere by construction.
  • PyTorch hit this exact bug class in 2017 and fixed it:
    • pytorch/pytorch#2421 "Gradient of zero norm is nan" → fixed in PR #2775 (zero-subgradient for norm at 0)
    • pytorch/pytorch#3264 "F.normalize NaN gradient"
    • F.normalize computes v / max(‖v‖, ε) — denominator clamped before the division.
      Ironically, ops.normalize(order=2) on the torch backend bypasses torch's own already-guarded F.normalize/linalg.norm and reintroduces the 2017 singularity via raw rsqrt.

Suggested fix

Semantics-preserving (since min(inv_norm, 1/ε) ⟺ effectively norm ≥ ε): clamp before the rsqrt, mirroring the function's own general-order branch:

inv_norm = backend.math.rsqrt(
    backend.numpy.maximum(square_sum, epsilon * epsilon))

Forward output is unchanged for all inputs (identical clamping threshold); gradients become finite and bounded everywhere, matching tf.math.l2_normalize exactly.

Happy to send a PR with the one-line fix plus a gradient regression test if useful.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions