File: DeserializeOps.cpp

package info (click to toggle)
swiftlang 6.0.3-2
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 2,519,992 kB
  • sloc: cpp: 9,107,863; ansic: 2,040,022; asm: 1,135,751; python: 296,500; objc: 82,456; f90: 60,502; lisp: 34,951; pascal: 19,946; sh: 18,133; perl: 7,482; ml: 4,937; javascript: 4,117; makefile: 3,840; awk: 3,535; xml: 914; fortran: 619; cs: 573; ruby: 573
file content (564 lines) | stat: -rw-r--r-- 19,729 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Deserializer methods for SPIR-V binary instructions.
//
//===----------------------------------------------------------------------===//

#include "Deserializer.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"
#include <optional>

using namespace mlir;

#define DEBUG_TYPE "spirv-deserialization"

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//

/// Extracts the opcode from the given first word of a SPIR-V instruction.
static inline spirv::Opcode extractOpcode(uint32_t word) {
  return static_cast<spirv::Opcode>(word & 0xffff);
}

//===----------------------------------------------------------------------===//
// Instruction
//===----------------------------------------------------------------------===//

Value spirv::Deserializer::getValue(uint32_t id) {
  if (auto constInfo = getConstant(id)) {
    // Materialize a `spirv.Constant` op at every use site.
    return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                               constInfo->first);
  }
  if (auto varOp = getGlobalVariable(id)) {
    auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
        unknownLoc, varOp.getType(), SymbolRefAttr::get(varOp.getOperation()));
    return addressOfOp.getPointer();
  }
  if (auto constOp = getSpecConstant(id)) {
    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
        unknownLoc, constOp.getDefaultValue().getType(),
        SymbolRefAttr::get(constOp.getOperation()));
    return referenceOfOp.getReference();
  }
  if (auto constCompositeOp = getSpecConstantComposite(id)) {
    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
        unknownLoc, constCompositeOp.getType(),
        SymbolRefAttr::get(constCompositeOp.getOperation()));
    return referenceOfOp.getReference();
  }
  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
    return materializeSpecConstantOperation(
        id, specConstOperationInfo->enclodesOpcode,
        specConstOperationInfo->resultTypeID,
        specConstOperationInfo->enclosedOpOperands);
  }
  if (auto undef = getUndefType(id)) {
    return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
  }
  return valueMap.lookup(id);
}

LogicalResult spirv::Deserializer::sliceInstruction(
    spirv::Opcode &opcode, ArrayRef<uint32_t> &operands,
    std::optional<spirv::Opcode> expectedOpcode) {
  auto binarySize = binary.size();
  if (curOffset >= binarySize) {
    return emitError(unknownLoc, "expected ")
           << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
                              : "more")
           << " instruction";
  }

  // For each instruction, get its word count from the first word to slice it
  // from the stream properly, and then dispatch to the instruction handler.

  uint32_t wordCount = binary[curOffset] >> 16;

  if (wordCount == 0)
    return emitError(unknownLoc, "word count cannot be zero");

  uint32_t nextOffset = curOffset + wordCount;
  if (nextOffset > binarySize)
    return emitError(unknownLoc, "insufficient words for the last instruction");

  opcode = extractOpcode(binary[curOffset]);
  operands = binary.slice(curOffset + 1, wordCount - 1);
  curOffset = nextOffset;
  return success();
}

LogicalResult spirv::Deserializer::processInstruction(
    spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
                                << spirv::stringifyOpcode(opcode) << "\n");

  // First dispatch all the instructions whose opcode does not correspond to
  // those that have a direct mirror in the SPIR-V dialect
  switch (opcode) {
  case spirv::Opcode::OpCapability:
    return processCapability(operands);
  case spirv::Opcode::OpExtension:
    return processExtension(operands);
  case spirv::Opcode::OpExtInst:
    return processExtInst(operands);
  case spirv::Opcode::OpExtInstImport:
    return processExtInstImport(operands);
  case spirv::Opcode::OpMemberName:
    return processMemberName(operands);
  case spirv::Opcode::OpMemoryModel:
    return processMemoryModel(operands);
  case spirv::Opcode::OpEntryPoint:
  case spirv::Opcode::OpExecutionMode:
    if (deferInstructions) {
      deferredInstructions.emplace_back(opcode, operands);
      return success();
    }
    break;
  case spirv::Opcode::OpVariable:
    if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
      return processGlobalVariable(operands);
    }
    break;
  case spirv::Opcode::OpLine:
    return processDebugLine(operands);
  case spirv::Opcode::OpNoLine:
    clearDebugLine();
    return success();
  case spirv::Opcode::OpName:
    return processName(operands);
  case spirv::Opcode::OpString:
    return processDebugString(operands);
  case spirv::Opcode::OpModuleProcessed:
  case spirv::Opcode::OpSource:
  case spirv::Opcode::OpSourceContinued:
  case spirv::Opcode::OpSourceExtension:
    // TODO: This is debug information embedded in the binary which should be
    // translated into the spirv.module.
    return success();
  case spirv::Opcode::OpTypeVoid:
  case spirv::Opcode::OpTypeBool:
  case spirv::Opcode::OpTypeInt:
  case spirv::Opcode::OpTypeFloat:
  case spirv::Opcode::OpTypeVector:
  case spirv::Opcode::OpTypeMatrix:
  case spirv::Opcode::OpTypeArray:
  case spirv::Opcode::OpTypeFunction:
  case spirv::Opcode::OpTypeImage:
  case spirv::Opcode::OpTypeSampledImage:
  case spirv::Opcode::OpTypeRuntimeArray:
  case spirv::Opcode::OpTypeStruct:
  case spirv::Opcode::OpTypePointer:
  case spirv::Opcode::OpTypeCooperativeMatrixNV:
    return processType(opcode, operands);
  case spirv::Opcode::OpTypeForwardPointer:
    return processTypeForwardPointer(operands);
  case spirv::Opcode::OpTypeJointMatrixINTEL:
    return processType(opcode, operands);
  case spirv::Opcode::OpConstant:
    return processConstant(operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstant:
    return processConstant(operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantComposite:
    return processConstantComposite(operands);
  case spirv::Opcode::OpSpecConstantComposite:
    return processSpecConstantComposite(operands);
  case spirv::Opcode::OpSpecConstantOp:
    return processSpecConstantOperation(operands);
  case spirv::Opcode::OpConstantTrue:
    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstantTrue:
    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantFalse:
    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstantFalse:
    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantNull:
    return processConstantNull(operands);
  case spirv::Opcode::OpDecorate:
    return processDecoration(operands);
  case spirv::Opcode::OpMemberDecorate:
    return processMemberDecoration(operands);
  case spirv::Opcode::OpFunction:
    return processFunction(operands);
  case spirv::Opcode::OpLabel:
    return processLabel(operands);
  case spirv::Opcode::OpBranch:
    return processBranch(operands);
  case spirv::Opcode::OpBranchConditional:
    return processBranchConditional(operands);
  case spirv::Opcode::OpSelectionMerge:
    return processSelectionMerge(operands);
  case spirv::Opcode::OpLoopMerge:
    return processLoopMerge(operands);
  case spirv::Opcode::OpPhi:
    return processPhi(operands);
  case spirv::Opcode::OpUndef:
    return processUndef(operands);
  default:
    break;
  }
  return dispatchToAutogenDeserialization(opcode, operands);
}

LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
    ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
    unsigned numOperands) {
  SmallVector<Type, 1> resultTypes;
  uint32_t valueID = 0;

  size_t wordIndex = 0;
  if (hasResult) {
    if (wordIndex >= words.size())
      return emitError(unknownLoc,
                       "expected result type <id> while deserializing for ")
             << opName;

    // Decode the type <id>
    auto type = getType(words[wordIndex]);
    if (!type)
      return emitError(unknownLoc, "unknown type result <id>: ")
             << words[wordIndex];
    resultTypes.push_back(type);
    ++wordIndex;

    // Decode the result <id>
    if (wordIndex >= words.size())
      return emitError(unknownLoc,
                       "expected result <id> while deserializing for ")
             << opName;
    valueID = words[wordIndex];
    ++wordIndex;
  }

  SmallVector<Value, 4> operands;
  SmallVector<NamedAttribute, 4> attributes;

  // Decode operands
  size_t operandIndex = 0;
  for (; operandIndex < numOperands && wordIndex < words.size();
       ++operandIndex, ++wordIndex) {
    auto arg = getValue(words[wordIndex]);
    if (!arg)
      return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
    operands.push_back(arg);
  }
  if (operandIndex != numOperands) {
    return emitError(
               unknownLoc,
               "found less operands than expected when deserializing for ")
           << opName << "; only " << operandIndex << " of " << numOperands
           << " processed";
  }
  if (wordIndex != words.size()) {
    return emitError(
               unknownLoc,
               "found more operands than expected when deserializing for ")
           << opName << "; only " << wordIndex << " of " << words.size()
           << " processed";
  }

  // Attach attributes from decorations
  if (decorations.count(valueID)) {
    auto attrs = decorations[valueID].getAttrs();
    attributes.append(attrs.begin(), attrs.end());
  }

  // Create the op and update bookkeeping maps
  Location loc = createFileLineColLoc(opBuilder);
  OperationState opState(loc, opName);
  opState.addOperands(operands);
  if (hasResult)
    opState.addTypes(resultTypes);
  opState.addAttributes(attributes);
  Operation *op = opBuilder.create(opState);
  if (hasResult)
    valueMap[valueID] = op->getResult(0);

  if (op->hasTrait<OpTrait::IsTerminator>())
    clearDebugLine();

  return success();
}

LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
  if (operands.size() != 2) {
    return emitError(unknownLoc, "OpUndef instruction must have two operands");
  }
  auto type = getType(operands[0]);
  if (!type) {
    return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
  }
  undefMap[operands[1]] = type;
  return success();
}

LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
  if (operands.size() < 4) {
    return emitError(unknownLoc,
                     "OpExtInst must have at least 4 operands, result type "
                     "<id>, result <id>, set <id> and instruction opcode");
  }
  if (!extendedInstSets.count(operands[2])) {
    return emitError(unknownLoc, "undefined set <id> in OpExtInst");
  }
  SmallVector<uint32_t, 4> slicedOperands;
  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
  return dispatchToExtensionSetAutogenDeserialization(
      extendedInstSets[operands[2]], operands[3], slicedOperands);
}

namespace mlir {
namespace spirv {

template <>
LogicalResult
Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
  unsigned wordIndex = 0;
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc,
                     "missing Execution Model specification in OpEntryPoint");
  }
  auto execModel = spirv::ExecutionModelAttr::get(
      context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc, "missing <id> in OpEntryPoint");
  }
  // Get the function <id>
  auto fnID = words[wordIndex++];
  // Get the function name
  auto fnName = decodeStringLiteral(words, wordIndex);
  // Verify that the function <id> matches the fnName
  auto parsedFunc = getFunction(fnID);
  if (!parsedFunc) {
    return emitError(unknownLoc, "no function matching <id> ") << fnID;
  }
  if (parsedFunc.getName() != fnName) {
    // The deserializer uses "spirv_fn_<id>" as the function name if the input
    // SPIR-V blob does not contain a name for it. We should use a more clear
    // indication for such case rather than relying on naming details.
    if (!parsedFunc.getName().startswith("spirv_fn_"))
      return emitError(unknownLoc,
                       "function name mismatch between OpEntryPoint "
                       "and OpFunction with <id> ")
             << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
    parsedFunc.setName(fnName);
  }
  SmallVector<Attribute, 4> interface;
  while (wordIndex < words.size()) {
    auto arg = getGlobalVariable(words[wordIndex]);
    if (!arg) {
      return emitError(unknownLoc, "undefined result <id> ")
             << words[wordIndex] << " while decoding OpEntryPoint";
    }
    interface.push_back(SymbolRefAttr::get(arg.getOperation()));
    wordIndex++;
  }
  opBuilder.create<spirv::EntryPointOp>(
      unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
      opBuilder.getArrayAttr(interface));
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
  unsigned wordIndex = 0;
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc,
                     "missing function result <id> in OpExecutionMode");
  }
  // Get the function <id> to get the name of the function
  auto fnID = words[wordIndex++];
  auto fn = getFunction(fnID);
  if (!fn) {
    return emitError(unknownLoc, "no function matching <id> ") << fnID;
  }
  // Get the Execution mode
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
  }
  auto execMode = spirv::ExecutionModeAttr::get(
      context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));

  // Get the values
  SmallVector<Attribute, 4> attrListElems;
  while (wordIndex < words.size()) {
    attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
  }
  auto values = opBuilder.getArrayAttr(attrListElems);
  opBuilder.create<spirv::ExecutionModeOp>(
      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
      execMode, values);
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
  if (operands.size() < 3) {
    return emitError(unknownLoc,
                     "OpFunctionCall must have at least 3 operands");
  }

  Type resultType = getType(operands[0]);
  if (!resultType) {
    return emitError(unknownLoc, "undefined result type from <id> ")
           << operands[0];
  }

  // Use null type to mean no result type.
  if (isVoidType(resultType))
    resultType = nullptr;

  auto resultID = operands[1];
  auto functionID = operands[2];

  auto functionName = getFunctionSymbol(functionID);

  SmallVector<Value, 4> arguments;
  for (auto operand : llvm::drop_begin(operands, 3)) {
    auto value = getValue(operand);
    if (!value) {
      return emitError(unknownLoc, "unknown <id> ")
             << operand << " used by OpFunctionCall";
    }
    arguments.push_back(value);
  }

  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
      unknownLoc, resultType,
      SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);

  if (resultType)
    valueMap[resultID] = opFunctionCall.getResult(0);
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
  SmallVector<Type, 1> resultTypes;
  size_t wordIndex = 0;
  SmallVector<Value, 4> operands;
  SmallVector<NamedAttribute, 4> attributes;

  if (wordIndex < words.size()) {
    auto arg = getValue(words[wordIndex]);

    if (!arg) {
      return emitError(unknownLoc, "unknown result <id> : ")
             << words[wordIndex];
    }

    operands.push_back(arg);
    wordIndex++;
  }

  if (wordIndex < words.size()) {
    auto arg = getValue(words[wordIndex]);

    if (!arg) {
      return emitError(unknownLoc, "unknown result <id> : ")
             << words[wordIndex];
    }

    operands.push_back(arg);
    wordIndex++;
  }

  bool isAlignedAttr = false;

  if (wordIndex < words.size()) {
    auto attrValue = words[wordIndex++];
    auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
        static_cast<spirv::MemoryAccess>(attrValue));
    attributes.push_back(opBuilder.getNamedAttr("memory_access", attr));
    isAlignedAttr = (attrValue == 2);
  }

  if (isAlignedAttr && wordIndex < words.size()) {
    attributes.push_back(opBuilder.getNamedAttr(
        "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
  }

  if (wordIndex < words.size()) {
    auto attrValue = words[wordIndex++];
    auto attr = opBuilder.getAttr<spirv::MemoryAccessAttr>(
        static_cast<spirv::MemoryAccess>(attrValue));
    attributes.push_back(opBuilder.getNamedAttr("source_memory_access", attr));
  }

  if (wordIndex < words.size()) {
    attributes.push_back(opBuilder.getNamedAttr(
        "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
  }

  if (wordIndex != words.size()) {
    return emitError(unknownLoc,
                     "found more operands than expected when deserializing "
                     "spirv::CopyMemoryOp, only ")
           << wordIndex << " of " << words.size() << " processed";
  }

  Location loc = createFileLineColLoc(opBuilder);
  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);

  return success();
}

template <>
LogicalResult Deserializer::processOp<spirv::GenericCastToPtrExplicitOp>(
    ArrayRef<uint32_t> words) {
  if (words.size() != 4) {
    return emitError(unknownLoc,
                     "expected 4 words in GenericCastToPtrExplicitOp"
                     " but got : ")
           << words.size();
  }
  SmallVector<Type, 1> resultTypes;
  SmallVector<Value, 4> operands;
  uint32_t valueID = 0;
  auto type = getType(words[0]);

  if (!type)
    return emitError(unknownLoc, "unknown type result <id> : ") << words[0];
  resultTypes.push_back(type);

  valueID = words[1];

  auto arg = getValue(words[2]);
  if (!arg)
    return emitError(unknownLoc, "unknown result <id> : ") << words[2];
  operands.push_back(arg);

  Location loc = createFileLineColLoc(opBuilder);
  Operation *op = opBuilder.create<spirv::GenericCastToPtrExplicitOp>(
      loc, resultTypes, operands);
  valueMap[valueID] = op->getResult(0);
  return success();
}

// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"

} // namespace spirv
} // namespace mlir