1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
|
// Copyright (c) 2017, Apple Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-3-clause license that can be
// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause
/*
* Each tree is a collection of nodes,
* each of which is identified by a unique identifier.
*
* Each node is either a branch or a leaf node.
* A branch node evaluates a value according to a behavior;
* if true, the node identified by ``true_child_node_id`` is evaluated next,
* if false, the node identified by ``false_child_node_id`` is evaluated next.
* A leaf node adds the evaluation value to the base prediction value
* to get the final prediction.
*
* A tree must have exactly one root node,
* which has no parent node.
* A tree must not terminate on a branch node.
* All leaf nodes must be accessible
* by evaluating one or more branch nodes in sequence,
* starting from the root node.
*/
syntax = "proto3";
option optimize_for = LITE_RUNTIME;
import public "DataStructures.proto";
package CoreML.Specification;
/*
* A tree ensemble post-evaluation transform.
*/
enum TreeEnsemblePostEvaluationTransform {
NoTransform = 0;
Classification_SoftMax = 1;
Regression_Logistic = 2;
Classification_SoftMaxWithZeroClassReference = 3;
}
/*
* Tree ensemble parameters.
*/
message TreeEnsembleParameters {
message TreeNode {
uint64 treeId = 1;
uint64 nodeId = 2;
enum TreeNodeBehavior {
BranchOnValueLessThanEqual = 0;
BranchOnValueLessThan = 1;
BranchOnValueGreaterThanEqual = 2;
BranchOnValueGreaterThan = 3;
BranchOnValueEqual = 4;
BranchOnValueNotEqual = 5;
LeafNode = 6;
}
/*
* The branch mode parameters.
*
* If branch is false,
* then the parameters in this section must be filled in
* to determine how the branching functions.
*/
TreeNodeBehavior nodeBehavior = 3;
/*
* If the node behavior mode is a branch mode,
* then these values must be filled in.
*/
uint64 branchFeatureIndex = 10;
double branchFeatureValue = 11;
uint64 trueChildNodeId = 12;
uint64 falseChildNodeId = 13;
bool missingValueTracksTrueChild = 14;
/*
* The leaf mode.
*
* If ``nodeBahavior`` == ``LeafNode``,
* then the evaluationValue is added to the base prediction value
* in order to get the final prediction.
* To support multiclass classification
* as well as regression and binary classification,
* the evaluation value is encoded here as a sparse vector,
* with evaluationIndex being the index of the base vector
* that evaluation value is added to.
* In the single class case,
* it is expected that evaluationIndex is exactly 0.
*/
message EvaluationInfo {
uint64 evaluationIndex = 1;
double evaluationValue = 2;
}
repeated EvaluationInfo evaluationInfo = 20;
/*
* The relative hit rate of a node for optimization purposes.
*
* This value has no effect on the accuracy of the result;
* it allows the tree to optimize for frequent branches.
* The value is relative,
* compared to the hit rates of other branch nodes.
*
* You typically use a proportion of training samples
* that reached this node
* or some similar metric to derive this value.
*/
double relativeHitRate = 30;
}
repeated TreeNode nodes = 1;
/*
* The number of prediction dimensions or classes in the model.
*
* All instances of ``evaluationIndex`` in a leaf node
* must be less than this value,
* and the number of values in the ``basePredictionValue`` field
* must be equal to this value.
*
* For regression,
* this is the dimension of the prediction.
* For classification,
* this is the number of classes.
*/
uint64 numPredictionDimensions = 2;
/*
* The base prediction value.
*
* The number of values in this must match
* the default values of the tree model.
*/
repeated double basePredictionValue = 3;
}
/*
* A tree ensemble classifier.
*/
message TreeEnsembleClassifier {
TreeEnsembleParameters treeEnsemble = 1;
TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;
// Required class label mapping
oneof ClassLabels {
StringVector stringClassLabels = 100;
Int64Vector int64ClassLabels = 101;
}
}
/*
* A tree ensemble regressor.
*/
message TreeEnsembleRegressor {
TreeEnsembleParameters treeEnsemble = 1;
TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2;
}
|