@@ -466,3 +466,67 @@ def test_decision_stumps(background_reg_dataset, background_clf_dataset):
466466 continue
467467
468468 assert pred == pytest .approx (efficiency , rel = 1e-5 )
469+
470+
471+ def test_extra_trees_clf (et_clf_model , background_clf_data ):
472+ """Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
473+ explanation_instance = 1
474+ class_label = 1
475+
476+ # the following code is used to get the shap values from the SHAP implementation
477+ """
478+ #import shap
479+ # model_copy = copy.deepcopy(et_clf_model)
480+ # explainer_shap = shap.TreeExplainer(model=model_copy)
481+ # baseline_shap = float(explainer_shap.expected_value[class_label])
482+ # x_explain_shap = copy.deepcopy(background_clf_data[explanation_instance].reshape(1, -1))
483+ # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
484+ # sv_shap = sv_shap_all_classes[0][:, class_label]
485+ # print(sv_shap_all_classes, format(baseline_shap, '.20f'))
486+ """ # noqa: ERA001
487+ sv_shap = [0.00207427 , 0.00949552 , - 0.00108266 , - 0.03825587 , - 0.02694092 , 0.0170296 , 0.02046364 ]
488+ sv_shap = np .asarray (sv_shap )
489+ baseline_shap = 0.34000000000000002
490+
491+ # compute with shapiq
492+ explainer_shapiq = TreeExplainer (
493+ model = et_clf_model , max_order = 1 , index = "SV" , class_index = class_label
494+ )
495+ x_explain_shapiq = copy .deepcopy (background_clf_data [explanation_instance ])
496+ sv_shapiq = explainer_shapiq .explain (x = x_explain_shapiq )
497+ sv_shapiq_values = sv_shapiq .get_n_order_values (1 )
498+ baseline_shapiq = sv_shapiq .baseline_value
499+
500+ assert baseline_shap == pytest .approx (baseline_shapiq , rel = 1e-4 )
501+ assert np .allclose (sv_shap , sv_shapiq_values , rtol = 1e-5 )
502+
503+
504+ def test_extra_trees_reg (et_reg_model , background_reg_data ):
505+ """Test the shapiq implementation of TreeSHAP vs. SHAP's implementation for Extra Trees."""
506+ explanation_instance = 1
507+
508+ # the following code is used to get the shap values from the SHAP implementation
509+ """
510+ # import shap
511+ # model_copy = copy.deepcopy(et_reg_model)
512+ # explainer_shap = shap.TreeExplainer(model=model_copy)
513+ # baseline_shap = float(explainer_shap.expected_value)
514+ # x_explain_shap = copy.deepcopy(background_reg_data[explanation_instance].reshape(1, -1))
515+ # sv_shap_all_classes = explainer_shap.shap_values(x_explain_shap)
516+ # sv_shap = sv_shap_all_classes[0]
517+ # print(sv_shap_all_classes, format(baseline_shap, '.20f'))
518+ """ # noqa: ERA001
519+ sv_shap = [19.28673017 , - 19.87182634 , 0.0 , 10.89201698 , - 9.62498263 , 0.35992212 , 42.31290091 ]
520+ sv_shap = np .asarray (sv_shap )
521+ print (sv_shap )
522+ baseline_shap = - 2.56682283435175007
523+
524+ # compute with shapiq
525+ explainer_shapiq = TreeExplainer (model = et_reg_model , max_order = 1 , index = "SV" )
526+ x_explain_shapiq = copy .deepcopy (background_reg_data [explanation_instance ])
527+ sv_shapiq = explainer_shapiq .explain (x = x_explain_shapiq )
528+ sv_shapiq_values = sv_shapiq .get_n_order_values (1 )
529+ baseline_shapiq = sv_shapiq .baseline_value
530+
531+ assert baseline_shap == pytest .approx (baseline_shapiq , rel = 1e-4 )
532+ assert np .allclose (sv_shap , sv_shapiq_values , rtol = 1e-5 )
0 commit comments