treeflow.vi.optimizers.robust_optimizer module

class treeflow.vi.optimizers.robust_optimizer.RobustOptimizer(inner, max_retries=100)

Bases: object

apply_gradients(grads_and_vars, **kwargs)

Apply gradients to variables, skipping the step if any gradient contains NaN.

When NaN gradients are detected, all gradients are zeroed so that the inner optimizer is still called with its expected variables (required by Keras 3) but no parameters are updated, preserving the original skip-step semantics.