1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
|
# encoding: utf-8
require File.dirname(__FILE__) + '/../test_helper'
require File.dirname(__FILE__) + '/../../lib/classifier-reborn/validators/classifier_validator'
require_relative '../data/test_data_loader'
class ClassifierValidation < Minitest::Test
class ValidationReporter < Minitest::Reporters::BaseReporter
REPORT_WIDTH = 80
def before_suite(suite)
puts
puts "# #{suite}"
puts
end
def after_suite(suite)
puts
end
def before_test(test)
super
validation_name = test.name.gsub(/^test_/, '')
puts " #{validation_name} ".center(REPORT_WIDTH, "=")
end
def after_test(test)
super
puts "-" * REPORT_WIDTH
puts
end
def report
super
puts('Finished in %.5fs' % total_time)
puts
end
end
Minitest::Reporters.use! ValidationReporter.new
SAMPLE_SIZE = 5000
def setup
data = TestDataLoader.sms_data
if data.length < SAMPLE_SIZE
TestDataLoader.report_insufficient_data(data.length, SAMPLE_SIZE)
skip(e)
end
@sample_data = data.take(SAMPLE_SIZE).collect { |line| line.strip.split("\t") }
end
def test_bayes_classifier_10_fold_cross_validate_memory
classifier = ClassifierReborn::Bayes.new
ClassifierValidator.cross_validate(classifier, @sample_data)
end
def test_bayes_classifier_3_fold_cross_validate_redis
begin
backend = ClassifierReborn::BayesRedisBackend.new
backend.instance_variable_get(:@redis).config(:set, "save", "")
classifier = ClassifierReborn::Bayes.new backend: backend
ClassifierValidator.cross_validate(classifier, @sample_data, 3)
rescue Redis::CannotConnectError => e
puts "Unable to connect to Redis server"
skip(e)
end
end
def test_lsi_classifier_5_fold_cross_validate
lsi = ClassifierReborn::LSI.new
required_methods = [:train, :classify, :categories]
unless required_methods.reduce(true){|m, o| m && lsi.respond_to?(o)}
puts "TODO: LSI is not validatable until all of the #{required_methods} methods are implemented!"
skip
end
ClassifierValidator.cross_validate(lsi, @sample_data, 5)
end
end
|