treeflow.debug.nonfinite_convergence_criterion module
- class treeflow.debug.nonfinite_convergence_criterion.NonfiniteConvergenceCriterion(name='NonfiniteConvergenceCriterion')
Bases:
ConvergenceCriterion- bootstrap(loss, grads, parameters)
Returns a structure of Tensors for the rule’s state at step 0.
The shape of the Tensor`s specifying `loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s).
- Parameters:
loss – float Tensor initial value of loss being optimized.
grads – list of float Tensor gradients of loss wrt parameters.
parameters – list of float Tensor initial values of parameters being optimized.
- Returns:
- (Structure of) `Tensor`(s) representing the
initial auxiliary state carried forward by this criterion.
- Return type:
initial_auxiliary_state
- property min_num_steps
- property name
- one_step(step, loss, grads, parameters, auxiliary_state)
Updates tracked quantities for a new step, and determines if converged.
The shape of the Tensor`s specifying `loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s). In this case, the returned value has_converged will have shape equal to the broadcast batch shape of whichever of those quantities is used by this convergence criterion, and the quantities defining the convergence criterion ( min_num_steps, etc.).
- Parameters:
step – integer Tensor index of the current step, where step >= 1 (on step 0, initial_state should be called instead).
loss – float Tensor value of loss at the current step.
grads – list of float Tensor gradients of loss wrt parameters.
parameters – list of float Tensor current values of parameters being optimized.
auxiliary_state – the (structure of) `Tensor`(s) containing state carried forward from the previous step.
- Returns:
- boolean Tensor indicating whether the optimization has
converged.
- updated_auxiliary_state: (Structure of) `Tensor`(s) representing
updated quantities tracked by the convergence criterion. This should match the structure of the value returned by bootstrap.
- Return type:
has_converged