Skip to content

Commit a6216cb

Browse files
committed
Add support for categorical feature auto-encoding
1 parent 3f601b0 commit a6216cb

5 files changed

Lines changed: 2075 additions & 1 deletion

File tree

lib/lightgbm/booster.rb

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
require_relative "categorical_feature_encoder"
2+
13
module LightGBM
24
class Booster
35
attr_accessor :best_iteration, :train_data_name
@@ -6,9 +8,11 @@ def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
68
@handle = ::FFI::MemoryPointer.new(:pointer)
79
if model_str
810
model_from_string(model_str)
11+
@categorical_feature_encoder = CategoricalFeatureEncoder.new(model_str.each_line)
912
elsif model_file
1013
out_num_iterations = ::FFI::MemoryPointer.new(:int)
1114
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, @handle)
15+
@categorical_feature_encoder = CategoricalFeatureEncoder.new(File.foreach(model_file))
1216
else
1317
params ||= {}
1418
set_verbosity(params)
@@ -152,7 +156,12 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
152156
num_iteration ||= best_iteration
153157
num_class ||= num_class()
154158

155-
flat_input = input.flatten
159+
flat_input = if @categorical_feature_encoder
160+
input.flat_map { |row| @categorical_feature_encoder.apply(row) }
161+
else
162+
input.flatten
163+
end
164+
156165
handle_missing(flat_input)
157166
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
158167
data.write_array_of_double(flat_input)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
require "json"
2+
3+
module LightGBM
4+
# Converts LightGBM categorical featulres to Float, using label encoding.
5+
# The categorical and mappings are extracted from the LightGBM model file.
6+
class CategoricalFeatureEncoder
7+
# Initializes a new CategoricalFeatureEncoder instance.
8+
#
9+
# @param model_enumerable [Enumerable] Enumerable with each line of LightGBM model file.
10+
def initialize(model_enumerable)
11+
@categorical_feature = []
12+
@pandas_categorical = []
13+
14+
load_categorical_features(model_enumerable)
15+
end
16+
17+
# Returns a new array with categorical features converted to Float, using label encoding.
18+
def apply(feature_values)
19+
return feature_values if @categorical_feature.empty?
20+
21+
transformed_features = feature_values.dup
22+
23+
@categorical_feature.each_with_index do |feature_index, pandas_categorical_index|
24+
pandas_categorical_entry = @pandas_categorical[pandas_categorical_index]
25+
value = feature_values[feature_index]
26+
transformed_features[feature_index] = pandas_categorical_entry.fetch(value, Float::NAN).to_f
27+
end
28+
29+
transformed_features
30+
end
31+
32+
private
33+
34+
def load_categorical_features(model_enumerable)
35+
categorical_found = false
36+
pandas_found = false
37+
38+
model_enumerable.each_entry do |line|
39+
# Format: "[categorical_feature: 0,1,2,3,4,5]"
40+
if line.start_with?("[categorical_feature:")
41+
parts = line.split("categorical_feature:")
42+
last_part = parts.last
43+
next if last_part.nil?
44+
45+
values = last_part.strip[0...-1]
46+
next if values.nil?
47+
48+
@categorical_feature = values.split(",").map(&:to_i)
49+
categorical_found = true
50+
end
51+
52+
# Format: "pandas_categorical:[[-1.0, 0.0, 1.0], ["", "a"], [false, true]]"
53+
if line.start_with?("pandas_categorical:")
54+
parts = line.split("pandas_categorical:")
55+
values = parts[1]
56+
next if values.nil?
57+
58+
@pandas_categorical = JSON.parse(values).map do |array|
59+
array.each_with_index.to_h
60+
end
61+
pandas_found = true
62+
end
63+
64+
# Break the loop if both lines are found
65+
break if categorical_found && pandas_found
66+
end
67+
68+
if @categorical_feature.size != @pandas_categorical.size
69+
raise "categorical_feature and pandas_categorical mismatch"
70+
end
71+
end
72+
end
73+
end

test/booster_test.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,27 @@ def test_model_file
88
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
99
end
1010

11+
def test_model_file_with_categorical_features
12+
x_test = [[false, "green", 7.2, 9.0], [true, "blue", 7.9, 0.0]]
13+
booster = LightGBM::Booster.new(model_file: "test/support/model_with_categorical_features.txt")
14+
y_pred = booster.predict(x_test)
15+
assert_elements_in_delta [0.9948804305465, 0.792909968121466], y_pred.first(2)
16+
end
17+
1118
def test_model_str
1219
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
1320
booster = LightGBM::Booster.new(model_str: File.read("test/support/model.txt"))
1421
y_pred = booster.predict(x_test)
1522
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
1623
end
1724

25+
def test_model_str_with_categorical_features
26+
x_test = [[false, "green", 7.2, 9.0], [true, "blue", 7.9, 0.0]]
27+
booster = LightGBM::Booster.new(model_str: File.read("test/support/model_with_categorical_features.txt"))
28+
y_pred = booster.predict(x_test)
29+
assert_elements_in_delta [0.9948804305465, 0.792909968121466], y_pred.first(2)
30+
end
31+
1832
def test_feature_importance
1933
assert_equal [280, 285, 335, 148], booster.feature_importance
2034
end
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
require_relative "test_helper"
2+
3+
class CategoricalFeatureEncoder < Minitest::Test
4+
def setup
5+
model = <<~MODEL
6+
[categorical_feature: 1,2,3]
7+
pandas_categorical:[[-1.0, 0.0, 1.0], ["red", "green", "blue"], [false, true]]
8+
MODEL
9+
10+
@encoder = LightGBM::CategoricalFeatureEncoder.new(model.each_line)
11+
end
12+
13+
def test_apply_with_categorical_features
14+
input = [42.0, 0.0, "green", true]
15+
expected = [42.0, 1.0, 1.0, 1.0]
16+
17+
assert_equal(expected, @encoder.apply(input))
18+
end
19+
20+
def test_apply_with_non_categorical_features
21+
input = [42.0, "non_categorical", 39.0, false]
22+
expected = [42.0, Float::NAN, Float::NAN, 0]
23+
24+
assert_equal(expected, @encoder.apply(input))
25+
end
26+
27+
def test_apply_with_missing_values
28+
input = [42.0, nil, "red", nil]
29+
expected = [42.0, Float::NAN, 0.0, Float::NAN]
30+
result = @encoder.apply(input)
31+
32+
assert_equal(expected, result)
33+
end
34+
35+
def test_apply_with_boolean_values
36+
input = [42.0, -1.0, "green", false]
37+
expected = [42.0, 0.0, 1.0, 0.0]
38+
39+
assert_equal(expected, @encoder.apply(input))
40+
end
41+
end

0 commit comments

Comments
 (0)