File: matrix-multiply.sml

package info (click to toggle)
mlton 20100608-2
  • links: PTS
  • area: main
  • in suites: squeeze
  • size: 34,980 kB
  • ctags: 69,089
  • sloc: ansic: 18,421; lisp: 2,879; makefile: 1,570; sh: 1,325; pascal: 256; asm: 97
file content (59 lines) | stat: -rw-r--r-- 1,611 bytes parent folder | download | duplicates (7)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
(* Written by Stephen Weeks (sweeks@sweeks.com). *)
structure Array = Array2
   
fun 'a fold (n : int, b : 'a, f : int * 'a -> 'a) =
   let
      fun loop (i : int, b : 'a) : 'a =
         if i = n
            then b
         else loop (i + 1, f (i, b))
   in loop (0, b)
   end

fun foreach (n : int, f : int -> unit) : unit =
   fold (n, (), f o #1)
      
fun mult (a1 : real Array.array, a2 : real Array.array) : real Array.array =
   let
      val r1 = Array.nRows a1
      val c1 = Array.nCols a1
      val r2 = Array.nRows a2
      val c2 = Array.nCols a2
   in if c1 <> r2
         then raise Fail "mult"
      else
         let val a = Array2.array (r1, c2, 0.0)
            fun dot (r, c) =
               fold (c1, 0.0, fn (i, sum) =>
                    sum + Array.sub (a1, r, i) * Array.sub (a2, i, c))
         in foreach (r1, fn r =>
                    foreach (c2, fn c =>
                            Array.update (a, r, c, dot (r,c))));
            a
         end
   end

structure Main =
   struct
      fun doit () =
         let
            val dim = 500
            val a = Array.tabulate Array.RowMajor (dim, dim, fn (r, c) =>
                                                   Real.fromInt (r + c))
         in
            if Real.== (41541750.0, Array2.sub (mult (a, a), 0, 0))
               then ()
            else raise Fail "bug"
         end
      
      val doit =
         fn size =>
         let
            fun loop n =
               if n = 0
                  then ()
               else (doit ();
                     loop (n-1))
         in loop size
         end
   end