Skip to content

Commit 32a0ae6

Browse files
committed
feat: Make complexity penalty
1 parent 5104e0b commit 32a0ae6

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

src/DataFrame/DecisionTree.hs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ import DataFrame.Functions ((.<), (.<=), (.==), (.>), (.>=))
3838
data TreeConfig = TreeConfig
3939
{ maxTreeDepth :: Int
4040
, minSamplesSplit :: Int
41+
, minLeafSize :: Int
4142
, synthConfig :: SynthConfig
4243
}
4344

4445
data SynthConfig = SynthConfig
4546
{ maxExprDepth :: Int
4647
, boolExpansion :: Int
4748
, percentiles :: [Int]
49+
, complexityPenalty :: Double
4850
, enableStringOps :: Bool
4951
, enableCrossCols :: Bool
5052
, enableArithOps :: Bool
@@ -56,6 +58,7 @@ defaultSynthConfig =
5658
{ maxExprDepth = 2
5759
, boolExpansion = 2
5860
, percentiles = [0, 10 .. 100]
61+
, complexityPenalty = 0.05
5962
, enableStringOps = True
6063
, enableCrossCols = True
6164
, enableArithOps = True
@@ -66,6 +69,7 @@ defaultTreeConfig =
6669
TreeConfig
6770
{ maxTreeDepth = 10
6871
, minSamplesSplit = 5
72+
, minLeafSize = 1
6973
, synthConfig = defaultSynthConfig
7074
}
7175

@@ -100,7 +104,7 @@ buildTree cfg depth target conds df
100104
| depth <= 0 || nRows df <= minSamplesSplit cfg =
101105
Lit (majorityValue @a target df)
102106
| otherwise =
103-
case findBestSplit @a (synthConfig cfg) target conds df of
107+
case findBestSplit @a cfg target conds df of
104108
Nothing -> Lit (majorityValue @a target df)
105109
Just bestCond ->
106110
let (dfTrue, dfFalse) = partitionDataFrame bestCond df
@@ -240,11 +244,9 @@ partitionDataFrame cond df = (filterWhere cond df, filterWhere (F.not cond) df)
240244
findBestSplit ::
241245
forall a.
242246
(Columnable a) =>
243-
SynthConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
247+
TreeConfig -> T.Text -> [Expr Bool] -> DataFrame -> Maybe (Expr Bool)
244248
findBestSplit cfg target conds df =
245249
let
246-
minLeafSize = 1
247-
lambda = 0.05
248250
initialImpurity = calculateGini @a target df
249251
evalGain cond =
250252
let (t, f) = partitionDataFrame cond df
@@ -254,7 +256,8 @@ findBestSplit cfg target conds df =
254256
newImpurity =
255257
(weightT * calculateGini @a target t)
256258
+ (weightF * calculateGini @a target f)
257-
in ( (initialImpurity - newImpurity) - lambda * fromIntegral (eSize cond)
259+
in ( (initialImpurity - newImpurity)
260+
- complexityPenalty (synthConfig cfg) * fromIntegral (eSize cond)
258261
, negate (eSize cond)
259262
)
260263

@@ -264,7 +267,7 @@ findBestSplit cfg target conds df =
264267
let
265268
(t, f) = partitionDataFrame c df
266269
in
267-
nRows t >= minLeafSize && nRows f >= minLeafSize
270+
nRows t >= minLeafSize cfg && nRows f >= minLeafSiz cfg
268271
)
269272
(nubOrd conds)
270273
sortedConditions = take 10 (sortBy (flip compare `on` evalGain) validConds)
@@ -275,7 +278,13 @@ findBestSplit cfg target conds df =
275278
Just $
276279
maximumBy
277280
(compare `on` evalGain)
278-
(boolExprs df sortedConditions sortedConditions 0 (boolExpansion cfg))
281+
( boolExprs
282+
df
283+
sortedConditions
284+
sortedConditions
285+
0
286+
(boolExpansion (synthConfig cfg))
287+
)
279288

280289
calculateGini ::
281290
forall a.

0 commit comments

Comments
 (0)