File: ml_tensor.h

package info (click to toggle)
chromium 138.0.7204.183-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 6,071,908 kB
  • sloc: cpp: 34,937,088; ansic: 7,176,967; javascript: 4,110,704; python: 1,419,953; asm: 946,768; xml: 739,971; pascal: 187,324; sh: 89,623; perl: 88,663; objc: 79,944; sql: 50,304; cs: 41,786; fortran: 24,137; makefile: 21,806; php: 13,980; tcl: 13,166; yacc: 8,925; ruby: 7,485; awk: 3,720; lisp: 3,096; lex: 1,327; ada: 727; jsp: 228; sed: 36
file content (137 lines) | stat: -rw-r--r-- 5,895 bytes parent folder | download | duplicates (5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_TENSOR_H_
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_TENSOR_H_

#include "base/timer/elapsed_timer.h"
#include "base/types/expected.h"
#include "base/types/pass_key.h"
#include "services/webnn/public/cpp/ml_tensor_usage.h"
#include "services/webnn/public/cpp/operand_descriptor.h"
#include "services/webnn/public/cpp/webnn_trace.h"
#include "services/webnn/public/mojom/webnn_context_provider.mojom-blink.h"
#include "services/webnn/public/mojom/webnn_tensor.mojom-blink.h"
#include "third_party/blink/renderer/bindings/core/v8/script_promise_resolver.h"
#include "third_party/blink/renderer/bindings/modules/v8/v8_ml_operand_data_type.h"
#include "third_party/blink/renderer/modules/ml/webnn/allow_shared_buffer_source_util.h"
#include "third_party/blink/renderer/modules/modules_export.h"
#include "third_party/blink/renderer/platform/bindings/exception_state.h"
#include "third_party/blink/renderer/platform/bindings/script_wrappable.h"
#include "third_party/blink/renderer/platform/heap/collection_support/heap_hash_set.h"
#include "third_party/blink/renderer/platform/heap/member.h"
#include "third_party/blink/renderer/platform/heap/visitor.h"
#include "third_party/blink/renderer/platform/mojo/heap_mojo_associated_remote.h"
#include "third_party/blink/renderer/platform/wtf/text/wtf_string.h"

namespace blink {

class MLTensorDescriptor;
class MLContext;

class MODULES_EXPORT MLTensor : public ScriptWrappable {
  DEFINE_WRAPPERTYPEINFO();

 public:
  // Instances should only be constructed via `MLContext.createTensor()`.
  // This method is public as required by the `MakeGarbageCollected` helper.
  //
  // `descriptor` describes the tensor data type and shape.
  // `create_tensor_success` contains the resulting handles to the created
  // tensor. which may be used to execute a context operation with respective
  // tensor.
  MLTensor(ExecutionContext* execution_context,
           MLContext* context,
           webnn::OperandDescriptor descriptor,
           webnn::MLTensorUsage usage,
           webnn::mojom::blink::CreateTensorSuccessPtr create_tensor_success,
           base::PassKey<MLContext> pass_key);
  MLTensor(const MLTensor&) = delete;
  MLTensor& operator=(const MLTensor&) = delete;

  ~MLTensor() override;

  void Trace(Visitor* visitor) const override;

  // ml_tensor.idl
  V8MLOperandDataType dataType() const;
  Vector<uint32_t> shape() const;
  bool exportableToGPU() const;
  bool readable() const;
  bool writable() const;
  bool constant() const;

  void destroy();

  // Convenience methods for accessing native types, which avoid a copy
  // compared to using the corresponding methods which return blink types.
  const webnn::OperandDescriptor& Descriptor() const;
  webnn::OperandDataType DataType() const;
  const std::vector<uint32_t>& Shape() const;
  const webnn::MLTensorUsage& Usage() const;

  uint64_t PackedByteLength() const;

  const blink::WebNNTensorToken& handle() const { return webnn_handle_; }

  const MLContext* context() const { return ml_context_.Get(); }

  bool IsValid() const { return remote_tensor_.is_bound(); }

  // Read data from the MLTensor. The resolver should be resolved with a copy of
  // the tensor data. Otherwise, the resolver should be rejected accordingly.
  ScriptPromise<DOMArrayBuffer> ReadTensorImpl(webnn::ScopedTrace scoped_trace,
                                               ScriptState* script_state,
                                               ExceptionState& exception_state);

  ScriptPromise<IDLUndefined> ReadTensorImpl(webnn::ScopedTrace scoped_trace,
                                             ScriptState* script_state,
                                             AllowSharedBufferSource* dst_data,
                                             ExceptionState& exception_state);

  // Write data to the MLTensor. If write was successful, the data will be
  // stored in the MLTensor.
  void WriteTensorImpl(base::span<const uint8_t> src_data,
                       ExceptionState& exception_state);

 private:
  // The callback of reading from `WebNNTensor` by calling hardware accelerated
  // OS machine learning APIs.
  void OnDidReadTensor(webnn::ScopedTrace scoped_trace,
                       ScriptPromiseResolver<DOMArrayBuffer>* resolver,
                       base::ElapsedTimer read_tensor_timer,
                       webnn::mojom::blink::ReadTensorResultPtr result);
  void OnDidReadTensorByob(webnn::ScopedTrace scoped_trace,
                           ScriptPromiseResolver<IDLUndefined>* resolver,
                           AllowSharedBufferSource* dst_data,
                           base::ElapsedTimer read_tensor_timer,
                           webnn::mojom::blink::ReadTensorResultPtr result);

  void OnConnectionError();

  Member<MLContext> ml_context_;

  // Represents a valid MLTensorDescriptor.
  const webnn::OperandDescriptor descriptor_;

  // Represents usage flags for the MLTensor.
  const webnn::MLTensorUsage usage_;

  // Identifies this `WebNNTensor` mojo instance in the service process.
  const blink::WebNNTensorToken webnn_handle_;

  // The `WebNNTensor` is a tensor that can be used by the hardware
  // accelerated OS machine learning API.
  HeapMojoAssociatedRemote<webnn::mojom::blink::WebNNTensor> remote_tensor_;

  // Keep a set of unresolved `ScriptPromiseResolver`s which will be
  // rejected when the Mojo pipe is unexpectedly disconnected.
  HeapHashSet<Member<ScriptPromiseResolver<DOMArrayBuffer>>> pending_resolvers_;
  HeapHashSet<Member<ScriptPromiseResolver<IDLUndefined>>>
      pending_byob_resolvers_;
};

}  // namespace blink

#endif  // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_WEBNN_ML_TENSOR_H_