File: _taxonomic_classifier.py

package info (click to toggle)
q2-feature-classifier 2024.2.0-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 1,504 kB
  • sloc: python: 3,730; makefile: 38; sh: 16
file content (124 lines) | stat: -rw-r--r-- 4,137 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2023, QIIME 2 development team.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import json
import tarfile
import os

import sklearn
import joblib
from sklearn.pipeline import Pipeline
import qiime2.plugin
import qiime2.plugin.model as model

from .plugin_setup import plugin


# Semantic Types
TaxonomicClassifier = qiime2.plugin.SemanticType('TaxonomicClassifier')


# Formats
class PickleFormat(model.BinaryFileFormat):
    def sniff(self):
        return tarfile.is_tarfile(str(self))


# https://github.com/qiime2/q2-types/issues/49
class JSONFormat(model.TextFileFormat):
    def sniff(self):
        with self.open() as fh:
            try:
                json.load(fh)
                return True
            except json.JSONDecodeError:
                pass
        return False


class TaxonomicClassifierDirFmt(model.DirectoryFormat):
    preprocess_params = model.File('preprocess_params.json', format=JSONFormat)
    sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat)


class TaxonomicClassiferTemporaryPickleDirFmt(model.DirectoryFormat):
    version_info = model.File('sklearn_version.json', format=JSONFormat)
    sklearn_pipeline = model.File('sklearn_pipeline.tar', format=PickleFormat)


# Transformers
@plugin.register_transformer
def _1(dirfmt: TaxonomicClassiferTemporaryPickleDirFmt) -> Pipeline:
    sklearn_version = dirfmt.version_info.view(dict)['sklearn-version']
    if sklearn_version != sklearn.__version__:
        raise ValueError('The scikit-learn version (%s) used to generate this'
                         ' artifact does not match the current version'
                         ' of scikit-learn installed (%s). Please retrain your'
                         ' classifier for your current deployment to prevent'
                         ' data-corruption errors.'
                         % (sklearn_version, sklearn.__version__))

    sklearn_pipeline = dirfmt.sklearn_pipeline.view(PickleFormat)

    with tarfile.open(str(sklearn_pipeline)) as tar:
        tmpdir = model.DirectoryFormat()
        dirname = str(tmpdir)
        tar.extractall(dirname)
        pipeline = joblib.load(os.path.join(dirname, 'sklearn_pipeline.pkl'))
        for fn in tar.getnames():
            os.unlink(os.path.join(dirname, fn))

    return pipeline


@plugin.register_transformer
def _2(data: Pipeline) -> TaxonomicClassiferTemporaryPickleDirFmt:
    sklearn_pipeline = PickleFormat()
    with tarfile.open(str(sklearn_pipeline), 'w') as tar:
        tmpdir = model.DirectoryFormat()
        pf = os.path.join(str(tmpdir), 'sklearn_pipeline.pkl')
        for fn in joblib.dump(data, pf):
            tar.add(fn, os.path.basename(fn))
            os.unlink(fn)

    dirfmt = TaxonomicClassiferTemporaryPickleDirFmt()
    dirfmt.version_info.write_data(
        {'sklearn-version': sklearn.__version__}, dict)
    dirfmt.sklearn_pipeline.write_data(sklearn_pipeline, PickleFormat)

    return dirfmt


@plugin.register_transformer
def _3(dirfmt: TaxonomicClassifierDirFmt) -> Pipeline:
    raise ValueError('The scikit-learn version could not be determined for'
                     ' this artifact, please retrain your classifier for your'
                     ' current deployment to prevent data-corruption errors.')


@plugin.register_transformer
def _4(fmt: JSONFormat) -> dict:
    with fmt.open() as fh:
        return json.load(fh)


@plugin.register_transformer
def _5(data: dict) -> JSONFormat:
    result = JSONFormat()
    with result.open() as fh:
        json.dump(data, fh)
    return result


# Registrations
plugin.register_semantic_types(TaxonomicClassifier)
plugin.register_formats(TaxonomicClassifierDirFmt,
                        TaxonomicClassiferTemporaryPickleDirFmt)
plugin.register_semantic_type_to_format(
    TaxonomicClassifier,
    artifact_format=TaxonomicClassiferTemporaryPickleDirFmt)