treeflow.model.approximation.mean_field module
- treeflow.model.approximation.mean_field.get_base_distribution(flat_event_size, dtype=tf.float64)
- treeflow.model.approximation.mean_field.get_mean_field_operator_classes(flat_event_size)
- treeflow.model.approximation.mean_field.get_trainable_shift_bijector(flat_event_size, init_loc_unconstrained, dtype=tf.float64)
- treeflow.model.approximation.mean_field.make_trainable(dist_class: Type[Distribution], initial_parameters: Dict[str, object] | None = None, batch_and_event_shape: Tensor = None, parameter_dtype: DType = tf.float64, seed=None, var_name_prefix='', **init_kwargs) Distribution
- treeflow.model.approximation.mean_field.get_mean_field_approximation(model: ~tensorflow_probability.python.distributions.joint_distribution.JointDistribution, init_loc=None, dtype=tf.float64, joint_bijector_func: ~typing.Callable[[~tensorflow_probability.python.distributions.joint_distribution.JointDistribution], ~tensorflow_probability.python.bijectors.composition.Composition] = <function get_default_event_space_bijector>, event_shape_fn: ~typing.Callable[[~tensorflow_probability.python.distributions.joint_distribution.JointDistribution], object] = <function event_shape_fn>, seed=None, var_name_prefix='') Tuple[Distribution, Dict[str, Variable]]
- treeflow.model.approximation.mean_field.get_fixed_topology_mean_field_approximation(model: JointDistribution, topology_pins: Dict[str, TensorflowTreeTopology], init_loc=None, dtype=tf.float64, var_name_prefix='') Tuple[Distribution, Dict[str, Tensor]]