Files
proxy/lean4/Sunbeam/Model/Ensemble.lean

43 lines
1.5 KiB
Lean4
Raw Permalink Normal View History

import Sunbeam.Model.Basic
import Sunbeam.Model.MLP
import Sunbeam.Model.DecisionTree
namespace Sunbeam
/-- Ensemble: tree decides first; MLP handles only Defer cases. -/
def ensemblePredict {inputDim hiddenDim : Nat}
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
(threshold : Float) (input : FloatVec inputDim) : Decision :=
match treePredictAux input tree with
| Decision.block => Decision.block
| Decision.allow => Decision.allow
| Decision.defer =>
let score := mlpForward mlpWeights input
if score > threshold then Decision.block else Decision.allow
/-- If the tree says Block, the ensemble says Block. -/
theorem tree_block_implies_ensemble_block {inputDim hiddenDim : Nat}
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
(threshold : Float) (input : FloatVec inputDim)
(h : treePredictAux input tree = Decision.block) :
ensemblePredict tree mlpWeights threshold input = Decision.block := by
unfold ensemblePredict
rw [h]
/-- Ensemble output is always Block or Allow (never Defer). -/
theorem ensemble_output_valid {inputDim hiddenDim : Nat}
(tree : TreeNode) (mlpWeights : MLPWeights inputDim hiddenDim)
(threshold : Float) (input : FloatVec inputDim) :
ensemblePredict tree mlpWeights threshold input = Decision.block
ensemblePredict tree mlpWeights threshold input = Decision.allow := by
unfold ensemblePredict
split
· left; rfl
· right; rfl
· dsimp only []
split
· left; rfl
· right; rfl
end Sunbeam