treeflow.debug.nonfinite_convergence_criterion module

class treeflow.debug.nonfinite_convergence_criterion.NonfiniteConvergenceCriterion(name='NonfiniteConvergenceCriterion')

Bases: ConvergenceCriterion

bootstrap(loss, grads, parameters)

Returns a structure of Tensors for the rule’s state at step 0.

The shape of the Tensor`s specifying `loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s).

Parameters:
  • loss – float Tensor initial value of loss being optimized.

  • grads – list of float Tensor gradients of loss wrt parameters.

  • parameters – list of float Tensor initial values of parameters being optimized.

Returns:

(Structure of) `Tensor`(s) representing the

initial auxiliary state carried forward by this criterion.

Return type:

initial_auxiliary_state

property min_num_steps
property name
one_step(step, loss, grads, parameters, auxiliary_state)

Updates tracked quantities for a new step, and determines if converged.

The shape of the Tensor`s specifying `loss, grads, and parameters may optionally be prefixed by one or more batch dimension(s). In this case, the returned value has_converged will have shape equal to the broadcast batch shape of whichever of those quantities is used by this convergence criterion, and the quantities defining the convergence criterion ( min_num_steps, etc.).

Parameters:
  • step – integer Tensor index of the current step, where step >= 1 (on step 0, initial_state should be called instead).

  • loss – float Tensor value of loss at the current step.

  • grads – list of float Tensor gradients of loss wrt parameters.

  • parameters – list of float Tensor current values of parameters being optimized.

  • auxiliary_state – the (structure of) `Tensor`(s) containing state carried forward from the previous step.

Returns:

boolean Tensor indicating whether the optimization has

converged.

updated_auxiliary_state: (Structure of) `Tensor`(s) representing

updated quantities tracked by the convergence criterion. This should match the structure of the value returned by bootstrap.

Return type:

has_converged