43 lines
1.5 KiB
Lean4
43 lines
1.5 KiB
Lean4
|
|
import Sunbeam.Model.Basic
|
|||
|
|
import Sunbeam.Model.MLP
|
|||
|
|
import Sunbeam.Model.DecisionTree
|
|||
|
|
|
|||
|
|
namespace Sunbeam
|
|||
|
|
|
|||
|
|
/-- Ensemble: tree decides first; MLP handles only Defer cases. -/
|
|||
|
|
def ensemblePredict {inputDim hiddenDim : Nat}
|
|||
|
|
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
|
|||
|
|
(threshold : Float) (input : FloatVec inputDim) : Decision :=
|
|||
|
|
match treePredictAux input tree with
|
|||
|
|
| Decision.block => Decision.block
|
|||
|
|
| Decision.allow => Decision.allow
|
|||
|
|
| Decision.defer =>
|
|||
|
|
let score := mlpForward mlpWeights input
|
|||
|
|
if score > threshold then Decision.block else Decision.allow
|
|||
|
|
|
|||
|
|
/-- If the tree says Block, the ensemble says Block. -/
|
|||
|
|
theorem tree_block_implies_ensemble_block {inputDim hiddenDim : Nat}
|
|||
|
|
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
|
|||
|
|
(threshold : Float) (input : FloatVec inputDim)
|
|||
|
|
(h : treePredictAux input tree = Decision.block) :
|
|||
|
|
ensemblePredict tree mlpWeights threshold input = Decision.block := by
|
|||
|
|
unfold ensemblePredict
|
|||
|
|
rw [h]
|
|||
|
|
|
|||
|
|
/-- Ensemble output is always Block or Allow (never Defer). -/
|
|||
|
|
theorem ensemble_output_valid {inputDim hiddenDim : Nat}
|
|||
|
|
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
|
|||
|
|
(threshold : Float) (input : FloatVec inputDim) :
|
|||
|
|
ensemblePredict tree mlpWeights threshold input = Decision.block ∨
|
|||
|
|
ensemblePredict tree mlpWeights threshold input = Decision.allow := by
|
|||
|
|
unfold ensemblePredict
|
|||
|
|
split
|
|||
|
|
· left; rfl
|
|||
|
|
· right; rfl
|
|||
|
|
· dsimp only []
|
|||
|
|
split
|
|||
|
|
· left; rfl
|
|||
|
|
· right; rfl
|
|||
|
|
|
|||
|
|
end Sunbeam
|