File: validate_derivatives.cpp

package info (click to toggle)
spirv-tools 2025.5-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 28,588 kB
  • sloc: cpp: 470,407; javascript: 5,893; python: 3,326; ansic: 488; sh: 450; ruby: 88; makefile: 18; lisp: 9
file content (117 lines) | stat: -rw-r--r-- 4,486 bytes parent folder | download | duplicates (13)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Copyright (c) 2017 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Validates correctness of derivative SPIR-V instructions.

#include <string>

#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"

namespace spvtools {
namespace val {

// Validates correctness of derivative instructions.
spv_result_t DerivativesPass(ValidationState_t& _, const Instruction* inst) {
  const spv::Op opcode = inst->opcode();
  const uint32_t result_type = inst->type_id();

  switch (opcode) {
    case spv::Op::OpDPdx:
    case spv::Op::OpDPdy:
    case spv::Op::OpFwidth:
    case spv::Op::OpDPdxFine:
    case spv::Op::OpDPdyFine:
    case spv::Op::OpFwidthFine:
    case spv::Op::OpDPdxCoarse:
    case spv::Op::OpDPdyCoarse:
    case spv::Op::OpFwidthCoarse: {
      if (!_.IsFloatScalarOrVectorType(result_type)) {
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
               << "Expected Result Type to be float scalar or vector type: "
               << spvOpcodeString(opcode);
      }
      if (!_.ContainsSizedIntOrFloatType(result_type, spv::Op::OpTypeFloat,
                                         32)) {
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
               << "Result type component width must be 32 bits";
      }

      const uint32_t p_type = _.GetOperandTypeId(inst, 2);
      if (p_type != result_type) {
        return _.diag(SPV_ERROR_INVALID_DATA, inst)
               << "Expected P type and Result Type to be the same: "
               << spvOpcodeString(opcode);
      }
      _.function(inst->function()->id())
          ->RegisterExecutionModelLimitation([opcode](spv::ExecutionModel model,
                                                      std::string* message) {
            if (model != spv::ExecutionModel::Fragment &&
                model != spv::ExecutionModel::GLCompute &&
                model != spv::ExecutionModel::MeshEXT &&
                model != spv::ExecutionModel::TaskEXT) {
              if (message) {
                *message =
                    std::string(
                        "Derivative instructions require Fragment, GLCompute, "
                        "MeshEXT or TaskEXT execution model: ") +
                    spvOpcodeString(opcode);
              }
              return false;
            }
            return true;
          });
      _.function(inst->function()->id())
          ->RegisterLimitation([opcode](const ValidationState_t& state,
                                        const Function* entry_point,
                                        std::string* message) {
            const auto* models = state.GetExecutionModels(entry_point->id());
            const auto* modes = state.GetExecutionModes(entry_point->id());
            if (models &&
                (models->find(spv::ExecutionModel::GLCompute) !=
                     models->end() ||
                 models->find(spv::ExecutionModel::MeshEXT) != models->end() ||
                 models->find(spv::ExecutionModel::TaskEXT) != models->end()) &&
                (!modes ||
                 (modes->find(spv::ExecutionMode::DerivativeGroupLinearKHR) ==
                      modes->end() &&
                  modes->find(spv::ExecutionMode::DerivativeGroupQuadsKHR) ==
                      modes->end()))) {
              if (message) {
                *message =
                    std::string(
                        "Derivative instructions require "
                        "DerivativeGroupQuadsKHR "
                        "or DerivativeGroupLinearKHR execution mode for "
                        "GLCompute, MeshEXT or TaskEXT execution model: ") +
                    spvOpcodeString(opcode);
              }
              return false;
            }
            return true;
          });
      break;
    }

    default:
      break;
  }

  return SPV_SUCCESS;
}

}  // namespace val
}  // namespace spvtools