@@ -38,13 +38,15 @@ import DataFrame.Functions ((.<), (.<=), (.==), (.>), (.>=))
3838data TreeConfig = TreeConfig
3939 { maxTreeDepth :: Int
4040 , minSamplesSplit :: Int
41+ , minLeafSize :: Int
4142 , synthConfig :: SynthConfig
4243 }
4344
4445data SynthConfig = SynthConfig
4546 { maxExprDepth :: Int
4647 , boolExpansion :: Int
4748 , percentiles :: [Int ]
49+ , complexityPenalty :: Double
4850 , enableStringOps :: Bool
4951 , enableCrossCols :: Bool
5052 , enableArithOps :: Bool
@@ -56,6 +58,7 @@ defaultSynthConfig =
5658 { maxExprDepth = 2
5759 , boolExpansion = 2
5860 , percentiles = [0 , 10 .. 100 ]
61+ , complexityPenalty = 0.05
5962 , enableStringOps = True
6063 , enableCrossCols = True
6164 , enableArithOps = True
@@ -66,6 +69,7 @@ defaultTreeConfig =
6669 TreeConfig
6770 { maxTreeDepth = 10
6871 , minSamplesSplit = 5
72+ , minLeafSize = 1
6973 , synthConfig = defaultSynthConfig
7074 }
7175
@@ -100,7 +104,7 @@ buildTree cfg depth target conds df
100104 | depth <= 0 || nRows df <= minSamplesSplit cfg =
101105 Lit (majorityValue @ a target df)
102106 | otherwise =
103- case findBestSplit @ a (synthConfig cfg) target conds df of
107+ case findBestSplit @ a cfg target conds df of
104108 Nothing -> Lit (majorityValue @ a target df)
105109 Just bestCond ->
106110 let (dfTrue, dfFalse) = partitionDataFrame bestCond df
@@ -240,11 +244,9 @@ partitionDataFrame cond df = (filterWhere cond df, filterWhere (F.not cond) df)
240244findBestSplit ::
241245 forall a .
242246 (Columnable a ) =>
243- SynthConfig -> T. Text -> [Expr Bool ] -> DataFrame -> Maybe (Expr Bool )
247+ TreeConfig -> T. Text -> [Expr Bool ] -> DataFrame -> Maybe (Expr Bool )
244248findBestSplit cfg target conds df =
245249 let
246- minLeafSize = 1
247- lambda = 0.05
248250 initialImpurity = calculateGini @ a target df
249251 evalGain cond =
250252 let (t, f) = partitionDataFrame cond df
@@ -254,7 +256,8 @@ findBestSplit cfg target conds df =
254256 newImpurity =
255257 (weightT * calculateGini @ a target t)
256258 + (weightF * calculateGini @ a target f)
257- in ( (initialImpurity - newImpurity) - lambda * fromIntegral (eSize cond)
259+ in ( (initialImpurity - newImpurity)
260+ - complexityPenalty (synthConfig cfg) * fromIntegral (eSize cond)
258261 , negate (eSize cond)
259262 )
260263
@@ -264,7 +267,7 @@ findBestSplit cfg target conds df =
264267 let
265268 (t, f) = partitionDataFrame c df
266269 in
267- nRows t >= minLeafSize && nRows f >= minLeafSize
270+ nRows t >= minLeafSize cfg && nRows f >= minLeafSiz cfg
268271 )
269272 (nubOrd conds)
270273 sortedConditions = take 10 (sortBy (flip compare `on` evalGain) validConds)
@@ -275,7 +278,13 @@ findBestSplit cfg target conds df =
275278 Just $
276279 maximumBy
277280 (compare `on` evalGain)
278- (boolExprs df sortedConditions sortedConditions 0 (boolExpansion cfg))
281+ ( boolExprs
282+ df
283+ sortedConditions
284+ sortedConditions
285+ 0
286+ (boolExpansion (synthConfig cfg))
287+ )
279288
280289calculateGini ::
281290 forall a .
0 commit comments