1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
|
{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
module Text.EditDistance.STUArray (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities
import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !costs str1 str2 = levenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (levenshteinDistanceST costs str1_len str2_len str1 str2)
levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !costs !str1_len !str2_len str1 str2 = do
-- Create string arrays
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
-- Create array of costs for a single row. Say we index costs by (i, j) where i is the column index and j the row index.
-- Rows correspond to characters of str2 and columns to characters of str1. We can get away with just storing a single
-- row of costs at a time, but we use two because it turns out to be faster
start_cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
start_cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
-- Fill out the first row (j = 0)
_ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in unsafeWriteArray start_cost_row i deletion_cost' >> return (i + 1, deletion_cost')
-- Fill out the remaining rows (j >= 1)
(_, final_row, _) <- (\f -> foldM f (0, start_cost_row, start_cost_row') [1..str2_len]) $ \(!insertion_cost, !cost_row, !cost_row') !j -> do
row_char <- read_str2 j
-- Initialize the first element of the row (i = 0)
let insertion_cost' = insertion_cost + insertionCost costs row_char
unsafeWriteArray cost_row' 0 insertion_cost'
-- Fill the remaining elements of the row (i >= 1)
loopM_ 1 str1_len $ \(!i) -> do
col_char <- read_str1 i
left_up <- unsafeReadArray cost_row (i - 1)
left <- unsafeReadArray cost_row' (i - 1)
here_up <- unsafeReadArray cost_row i
let here = standardCosts costs row_char col_char left left_up here_up
unsafeWriteArray cost_row' i here
return (insertion_cost', cost_row', cost_row)
-- Return an actual answer
unsafeReadArray final_row str1_len
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance !costs str1 str2 = restrictedDamerauLevenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (restrictedDamerauLevenshteinDistanceST costs str1_len str2_len str1 str2)
restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !costs str1_len str2_len str1 str2 = do
-- Create string arrays
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
-- Create array of costs for a single row. Say we index costs by (i, j) where i is the column index and j the row index.
-- Rows correspond to characters of str2 and columns to characters of str1. We can get away with just storing two
-- rows of costs at a time, but I use three because it turns out to be faster
cost_row <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
-- Fill out the first row (j = 0)
_ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in unsafeWriteArray cost_row i deletion_cost' >> return (i + 1, deletion_cost')
if (str2_len == 0)
then unsafeReadArray cost_row str1_len
else do
-- We defer allocation of these arrays to here because they aren't used in the other branch
cost_row' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
cost_row'' <- newArray_ (0, str1_len) :: ST s (STUArray s Int Int)
-- Fill out the second row (j = 1)
row_char <- read_str2 1
-- Initialize the first element of the row (i = 0)
let zero = insertionCost costs row_char
unsafeWriteArray cost_row' 0 zero
-- Fill the remaining elements of the row (i >= 1)
loopM_ 1 str1_len (firstRowColWorker read_str1 row_char cost_row cost_row')
-- Fill out the remaining rows (j >= 2)
(_, _, final_row, _, _) <- foldM (restrictedDamerauLevenshteinDistanceSTRowWorker costs str1_len read_str1 read_str2) (zero, cost_row, cost_row', cost_row'', row_char) [2..str2_len]
-- Return an actual answer
unsafeReadArray final_row str1_len
where
{-# INLINE firstRowColWorker #-}
firstRowColWorker read_str1 !row_char !cost_row !cost_row' !i = do
col_char <- read_str1 i
left_up <- unsafeReadArray cost_row (i - 1)
left <- unsafeReadArray cost_row' (i - 1)
here_up <- unsafeReadArray cost_row i
let here = standardCosts costs row_char col_char left left_up here_up
unsafeWriteArray cost_row' i here
{-# INLINE restrictedDamerauLevenshteinDistanceSTRowWorker #-}
restrictedDamerauLevenshteinDistanceSTRowWorker :: EditCosts -> Int
-> (Int -> ST s Char) -> (Int -> ST s Char) -- String array accessors
-> (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char) -> Int -- Incoming rows of the matrix in recency order
-> ST s (Int, STUArray s Int Int, STUArray s Int Int, STUArray s Int Int, Char) -- Outgoing rows of the matrix in recency order
restrictedDamerauLevenshteinDistanceSTRowWorker !costs !str1_len read_str1 read_str2 (!insertion_cost, !cost_row, !cost_row', !cost_row'', !prev_row_char) !j = do
row_char <- read_str2 j
-- Initialize the first element of the row (i = 0)
zero_up <- unsafeReadArray cost_row' 0
let insertion_cost' = insertion_cost + insertionCost costs row_char
unsafeWriteArray cost_row'' 0 insertion_cost'
-- Initialize the second element of the row (i = 1)
when (str1_len > 0) $ do
col_char <- read_str1 1
one_up <- unsafeReadArray cost_row' 1
let one = standardCosts costs row_char col_char insertion_cost' zero_up one_up
unsafeWriteArray cost_row'' 1 one
-- Fill the remaining elements of the row (i >= 2)
loopM_ 2 str1_len (colWorker row_char)
return (insertion_cost', cost_row', cost_row'', cost_row, row_char)
where
colWorker !row_char !i = do
prev_col_char <- read_str1 (i - 1)
col_char <- read_str1 i
left_left_up_up <- unsafeReadArray cost_row (i - 2)
left_up <- unsafeReadArray cost_row' (i - 1)
left <- unsafeReadArray cost_row'' (i - 1)
here_up <- unsafeReadArray cost_row' i
let here_standard_only = standardCosts costs row_char col_char left left_up here_up
here = if prev_row_char == col_char && prev_col_char == row_char
then here_standard_only `min` (left_left_up_up + transpositionCost costs col_char row_char)
else here_standard_only
unsafeWriteArray cost_row'' i here
{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> Char -> Char -> Int -> Int -> Int -> Int
standardCosts !costs !row_char !col_char !cost_left !cost_left_up !cost_up = deletion_cost `min` insertion_cost `min` subst_cost
where
deletion_cost = cost_left + deletionCost costs col_char
insertion_cost = cost_up + insertionCost costs row_char
subst_cost = cost_left_up + if row_char == col_char then 0 else substitutionCost costs col_char row_char
|