The package supports a range of existing mutual information estimators. For the full list, see below.


The design of the estimators was motivated by SciKit-Learn API1. All estimators are classes. Once a class is initialized, one can use the estimate method, which maps arrays containing data points (of shape (n_points, n_dim)) to mutual information estimates:

import bmi

# Generate a sample with 1000 data points
task = bmi.benchmark.BENCHMARK_TASKS['1v1-normal-0.75']
X, Y = task.sample(1000, seed=42)
print(f"X shape: {X.shape}")  # Shape (1000, 1)
print(f"Y shape: {Y.shape}")  # Shape (1000, 1)

# Once an estimator is instantiated, it can be used to estimate mutual information
# by using the `estimate` method.
cca = bmi.estimators.CCAMutualInformationEstimator()
print(f"Estimate by CCA: {cca.estimate(X, Y):.2f}")

ksg = bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(5,))
print(f"Estimate by KSG: {ksg.estimate(X, Y):.2f}")

Additionally, the estimators can be queried for their hyperparameters:

print(cca.parameters())  # CCA does not have tunable hyperparameters
# _EmptyParams()

print(ksg.parameters())  # KSG has tunable hyperparameters
# KSGEnsembleParameters(neighborhoods=[5], standardize=True, metric_x='euclidean', metric_y='euclidean')

The returned objects are structured using Pydantic.

List of estimators

Neural estimators

We support several standard neural estimators in JAX basing on the PyTorch implementations2:

  • Donsker-Varadhan estimator3 is implemented in DonskerVaradhanEstimator.
  • MINE3 estimator, which is a Donsker-Varadhan estimator with correction debiasing gradient during the fitting phase, is implemented in MINEEstimator.
  • InfoNCE4, also known as Contrastive Predictive Coding, is implemented in InfoNCEEstimator.
  • NWJ estimator5 is implemented as NWJEstimator.

Model-based estimators

  • Canonical correlation analysis67 is suitable when \(P(X, Y)\) is multivariate normal and does not require hyperparameter tuning. It's implemented in CCAMutualInformationEstimator.

Histogram-based estimators

  • We implement a histogram-based estimator8 in HistogramEstimator. However, note that we do not support adaptive binning schemes.

Kernel density estimators

Neighborhood-based estimators


Do these estimators work for discrete variables?

When both variables \(X\) and \(Y\) are discrete, we recommend the dit package. When one variable is discrete and the other is continuous, one can approximate mutual information by adding small noise to the discrete variable.


Where is the API showing how to use the estimators?

The API is here.

How can I add a new estimator?

Thank you for considering contributing to this project! Please, consult contributing guidelines and reach out to us on GitHub, so we can discuss the best way of adding the estimator to the package.

Generally, the following steps are required:

  1. Implement the interface IMutualInformationPointEstimator in a new file inside src/bmi/estimators directory. The unit tests should be added in tests/estimators directory.
  2. Export the new estimator to the public API by adding an entry in src/bmi/estimators/
  3. Export the docstring of new estimator to docs/api/
  4. Add the estimator to the list of estimators and ReadMe

