-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathvm.h
More file actions
49 lines (41 loc) · 1.53 KB
/
Copy pathvm.h
File metadata and controls
49 lines (41 loc) · 1.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#ifndef LIBVM_VM_H_
#define LIBVM_VM_H_
#include "utilities.h"
#include "kernel.h"
#include "knn.h"
#include "svm.h"
#include "mcsvm.h"
enum { KNN, SVM_EL, SVM_ES, SVM_KM, OVA_SVM, MCSVM, MCSVM_EL };
struct Parameter {
struct KNNParameter *knn_param;
struct SVMParameter *svm_param;
struct MCSVMParameter *mcsvm_param;
int num_categories;
int save_model;
int load_model;
int taxonomy_type;
int num_folds;
int probability;
};
struct Model {
struct Parameter param;
struct SVMModel *svm_model;
struct KNNModel *knn_model;
struct MCSVMModel *mcsvm_model;
int num_ex;
int num_classes;
int num_categories;
int *labels;
int *categories;
double *points;
};
Model *TrainVM(const struct Problem *train, const struct Parameter *param);
double PredictVM(const struct Problem *train, const struct Model *model, const struct Node *x, double &lower, double &upper, double **avg_prob);
void CrossValidation(const struct Problem *prob, const struct Parameter *param, double *predict_labels, double *lower_bounds, double *upper_bounds, double *brier, double *logloss);
void OnlinePredict(const struct Problem *prob, const struct Parameter *param, double *predict_labels, int *indices, double *lower_bounds, double *upper_bounds, double *brier, double *logloss);
int SaveModel(const char *model_file_name, const struct Model *model);
Model *LoadModel(const char *model_file_name);
void FreeModel(struct Model *model);
void FreeParam(struct Parameter *param);
const char *CheckParameter(const struct Parameter *param);
#endif // LIBVM_VM_H_