Skip to content

Commit d3f4771

Browse files
authored
feat: add save/load methods for classifier persistence (#85)
* feat: add save/load methods for classifier persistence Provides a cleaner API than raw Marshal for persisting trained classifiers. Users can now save training state and resume later: classifier.save('model.json') loaded = Classifier::Bayes.load('model.json') loaded.train_spam('more data') # continue training Both Bayes and LSI classifiers support: - to_json / from_json for string serialization - save(path) / load(path) for file operations LSI serializes only source data (word_hash, categories), not computed vectors. The index rebuilds on load, making JSON files portable across GSL/non-GSL environments. Closes #17 * feat: add as_json method and accept hash in from_json - Add as_json method that returns a Hash representation - Modify to_json to use as_json internally - Modify from_json to accept both String and Hash arguments This provides more flexibility for serialization workflows. * fix: resolve RuboCop and Steep type check issues - Extract restore_state private method to reduce from_json AbcSize - Change as_json return type to untyped for Steep compatibility - Use assert_path_exists in tests per Minitest/AssertPathExists - Add JSON RBS vendor file for type checking - Regenerate RBS files
1 parent f9cb27c commit d3f4771

File tree

6 files changed

+410
-0
lines changed

6 files changed

+410
-0
lines changed

lib/classifier/bayes.rb

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Copyright:: Copyright (c) 2005 Lucas Carlson
55
# License:: LGPL
66

7+
require 'json'
78
require 'mutex_m'
89

910
module Classifier
@@ -117,6 +118,55 @@ def classify(text)
117118
best.first.to_s
118119
end
119120

121+
# Returns a hash representation of the classifier state.
122+
# This can be converted to JSON or used directly.
123+
#
124+
# @rbs () -> untyped
125+
def as_json(*)
126+
{
127+
version: 1,
128+
type: 'bayes',
129+
categories: @categories.transform_keys(&:to_s).transform_values { |v| v.transform_keys(&:to_s) },
130+
total_words: @total_words,
131+
category_counts: @category_counts.transform_keys(&:to_s),
132+
category_word_count: @category_word_count.transform_keys(&:to_s)
133+
}
134+
end
135+
136+
# Serializes the classifier state to a JSON string.
137+
# This can be saved to a file and later loaded with Bayes.from_json.
138+
#
139+
# @rbs () -> String
140+
def to_json(*)
141+
as_json.to_json
142+
end
143+
144+
# Loads a classifier from a JSON string or a Hash created by #to_json or #as_json.
145+
#
146+
# @rbs (String | Hash[String, untyped]) -> Bayes
147+
def self.from_json(json)
148+
data = json.is_a?(String) ? JSON.parse(json) : json
149+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'bayes'
150+
151+
instance = allocate
152+
instance.send(:restore_state, data)
153+
instance
154+
end
155+
156+
# Saves the classifier state to a file.
157+
#
158+
# @rbs (String) -> Integer
159+
def save(path)
160+
File.write(path, to_json)
161+
end
162+
163+
# Loads a classifier from a file saved with #save.
164+
#
165+
# @rbs (String) -> Bayes
166+
def self.load(path)
167+
from_json(File.read(path))
168+
end
169+
120170
#
121171
# Provides training and untraining methods for the categories specified in Bayes#new
122172
# For example:
@@ -200,5 +250,29 @@ def remove_category(category)
200250
@category_word_count.delete(category)
201251
end
202252
end
253+
254+
private
255+
256+
# Restores classifier state from a hash (used by from_json)
257+
# @rbs (Hash[String, untyped]) -> void
258+
def restore_state(data)
259+
mu_initialize
260+
@categories = {} #: Hash[Symbol, Hash[Symbol, Integer]]
261+
@total_words = data['total_words']
262+
@category_counts = Hash.new(0) #: Hash[Symbol, Integer]
263+
@category_word_count = Hash.new(0) #: Hash[Symbol, Integer]
264+
265+
data['categories'].each do |cat_name, words|
266+
@categories[cat_name.to_sym] = words.transform_keys(&:to_sym)
267+
end
268+
269+
data['category_counts'].each do |cat_name, count|
270+
@category_counts[cat_name.to_sym] = count
271+
end
272+
273+
data['category_word_count'].each do |cat_name, count|
274+
@category_word_count[cat_name.to_sym] = count
275+
end
276+
end
203277
end
204278
end

lib/classifier/lsi.rb

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def matrix_class
5353
require 'classifier/extensions/vector'
5454
end
5555

56+
require 'json'
5657
require 'mutex_m'
5758
require 'classifier/lsi/word_list'
5859
require 'classifier/lsi/content_node'
@@ -365,6 +366,75 @@ def marshal_load(data)
365366
@auto_rebuild, @word_list, @items, @version, @built_at_version = data
366367
end
367368

369+
# Returns a hash representation of the LSI index.
370+
# Only source data (word_hash, categories) is included, not computed vectors.
371+
# This can be converted to JSON or used directly.
372+
#
373+
# @rbs () -> untyped
374+
def as_json(*)
375+
items_data = @items.transform_values do |node|
376+
{
377+
word_hash: node.word_hash.transform_keys(&:to_s),
378+
categories: node.categories.map(&:to_s)
379+
}
380+
end
381+
382+
{
383+
version: 1,
384+
type: 'lsi',
385+
auto_rebuild: @auto_rebuild,
386+
items: items_data
387+
}
388+
end
389+
390+
# Serializes the LSI index to a JSON string.
391+
# Only source data (word_hash, categories) is serialized, not computed vectors.
392+
# On load, the index will be rebuilt automatically.
393+
#
394+
# @rbs () -> String
395+
def to_json(*)
396+
as_json.to_json
397+
end
398+
399+
# Loads an LSI index from a JSON string or Hash created by #to_json or #as_json.
400+
# The index will be rebuilt after loading.
401+
#
402+
# @rbs (String | Hash[String, untyped]) -> LSI
403+
def self.from_json(json)
404+
data = json.is_a?(String) ? JSON.parse(json) : json
405+
raise ArgumentError, "Invalid classifier type: #{data['type']}" unless data['type'] == 'lsi'
406+
407+
# Create instance with auto_rebuild disabled during loading
408+
instance = new(auto_rebuild: false)
409+
410+
# Restore items (categories stay as strings, matching original storage)
411+
data['items'].each do |item_key, item_data|
412+
word_hash = item_data['word_hash'].transform_keys(&:to_sym)
413+
categories = item_data['categories']
414+
instance.instance_variable_get(:@items)[item_key] = ContentNode.new(word_hash, *categories)
415+
instance.instance_variable_set(:@version, instance.instance_variable_get(:@version) + 1)
416+
end
417+
418+
# Restore auto_rebuild setting and rebuild index
419+
instance.auto_rebuild = data['auto_rebuild']
420+
instance.build_index
421+
instance
422+
end
423+
424+
# Saves the LSI index to a file.
425+
#
426+
# @rbs (String) -> Integer
427+
def save(path)
428+
File.write(path, to_json)
429+
end
430+
431+
# Loads an LSI index from a file saved with #save.
432+
#
433+
# @rbs (String) -> LSI
434+
def self.load(path)
435+
from_json(File.read(path))
436+
end
437+
368438
private
369439

370440
# Assigns LSI vectors using native C extension

sig/vendor/json.rbs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
module JSON
2+
def self.parse: (String source, ?symbolize_names: bool) -> untyped
3+
def self.generate: (untyped obj) -> String
4+
end

test/bayes/bayesian_test.rb

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,4 +435,112 @@ def test_laplace_smoothing_denominator_includes_vocabulary
435435
refute_in_delta scores1['Spam'], scores2['Spam'], 0.1,
436436
'Vocabulary size should affect word probabilities in denominator'
437437
end
438+
439+
# Save/Load tests
440+
441+
def test_as_json
442+
@classifier.train_interesting 'good words here'
443+
@classifier.train_uninteresting 'bad words there'
444+
445+
data = @classifier.as_json
446+
447+
assert_instance_of Hash, data
448+
assert_equal 1, data[:version]
449+
assert_equal 'bayes', data[:type]
450+
assert_includes data[:categories].keys, 'Interesting'
451+
assert_includes data[:categories].keys, 'Uninteresting'
452+
end
453+
454+
def test_to_json
455+
@classifier.train_interesting 'good words here'
456+
@classifier.train_uninteresting 'bad words there'
457+
458+
json = @classifier.to_json
459+
data = JSON.parse(json)
460+
461+
assert_equal 1, data['version']
462+
assert_equal 'bayes', data['type']
463+
assert_includes data['categories'].keys, 'Interesting'
464+
assert_includes data['categories'].keys, 'Uninteresting'
465+
end
466+
467+
def test_from_json_with_string
468+
@classifier.train_interesting 'good words here'
469+
@classifier.train_uninteresting 'bad words there'
470+
471+
json = @classifier.to_json
472+
loaded = Classifier::Bayes.from_json(json)
473+
474+
assert_equal @classifier.categories.sort, loaded.categories.sort
475+
assert_equal @classifier.classify('good words'), loaded.classify('good words')
476+
assert_equal @classifier.classify('bad words'), loaded.classify('bad words')
477+
end
478+
479+
def test_from_json_with_hash
480+
@classifier.train_interesting 'good words here'
481+
@classifier.train_uninteresting 'bad words there'
482+
483+
# Use as_json to get a hash, then convert keys to strings (as would happen from JSON.parse)
484+
hash = JSON.parse(@classifier.to_json)
485+
loaded = Classifier::Bayes.from_json(hash)
486+
487+
assert_equal @classifier.categories.sort, loaded.categories.sort
488+
assert_equal @classifier.classify('good words'), loaded.classify('good words')
489+
assert_equal @classifier.classify('bad words'), loaded.classify('bad words')
490+
end
491+
492+
def test_from_json_invalid_type
493+
invalid_json = { version: 1, type: 'invalid' }.to_json
494+
495+
assert_raises(ArgumentError) { Classifier::Bayes.from_json(invalid_json) }
496+
end
497+
498+
def test_save_and_load
499+
@classifier.train_interesting 'good words here'
500+
@classifier.train_uninteresting 'bad words there'
501+
502+
Dir.mktmpdir do |dir|
503+
path = File.join(dir, 'classifier.json')
504+
@classifier.save(path)
505+
506+
assert_path_exists path, 'Save should create file'
507+
508+
loaded = Classifier::Bayes.load(path)
509+
510+
assert_equal @classifier.categories.sort, loaded.categories.sort
511+
assert_equal @classifier.classify('good'), loaded.classify('good')
512+
end
513+
end
514+
515+
def test_save_load_preserves_training_state
516+
@classifier.train_interesting 'apple banana cherry'
517+
@classifier.train_uninteresting 'dog elephant fox'
518+
519+
Dir.mktmpdir do |dir|
520+
path = File.join(dir, 'classifier.json')
521+
@classifier.save(path)
522+
loaded = Classifier::Bayes.load(path)
523+
524+
# Verify classifications match
525+
assert_equal @classifier.classifications('apple'), loaded.classifications('apple')
526+
assert_equal @classifier.classifications('dog'), loaded.classifications('dog')
527+
end
528+
end
529+
530+
def test_loaded_classifier_can_continue_training
531+
@classifier.train_interesting 'initial training'
532+
533+
Dir.mktmpdir do |dir|
534+
path = File.join(dir, 'classifier.json')
535+
@classifier.save(path)
536+
loaded = Classifier::Bayes.load(path)
537+
538+
# Continue training on loaded classifier
539+
loaded.train_interesting 'more interesting content'
540+
loaded.train_uninteresting 'boring content here'
541+
542+
assert_equal 'Interesting', loaded.classify('interesting content')
543+
assert_equal 'Uninteresting', loaded.classify('boring content')
544+
end
545+
end
438546
end

0 commit comments

Comments
 (0)