-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlearners.v
More file actions
335 lines (282 loc) · 12.2 KB
/
learners.v
File metadata and controls
335 lines (282 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
Set Implicit Arguments.
Unset Strict Implicit.
Require Import mathcomp.ssreflect.ssreflect.
From mathcomp Require Import all_ssreflect.
Require Import List. Import ListNotations.
Require Import QArith Reals Rpower Ranalysis Fourier.
Require Import OUVerT.chernoff OUVerT.learning OUVerT.bigops OUVerT.dist OUVerT.dyadic.
Require Import MLCert.monads MLCert.noise MLCert.samplers.
Module Learner.
Record t (X Y Hypers Params : Type) :=
mk { predict : Hypers -> Params -> X -> Y;
update : Hypers -> X*Y -> Params -> Params }.
End Learner.
Section extractible_semantics.
Variable X Y Params Hypers : Type.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable epochs : nat.
Context {training_set} `{Foldable training_set (X*Y)}.
Variable r : Type.
Notation C t := (Cont r t).
Variable noise_model : training_set -> C training_set.
Variable sample : (X*Y -> R) -> C training_set.
Definition noised_sample (d:X*Y->R) : C training_set :=
T <-- sample d;
noise_model T.
Definition learn_func (init:Params) (T:training_set) : Params :=
foldrM (fun epoch p_epoch =>
foldable_foldM (M:=Id) (fun xy p =>
ret (Learner.update learner h xy p))
p_epoch T)
init (enum 'I_epochs).
Definition learn (init:Params) (T:training_set) : C Params :=
fun f => f (learn_func init T).
Definition extractible_main (d:X*Y->R) (init:training_set -> Params)
: C (Params * training_set) :=
T <-- noised_sample d;
p <-- learn (init T) T;
ret (p, T).
End extractible_semantics.
Section rv_impulse_extractible_semantics.
Variable X Y Params Hypers : Type.
Variable x : Type.
Variable x_of_nat : nat -> x.
Variable example_mapM : forall (M:Type->Type), (x -> M x) -> X -> M X.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable epochs : nat.
Context {training_set} `{Foldable training_set (X*Y)}.
Context {sampler_state} `{BasicSamplers sampler_state}.
Variable r : Type.
Notation C t := (Cont r t).
Variable sample : (X*Y -> R) -> C training_set.
Definition rv_impulse_noise_model (p:D) (t:training_set) : C training_set :=
Ts <-- rv_impulse example_mapM x_of_nat p t init_sampler_state;
let: (T,s) := Ts in
ret T.
Definition rv_impulse_extractible_main (d:X*Y->R) (p:D) (init:training_set -> Params)
: C (Params * training_set) :=
extractible_main learner h epochs (rv_impulse_noise_model p) sample d init.
End rv_impulse_extractible_semantics.
Section training_set.
Variable m : nat.
Definition semantic_training_set (X Y:finType) := training_set X Y m.
Variables X Y : finType.
Definition semantic_training_set_foldrM M `(Monad M) (R:Type)
(f:X*Y -> R -> M R) (r:R) (t:semantic_training_set X Y) : M R
:= foldrM (fun i r => f (t i) r) r (enum 'I_m).
Definition semantic_training_set_mapM M `(Monad M)
(f:X*Y -> M (X*Y)%type) (t:semantic_training_set X Y)
: M (semantic_training_set X Y)
:= foldrM
(fun i (r:semantic_training_set X Y) =>
xy' <-- f (t i);
ret (finfun (fun j => if i==j then xy' else r j)))
t (enum 'I_m).
Global Instance semantic_TrainingSet
: Foldable (semantic_training_set X Y) (X*Y) :=
mkFoldable semantic_training_set_foldrM semantic_training_set_mapM.
End training_set.
Section semantics.
Variables X Y Params : finType.
Variable Hypers : Type.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable m_gt0 : (0 < m)%nat.
Variable epochs : nat.
Notation C := (@Cont R).
Definition semantic_sample (d:X*Y -> R) : C (training_set X Y m) :=
fun f => big_sum (enum (training_set X Y m)) (fun T =>
(prodR (fun _:'I_m => d)) T * f T).
Definition observe (t:Type) (p:pred t) : t -> C t :=
fun t f => if p t then f t else 0.
Definition accuracy := accuracy01 (m:=m) (Learner.predict learner h).
Definition post (d:X*Y -> R) (eps:R)
(pT : Params * training_set X Y m) : bool :=
let: (p, T) := pT in
Rlt_le_dec (expVal d m_gt0 accuracy p + eps) (empVal accuracy T p).
Definition semantic_main (d:X*Y -> R) (init:training_set X Y m -> Params) :=
extractible_main learner h epochs (fun T => ret T) semantic_sample d init.
Definition main (d:X*Y -> R) (eps:R) (init:training_set X Y m -> Params)
: C (Params * training_set X Y m) :=
pT <-- semantic_main d init;
let: (p,T) := pT in
observe (post d eps) (p,T).
Variables
(d:X*Y -> R)
(d_dist : big_sum (enum [finType of X*Y]) d = 1)
(d_nonneg : forall x, 0 <= d x)
(not_perfectly_learnable :
forall p : Params, 0 < expVal d m_gt0 accuracy p < 1).
Lemma main_bound (eps:R) (eps_gt0 : 0 < eps) (init:training_set X Y m -> Params) :
main d eps init (fun _ => 1) <=
INR #|Params| * exp (-2%R * eps^2 * mR m).
Proof.
rewrite /main/semantic_main/extractible_main/bind/=/Cbind/Cret.
rewrite /noised_sample/Cbind/=/Cbind/semantic_sample.
rewrite big_sum_pred2; apply: Rle_trans; last first.
{ apply chernoff_bound_accuracy01
with (d:=d) (learn:=fun t => learn_func learner h epochs (init t) t) => //.
move => p; apply: not_perfectly_learnable. }
rewrite /probOfR/=.
apply big_sum_le => c; rewrite /in_mem Rmult_1_r /= => _; apply: Rle_refl.
Qed.
End semantics.
Section holdout_semantics.
Variables X Y Params : finType.
Variable Hypers : Type.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable m_gt0 : (0 < m)%nat.
Variable epochs : nat.
Notation C := (@Cont R).
Notation semantic_sample m := (@semantic_sample X Y m).
Notation accuracy := (@accuracy X Y Params Hypers learner h m).
Notation post := (@post X Y Params Hypers learner h m m_gt0).
Definition eps_hyp_range (d:X*Y -> R) (eps:R) (p:Params) (t:training_set X Y m) : bool :=
Rlt_le_dec eps
(1 - (expVal (A:=X) (B:=Y) d m_gt0 (Hyp:=Params)
(accuracy01 (A:=X) (B:=Y) (Params:=Params) (Learner.predict learner h)) p)).
Definition main_holdout (d:X*Y -> R) (eps:R) (init:training_set X Y m -> Params)
: C (Params * training_set X Y m) :=
T_train <-- semantic_sample m d;
p <-- learn learner h epochs (init T_train) T_train;
_ <-- observe (eps_hyp_range d eps p) T_train;
T_test <-- semantic_sample m d;
observe (post d eps) (p,T_test).
Variables
(d:X*Y -> R)
(d_dist : big_sum (enum [finType of X*Y]) d = 1)
(d_nonneg : forall x, 0 <= d x)
(not_perfectly_learnable :
forall p : Params, 0 < expVal d m_gt0 accuracy p < 1).
Lemma main_holdout_bound (eps:R) (eps_gt0 : 0 < eps) (init:training_set X Y m -> Params) :
main_holdout d eps init (fun _ => 1) <=
exp (-2%R * eps^2 * mR m).
Proof.
rewrite /main_holdout/bind/=/Cbind/Cret.
rewrite /(semantic_sample m)/Cbind/=/Cbind.
have H:
big_sum (enum {ffun 'I_m -> X * Y})
(fun T : {ffun 'I_m -> X * Y} =>
exp (- (2) * (eps * (eps * 1)) * mR m) *
prodR (T:=prod_finType X Y) (fun _ : 'I_m => d) T) =
exp (- (2) * (eps * (eps * 1)) * mR m).
{ by rewrite big_sum_scalar prodR_dist => //; rewrite Rmult_1_r. }
rewrite -H; apply: big_sum_le => T_train Htrain.
rewrite [exp _ * _]Rmult_comm; apply: Rmult_le_compat_l; first by apply: prodR_nonneg.
rewrite /learn/observe; case Heps: (eps_hyp_range _ _); last by apply: Rlt_le; apply: exp_pos.
apply: Rle_trans; last first.
{ apply: (@chernoff_bound_accuracy01_holdout
_ _ _ d_dist d_nonneg _ _ _ _ not_perfectly_learnable
(learn_func learner h epochs (init T_train) T_train) eps) => //.
move: Heps; rewrite /eps_hyp_range; case: (Rlt_le_dec _ _) => // Hlt. }
rewrite /probOfR big_sum_pred2; apply: big_sum_le => T_test Htest.
by rewrite Rmult_1_r; apply: Rle_refl.
Qed.
End holdout_semantics.
Section OracleLearner.
Variables X Y Hypers Params : Type.
Variable oracular_params : Params.
Variable predict : Hypers -> Params -> X -> Y.
Definition OracleLearner : Learner.t X Y Hypers Params :=
Learner.mk predict (fun _ _ _ => oracular_params).
End OracleLearner.
Section oracular_semantics.
Variables X Y Params : finType.
Variable Hypers : Type.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable m_gt0 : (0 < m)%nat.
Notation C := (@Cont R).
Variable oracle : training_set X Y m -> Params.
Notation semantic_sample m := (@semantic_sample X Y m).
Notation accuracy := (@accuracy X Y Params Hypers learner h m).
Notation post := (@post X Y Params Hypers learner h m m_gt0).
Definition oracular_main (d:X*Y -> R) (eps:R)
: C (Params * training_set X Y m) :=
T <-- semantic_sample m d;
p <-- ret (oracle T);
observe (post d eps) (p,T).
Variables
(d:X*Y -> R)
(d_dist : big_sum (enum [finType of X*Y]) d = 1)
(d_nonneg : forall x, 0 <= d x)
(not_perfectly_learnable :
forall p : Params, 0 < expVal d m_gt0 accuracy p < 1).
Lemma oracular_main_bound (eps:R) (eps_gt0 : 0 < eps) :
oracular_main d eps (fun _ => 1) <=
INR #|Params| * exp (-2%R * eps^2 * mR m).
Proof.
rewrite /oracular_main/bind/=/Cbind/Cret.
rewrite /(semantic_sample m)/Cbind/=/Cbind.
rewrite big_sum_pred2; apply: Rle_trans; last first.
{ apply chernoff_bound_accuracy01
with (d:=d) (learn:=fun t => oracle t) => //.
move => p; apply: not_perfectly_learnable. }
rewrite /probOfR/=.
apply big_sum_le => c; rewrite /in_mem Rmult_1_r /= => _; apply: Rle_refl.
Qed.
End oracular_semantics.
Section oracular_holdout_semantics.
Variables X Y Params : finType.
Variable Hypers : Type.
Variable learner : Learner.t X Y Hypers Params.
Variable h : Hypers.
Variable m : nat.
Variable m_gt0 : (0 < m)%nat.
Variable n : nat.
Variable n_gt0 : (0 < n)%nat.
Notation C := (@Cont R).
Variable oracle : forall m:nat, training_set X Y m -> Params.
Notation semantic_sample m := (@semantic_sample X Y m).
Notation accuracy := (@accuracy X Y Params Hypers learner h n).
Notation post := (@post X Y Params Hypers learner h n n_gt0).
Definition eps_hyp_ok m (d:X*Y -> R) (eps:R) (t:training_set X Y m) : bool :=
Rlt_le_dec eps
(1 - (expVal (A:=X) (B:=Y) d m_gt0 (Hyp:=Params)
(accuracy01 (A:=X) (B:=Y) (Params:=Params) (Learner.predict learner h)) (oracle t))).
Definition oracular_main_holdout (d:X*Y -> R) (eps:R)
: C (Params * training_set X Y n) :=
T_train <-- semantic_sample m d;
_ <-- observe (eps_hyp_ok d eps) T_train;
p <-- ret (oracle T_train);
T_test <-- semantic_sample n d;
observe (post d eps) (p,T_test).
Variables
(d:X*Y -> R)
(d_dist : big_sum (enum [finType of X*Y]) d = 1)
(d_nonneg : forall x, 0 <= d x)
(not_perfectly_learnable :
forall p : Params, 0 < expVal d n_gt0 accuracy p < 1).
Lemma oracular_main_holdout_bound (eps:R) (eps_gt0 : 0 < eps) :
oracular_main_holdout d eps (fun _ => 1) <=
exp (-2%R * eps^2 * mR n).
Proof.
rewrite /oracular_main_holdout/bind/=/Cbind/Cret.
rewrite /(semantic_sample m)/(semantic_sample n)/Cbind/=/Cbind.
have H:
big_sum (enum {ffun 'I_m -> X * Y})
(fun T : {ffun 'I_m -> X * Y} =>
exp (- (2) * (eps * (eps * 1)) * mR n) *
prodR (T:=prod_finType X Y) (fun _ : 'I_m => d) T) =
exp (- (2) * (eps * (eps * 1)) * mR n).
{ by rewrite big_sum_scalar prodR_dist => //; rewrite Rmult_1_r. }
rewrite -H; apply: big_sum_le => T_train Htrain.
rewrite [exp _ * _]Rmult_comm; apply: Rmult_le_compat_l; first by apply: prodR_nonneg.
rewrite /observe; case Heps: (eps_hyp_ok _ _); last by apply: Rlt_le; apply: exp_pos.
apply: Rle_trans; last first.
{ apply: (@chernoff_bound_accuracy01_holdout
_ _ _ d_dist d_nonneg n _ _ _ not_perfectly_learnable
(oracle T_train) eps) => //.
move: Heps; rewrite /eps_hyp_ok; case: (Rlt_le_dec _ _) => // Hlt. }
rewrite /probOfR big_sum_pred2. apply: big_sum_le => T_test Htest.
by rewrite Rmult_1_r; apply: Rle_refl.
Qed.
End oracular_holdout_semantics.