1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------
import json
import tarfile
import os
import sklearn
import joblib
from sklearn.pipeline import Pipeline
import qiime2.plugin
import qiime2.plugin.model as model
from .plugin_setup import plugin
# Semantic Types
TaxonomicClassifier = qiime2.plugin.SemanticType('TaxonomicClassifier')
# Formats
class PickleFormat(model.BinaryFileFormat):
def sniff(self):
return tarfile.is_tarfile(str(self))
# https://github.com/qiime2/q2-types/issues/49
class JSONFormat(model.TextFileFormat):
def sniff(self):
with self.open() as fh:
try:
json.load(fh)
return True
except json.JSONDecodeError:
pass
return False
class TaxonomicClassifierDirFmt(model.DirectoryFormat):
preprocess_params = model.File('preprocess_params.json', format=JSONFormat)
sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat)
class TaxonomicClassiferTemporaryPickleDirFmt(model.DirectoryFormat):
version_info = model.File('sklearn_version.json', format=JSONFormat)
sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat)
# Transformers
@plugin.register_transformer
def _1(dirfmt: TaxonomicClassiferTemporaryPickleDirFmt) -> Pipeline:
sklearn_version = dirfmt.version_info.view(dict)['sklearn-version']
if sklearn_version != sklearn.__version__:
raise ValueError('The scikit-learn version (%s) used to generate this'
' artifact does not match the current version'
' of scikit-learn installed (%s). Please retrain your'
' classifier for your current deployment to prevent'
' data-corruption errors.'
% (sklearn_version, sklearn.__version__))
sklearn_pipeline = dirfmt.sklearn_pipeline.view(PickleFormat)
with tarfile.open(str(sklearn_pipeline)) as tar:
tmpdir = model.DirectoryFormat()
dirname = str(tmpdir)
tar.extractall(dirname)
pipeline = joblib.load(os.path.join(dirname, 'sklearn_pipeline.pkl'))
for fn in tar.getnames():
os.unlink(os.path.join(dirname, fn))
return pipeline
@plugin.register_transformer
def _2(data: Pipeline) -> TaxonomicClassiferTemporaryPickleDirFmt:
sklearn_pipeline = PickleFormat()
with tarfile.open(str(sklearn_pipeline), 'w') as tar:
tmpdir = model.DirectoryFormat()
pf = os.path.join(str(tmpdir), 'sklearn_pipeline.pkl')
for fn in joblib.dump(data, pf):
tar.add(fn, os.path.basename(fn))
os.unlink(fn)
dirfmt = TaxonomicClassiferTemporaryPickleDirFmt()
dirfmt.version_info.write_data(
{'sklearn-version': sklearn.__version__}, dict)
dirfmt.sklearn_pipeline.write_data(sklearn_pipeline, PickleFormat)
return dirfmt
@plugin.register_transformer
def _3(dirfmt: TaxonomicClassifierDirFmt) -> Pipeline:
raise ValueError('The scikit-learn version could not be determined for'
' this artifact, please retrain your classifier for your'
' current deployment to prevent data-corruption errors.')
@plugin.register_transformer
def _4(fmt: JSONFormat) -> dict:
with fmt.open() as fh:
return json.load(fh)
@plugin.register_transformer
def _5(data: dict) -> JSONFormat:
result = JSONFormat()
with result.open() as fh:
json.dump(data, fh)
return result
# Registrations
plugin.register_semantic_types(TaxonomicClassifier)
plugin.register_formats(TaxonomicClassifierDirFmt,
TaxonomicClassiferTemporaryPickleDirFmt)
plugin.register_semantic_type_to_format(
TaxonomicClassifier,
artifact_format=TaxonomicClassiferTemporaryPickleDirFmt)
|