1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
|
{-# LANGUAGE PatternGuards, ScopedTypeVariables, BangPatterns, Trustworthy #-}
module Text.EditDistance.SquareSTUArray (
levenshteinDistance, levenshteinDistanceWithLengths, restrictedDamerauLevenshteinDistance, restrictedDamerauLevenshteinDistanceWithLengths
) where
import Text.EditDistance.EditCosts
import Text.EditDistance.MonadUtilities
import Text.EditDistance.ArrayUtilities
import Control.Monad hiding (foldM)
import Control.Monad.ST
import Data.Array.ST
levenshteinDistance :: EditCosts -> String -> String -> Int
levenshteinDistance !costs str1 str2 = levenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
levenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
levenshteinDistanceWithLengths !costs !str1_len !str2_len str1 str2 = runST (levenshteinDistanceST costs str1_len str2_len str1 str2)
levenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
levenshteinDistanceST !costs !str1_len !str2_len str1 str2 = do
-- Create string arrays
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
-- Create array of costs. Say we index it by (i, j) where i is the column index and j the row index.
-- Rows correspond to characters of str2 and columns to characters of str1.
cost_array <- newArray_ ((0, 0), (str1_len, str2_len)) :: ST s (STUArray s (Int, Int) Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
read_cost <- unsafeReadArray' cost_array
write_cost <- unsafeWriteArray' cost_array
-- Fill out the first row (j = 0)
_ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in write_cost (i, 0) deletion_cost' >> return (i + 1, deletion_cost')
-- Fill the remaining rows (j >= 1)
_ <- (\f -> foldM f 0 [1..str2_len]) $ \insertion_cost (!j) -> do
row_char <- read_str2 j
-- Initialize the first element of the row (i = 0)
let insertion_cost' = insertion_cost + insertionCost costs row_char
write_cost (0, j) insertion_cost'
-- Fill the remaining elements of the row (i >= 1)
loopM_ 1 str1_len $ \(!i) -> do
col_char <- read_str1 i
cost <- standardCosts costs read_cost row_char col_char (i, j)
write_cost (i, j) cost
return insertion_cost'
-- Return an actual answer
read_cost (str1_len, str2_len)
restrictedDamerauLevenshteinDistance :: EditCosts -> String -> String -> Int
restrictedDamerauLevenshteinDistance costs str1 str2 = restrictedDamerauLevenshteinDistanceWithLengths costs str1_len str2_len str1 str2
where
str1_len = length str1
str2_len = length str2
restrictedDamerauLevenshteinDistanceWithLengths :: EditCosts -> Int -> Int -> String -> String -> Int
restrictedDamerauLevenshteinDistanceWithLengths costs str1_len str2_len str1 str2 = runST (restrictedDamerauLevenshteinDistanceST costs str1_len str2_len str1 str2)
restrictedDamerauLevenshteinDistanceST :: EditCosts -> Int -> Int -> String -> String -> ST s Int
restrictedDamerauLevenshteinDistanceST !costs str1_len str2_len str1 str2 = do
-- Create string arrays
str1_array <- stringToArray str1 str1_len
str2_array <- stringToArray str2 str2_len
-- Create array of costs. Say we index it by (i, j) where i is the column index and j the row index.
-- Rows correspond to characters of str2 and columns to characters of str1.
cost_array <- newArray_ ((0, 0), (str1_len, str2_len)) :: ST s (STUArray s (Int, Int) Int)
read_str1 <- unsafeReadArray' str1_array
read_str2 <- unsafeReadArray' str2_array
read_cost <- unsafeReadArray' cost_array
write_cost <- unsafeWriteArray' cost_array
-- Fill out the first row (j = 0)
_ <- (\f -> foldM f (1, 0) str1) $ \(i, deletion_cost) col_char -> let deletion_cost' = deletion_cost + deletionCost costs col_char in write_cost (i, 0) deletion_cost' >> return (i + 1, deletion_cost')
-- Fill out the second row (j = 1)
when (str2_len > 0) $ do
initial_row_char <- read_str2 1
-- Initialize the first element of the second row (i = 0)
write_cost (0, 1) (insertionCost costs initial_row_char)
-- Initialize the remaining elements of the row (i >= 1)
loopM_ 1 str1_len $ \(!i) -> do
col_char <- read_str1 i
cost <- standardCosts costs read_cost initial_row_char col_char (i, 1)
write_cost (i, 1) cost
-- Fill the remaining rows (j >= 2)
loopM_ 2 str2_len (\(!j) -> do
row_char <- read_str2 j
prev_row_char <- read_str2 (j - 1)
-- Initialize the first element of the row (i = 0)
write_cost (0, j) (insertionCost costs row_char * j)
-- Initialize the second element of the row (i = 1)
when (str1_len > 0) $ do
col_char <- read_str1 1
cost <- standardCosts costs read_cost row_char col_char (1, j)
write_cost (1, j) cost
-- Fill the remaining elements of the row (i >= 2)
loopM_ 2 str1_len (\(!i) -> do
col_char <- read_str1 i
prev_col_char <- read_str1 (i - 1)
standard_cost <- standardCosts costs read_cost row_char col_char (i, j)
cost <- if prev_row_char == col_char && prev_col_char == row_char
then do transpose_cost <- fmap (+ (transpositionCost costs col_char row_char)) $ read_cost (i - 2, j - 2)
return (standard_cost `min` transpose_cost)
else return standard_cost
write_cost (i, j) cost))
-- Return an actual answer
read_cost (str1_len, str2_len)
{-# INLINE standardCosts #-}
standardCosts :: EditCosts -> ((Int, Int) -> ST s Int) -> Char -> Char -> (Int, Int) -> ST s Int
standardCosts !costs read_cost !row_char !col_char (!i, !j) = do
deletion_cost <- fmap (+ (deletionCost costs col_char)) $ read_cost (i - 1, j)
insertion_cost <- fmap (+ (insertionCost costs row_char)) $ read_cost (i, j - 1)
subst_cost <- fmap (+ if row_char == col_char
then 0
else (substitutionCost costs col_char row_char))
(read_cost (i - 1, j - 1))
return $ deletion_cost `min` insertion_cost `min` subst_cost
|