1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
|
open Printf
type worker_info = {
worker_id : int;
worker_loop : 'a. unit -> 'a;
}
exception Start_worker of worker_info
let log_error = ref (fun s -> eprintf "[err] %s\n%!" s)
let log_info = ref (fun s -> eprintf "[info] %s\n%!" s)
let string_of_exn = ref Printexc.to_string
let report_error msg =
try !log_error msg
with e ->
eprintf "%s\n" msg;
eprintf "*** Critical error *** Error logger raised an exception:\n%s\n%!"
(Printexc.to_string e)
let report_info msg =
try !log_info msg
with e ->
eprintf "%s\n" msg;
eprintf "*** Critical error *** Info logger raised an exception:\n%s\n%!"
(Printexc.to_string e)
(* Get the n first elements of the stream as a reversed list. *)
let rec npop acc n strm =
if n > 0 then
match Stream.peek strm with
None -> acc
| Some x ->
Stream.junk strm;
npop (x :: acc) (n-1) strm
else
acc
(* Chunkify stream; each chunk is in reverse order. *)
let chunkify n strm =
Stream.from (
fun _ ->
match npop [] n strm with
[] -> None
| l -> Some l
)
module Full =
struct
type worker = {
worker_pid : int;
worker_in : Lwt_unix.file_descr;
worker_out : Lwt_unix.file_descr;
}
type ('b, 'c) from_worker =
Worker_res of 'b
| Central_req of 'c
| Worker_error of string
type ('a, 'b, 'c, 'd, 'e) to_worker =
Worker_req of (('c -> 'd) -> 'e -> 'a -> 'b) * 'a
| Central_res of 'd
(* --worker-- *)
(* executed in worker processes right after the fork or in
the master when closing the process pool.
It closes the master side of the pipes. *)
let close_worker x =
Unix.close (Lwt_unix.unix_file_descr x.worker_in);
Unix.close (Lwt_unix.unix_file_descr x.worker_out)
(* --worker-- *)
let cleanup_proc_pool a =
for i = 0 to Array.length a - 1 do
match a.(i) with
None -> ()
| Some x ->
close_worker x;
a.(i) <- None
done
(* Exception raised by f *)
let user_error1 e =
sprintf "Exception raised by Nproc task: %s" (!string_of_exn e)
(* Exception raised by g *)
let user_error2 e =
sprintf "Error while handling result of Nproc task: exception %s"
(!string_of_exn e)
(* --worker-- *)
let start_worker_loop worker_data fd_in fd_out =
let ic = Unix.in_channel_of_descr fd_in in
let oc = Unix.out_channel_of_descr fd_out in
let central_service x =
Marshal.to_channel oc (Central_req x) [Marshal.Closures];
flush oc;
match Marshal.from_channel ic with
Central_res y -> y
| Worker_req _ -> assert false
in
while true do
let result =
try
match Marshal.from_channel ic with
Worker_req (f, x) ->
(try Worker_res (f central_service worker_data x)
with e -> Worker_error (user_error1 e)
)
| Central_res _ -> assert false
with
End_of_file -> exit 0
| e ->
let msg =
sprintf "Internal error in Nproc worker: %s" (!string_of_exn e)
in
Worker_error msg
in
try
Marshal.to_channel oc result [Marshal.Closures];
flush oc
with Sys_error "Broken pipe" ->
exit 0
done;
assert false
let write_value oc x =
Lwt.bind
(Lwt_io.write_value oc ~flags:[Marshal.Closures] x)
(fun () -> Lwt_io.flush oc)
type in_t = Obj.t
type out_t = Obj.t
type ('a, 'b, 'c) t = {
stream :
((('a -> 'b) -> 'c -> in_t -> out_t)
* in_t
* (out_t option -> unit))
Lwt_stream.t;
push :
(((('a -> 'b) -> 'c -> in_t -> out_t)
* in_t
* (out_t option -> unit))
option -> unit);
kill_workers : unit -> unit;
close : unit -> unit Lwt.t;
closed : bool ref;
}
let rec waitpid pid =
try Unix.waitpid [] pid
with Unix.Unix_error (Unix.EINTR, _, _) -> waitpid pid
(* --master-- *)
let pull_task kill_workers in_stream in_stream_mutex central_service worker =
(* Note: input and output file descriptors are automatically closed
when the end of the lwt channel is reached. *)
let ic = Lwt_io.of_fd ~mode:Lwt_io.input worker.worker_in in
let oc = Lwt_io.of_fd ~mode:Lwt_io.output worker.worker_out in
let rec pull () =
Lwt.bind (Lwt_mutex.with_lock in_stream_mutex (fun () -> Lwt_stream.get in_stream)) (
function
| None -> Lwt.return ()
| Some (f, x, g) ->
let req = Worker_req (f, x) in
Lwt.bind
(write_value oc req)
(read_from_worker g)
)
and read_from_worker g () =
Lwt.try_bind
(fun () -> Lwt_io.read_value ic)
(handle_input g)
(fun e ->
let msg =
sprintf "Cannot read from Nproc worker: exception %s"
(!string_of_exn e)
in
report_error msg;
kill_workers ();
exit 1
)
and handle_input g = function
Worker_res result ->
(try
g (Some result)
with e ->
report_error (user_error2 e)
);
pull ()
| Central_req x ->
Lwt.bind (central_service x) (
fun y ->
let res = Central_res y in
Lwt.bind
(write_value oc res)
(read_from_worker g)
)
| Worker_error msg ->
report_error msg;
(try
g None
with e ->
report_error (user_error2 e)
);
pull ()
in
pull ()
(* --master-- *)
let create_gen init ((in_stream, push), in_stream_mutex) nproc central_service worker_data =
let proc_pool = Array.make nproc None in
Array.iteri (
fun i _ ->
let (in_read, in_write) = Lwt_unix.pipe_in () in
let (out_read, out_write) = Lwt_unix.pipe_out () in
match Unix.fork () with
0 ->
(try
Unix.close (Lwt_unix.unix_file_descr in_read);
Unix.close (Lwt_unix.unix_file_descr out_write);
cleanup_proc_pool proc_pool;
let start () =
start_worker_loop worker_data out_read in_write
in
init { worker_id = i; worker_loop = start };
start ()
with e ->
match e with
Start_worker start -> raise e
| _ ->
!log_error
(sprintf "Uncaught exception in worker (pid %i): %s"
(Unix.getpid ()) (!string_of_exn e));
exit 1
)
| child_pid ->
Unix.close in_write;
Unix.close out_read;
proc_pool.(i) <-
Some {
worker_pid = child_pid;
worker_in = in_read;
worker_out = out_write;
}
) proc_pool;
(*
Create nproc lightweight threads.
Each lightweight thread pull tasks from the stream and feeds its worker
until the stream is empty.
*)
let worker_info =
Array.to_list
(Array.map (function Some x -> x | None -> assert false) proc_pool)
in
let kill_workers () =
Array.iter (
function
None -> ()
| Some x ->
(try close_worker x with _ -> ());
(try
Unix.kill x.worker_pid Sys.sigkill;
ignore (waitpid x.worker_pid)
with e ->
!log_error
(sprintf "kill worker %i: %s"
x.worker_pid (!string_of_exn e)))
) proc_pool
in
let jobs =
Lwt.join
(List.map
(pull_task kill_workers in_stream in_stream_mutex central_service)
worker_info)
in
let closed = ref false in
let close_stream () =
if not !closed then (
push None;
closed := true;
Lwt.bind jobs (fun () -> Lwt.return (kill_workers ()))
)
else
Lwt.return ()
in
let p = {
stream = in_stream;
push = push;
kill_workers = kill_workers;
close = close_stream;
closed = closed;
}
in
p, jobs
let default_init worker_info = ()
let create ?(init = default_init) nproc central_service worker_data =
create_gen init (Lwt_stream.create (), Lwt_mutex.create ()) nproc central_service worker_data
let close p =
p.close ()
let terminate p =
p.closed := true;
p.kill_workers ()
let submit p ~f x =
if !(p.closed) then
Lwt.fail (Failure
("Cannot submit task to process pool because it is closed"))
else
let waiter, wakener = Lwt.task () in
let handle_result y = Lwt.wakeup wakener y in
p.push
(Some (Obj.magic f, Obj.magic x, Obj.magic handle_result));
waiter
let stream_pop x =
let o = Stream.peek x in
(match o with
None -> ()
| Some _ -> Stream.junk x
);
o
let lwt_of_stream f g strm =
Lwt_stream.from (
fun () ->
let elt =
match stream_pop strm with
None -> None
| Some x -> Some (Obj.magic f, Obj.magic x, Obj.magic g)
in
Lwt.return elt
)
type 'a result_or_error = Result of 'a | Error of string
let iter_stream
?(granularity = 1)
?(init = default_init)
~nproc ~serv ~env ~f ~g in_stream =
if granularity <= 0 then
invalid_arg (sprintf "Nproc.iter_stream: granularity=%i" granularity)
else
let task_stream =
if granularity = 1 then
lwt_of_stream f g in_stream
else
let in_stream' = chunkify granularity in_stream in
let f' central_service worker_data l =
List.rev_map (
fun x ->
try Result (f central_service worker_data x)
with e -> Error (user_error1 e)
) l
in
let g' = function
None ->
report_error "Nproc error: missing result due to an internal \
error in Nproc or due to a killed worker process"
| Some l ->
List.iter (
function
Result y ->
(try
g (Some y)
with e ->
report_error (user_error2 e)
)
| Error s ->
report_error s;
(try
g None
with e ->
report_error (user_error2 e)
)
) l
in
lwt_of_stream f' g' in_stream'
in
let p, t =
create_gen init
((task_stream,
(fun _ -> assert false) (* push *)),
Lwt_mutex.create ())
nproc serv env
in
try
Lwt_main.run t;
p.kill_workers ();
with e ->
p.kill_workers ();
raise e
end
type t = (unit, unit, unit) Full.t
let create ?init n =
Full.create ?init n (fun () -> Lwt.return ()) ()
let close = Full.close
let terminate = Full.terminate
let submit p ~f x =
Full.submit p ~f: (fun _ _ x -> f x) x
let iter_stream ?granularity ?init ~nproc ~f ~g strm =
Full.iter_stream
?granularity
?init
~nproc
~env: ()
~serv: (fun () -> Lwt.return ())
~f: (fun serv env x -> f x)
~g
strm
|