-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathMyRJObject.h
More file actions
84 lines (70 loc) · 2.07 KB
/
MyRJObject.h
File metadata and controls
84 lines (70 loc) · 2.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#ifndef _MyRJObject_
#define _MyRJObject_
#include <RJObject.h>
#include <iostream>
template<class Distribution>
class MyRJObject : public RJObject<Distribution> {
public:
MyRJObject(int num_dimensions, int max_num_components, bool fixed, const Distribution& dist);
double perturb();
void print(std::ostream& out) const;
};
template<class Distribution>
MyRJObject<Distribution>::MyRJObject(int num_dimensions, int max_num_components, bool fixed, const Distribution& dist)
: RJObject<Distribution>(num_dimensions, max_num_components, fixed, dist){
}
template<class Distribution>
double MyRJObject<Distribution>::perturb()
{
if(this->max_num_components == 0)
return 0.;
this->added.resize(0);
this->removed.resize(0);
double logH = 0.;
int which = (this->fixed)?(1 + DNest3::randInt(2)):(DNest3::randInt(3)); //leaving out merges/splits for now
if(which == 0)
{
logH -= this->get_dist().log_pn(this->num_components);
// Do some birth or death
logH += this->perturb_num_components(
pow(10., - 6.*DNest3::randomU()));
logH += this->get_dist().log_pn(this->num_components);
}
else if(which == 1)
{
// Change the hyperparameters
logH -= this->get_dist().log_pn(this->num_components);
if(DNest3::randomU() <= 0.97)
{
logH += this->dist.perturb1(this->components, this->u_components);
}
else
{
this->removed = this->components;
logH += this->dist.perturb2(this->components, this->u_components);
this->added = this->components;
}
logH += this->get_dist().log_pn(this->num_components);
}
else if(which == 2)
{
logH += this->perturb_components(pow(10., - 6.*DNest3::randomU()));
}
return logH;
}
template<class Distribution>
void MyRJObject<Distribution>::print(std::ostream& out) const
{
this->dist.print(out); out<<' ';
out<<this->num_components<<' ';
// Write out components
for(int j=0; j<this->num_dimensions; j++)
{
for(int i=0; i<this->num_components; i++)
out<<this->components[i][j]<<' ';
// Pad with zeros (turned-off components)
for(int i=this->num_components; i<this->max_num_components; i++)
out<<0.<<' ';
}
}
#endif