File: STUArray.hs

package info (click to toggle)
haskell-edit-distance 0.2.2.1-15
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 156 kB
  • sloc: haskell: 682; makefile: 3
file content (173 lines) | stat: -rw-r--r-- 8,700 bytes parent folder | download | duplicates (6)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}

module Text.EditDistance.STUArray (
        levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
    ) where

import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities

import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST


levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !costs str1 str2 = levenshteinDistanceWithLengths costs str1_len str2_len str1 str2
  where
    str1_len = length str1
    str2_len = length str2

levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (levenshteinDistanceST costs str1_len str2_len str1 str2)

levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !costs !str1_len !str2_len str1 str2 = do
    -- Create string arrays
    str1_array <- stringToArray str1 str1_len
    str2_array <- stringToArray str2 str2_len

    -- Create array of costs for a single row. Say we index costs by (i, j) where i is the column index and j the row index.
    -- Rows correspond to characters of str2 and columns to characters of str1. We can get away with just storing a single
    -- row of costs at a time, but we use two because it turns out to be faster
    start_cost_row  <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
    start_cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)

    read_str1 <- unsafeReadArray' str1_array
    read_str2 <- unsafeReadArray' str2_array

     -- Fill out the first row (j = 0)
    _ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in unsafeWriteArray start_cost_row i deletion_cost' >> return (i + 1, deletion_cost')

    -- Fill out the remaining rows (j >= 1)
    (_, final_row, _) <- (\f -> foldM f (0, start_cost_row, start_cost_row') [1..str2_len]) $ \(!insertion_cost, !cost_row, !cost_row') !j -> do
        row_char <- read_str2 j

        -- Initialize the first element of the row (i = 0)
        let insertion_cost' = insertion_cost + insertionCost costs row_char
        unsafeWriteArray cost_row' 0 insertion_cost'

        -- Fill the remaining elements of the row (i >= 1)
        loopM_ 1 str1_len $ \(!i) -> do
            col_char <- read_str1 i

            left_up <- unsafeReadArray cost_row  (i - 1)
            left    <- unsafeReadArray cost_row' (i - 1)
            here_up <- unsafeReadArray cost_row i
            let here = standardCosts costs row_char col_char left left_up here_up
            unsafeWriteArray cost_row' i here

        return (insertion_cost', cost_row', cost_row)


    -- Return an actual answer
    unsafeReadArray final_row str1_len

restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance !costs str1 str2 = restrictedDamerauLevenshteinDistanceWithLengths costs str1_len str2_len str1 str2
  where
    str1_len = length str1
    str2_len = length str2

restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (restrictedDamerauLevenshteinDistanceST costs str1_len str2_len str1 str2)

restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !costs str1_len str2_len str1 str2 = do
    -- Create string arrays
    str1_array <- stringToArray str1 str1_len
    str2_array <- stringToArray str2 str2_len

    -- Create array of costs for a single row. Say we index costs by (i, j) where i is the column index and j the row index.
    -- Rows correspond to characters of str2 and columns to characters of str1. We can get away with just storing two
    -- rows of costs at a time, but I use three because it turns out to be faster
    cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)

    read_str1 <- unsafeReadArray' str1_array
    read_str2 <- unsafeReadArray' str2_array

    -- Fill out the first row (j = 0)
    _ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in unsafeWriteArray cost_row i deletion_cost' >> return (i + 1, deletion_cost')

    if (str2_len == 0)
      then unsafeReadArray cost_row str1_len
      else do
        -- We defer allocation of these arrays to here because they aren't used in the other branch
        cost_row'  <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
        cost_row'' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)

        -- Fill out the second row (j = 1)
        row_char <- read_str2 1

        -- Initialize the first element of the row (i = 0)
        let zero = insertionCost costs row_char
        unsafeWriteArray cost_row' 0 zero

        -- Fill the remaining elements of the row (i >= 1)
        loopM_ 1 str1_len (firstRowColWorker read_str1 row_char cost_row cost_row')

        -- Fill out the remaining rows (j >= 2)
        (_, _, final_row, _, _) <- foldM (restrictedDamerauLevenshteinDistanceSTRowWorker costs str1_len read_str1 read_str2) (zero, cost_row, cost_row', cost_row'', row_char) [2..str2_len]

        -- Return an actual answer
        unsafeReadArray final_row str1_len
  where
    {-# INLINE firstRowColWorker #-}
    firstRowColWorker read_str1 !row_char !cost_row !cost_row' !i = do
        col_char <- read_str1 i

        left_up <- unsafeReadArray cost_row  (i - 1)
        left    <- unsafeReadArray cost_row' (i - 1)
        here_up <- unsafeReadArray cost_row  i
        let here = standardCosts costs row_char col_char left left_up here_up
        unsafeWriteArray cost_row' i here

{-# INLINE restrictedDamerauLevenshteinDistanceSTRowWorker #-}
restrictedDamerauLevenshteinDistanceSTRowWorker :: EditCosts -> Int
                                                -> (Int -> ST s Char) -> (Int -> ST s Char) -- String array accessors
                                                -> (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char) -> Int -- Incoming rows of the matrix in recency order
                                                -> ST s (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char)   -- Outgoing rows of the matrix in recency order
restrictedDamerauLevenshteinDistanceSTRowWorker !costs !str1_len read_str1 read_str2 (!insertion_cost, !cost_row, !cost_row', !cost_row'', !prev_row_char) !j = do
    row_char <- read_str2 j

    -- Initialize the first element of the row (i = 0)
    zero_up    <- unsafeReadArray cost_row' 0
    let insertion_cost' = insertion_cost + insertionCost costs row_char
    unsafeWriteArray cost_row'' 0 insertion_cost'

    -- Initialize the second element of the row (i = 1)
    when (str1_len > 0) $ do
        col_char <- read_str1 1
        one_up   <- unsafeReadArray cost_row' 1
        let one = standardCosts costs row_char col_char insertion_cost' zero_up one_up
        unsafeWriteArray cost_row'' 1 one

        -- Fill the remaining elements of the row (i >= 2)
        loopM_ 2 str1_len (colWorker row_char)

    return (insertion_cost', cost_row', cost_row'', cost_row, row_char)
  where
    colWorker !row_char !i = do
        prev_col_char <- read_str1 (i - 1)
        col_char <- read_str1 i

        left_left_up_up <- unsafeReadArray cost_row (i - 2)
        left_up    <- unsafeReadArray cost_row'  (i - 1)
        left       <- unsafeReadArray cost_row'' (i - 1)
        here_up    <- unsafeReadArray cost_row' i
        let here_standard_only = standardCosts costs row_char col_char left left_up here_up
            here = if prev_row_char == col_char && prev_col_char == row_char
                   then here_standard_only `min` (left_left_up_up + transpositionCost costs col_char row_char)
                   else here_standard_only

        unsafeWriteArray cost_row'' i here


{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts !costs !row_char !col_char !cost_left !cost_left_up !cost_up = deletion_cost `min` insertion_cost `min` subst_cost
  where
    deletion_cost  = cost_left + deletionCost costs col_char
    insertion_cost = cost_up + insertionCost costs row_char
    subst_cost     = cost_left_up + if row_char == col_char then 0 else substitutionCost costs col_char row_char