Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions lightgbmlss/distributions/LogitNormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from torch.distributions import Normal, TransformedDistribution, SigmoidTransform
from .distribution_utils import DistributionClass
from ..utils import *


class LogitNormal(DistributionClass):
"""
Logit-Normal distribution class.

Distributional Parameters
-------------------------
loc: torch.Tensor
Mean of the normal distribution before applying the logit transformation.
scale: torch.Tensor
Standard deviation of the normal distribution before applying the logit transformation.

Source
-------------------------
https://pytorch.org/docs/stable/distributions.html#normal

Parameters
-------------------------
stabilization: str
Stabilization method for the Gradient and Hessian. Options are "None", "MAD", "L2".
response_fn: str
Response function for transforming the distributional parameters to the correct support. Options are
"identity" (no transformation) or "softplus" (softplus to ensure positivity).
loss_fn: str
Loss function. Options are "nll" (negative log-likelihood) or "crps" (continuous ranked probability score).
Note that if "crps" is used, the Hessian is set to 1, as the current CRPS version is not twice differentiable.
initialize: bool
Whether to initialize the distributional parameters with unconditional start values. Initialization can help
to improve speed of convergence in some cases. However, it may also lead to early stopping or suboptimal
solutions if the unconditional start values are far from the optimal values.
"""

def __init__(self,
stabilization: str = "None",
response_fn: str = "identity",
loss_fn: str = "nll",
initialize: bool = False,
):

# Input Checks
if stabilization not in ["None", "MAD", "L2"]:
raise ValueError("Invalid stabilization method. Please choose from 'None', 'MAD' or 'L2'.")
if loss_fn not in ["nll", "crps"]:
raise ValueError("Invalid loss function. Please choose from 'nll' or 'crps'.")
if not isinstance(initialize, bool):
raise ValueError("Invalid initialize. Please choose from True or False.")

# Specify Response Functions
response_functions = {"identity": identity_fn, "softplus": softplus_fn}
if response_fn in response_functions:
response_fn = response_functions[response_fn]
else:
raise ValueError("Invalid response function. Please choose from 'identity' or 'softplus'.")

# Define Logit-Normal as a transformed distribution
base_distribution = Normal

# Create a proper class instead of lambda to have arg_constraints
class LogitNormalDistribution(TransformedDistribution):
arg_constraints = base_distribution.arg_constraints

def __init__(self, loc, scale):
super().__init__(base_distribution(loc, scale), [SigmoidTransform()])

transformed_distribution = LogitNormalDistribution

# Define Parameter Mapping
param_dict = {"loc": identity_fn, "scale": response_fn}
torch.distributions.Distribution.set_default_validate_args(False)

# Specify Distribution Class
super().__init__(distribution=transformed_distribution,
univariate=True,
discrete=False,
n_dist_param=len(param_dict),
stabilization=stabilization,
param_dict=param_dict,
distribution_arg_names=list(param_dict.keys()),
loss_fn=loss_fn,
initialize=initialize,
)
3 changes: 2 additions & 1 deletion lightgbmlss/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@
from . import ZALN
from . import SplineFlow
from . import Mixture
from . import Logistic
from . import Logistic
from . import LogitNormal