User API
This is an example function:
- class fishereyes.FisherEyes(model: Any, optimizer: Any, opt_state: Any, loss_fn: Any, epochs: int, batch_size: int, config: Dict[str, Any] | None = None)
FisherEyes: A class for learning diffeomorphic transformations that normalize heteroskedastic uncertainty.
Parameters
- model: Any
The model to be trained.
- optimizer: Any
The optimizer to be used for training.
- opt_state: Any
The initial state of the optimizer.
- loss_fn: Any
The loss function to be used for training.
- epochs: int
The number of epochs to train the model.
- batch_size: int
The size of the batches to be used during training.
- config: Dict[str, Any], optional
A dictionary containing the configuration for the model, optimizer, and loss function. If None, the default configuration is used.
Attributes
- loss_history: List[float]
A list containing the loss values for each epoch during training.
- fit(y0: Array, sigma0: Array, key: key | int | None = None) None
Fit the transformation model to data.
Parameters: - y0: Input data array of shape [N, D] - sigma0: Covariance matrices of shape [N, D, D] - key: Optional jax.random.key or integer seed for reproducibility.
- classmethod from_config(data_dim: int, config_path: str | Path | None = None, key: key | int | None = None) FisherEyes
Create a FisherEyes instance from a configuration file.
Parameters: - data_dim: Dimensionality of the input/output data. - config_path: Path to the configuration file. If None, the default configuration is used. - key: Optional jax.random.key or integer seed for reproducibility.
Returns: - An instance of the FisherEyes class.