1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
|
from typing import List, Optional, Tuple
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
AHMetadata,
Choice,
)
class LearnedHeuristic:
"""
LearnedHeuristic is a base class for all learned heuristics.
"""
def __init__(self) -> None:
pass
def check_precondition(
self,
metadata: AHMetadata,
context: AHContext,
) -> bool:
return True
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
return None
def get_confidence_threshold(self) -> float:
return 1.0
def get_name(self) -> str:
return ""
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
return None
class LearnedHeuristicRegression(LearnedHeuristic):
def __init__(self) -> None:
super().__init__()
def get_feedback(self, context: AHContext, choice: Choice) -> float:
return 1.0
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
choice2feedback = {}
for choice in choices:
predicted_feedback = self.get_feedback(context, choice)
choice2feedback[choice] = predicted_feedback
sorted_choices_feedback = sorted(choice2feedback.items(), key=lambda t: t[1])
highest_feedback = sorted_choices_feedback[-1][1]
second_highest_feedback = sorted_choices_feedback[-2][1]
if highest_feedback / second_highest_feedback > self.get_confidence_threshold():
return sorted_choices_feedback[-1][0]
# We are not sure which choice is the best one
return None
class LearnedHeuristicDecision(LearnedHeuristic):
def __init__(self) -> None:
super().__init__()
def get_choice(self, idx: int) -> Optional[str]:
return None
def get_decision(
self, context: AHContext, choices: List[Choice]
) -> Optional[Choice]:
best_choices = self.get_best_choices(context)
if not best_choices:
return None
(best_choice_proba, best_choice_idx) = best_choices[0]
if best_choice_proba <= self.get_confidence_threshold():
return None
return self.get_choice(best_choice_idx)
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
feedback_idx_list = self.get_best_choices(context)
if feedback_idx_list is None:
return None
choices = [
self.get_choice(feedback_idx[1]) for feedback_idx in feedback_idx_list
]
choices = [choice for choice in choices if choice is not None]
return choices
def get_best_choices(self, context: AHContext) -> Optional[List[Tuple[float, int]]]:
return []
|