File: MLEngine.worker.mjs

package info (click to toggle)
firefox 143.0.3-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 4,617,328 kB
  • sloc: cpp: 7,478,492; javascript: 6,417,157; ansic: 3,720,058; python: 1,396,372; xml: 627,523; asm: 438,677; java: 186,156; sh: 63,477; makefile: 19,171; objc: 13,059; perl: 12,983; yacc: 4,583; cs: 3,846; pascal: 3,405; lex: 1,720; ruby: 1,003; exp: 762; php: 436; lisp: 258; awk: 247; sql: 66; sed: 53; csh: 10
file content (143 lines) | stat: -rw-r--r-- 4,634 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

const lazy = {};

ChromeUtils.defineESModuleGetters(
  lazy,
  {
    PromiseWorker: "resource://gre/modules/workers/PromiseWorker.mjs",
    getBackend: "chrome://global/content/ml/backends/Pipeline.mjs",
    OPFS: "chrome://global/content/ml/OPFS.sys.mjs",
    generateUUID: "chrome://global/content/ml/Utils.sys.mjs",
  },
  { global: "current" }
);

/**
 * The actual MLEngine lives here in a worker.
 */
class MLEngineWorker {
  #pipeline;
  #sessionId;

  constructor() {
    // Connect the provider to the worker.
    this.#connectToPromiseWorker();
  }

  /**
   * Implements the `match` function from the Cache API for Transformers.js custom cache.
   *
   * See https://developer.mozilla.org/en-US/docs/Web/API/Cache
   *
   * Attempts to match and retrieve a model file based on a provided key.
   * Fetches a model file by delegating the call to the worker's main thread.
   * Then wraps the fetched model file into a response object compatible with Transformers.js expectations.
   *
   * @param {string} key The unique identifier for the model to fetch.
   * @returns {Promise<Response|null>} A promise that resolves with a Response object containing the model file or null if not found.
   */
  async match(key) {
    // if the key starts with NO_LOCAL, we return null immediately to tell transformers.js
    // we don't server local files, and it will do a second call with the full URL
    if (key.startsWith("NO_LOCAL")) {
      return null;
    }
    let res = await this.getModelFile(key);
    if (res.fail) {
      return null;
    }

    // Transformers.js expects a response object, so we wrap the array buffer
    return lazy.OPFS.toResponse(res.ok[2], res.ok[1]);
  }

  async getModelFile(...args) {
    let result = await self.callMainThread("getModelFile", [
      ...args,
      this.#sessionId,
    ]);
    return result;
  }

  async notifyModelDownloadComplete() {
    return self.callMainThread("notifyModelDownloadComplete", [
      this.#sessionId,
    ]);
  }

  /**
   * Placeholder for the `put` method from the Cache API for Transformers.js custom cache.
   *
   * @throws {Error} Always thrown to indicate the method is not implemented.
   */
  put() {
    throw new Error("Method not implemented.");
  }

  /**
   * @param {ArrayBuffer} wasm
   * @param {object} options received as an object, converted to a PipelineOptions instance
   */
  async initializeEngine(wasm, options) {
    this.#sessionId = lazy.generateUUID();
    this.#pipeline = await lazy
      .getBackend(this, wasm, options)
      .finally(async () => {
        // Notifying here means the backend doesn't need to notify. But the backend could notify
        // so that we receive completion as soon as possible. Otherwise, we receive download completion
        // once pipeline is fully initialized.
        await this.notifyModelDownloadComplete();
      });
  }
  /**
   * Run the worker.
   *
   * @param {string} request
   * @param {string} requestId - The identifier used to internally track this request.
   * @param {object} engineRunOptions - Additional run options for the engine.
   * @param {boolean} engineRunOptions.enableInferenceProgress - Whether to enable inference progress.
   */
  async run(request, requestId, engineRunOptions = {}) {
    if (request === "throw") {
      throw new Error(
        'Received the message "throw", so intentionally throwing an error.'
      );
    }

    return await this.#pipeline.run(
      request,
      requestId,
      engineRunOptions.enableInferenceProgress
        ? data => self.callMainThread("onInferenceProgress", [data])
        : null
    );
  }

  /**
   * Glue code to connect the `MLEngineWorker` to the PromiseWorker interface.
   */
  #connectToPromiseWorker() {
    const worker = new lazy.PromiseWorker.AbstractWorker();
    worker.dispatch = (method, args = []) => {
      if (!this[method]) {
        throw new Error("Method does not exist: " + method);
      }
      return this[method](...args);
    };
    worker.close = () => self.close();
    worker.postMessage = (message, ...transfers) => {
      self.postMessage(message, ...transfers);
    };

    self.callMainThread = worker.callMainThread.bind(worker);
    self.addEventListener("message", msg => worker.handleMessage(msg));
    self.addEventListener("unhandledrejection", function (error) {
      throw error.reason?.fail ?? error.reason;
    });
  }
}

new MLEngineWorker();