diff --git a/Project.toml b/Project.toml index 37f0de7..e468fbc 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161" StatisticalMeasuresBase = "c062fc1d-0d66-479b-b6ac-8b44719de4cc" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" @@ -34,7 +35,7 @@ LossFunctions = "0.10, 0.11, 1" MacroTools = "0.5" OrderedCollections = "1" PrecompileTools = "1.1" -ScientificTypes = "3" +REPL = "1" ScientificTypesBase = "3" StatisticalMeasuresBase = "0.1" Statistics = "1" diff --git a/docs/make.jl b/docs/make.jl index 5105156..f8e8823 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -28,6 +28,7 @@ makedocs( "The Measures" => "auto_generated_list_of_measures.md", "Confusion Matrices" => "confusion_matrices.md", "Receiver Operator Characteristics" => "roc.md", + "Precision-Recall Curves" => "precision_recall.md", "Tools" => "tools.md", "Reference" => "reference.md", ], @@ -41,4 +42,3 @@ deploydocs( devbranch="dev", push_preview=false, ) - diff --git a/docs/src/assets/precision_recall_curve.png b/docs/src/assets/precision_recall_curve.png new file mode 100644 index 0000000..89589b3 Binary files /dev/null and b/docs/src/assets/precision_recall_curve.png differ diff --git a/docs/src/precision_recall.md b/docs/src/precision_recall.md new file mode 100644 index 0000000..40e0718 --- /dev/null +++ b/docs/src/precision_recall.md @@ -0,0 +1,40 @@ +# Precision-Recall Curves + +In binary classification problems, precision-recall curves (or PR curves) are a popular +alternative to [Receiver Operator Characteristics](@ref) when the target values are highly +imbalanced. + +## Example + +```@example 70 +using StatisticalMeasures +using CategoricalArrays +using CategoricalDistributions + +# ground truth: +y = categorical(["X", "O", "X", "X", "O", "X", "X", "O", "O", "X"], ordered=true) + +# probabilistic predictions: +X_probs = [0.3, 0.2, 0.4, 0.9, 0.1, 0.4, 0.5, 0.2, 0.8, 0.7] +ŷ = UnivariateFinite(["O", "X"], X_probs, augment=true, pool=y) +ŷ[1] +``` + +```julia +using Plots +recalls, precisions, thresholds = precision_recall_curve(ŷ, y) +plt = plot(recalls, precisions, legend=false) +plot!(plt, xlab="recall", ylab="precision") + +# proportion of observations that are positive: +p = precisions[end] # threshold=0 +plot!([0, 1], [p, p], linewidth=2, linestyle=:dash, color=:black) +``` + +![](assets/precision_recall_curve.png) + +## Reference + +```@docs +precision_recall_curve +``` diff --git a/docs/src/reference.md b/docs/src/reference.md index c221637..796b668 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -1,5 +1,9 @@ # Reference +```@index +Pages = ["reference.md",] +``` + ```@docs StatisticalMeasuresBase.unwrap StatisticalMeasuresBase.is_measure diff --git a/docs/src/roc.md b/docs/src/roc.md index ae78333..785cfd2 100644 --- a/docs/src/roc.md +++ b/docs/src/roc.md @@ -18,8 +18,8 @@ ŷ[1] ```julia using Plots -curve = roc_curve(ŷ, y) -plt = plot(curve, legend=false) +false_positive_rates, true_positive_rates, thresholds = roc_curve(ŷ, y) +plt = plot(false_positive_rates, true_positive_rates; legend=false) plot!(plt, xlab="false positive rate", ylab="true positive rate") plot!([0, 1], [0, 1], linewidth=2, linestyle=:dash, color=:black) ``` diff --git a/docs/src/roc_curve.png b/docs/src/roc_curve.png deleted file mode 100644 index 4c85ba9..0000000 Binary files a/docs/src/roc_curve.png and /dev/null differ diff --git a/src/StatisticalMeasures.jl b/src/StatisticalMeasures.jl index 58c84be..6e6b779 100644 --- a/src/StatisticalMeasures.jl +++ b/src/StatisticalMeasures.jl @@ -14,6 +14,7 @@ using LinearAlgebra using StatsBase import Distributions using PrecompileTools +using REPL # needed for `Base.Docs.doc` const SM = "StatisticalMeasures" const CatArrOrSub{T, N} = @@ -34,6 +35,7 @@ include("tools.jl") include("functions.jl") include("confusion_matrices.jl") include("roc.jl") +include("precision_recall.jl") include("docstrings.jl") include("registry.jl") include("continuous.jl") @@ -71,22 +73,7 @@ export measures, supports_missings_measure, fussy_measure -export Functions, ConfusionMatrices, NoAvg, MacroAvg, MicroAvg, roc_curve - -#tod look out for MLJBase.aggregate called on scalars, which is not supported here. -#todo in mljbase, single(measure, array1, array2) - -#todo need a show(::Measure) -#todo `is_measure_type` in MLJBase is not provided here - -#todo: following needs adding to section on continuous measures -# _scale(x, w::Arr, i) = x*w[i] -# _scale(x, ::Nothing, i::Any) = x - -#todo: _skipinvalid from MLJBase/src/data/data.jl is needed for balanced accuracy, barring -# a refactor of that measure to use `skipinvalid` as provided in this package. - -#todo: look for uses of aggregation of dictionaries in MLJBase, which is no longer -# supported, or add support. +export Functions, ConfusionMatrices, NoAvg, MacroAvg, MicroAvg +export roc_curve, precision_recall_curve end diff --git a/src/functions.jl b/src/functions.jl index 0854cac..feef7b3 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -27,15 +27,26 @@ log_cosh(x::T) where T<:Real = x + _softplus(-2x) - log(convert(T, 2)) log_cosh_difference(yhat, y) = log_cosh(yhat - y) -# # ROC CURVE +# # CONFUSION MATRIX AT THRESHOLDS + """ _idx_unique_sorted(v) *Private method.* -Return the index of unique elements in `Real` vector `v` under the assumption that the -vector `v` is sorted in decreasing order. +Return the index of the first appearance of each element within `v`, under the untested +assumption that `v` is sorted in decreasing order. + +```julia-repl +julia> [5, 5, 4, 3, 3, 3, 2, 1] |> _idx_unique_sorted +5-element Vector{Int64}: + 1 + 3 + 4 + 7 + 8 +``` """ function _idx_unique_sorted(v) @@ -55,79 +66,218 @@ function _idx_unique_sorted(v) return idx end -const DOC_ROC(;middle="", footer="") = +const DOC_YHAT_Y = +""" + +Here `ŷ` is a vector of predicted numerical probabilities of the specified +`positive_class`, which is one of two possible values occurring in the provided vector +`y` of ground truth observations. + +The returned probability `thresholds` are the distinct values taken on by `ŷ`, listed in +descending order. In particular, `0` and `1` are only included if they are present in `ŷ`. + +""" + +DOC_THRESHOLDS(; counts="counts") = +""" + +If `thresholds` has length `k`, the interval [0, 1] is partitioned into `k+1` bins. +The $counts are constant within each bin: + +- `[0.0, thresholds[k])` +- `[thresholds[k], thresholds[k - 1])` +- ... +- `[thresholds[1], 1]` + +""" + +const DOC_CONFUSION_CHECK = "Assumes there are no more than two classes but does "* + "not check this. Does not check that "* + "`positive_class` is one of the observed classes. " + +const DOC_CONFUSION_AT_THRESHOLDS(;middle=DOC_YHAT_Y, footer=DOC_CONFUSION_CHECK) = +""" + +For a binary classification problem, return probability thresholds and corresponding +confusion matrix entries, suitable for generating ROC curves and precision-recall curves +(and variations on these). Primarily intended as a backend for implementations of those +two cases. + +$middle + +$(DOC_THRESHOLDS()) + +Consequently, `TN`, `FP`, `FN` and `TP`, will each have length `k + 1` in that case. + +The `j`th raw confusion matrix will be `reshape([TN[j], FP[j], FN[j], TP[j]], 2, 2)`, +according to conventions used elsewhere in StatisticalMeasures.jl, which explains the +chosen order for the return value. + +$footer + +""" + +""" + Functions.confusion_counts_at_thresholds(ŷ, y, positive_class) -> + (TN, FP, FN, TP), thresholds + +$(DOC_CONFUSION_AT_THRESHOLDS()) + +""" +function confusion_counts_at_thresholds(scores, y, positive_class) + n = length(y) + + ranking = sortperm(scores, rev=true) + + scores_sort = scores[ranking] +# Sort samples by score in descending order +# This lets us easily count predictions by threshold: for any threshold t, +# all samples with score ≥ t come before those with score < t +ranking = sortperm(scores, rev=true) +sorted_scores = scores[ranking] +sorted_labels = (y[ranking] .== positive_class) + + # Find where unique thresholds begin + # Since scores are sorted descending, each unique score value marks a threshold + # Example: scores [0.5, 0.5, 0.2, 0.2, 0.1] → thresholds start at indices [1, 3, 5] + threshold_indices = _idx_unique_sorted(sorted_scores) + thresholds = sorted_scores[threshold_indices] + + # detailed computations with example: + # sorted_labels = [ 1 0 0 1 0 0 1] + # s = [0.5 0.5 0.2 0.2 0.1 0.1 0.1] thresh are 0.5 0.2, 0.1 // idx [1, 3, 5] + # ŷ = [ 0 0 0 0 0 0 0] (0.5 - 1.0] # no pos pred + # ŷ = [ 1 1 0 0 0 0 0] (0.2 - 0.5] # 2 pos pred + # ŷ = [ 1 1 1 1 0 0 0] (0.1 - 0.2] # 4 pos pred + # ŷ = [ 1 1 1 1 1 1 1] [0.0 - 0.1] # all pos pre + # Count total positives and negatives in the dataset + cum_positives = cumsum(sorted_labels) # running count of true positives # [1, 1, 1, 2, 2, 2, 3] + P = cum_positives[end] # total number of observed positives (3) + N = n - P # total number of observed negatives (4) + # For each threshold (except the highest), count predictions + # At a given threshold starting at index i, all samples 1..(i-1) are predicted positive + # Example: threshold at index 3 → samples 1-2 predicted positive (2 samples) + n_ŷ_pos = threshold_indices[2:end] .- 1 # [2, 4] implicit [0, 2, 4, 7] + + # Compute true positives and false positives + tp = cum_positives[n_ŷ_pos] # [1, 2] implicit [0, 1, 2, 3] + fp = n_ŷ_pos .- tp # [1, 2] implicit [0, 1, 2, 4] + + # add end points + # - First endpoint: threshold > max score → no positive predictions + # - Last endpoint: threshold ≤ min score → all samples predicted positive + tp = [0, tp..., P] # [0, 1, 2, 3] + fp = [0, fp..., N] # [0, 1, 2, 4] + + # Derive the remaining confusion matrix entries + fn = P .- tp # [3, 2, 1, 0] + tn = N .- fp # [4, 3, 2, 0] + + return (tn, fp, fn, tp), thresholds +end + + +# # ROC CURVE + +const DOC_ROC(;middle=DOC_YHAT_Y, footer=DOC_CONFUSION_CHECK) = """ + Return data for plotting the receiver operator characteristic (ROC curve) for a binary classification problem. $middle -If there are `k` unique probabilities, then there are correspondingly `k` thresholds -and `k+1` "bins" over which the false positive and true positive rates are constant.: - -- `[0.0 - thresholds[1]]` -- `[thresholds[1] - thresholds[2]]` -- ... -- `[thresholds[k] - 1]` +$(DOC_THRESHOLDS(counts="`true_positive_rate` and `false_positive_rate`")) -Consequently, `true_positive_rates` and `false_positive_rates` have length `k+1` if -`thresholds` has length `k`. +Accordingly, `true_positive_rates` and `false_positive_rates` have length `k+1` in that +case. -To plot the curve using your favorite plotting backend, do something like +To plot the curve using your favorite plotting library, do something like `plot(false_positive_rates, true_positive_rates)`. $footer """ """ - Functions.roc_curve(probs_of_positive, ground_truth_obs, positive_class) -> + Functions.roc_curve(ŷ, y, positive_class) -> false_positive_rates, true_positive_rates, thresholds $(DOC_ROC()) -Assumes there are no more than two classes but does not check this. Does not check that -`positive_class` is one of the observed classes. +For a method with checks, see [`StatisticalMeasures.roc_curve`](@ref). See also +[`Functions.confusion_counts_at_thresholds`](@ref). """ function roc_curve(scores, y, positive_class) - n = length(y) + (tn, fp, fn, tp), thresholds = + confusion_counts_at_thresholds(scores, y, positive_class) - ranking = sortperm(scores, rev=true) + N = tn[1] # num observed negatives + P = fn[1] # num observed positives - scores_sort = scores[ranking] - y_sort_bin = (y[ranking] .== positive_class) + tpr = tp ./ P + fpr = fp ./ N - idx_unique = _idx_unique_sorted(scores_sort) - thresholds = scores_sort[idx_unique] + return fpr, tpr, thresholds +end - # detailed computations with example: - # y = [ 1 0 0 1 0 0 1] - # s = [0.5 0.5 0.2 0.2 0.1 0.1 0.1] thresh are 0.5 0.2, 0.1 // idx [1, 3, 5] - # ŷ = [ 0 0 0 0 0 0 0] (0.5 - 1.0] # no pos pred - # ŷ = [ 1 1 0 0 0 0 0] (0.2 - 0.5] # 2 pos pred - # ŷ = [ 1 1 1 1 0 0 0] (0.1 - 0.2] # 4 pos pred - # ŷ = [ 1 1 1 1 1 1 1] [0.0 - 0.1] # all pos pre - idx_unique_2 = idx_unique[2:end] # [3, 5] - n_ŷ_pos = idx_unique_2 .- 1 # [2, 4] implicit [0, 2, 4, 7] +# # PRECISION RECALL CURVE - cs = cumsum(y_sort_bin) # [1, 1, 1, 2, 2, 2, 3] - n_tp = cs[n_ŷ_pos] # [1, 2] implicit [0, 1, 2, 3] - n_fp = n_ŷ_pos .- n_tp # [1, 2] implicit [0, 1, 2, 4] +tamed_divide(a, b) = b == 0 ? 0 : a/b - # add end points - P = sum(y_sort_bin) # total number of true positives - N = n - P # total number of true negatives +const DOC_ROC_CHECK = DOC_CONFUSION_CHECK* + "That failing to be the case, each returned recall will be `Inf` or `NaN`. " - n_tp = [0, n_tp..., P] # [0, 1, 2, 3] - n_fp = [0, n_fp..., N] # [0, 1, 2, 4] +const DOC_PRECISION_RECALL(;middle=DOC_YHAT_Y, footer=DOC_ROC_CHECK) = +""" + +Return data for plotting the precision-recall curve (PR curve) for a binary classification +problem. The first point on the corresponding curve is always `(recall, precision) = (0, +1)`, while the last point is always `(recall, precision) = (1, p)` where `p` is the +proportion of positives in the observed sample `y`. + +$middle + +$(DOC_THRESHOLDS(counts="precison and recall")) + +Accordingly, `precisions` and `recalls` have length `k+1` in that case. + +To plot the curve using your favorite plotting library, do something like +`plot(recalls, precisions)`. + +$footer +""" + +""" + Functions.precision_recall_curve(ŷ, y, positive_class) -> + precisions, recalls, thresholds - tprs = n_tp ./ P # [0/3, 1/3, 2/3, 1] - fprs = n_fp ./ N # [0/4, 1/4, 2/4, 1] +$(DOC_PRECISION_RECALL()) - return fprs, tprs, thresholds +See also [`StatisticalMeasures.precision_recall_curve`](@ref), which includes some +checks, and [`Functions.confusion_counts_at_thresholds`](@ref). + +""" +function precision_recall_curve(scores, y, positive_class) + (tn, fp, fn, tp), thresholds = + confusion_counts_at_thresholds(scores, y, positive_class) + + k = length(tp) + precisions = Vector{Float64}(undef, k) + @. precisions = tamed_divide(tp, tp + fp) + # force precision = 1 at threshold -> 1: + precisions[1] = 1 + + recalls = Vector{Float64}(undef, k) + P = fn[1] # num observed positives + @. recalls = tp / P + return recalls, precisions, thresholds end + +# # AUC + const DOC_AUC_REF = "Implementation is based on the Mann-Whitney U statistic. See the "* "[*Whitney U "* @@ -251,14 +401,14 @@ end """ Functions.cbi( - probability_of_positive, ground_truth_observations, positive_class, + probability_of_positive, ground_truth_observations, positive_class, nbins, binwidth, ma=maximum(scores), mi=minimum(scores), cor=corspearman ) Return the Continuous Boyce Index (CBI) for a vector of probabilities and ground truth observations. """ function cbi( - scores, y, positive_class; + scores, y, positive_class; verbosity, nbins, binwidth, max=maximum(scores), min=minimum(scores), cor=StatsBase.corspearman ) @@ -282,7 +432,7 @@ function cbi( any_empty = true end @inbounds for j in bin_index_first:bin_index_last - if sorted_y[j] == positive_class + if sorted_y[j] == positive_class n_positive[i] += 1 end end @@ -552,7 +702,7 @@ $(docstring( "Functions.multiclass_fscore", sig="(m, β, average[, weights])", the=true, -))*"\n Note that the `MicroAvg` score is insenstive to `β`. " +))*"\n Note that the `MicroAvg` score is insensitive to `β`. " """ multiclass_fscore(m, beta, average::MicroAvg) = multiclass_true_positive_rate(m, MicroAvg()) diff --git a/src/precision_recall.jl b/src/precision_recall.jl new file mode 100644 index 0000000..f0e5db5 --- /dev/null +++ b/src/precision_recall.jl @@ -0,0 +1,80 @@ +const ERR_NEED_CATEGORICAL_PR = ArgumentError( + "Was expecting categorical arguments: "* + "In a call like `precision_recall_curve(ŷ, y)`, `ŷ` must have eltype "* + "`<:CategoricalDistributions.UnivariateFinite` and `y` must have eltype "* + "`<:CategoricalArrays.CategoricalArray` . If using raw probabilities, consider "* + "using `Functions.precision_recall_curve` instead. " +) + +const ERR_PR1 = ArgumentError( + "probabilistic predictions should be for exactly two classes (levels)" +) + +const ERR_PR2 = ArgumentError( + "ground truth observations must have exactly two classes (levels) in the pool" +) + +# perform some argument checks and return the ordered levels: +function binary_levels_pr( + yhat::AbstractArray{<:Union{Missing,UnivariateFinite{<:Finite{2}}}}, + y::CategoricalArrays.CatArrOrSub + ) + classes = CategoricalArrays.levels(y) + length(classes) == 2 || throw(ERR_PR2) + API.check_numobs(yhat, y) + API.check_pools(yhat, y) + warn_unordered(classes) + classes +end +binary_levels_pr( + yhat::AbstractArray{<:Union{Missing,UnivariateFinite{<:Finite}}}, + y::CategoricalArrays.CatArrOrSub +) = throw(ERR_PR1) +binary_levels_pr(yhat, y) = throw(ERR_NEED_CATEGORICAL_PR) + +const DOC_PR_EXAMPLE = +""" + +``` +using StatisticalMeasures +using CategoricalArrays +using CategoricalDistributions + +# ground truth: +y = categorical(["X", "O", "X", "X", "O", "X", "X", "O", "O", "X"], ordered=true) + +# probabilistic predictions: +X_probs = [0.3, 0.2, 0.4, 0.9, 0.1, 0.4, 0.5, 0.2, 0.8, 0.7] +ŷ = UnivariateFinite(["O", "X"], X_probs, augment=true, pool=y) +ŷ[1] + +using Plots +recalls, precisions, thresholds = precision_recall_curve(ŷ, y) +plt = plot(recalls, precisions, legend=false) +plot!(plt, xlab="recall", ylab="precision") + +# proportion of observations that are positive: +p = precisions[end] # threshold=0 +plot!([0, 1], [p, p], linewidth=2, linestyle=:dash, color=:black) +``` + +""" + +""" + precision_recall_curve(ŷ, y) -> false_positive_rates, true_positive_rates, thresholds + +$(Functions.DOC_PRECISION_RECALL( + middle="Here `ŷ` is a vector of `UnivariateFinite` distributions "* + "(from CategoricalDistributions.jl) over the two "* + "values taken by the ground truth observations `y`, a `CategoricalVector`. "* + "The `thresholds`, listed in descending order, are the distinct predicted "* + "probabilities of the positive class. ", + footer="Core algorithm: [`Functions.precision_recall_curve`](@ref). "*DOC_PR_EXAMPLE +)) +""" +function precision_recall_curve(yhat, y) + # `binary_levels` also performs argument checks and issues warnings about order: + positive_class = binary_levels_pr(yhat, y) |> last + scores = pdf.(yhat, positive_class) + Functions.precision_recall_curve(scores, y, positive_class) +end diff --git a/src/probabilistic.jl b/src/probabilistic.jl index 6c02d0b..efb9cff 100644 --- a/src/probabilistic.jl +++ b/src/probabilistic.jl @@ -548,33 +548,47 @@ const spherical_score = SphericalScore() # --------------------------------------------------------------------- # Continuous Boyce Index -struct _ContinuousBoyceIndex +struct _ContinuousBoyceIndex verbosity::Int nbins::Integer binwidth::Float64 min::Float64 max::Float64 cor::Function - function _ContinuousBoyceIndex(; - verbosity = 1, nbins = 101, binwidth = 0.1, + function _ContinuousBoyceIndex(; + verbosity = 1, nbins = 101, binwidth = 0.1, min = 0, max = 1, cor = StatsBase.corspearman ) new(verbosity, nbins, binwidth, min, max, cor) end end -ContinuousBoyceIndex(; kw...) = _ContinuousBoyceIndex(; kw...) |> robust_measure |> fussy_measure +ContinuousBoyceIndex(; kw...) = + _ContinuousBoyceIndex(; kw...) |> robust_measure |> fussy_measure -function (m::_ContinuousBoyceIndex)(ŷ::AbstractArray{<:UnivariateFinite}, y::NonMissingCatArrOrSub) +function (m::_ContinuousBoyceIndex)( + ŷ::AbstractArray{<:UnivariateFinite}, + y::NonMissingCatArrOrSub, + ) m.verbosity > 0 && warn_unordered(levels(y)) positive_class = levels(first(ŷ))|> last scores = pdf.(ŷ, positive_class) - return Functions.cbi(scores, y, positive_class; - verbosity = m.verbosity, nbins = m.nbins, binwidth = m.binwidth, max = m.max, min = m.min, cor = m.cor) + return Functions.cbi( + scores, + y, + positive_class; + verbosity = m.verbosity, + nbins = m.nbins, + binwidth = m.binwidth, + max = m.max, + min = m.min, + cor = m.cor, + ) end -const ContinuousBoyceIndexType = API.FussyMeasure{<:API.RobustMeasure{<:_ContinuousBoyceIndex}} +const ContinuousBoyceIndexType = + API.FussyMeasure{<:API.RobustMeasure{<:_ContinuousBoyceIndex}} @fix_show ContinuousBoyceIndex::ContinuousBoyceIndexType @@ -591,27 +605,37 @@ StatisticalMeasures.@trait( register(ContinuousBoyceIndex, "continuous_boyce_index", "cbi") const ContinuousBoyceIndexDoc = docstring( - "ContinuousBoyceIndex(; verbosity=1, nbins=101, bin_overlap=0.1, min=nothing, max=nothing, cor=StatsBase.corspearman)", + "ContinuousBoyceIndex(; verbosity=1, nbins=101, bin_overlap=0.1, "* + "min=nothing, max=nothing, cor=StatsBase.corspearman)", body= """ -The Continuous Boyce Index is a measure for evaluating the performance of probabilistic predictions for binary classification, -especially for presence-background data in ecological modeling. -It compares the predicted probability scores for the positive class across bins, giving higher scores if the ratio of positive - and negative samples in each bin is strongly correlated to the value at that bin. + +The Continuous Boyce Index is a measure for evaluating the performance of probabilistic +predictions for binary classification, especially for presence-background data in +ecological modeling. It compares the predicted probability scores for the positive class +across bins, giving higher scores if the ratio of positive and negative samples in each +bin is strongly correlated to the value at that bin. ## Keywords + - `verbosity`: Verbosity level. + - `nbins`: Number of bins to use for score partitioning. + - `binwidth`: The width of each bin, which defaults to 0.1. -- `min`, `max`: Optional minimum and maximum score values for binning. Default to the 0 and 1, respectively. + +- `min`, `max`: Optional minimum and maximum score values for binning. Default to the 0 + and 1, respectively. + - `cor`: Correlation function (defaults to StatsBase.corspearman, i.e. Spearman correlation). ## Arguments -The predictions `ŷ` should be a vector of `UnivariateFinite` distributions from CategoricalDistributions.jl, - and `y` a CategoricalVector of ground truth labels. +The predictions `ŷ` should be a vector of `UnivariateFinite` distributions from +CategoricalDistributions.jl, and `y` a CategoricalVector of ground truth labels. -Returns the correlation between the ratio of positive to negative samples in each bin and the bin centers. +Returns the correlation between the ratio of positive to negative samples in each bin and +the bin centers. Core implementation: [`Functions.cbi`](@ref). diff --git a/src/roc.jl b/src/roc.jl index 3e5c949..768dc56 100644 --- a/src/roc.jl +++ b/src/roc.jl @@ -32,6 +32,32 @@ binary_levels( ) = throw(ERR_ROC1) binary_levels(yhat, y) = throw(ERR_NEED_CATEGORICAL) +const DOC_ROC_EXAMPLE = +""" + +# Example + +``` +using StatisticalMeasures +using CategoricalArrays +using CategoricalDistributions + +# ground truth: +y = categorical(["X", "O", "X", "X", "O", "X", "X", "O", "O", "X"], ordered=true) + +# probabilistic predictions: +X_probs = [0.3, 0.2, 0.4, 0.9, 0.1, 0.4, 0.5, 0.2, 0.8, 0.7] +ŷ = UnivariateFinite(["O", "X"], X_probs, augment=true, pool=y) +ŷ[1] + +using Plots +false_positive_rates, true_positive_rates, thresholds = roc_curve(ŷ, y) +plt = plot(false_positive_rates, true_positive_rates; legend=false) +plot!(plt, xlab="false positive rate", ylab="true positive rate") +plot!([0, 1], [0, 1], linewidth=2, linestyle=:dash, color=:black) +``` + +""" """ roc_curve(ŷ, y) -> false_positive_rates, true_positive_rates, thresholds @@ -39,9 +65,11 @@ binary_levels(yhat, y) = throw(ERR_NEED_CATEGORICAL) $(Functions.DOC_ROC( middle="Here `ŷ` is a vector of `UnivariateFinite` distributions "* "(from CategoricalDistributions.jl) over the two "* - "values taken by the ground truth observations `y`, a `CategoricalVector`. ", + "values taken by the ground truth observations `y`, a `CategoricalVector`. "* + "The `thresholds`, listed in descending order, are the distinct predicted "* + "probabilities of the positive class. ", footer="Core algorithm: [`Functions.roc_curve`](@ref)"* - "\n\nSee also [`AreaUnderCurve`](@ref). ", + "\n\nSee also [`AreaUnderCurve`](@ref). "*DOC_ROC_EXAMPLE, )) """ function roc_curve(yhat, y) diff --git a/src/tools.jl b/src/tools.jl index 327dfae..2014d97 100644 --- a/src/tools.jl +++ b/src/tools.jl @@ -59,15 +59,18 @@ function API.check_pools( return nothing end -# Throw a warning if levels are not explicitly ordered -function warn_unordered(levels) - levels isa CategoricalArray && CategoricalArrays.isordered(levels) && return +# string to use in warning for unordered levels: +function warning_unordered(levels) raw_levels = CategoricalArrays.unwrap.(levels) ret = "Levels not explicitly ordered. "* "Using the order $raw_levels. " if length(levels) == 2 ret *= "The \"positive\" level is $(raw_levels[2]). " end - @warn ret return ret -end \ No newline at end of file +end + +# function to throw warning if `levels` are unordered: +warn_unordered(levels) = @warn warning_unordered(levels) +warn_unordered(levels::CategoricalArrays.CatArrOrSub) = + CategoricalArrays.isordered(levels) ? nothing : @warn warning_unordered(levels) diff --git a/test/confusion_matrices.jl b/test/confusion_matrices.jl index 826b088..434e648 100644 --- a/test/confusion_matrices.jl +++ b/test/confusion_matrices.jl @@ -30,7 +30,7 @@ const CM = StatisticalMeasures.ConfusionMatrices rev_index_given_level = Dict("B" => 1, "A" => 2) @test cm == CM.ConfusionMatrix(n, rev_index_given_level) mat = @test_logs( - (:warn, StatisticalMeasures.warn_unordered(levels)), + (:warn, StatisticalMeasures.warning_unordered(levels)), CM.matrix(cm), ) @test mat == m diff --git a/test/finite.jl b/test/finite.jl index 0bfce5e..5f22a37 100644 --- a/test/finite.jl +++ b/test/finite.jl @@ -114,7 +114,7 @@ end 1, 1, 1, 2, 2, 1, 2, 1, 2, 2, 2, 1, 2, 1, 2, 2, missing] - @test_logs (:warn, StatisticalMeasures.warn_unordered([1, 2])) f1score(ŷ, y) + @test_logs (:warn, StatisticalMeasures.warning_unordered([1, 2])) f1score(ŷ, y) f05 = @test_logs FScore(0.5, levels=[1, 2])(ŷ, y) sk_f05 = 0.625 @test f05 ≈ sk_f05 # m.fbeta_score(y, yhat, 0.5, pos_label=2) diff --git a/test/functions.jl b/test/functions.jl index 0ac7f94..ad4691e 100644 --- a/test/functions.jl +++ b/test/functions.jl @@ -1,6 +1,37 @@ rng = srng(34234) -@testset "ROC" begin +y = ["0", "1", "1", "1", "1", "1", "1", "0", "0", "0", + "1", "0", "0", "0", "1", "1", "0", "1", "1", "0"] +ŷ = [0.8, 0.9, 0.7, 0.1, 0.7, 0.8, 0.6, 0.7, 0.3, 0.9, + 0.3, 0.8, 0.6, 0.7, 0.3, 0.9, 0.6, 0.7, 0.1, 0.8] + +# sorted versions: +# ["1", "0", "1", "0", "1", "0", "0", "1", "1", "0", +# "0", "1", "1", "0", "0", "0", "1", "1", "1", "1"] +# [0.9, 0.9, 0.9, 0.8, 0.8, 0.8, 0.8, 0.7, 0.7, 0.7, +# 0.7, 0.7, 0.6, 0.6, 0.6, 0.3, 0.3, 0.3, 0.1, 0.1] + +@testset "confusion_counts_at_thresholds" begin + (tn, fp, fn, tp), thresholds = Functions.confusion_counts_at_thresholds(ŷ, y, "1") + + @test thresholds == sort(unique(ŷ)) |> reverse + + P = 11 # num observed positives + N = 9 # num observed negatives + @test tn == [9, 8, 5, 3, 1, 0, 0] # hand-computed + @test tp == [0, 2, 3, 6, 7, 9, 11] # hand-computed + @test all(==(N), tn + fp) + @test all(==(P), tp + fn) +end + +@testset "precision_recall_curve" begin + recalls, precisions, thresholds = Functions.precision_recall_curve(ŷ, y, "1") + @test precisions ≈ [1.0, [2, 3, 6, 7, 9, 11] ./ [3, 7, 12, 15, 18, 20]...] + @test recalls ≈ [0, 2, 3, 6, 7, 9, 11]/11 + @test thresholds == sort(unique(ŷ)) |> reverse +end + +@testset "roc_curve" begin y = [ 0 0 0 1 0 1 1 0] |> vec s = [0.0 0.1 0.1 0.1 0.2 0.2 0.5 0.5] |> vec diff --git a/test/precision_recall.jl b/test/precision_recall.jl new file mode 100644 index 0000000..2330456 --- /dev/null +++ b/test/precision_recall.jl @@ -0,0 +1,47 @@ +@testset "precision_recall_curve" begin + perm = [4, 7, 2, 1, 3, 8, 5, 6] + y = [0 0 0 1 0 1 1 0][perm] |> vec |> categorical + s = [0.0 0.1 0.1 0.1 0.2 0.2 0.5 0.5][perm] |> vec + ŷ = UnivariateFinite([0, 1], s, augment=true, pool=y) + @test_throws(StatisticalMeasures.ERR_NEED_CATEGORICAL_PR, + precision_recall_curve(ŷ, CategoricalArrays.unwrap.(y)), + ) + @test_throws( + StatisticalMeasures.ERR_NEED_CATEGORICAL_PR, + precision_recall_curve(s, y), + ) + @test_throws( + StatisticalMeasures.ERR_PR2, + precision_recall_curve(ŷ, categorical([0,1,2, fill(0, 7)...])), + ) + @test_throws( + StatisticalMeasures.ERR_PR1, + precision_recall_curve(UnivariateFinite([0, 1, 2], rand(0:2,10,3), pool=missing), y) + ) + @test_throws( + API.ERR_POOL, + precision_recall_curve(ŷ, categorical([1, 2, 2, 2, 2, 2, 1, 2])) + ) + + recalls, precisions, ts = @test_logs( + (:warn, StatisticalMeasures.warning_unordered([0, 1])), + precision_recall_curve(ŷ, y), + ) + + core_function_recalls, core_function_precisions = + Functions.precision_recall_curve(s, y, 1) + + @test precisions == core_function_precisions + @test recalls == core_function_recalls + + y = categorical([ 0 0 0 1 0 1 1 0] |> vec, ordered=true) + s = [0.0 0.1 0.1 0.1 0.2 0.2 0.5 0.5] |> vec + ŷ = UnivariateFinite([0, 1], s, augment=true, pool=y) + + recalls2, precisions2, ts2 = @test_logs precision_recall_curve(ŷ, y) + @test precisions2 == precisions + @test recalls2 == recalls + @test ts2 == ts +end + +true diff --git a/test/probabilistic.jl b/test/probabilistic.jl index 2645bbd..196cdd6 100644 --- a/test/probabilistic.jl +++ b/test/probabilistic.jl @@ -185,14 +185,14 @@ end # Simple synthetic test: perfectly separates positives and negatives c = ["neg", "pos"] probs = repeat(0.0:0.1:0.9, inner = 10) .+ rand(rng, 100) .* 0.1 - y = categorical(probs .> rand(rng, 100)) + y = categorical(probs .> rand(rng, 100), ordered=true) ŷ = UnivariateFinite(levels(y), probs, augment=true) # Should be pretty high @test cbi(ŷ, y) ≈ 0.87 atol=0.01 # Passing different correlation methods works @test ContinuousBoyceIndex(cor=cor)(ŷ, y) ≈ 0.90 atol = 0.01 - @test ContinuousBoyceIndex(nbins = 11, binwidth = 0.03)(ŷ, y) ≈ 0.77 atol = 0.01 + @test ContinuousBoyceIndex(nbins = 11, binwidth = 0.03)(ŷ, y) ≈ 0.77 atol = 0.01 # Randomized test: shuffled labels, should be near 0 y_shuf = copy(y) @@ -204,27 +204,28 @@ end @test isapprox(cbi(ŷ[idx], y[idx]), cbi(ŷ, y), atol=1e-8) # Test with all positives or all negatives return NaN - y_allpos = categorical(trues(100), levels = levels(y)) - y_allneg = categorical(falses(100), levels = levels(y)) + y_allpos = categorical(trues(100), levels = levels(y), ordered=true) + y_allneg = categorical(falses(100), levels = levels(y), ordered=true) @test isnan(cbi(ŷ, y_allpos)) @test isnan(cbi(ŷ, y_allneg)) - unordered_warning = StatisticalMeasures.warn_unordered([false, true]) + yunordered = categorical(y, ordered=false) + unordered_warning = StatisticalMeasures.warning_unordered([false, true]) @test_logs( (:warn, unordered_warning), - cbi(ŷ, y), + cbi(ŷ, yunordered), ) cbi_dropped_bins = @test_logs( - (:warn, unordered_warning), (:info, "removing 91 bins without any observations",), + (:info, "removing 91 bins without any observations"), ContinuousBoyceIndex(; verbosity = 2, min =0.0, max = 2.0, nbins = 191)(ŷ, y), ) # These two are identical because bins are dropped - @test cbi_dropped_bins == + @test cbi_dropped_bins == ContinuousBoyceIndex(; min = 0.0, max = 1.2, nbins = 111)(ŷ, y) - + # cbi is silent for verbosity 0 - @test_logs ContinuousBoyceIndex(; verbosity = 0)(ŷ, y) + @test_logs ContinuousBoyceIndex(; verbosity = 0)(ŷ, yunordered) end @testset "l2_check" begin diff --git a/test/registry.jl b/test/registry.jl index e70144d..06e445c 100644 --- a/test/registry.jl +++ b/test/registry.jl @@ -1,17 +1,15 @@ -API.register(LPLossOnScalars) -API.register(LPLossOnVectors, "l2") -metadata = API.measures()[LPLossOnScalars] -measure = LPLossOnScalars() +metadata = measures()[LPLoss] +measure = LPLoss() -@testset "register" begin - @test Set(keys(API.measures())) == Set([LPLossOnScalars, LPLossOnVectors]) +@testset "registration" begin for trait in API.METADATA_TRAITS trait_ex = QuoteNode(trait) quote @test API.$trait(measure) == getproperty(metadata, $trait_ex) end |> eval end - @test measures()[LPLossOnVectors].aliases == ("l2", ) + @test measures()[LPLoss].aliases == + ("l1", "l2", "mae", "mav", "mean_absolute_error", "mean_absolute_value") end @testset "search for needle in docstring" begin diff --git a/test/roc.jl b/test/roc.jl index 82b5c51..414104d 100644 --- a/test/roc.jl +++ b/test/roc.jl @@ -25,7 +25,7 @@ ) fprs, tprs, ts = @test_logs( - (:warn, StatisticalMeasures.warn_unordered([0, 1])), + (:warn, StatisticalMeasures.warning_unordered([0, 1])), roc_curve(ŷ, y), ) diff --git a/test/runtests.jl b/test/runtests.jl index a091b8e..6de31f4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -21,38 +21,27 @@ call(m, args...) = m(args...) srng(n=123) = StableRNG(n) -@testset "tools.jl" begin - include("tools.jl") -end - -@testset "functions.jl" begin - include("functions.jl") -end - -@testset "confusion_matrices.jl" begin - include("confusion_matrices.jl") -end - -@testset "roc.jl" begin - include("roc.jl") -end - -@testset "continuous.jl" begin - include("continuous.jl") -end - -@testset "finite.jl" begin - include("finite.jl") -end - -@testset "probabilistic.jl" begin - include("probabilistic.jl") -end - -@testset "LossFunctionsExt.jl" begin - include("LossFunctionsExt.jl") -end -@testset "ScientificTypesExt.jl" begin - include("ScientificTypesExt.jl") +test_files = [ + "tools.jl", + "functions.jl", + "confusion_matrices.jl", + "roc.jl", + "precision_recall.jl", + "continuous.jl", + "finite.jl", + "probabilistic.jl", + "LossFunctionsExt.jl", + "ScientificTypesExt.jl", + "registry.jl", +] + +files = isempty(ARGS) ? test_files : ARGS + +for file in files + quote + @testset $file begin + include($file) + end + end |> eval end