1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
|
"""A diagnostic engine based on SARIF."""
from __future__ import annotations
from typing import List, Optional
from torch.onnx._internal.diagnostics import infra
from torch.onnx._internal.diagnostics.infra import formatter, sarif
from torch.onnx._internal.diagnostics.infra.sarif import version as sarif_version
class DiagnosticEngine:
"""A generic diagnostic engine based on SARIF.
This class is the main interface for diagnostics. It manages the creation of diagnostic contexts.
A DiagnosticContext provides the entry point for recording Diagnostics.
Each DiagnosticContext is powered by a DiagnosticTool, which can be customized with
custom RuleCollection and Diagnostic type.
See infra.DiagnosticContext and infra.DiagnosticTool for more details.
Examples:
Step 1: Create a set of rules.
>>> rules = infra.RuleCollection.from_list(
... "CustomRuleCollection",
... [
... infra.Rule(
... id="r1",
... name="rule-1",
... message_default_template="Mising xxx",
... ),
... ],
... )
Step 2: Create a diagnostic tool.
>>> tool = infra.DiagnosticTool(
... name="tool",
... version="1.0.0",
... rules=rules,
... )
Step 3: Create a diagnostic engine.
>>> engine = DiagnosticEngine()
Step 4: Start a new diagnostic context.
>>> with engine.start_diagnostic_context(tool) as context:
Step 5: Add diagnostics in your code.
... context.diagnose(rules.rule1, infra.Level.ERROR)
Step 6: Afterwards, get the SARIF log.
>>> sarif_log = engine.sarif_log()
"""
_contexts: List[infra.DiagnosticContext]
def __init__(self) -> None:
self._contexts = []
def sarif_log(self) -> sarif.SarifLog:
return sarif.SarifLog(
version=sarif_version.SARIF_VERSION,
schema_uri=sarif_version.SARIF_SCHEMA_LINK,
runs=[context.sarif() for context in self._contexts],
)
def __str__(self) -> str:
# TODO: pretty print.
return self.to_json()
def __repr__(self) -> str:
return self.to_json()
def to_json(self) -> str:
return formatter.sarif_to_json(self.sarif_log())
def clear(self) -> None:
"""Clears all diagnostic contexts."""
self._contexts.clear()
def create_diagnostic_context(
self,
tool: infra.DiagnosticTool,
options: Optional[infra.DiagnosticOptions] = None,
) -> infra.DiagnosticContext:
context = infra.DiagnosticContext(tool, options)
self._contexts.append(context)
return context
|