File: add_classifications.py

package info (click to toggle)
pplacer 1.1~alpha19-8
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 17,324 kB
  • sloc: ml: 20,927; ansic: 9,002; python: 1,641; makefile: 171; sh: 77; xml: 50
file content (88 lines) | stat: -rwxr-xr-x 3,332 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
import collections
import itertools
import argparse
import logging
import sqlite3
import csv
import sys

from taxtastic.taxtable import TaxNode
from taxtastic.refpkg import Refpkg

log = logging.getLogger(__name__)

def group_by_name_and_mass(rows):
    return itertools.groupby(rows, lambda row: (row['seqname'], int(row.get('mass', 1))))

def main():
    logging.basicConfig(
        level=logging.INFO, format="%(levelname)s: %(message)s")

    parser = argparse.ArgumentParser(
        description="Add classifications to a database.")
    parser.add_argument('refpkg', type=Refpkg,
                        help="refpkg containing input taxa")
    parser.add_argument('classification_db', type=sqlite3.connect,
                        help="output sqlite database")
    parser.add_argument('classifications', type=argparse.FileType('r'), nargs='?',
                        default=sys.stdin, help="input query sequences")

    args = parser.parse_args()

    log.info('loading taxonomy')
    taxtable = TaxNode.from_taxtable(args.refpkg.open_resource('taxonomy', 'rU'))
    rank_order = {rank: e for e, rank in enumerate(taxtable.ranks)}
    def full_lineage(node):
        rank_iter = reversed(taxtable.ranks)
        for n in reversed(node.lineage()):
            n_order = rank_order[n.rank]
            yield n, list(itertools.takewhile(lambda r: rank_order[r] >= n_order, rank_iter))

    def multiclass_rows(placement_id, seq, taxa):
        ret = collections.defaultdict(float)
        likelihood = 1. / len(taxa)
        for taxon in taxa:
            for node, want_ranks in full_lineage(taxtable.get_node(taxon)):
                for want_rank in want_ranks:
                    ret[placement_id, seq, want_rank, node.rank, node.tax_id] += likelihood
        for k, v in sorted(ret.items()):
            yield k + (v,)

    curs = args.classification_db.cursor()
    curs.execute('INSERT INTO runs (params) VALUES (?)', (' '.join(sys.argv),))
    run_id = curs.lastrowid

    log.info('inserting classifications')
    for (name, mass), rows in group_by_name_and_mass(csv.DictReader(args.classifications)):
        curs.execute('INSERT INTO placements (classifier, run_id) VALUES ("csv", ?)', (run_id,))
        placement_id = curs.lastrowid
        curs.execute(
            'INSERT INTO placement_names (placement_id, name, origin, mass) VALUES (?, ?, ?, ?)',
            (placement_id, name, args.classifications.name, mass))
        taxa = [row['tax_id'] for row in rows]
        curs.executemany('INSERT INTO multiclass VALUES (?, ?, ?, ?, ?, ?)',
                         multiclass_rows(placement_id, name, taxa))

    log.info('cleaning up `multiclass` table')
    curs.execute("""
        CREATE TEMPORARY TABLE duplicates AS
          SELECT name
            FROM multiclass
                 JOIN placements USING (placement_id)
           GROUP BY name
          HAVING SUM(run_id = ?)
                 AND COUNT(DISTINCT run_id) > 1
    """, (run_id,))
    curs.execute("""
        DELETE FROM multiclass
         WHERE (SELECT run_id
                  FROM placements p
                 WHERE p.placement_id = multiclass.placement_id) <> ?
           AND name IN (SELECT name
                          FROM duplicates)
    """, (run_id,))

    args.classification_db.commit()

main()