1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
|
# frozen_string_literal: true
module GraphQL
module StaticValidation
module DirectivesAreInValidLocations
include GraphQL::Language
def on_directive(node, parent)
validate_location(node, parent, context.schema_directives)
super
end
private
LOCATION_MESSAGE_NAMES = {
GraphQL::Schema::Directive::QUERY => "queries",
GraphQL::Schema::Directive::MUTATION => "mutations",
GraphQL::Schema::Directive::SUBSCRIPTION => "subscriptions",
GraphQL::Schema::Directive::FIELD => "fields",
GraphQL::Schema::Directive::FRAGMENT_DEFINITION => "fragment definitions",
GraphQL::Schema::Directive::FRAGMENT_SPREAD => "fragment spreads",
GraphQL::Schema::Directive::INLINE_FRAGMENT => "inline fragments",
}
SIMPLE_LOCATIONS = {
Nodes::Field => GraphQL::Schema::Directive::FIELD,
Nodes::InlineFragment => GraphQL::Schema::Directive::INLINE_FRAGMENT,
Nodes::FragmentSpread => GraphQL::Schema::Directive::FRAGMENT_SPREAD,
Nodes::FragmentDefinition => GraphQL::Schema::Directive::FRAGMENT_DEFINITION,
}
SIMPLE_LOCATION_NODES = SIMPLE_LOCATIONS.keys
def validate_location(ast_directive, ast_parent, directives)
directive_defn = directives[ast_directive.name]
case ast_parent
when Nodes::OperationDefinition
required_location = GraphQL::Schema::Directive.const_get(ast_parent.operation_type.upcase)
assert_includes_location(directive_defn, ast_directive, required_location)
when *SIMPLE_LOCATION_NODES
required_location = SIMPLE_LOCATIONS[ast_parent.class]
assert_includes_location(directive_defn, ast_directive, required_location)
else
add_error(GraphQL::StaticValidation::DirectivesAreInValidLocationsError.new(
"Directives can't be applied to #{ast_parent.class.name}s",
nodes: ast_directive,
target: ast_parent.class.name
))
end
end
def assert_includes_location(directive_defn, directive_ast, required_location)
if !directive_defn.locations.include?(required_location)
location_name = LOCATION_MESSAGE_NAMES[required_location]
allowed_location_names = directive_defn.locations.map { |loc| LOCATION_MESSAGE_NAMES[loc] }
add_error(GraphQL::StaticValidation::DirectivesAreInValidLocationsError.new(
"'@#{directive_defn.graphql_name}' can't be applied to #{location_name} (allowed: #{allowed_location_names.join(", ")})",
nodes: directive_ast,
target: location_name,
name: directive_defn.graphql_name
))
end
end
end
end
end
|