File: ngramcount.pyx

package info (click to toggle)
python-bx 0.13.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 5,000 kB
  • sloc: python: 17,136; ansic: 2,326; makefile: 24; sh: 8
file content (85 lines) | stat: -rw-r--r-- 3,818 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""
Tools for counting words (n-grams) in integer sequences.
"""

import numpy


cdef extern from "Python.h":
    ctypedef int Py_intptr_t

# cdef extern from "numpy/npy_3kcompat.h":
    # NOTE: including npy_3kcompat.h did not compile,
    #       so use the explicitly extracted function from here:
cdef extern from "npy_capsule_as_void_ptr.h":
    void * NpyCapsule_AsVoidPtr(object) except NULL

# for PyArrayInterface:
CONTIGUOUS=0x01
FORTRAN=0x02
ALIGNED=0x100
NOTSWAPPED=0x200
WRITEABLE=0x400

ctypedef struct PyArrayInterface:
    int two              # contains the integer 2 as a sanity check
    int nd               # number of dimensions
    char typekind        # kind in array --- character code of typestr
    int itemsize         # size of each element
    int flags            # flags indicating how the data should be interpreted
    Py_intptr_t *shape   # A length-nd array of shape information
    Py_intptr_t *strides # A length-nd array of stride information
    void *data           # A pointer to the first element of the array
    
def count_ngrams( object ints, int n, int radix ):
    """
    Count the number of occurrences of each possible length `n` word in 
    `ints` (which contains values from 0 to `radix`). Returns an array
    of length `radix` ** `n` containing the counts.
    """
    cdef PyArrayInterface * ints_desc
    cdef PyArrayInterface * rval_desc
    # Get array interface for input string and validate
    ints_desc_obj = ints.__array_struct__
    ints_desc = <PyArrayInterface *> NpyCapsule_AsVoidPtr( ints_desc_obj )
    assert ints_desc.two == 2, "Array interface sanity check failed, got %d" % ints_desc.two
    assert ints_desc.nd == 1, "Input array must be 1d"
    assert ints_desc.typekind == 'i'[0], "Input array must contain integers"
    assert ints_desc.itemsize == 4, "Input array must contain 32bit integers"
    assert ints_desc.flags & CONTIGUOUS > 0, "Input array must be contiguous"
    assert ints_desc.flags & ALIGNED > 0, "Input array must be aligned"
    assert ints_desc.flags & NOTSWAPPED > 0, "Input array must not be byteswapped"
    # Create numpy array for return value, get array interface and validate
    rval = numpy.zeros( <int> ( ( <float> radix ) ** n ), dtype=numpy.int32 )
    assert ints_desc.two == 2, "Array interface sanity check failed, got %d" % ints_desc.two
    rval_desc_obj = rval.__array_struct__
    rval_desc = <PyArrayInterface *> NpyCapsule_AsVoidPtr( rval_desc_obj )
    assert rval_desc.two == 2, "Array interface sanity check failed"
    assert rval_desc.nd == 1, "Input array must be 1d"
    assert rval_desc.typekind == 'i'[0], "Input array must contain integers"
    assert rval_desc.itemsize == 4, "Input array must contain 32bit integers"
    assert rval_desc.flags & CONTIGUOUS > 0, "Input array must be contiguous"
    assert rval_desc.flags & ALIGNED > 0, "Input array must be aligned"
    assert rval_desc.flags & NOTSWAPPED > 0, "Input array must not be byteswapped"
    # Do it
    _count_ngrams( <int*> ints_desc.data, ints_desc.shape[0], <int*> rval_desc.data, n, radix )
    return rval
    
cdef _count_ngrams( int* ints, int n_ints, int* rval, int n, int radix ):
    cdef int i, j, index, factor, letter
    # Loop over each word in the string
    for i from 0 <= i < ( n_ints - n ):
        # Walk back to build index into count array
        index = 0
        factor = 1
        for j from 0 <= j < n:
            letter = ints[ i + j ]
            if letter < 0 or letter >= radix:
                # This word is bad, break out and do not increment counts
                print("breaking, letter", letter)
                break
            index = index + letter * factor
            factor = factor * radix
        else:
            print(index)
            rval[ index ] = rval[ index ] + 1