1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------
import unittest
import json
import tempfile
import tarfile
import os
import shutil
import sklearn
import joblib
from sklearn.pipeline import Pipeline
from qiime2.sdk import Artifact
from qiime2.plugins.feature_classifier.methods import \
fit_classifier_naive_bayes
from .._taxonomic_classifier import (
TaxonomicClassifierDirFmt, TaxonomicClassifier,
TaxonomicClassiferTemporaryPickleDirFmt, PickleFormat)
from . import FeatureClassifierTestPluginBase
class TaxonomicClassifierTestBase(FeatureClassifierTestPluginBase):
package = 'q2_feature_classifier.tests'
def setUp(self):
super().setUp()
reads = Artifact.import_data(
'FeatureData[Sequence]',
self.get_data_path('se-dna-sequences.fasta'))
taxonomy = Artifact.import_data(
'FeatureData[Taxonomy]', self.get_data_path('taxonomy.tsv'))
classifier = fit_classifier_naive_bayes(reads, taxonomy)
pipeline = classifier.classifier.view(Pipeline)
transformer = self.get_transformer(
Pipeline, TaxonomicClassiferTemporaryPickleDirFmt)
self._sklp = transformer(pipeline)
sklearn_pipeline = self._sklp.sklearn_pipeline.view(PickleFormat)
self.sklearn_pipeline = str(sklearn_pipeline)
def _custom_setup(self, version):
with open(os.path.join(self.temp_dir.name,
'sklearn_version.json'), 'w') as fh:
fh.write(json.dumps({'sklearn-version': version}))
shutil.copy(self.sklearn_pipeline, self.temp_dir.name)
return TaxonomicClassiferTemporaryPickleDirFmt(
self.temp_dir.name, mode='r')
class TestTypes(FeatureClassifierTestPluginBase):
def test_taxonomic_classifier_semantic_type_registration(self):
self.assertRegisteredSemanticType(TaxonomicClassifier)
def test_taxonomic_classifier_semantic_type_to_format_registration(self):
self.assertSemanticTypeRegisteredToFormat(
TaxonomicClassifier, TaxonomicClassiferTemporaryPickleDirFmt)
class TestFormats(TaxonomicClassifierTestBase):
def test_taxonomic_classifier_dir_fmt(self):
format = self._custom_setup(sklearn.__version__)
# Should not error
format.validate()
class TestTransformers(TaxonomicClassifierTestBase):
def test_old_sklearn_version(self):
transformer = self.get_transformer(
TaxonomicClassiferTemporaryPickleDirFmt, Pipeline)
input = self._custom_setup('a very old version')
with self.assertRaises(ValueError):
transformer(input)
def test_old_dirfmt(self):
transformer = self.get_transformer(TaxonomicClassifierDirFmt, Pipeline)
with open(os.path.join(self.temp_dir.name,
'preprocess_params.json'), 'w') as fh:
fh.write(json.dumps([]))
shutil.copy(self.sklearn_pipeline, self.temp_dir.name)
input = TaxonomicClassifierDirFmt(
self.temp_dir.name, mode='r')
with self.assertRaises(ValueError):
transformer(input)
def test_taxo_class_dir_fmt_to_taxo_class_result(self):
input = self._custom_setup(sklearn.__version__)
transformer = self.get_transformer(
TaxonomicClassiferTemporaryPickleDirFmt, Pipeline)
obs = transformer(input)
self.assertTrue(obs)
def test_taxo_class_result_to_taxo_class_dir_fmt(self):
def read_pipeline(pipeline_filepath):
with tarfile.open(pipeline_filepath) as tar:
dirname = tempfile.mkdtemp()
tar.extractall(dirname)
pipeline = joblib.load(os.path.join(dirname,
'sklearn_pipeline.pkl'))
for fn in tar.getnames():
os.unlink(os.path.join(dirname, fn))
os.rmdir(dirname)
return pipeline
exp = read_pipeline(self.sklearn_pipeline)
transformer = self.get_transformer(
Pipeline, TaxonomicClassiferTemporaryPickleDirFmt)
obs = transformer(exp)
sklearn_pipeline = obs.sklearn_pipeline.view(PickleFormat)
obs_pipeline = read_pipeline(str(sklearn_pipeline))
obs = obs_pipeline
self.assertTrue(obs)
if __name__ == "__main__":
unittest.main()
|