User API

This is an example function:

class fishereyes.FisherEyes(model: Any, optimizer: Any, opt_state: Any, loss_fn: Any, epochs: int, batch_size: int, config: Dict[str, Any] | None = None)

FisherEyes: A class for learning diffeomorphic transformations that normalize heteroskedastic uncertainty.

Parameters

model: Any

The model to be trained.

optimizer: Any

The optimizer to be used for training.

opt_state: Any

The initial state of the optimizer.

loss_fn: Any

The loss function to be used for training.

epochs: int

The number of epochs to train the model.

batch_size: int

The size of the batches to be used during training.

config: Dict[str, Any], optional

A dictionary containing the configuration for the model, optimizer, and loss function. If None, the default configuration is used.

Attributes

loss_history: List[float]

A list containing the loss values for each epoch during training.

fit(y0: Array, sigma0: Array, key: key | int | None = None) None

Fit the transformation model to data.

Parameters: - y0: Input data array of shape [N, D] - sigma0: Covariance matrices of shape [N, D, D] - key: Optional jax.random.key or integer seed for reproducibility.

classmethod from_config(data_dim: int, config_path: str | Path | None = None, key: key | int | None = None) FisherEyes

Create a FisherEyes instance from a configuration file.

Parameters: - data_dim: Dimensionality of the input/output data. - config_path: Path to the configuration file. If None, the default configuration is used. - key: Optional jax.random.key or integer seed for reproducibility.

Returns: - An instance of the FisherEyes class.