Skip to content

Commit 6f5045b

Browse files
committed
Add support for categorical feature auto-encoding
1 parent a20284d commit 6f5045b

File tree

5 files changed

+139
-1
lines changed

5 files changed

+139
-1
lines changed

Gemfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ gem "matrix"
99
gem "numo-narray", platform: [:ruby, :x64_mingw]
1010
gem "rover-df", platform: [:ruby, :x64_mingw]
1111
gem "csv"
12+
gem "debug"

lib/lightgbm/booster.rb

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
require_relative "categorical_feature_encoder"
2+
13
module LightGBM
24
class Booster
35
attr_accessor :best_iteration, :train_data_name
46

57
def initialize(params: nil, train_set: nil, model_file: nil, model_str: nil)
68
if model_str
79
model_from_string(model_str)
10+
@categorical_feature_encoder = CategoricalFeatureEncoder.new(model_str.each_line)
811
elsif model_file
912
out_num_iterations = ::FFI::MemoryPointer.new(:int)
1013
create_handle do |handle|
1114
check_result FFI.LGBM_BoosterCreateFromModelfile(model_file, out_num_iterations, handle)
1215
end
16+
@categorical_feature_encoder = CategoricalFeatureEncoder.new(File.foreach(model_file))
1317
else
1418
params ||= {}
1519
set_verbosity(params)
@@ -164,7 +168,12 @@ def predict(input, start_iteration: nil, num_iteration: nil, **params)
164168
num_iteration ||= best_iteration
165169
num_class = self.num_class
166170

167-
flat_input = input.flatten
171+
flat_input = if @categorical_feature_encoder
172+
input.flat_map { |row| @categorical_feature_encoder.apply(row) }
173+
else
174+
input.flatten
175+
end
176+
168177
handle_missing(flat_input)
169178
data = ::FFI::MemoryPointer.new(:double, input.count * input.first.count)
170179
data.write_array_of_double(flat_input)
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
require "json"
2+
3+
module LightGBM
4+
# Converts LightGBM categorical featulres to Float, using label encoding.
5+
# The categorical and mappings are extracted from the LightGBM model file.
6+
class CategoricalFeatureEncoder
7+
# Initializes a new CategoricalFeatureEncoder instance.
8+
#
9+
# @param model_enumerable [Enumerable] Enumerable with each line of LightGBM model file.
10+
def initialize(model_enumerable)
11+
@categorical_feature = []
12+
@pandas_categorical = []
13+
14+
load_categorical_features(model_enumerable)
15+
end
16+
17+
# Returns a new array with categorical features converted to Float, using label encoding.
18+
def apply(feature_values)
19+
return feature_values if @categorical_feature.empty?
20+
21+
transformed_features = feature_values.dup
22+
23+
@categorical_feature.each_with_index do |feature_index, pandas_categorical_index|
24+
pandas_categorical_entry = @pandas_categorical[pandas_categorical_index]
25+
value = feature_values[feature_index]
26+
transformed_features[feature_index] = pandas_categorical_entry.fetch(value, Float::NAN).to_f
27+
end
28+
29+
transformed_features
30+
end
31+
32+
private
33+
34+
def load_categorical_features(model_enumerable)
35+
categorical_found = false
36+
pandas_found = false
37+
38+
model_enumerable.each_entry do |line|
39+
# Format: "[categorical_feature: 0,1,2,3,4,5]"
40+
if line.start_with?("[categorical_feature:")
41+
parts = line.split("categorical_feature:")
42+
last_part = parts.last
43+
next if last_part.nil?
44+
45+
values = last_part.strip[0...-1]
46+
next if values.nil?
47+
48+
@categorical_feature = values.split(",").map(&:to_i)
49+
categorical_found = true
50+
end
51+
52+
# Format: "pandas_categorical:[[-1.0, 0.0, 1.0], ["", "a"], [false, true]]"
53+
if line.start_with?("pandas_categorical:")
54+
parts = line.split("pandas_categorical:")
55+
values = parts[1]
56+
next if values.nil?
57+
58+
@pandas_categorical = JSON.parse(values).map do |array|
59+
array.each_with_index.to_h
60+
end
61+
pandas_found = true
62+
end
63+
64+
# Break the loop if both lines are found
65+
break if categorical_found && pandas_found
66+
end
67+
68+
if @categorical_feature.size != @pandas_categorical.size
69+
raise "categorical_feature and pandas_categorical mismatch"
70+
end
71+
end
72+
end
73+
end

test/booster_test.rb

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@ def test_model_file
88
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
99
end
1010

11+
def test_model_file_with_categorical_features
12+
x_test = [[3.7, 1.2, 7.2, "9"], [7.5, 0.5, 7.9, "0"]]
13+
booster = LightGBM::Booster.new(model_file: "test/support/model_categorical.txt")
14+
y_pred = booster.predict(x_test)
15+
assert_elements_in_delta [1.014580415457883, 0.9327349972866771], y_pred.first(2)
16+
end
17+
1118
def test_model_str
1219
x_test = [[3.7, 1.2, 7.2, 9.0], [7.5, 0.5, 7.9, 0.0]]
1320
booster = LightGBM::Booster.new(model_str: File.read("test/support/model.txt"))
@@ -23,6 +30,13 @@ def test_model_from_string
2330
assert_elements_in_delta [0.9823112229173586, 0.9583143724610858], y_pred.first(2)
2431
end
2532

33+
def test_model_str_with_categorical_features
34+
x_test = [[3.7, 1.2, 7.2, "9"], [7.5, 0.5, 7.9, "0"]]
35+
booster = LightGBM::Booster.new(model_str: File.read("test/support/model_categorical.txt"))
36+
y_pred = booster.predict(x_test)
37+
assert_elements_in_delta [1.014580415457883, 0.9327349972866771], y_pred.first(2)
38+
end
39+
2640
def test_feature_importance
2741
assert_equal [280, 285, 335, 148], booster.feature_importance
2842
end
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
require_relative "test_helper"
2+
3+
class CategoricalFeatureEncoder < Minitest::Test
4+
def setup
5+
model = <<~MODEL
6+
[categorical_feature: 1,2,3]
7+
pandas_categorical:[[-1.0, 0.0, 1.0], ["red", "green", "blue"], [false, true]]
8+
MODEL
9+
10+
@encoder = LightGBM::CategoricalFeatureEncoder.new(model.each_line)
11+
end
12+
13+
def test_apply_with_categorical_features
14+
input = [42.0, 0.0, "green", true]
15+
expected = [42.0, 1.0, 1.0, 1.0]
16+
17+
assert_equal(expected, @encoder.apply(input))
18+
end
19+
20+
def test_apply_with_non_categorical_features
21+
input = [42.0, "non_categorical", 39.0, false]
22+
expected = [42.0, Float::NAN, Float::NAN, 0]
23+
24+
assert_equal(expected, @encoder.apply(input))
25+
end
26+
27+
def test_apply_with_missing_values
28+
input = [42.0, nil, "red", nil]
29+
expected = [42.0, Float::NAN, 0.0, Float::NAN]
30+
result = @encoder.apply(input)
31+
32+
assert_equal(expected, result)
33+
end
34+
35+
def test_apply_with_boolean_values
36+
input = [42.0, -1.0, "green", false]
37+
expected = [42.0, 0.0, 1.0, 0.0]
38+
39+
assert_equal(expected, @encoder.apply(input))
40+
end
41+
end

0 commit comments

Comments
 (0)