我们在这里结束对libSVM代码级的讨论(数学上还有很多值得深挖的地方),我们这里讨论libSVM的接口函数,也是用户直接调用的方法。不少函数在libSVM源码解读(1)中提及,如果代码实现不难就将其略去。
给定数据集进行训练,在libSVM中由svm_train
完成。其逻辑用伪代码书写如下:
def svm_train(problem):
if problem is distribution estimate or regression:
Initialization
if problem is regression and need estimating:
call svm_svr_probability /* 噪声估计 */
call svm_train_one /* 单次训练 */
process the parameters
else: /* 回归问题 */
call svm_group_classes /* 整理多类别数据 */
calculate parameter 'C'
for i <- 1 to nr_class:
for j <- i+1 to nr_class:
/* 用一对一策略处理多分类问题 */
if need estimating:
call svm_binary_svc_probability
call svm_train_one
build output
free the sapce we allocated before
return model we trained
svm_train
工作量最大的地方就是多分类问题,包括训练好$\dfrac{n(n-1)}{2}$个模型后如何处理输出的问题:用矩阵来存储所有分类模型中的$\alpha$向量。
在svm_predict_values
函数中,我们用训练好的模型对测试数据进行预测:
函数svm_predict
是对svm_predict_values
的封装。svm_predict_probablity
也是面向分类问题的预测函数,但基于的是概率估计最大原则。注意到函数中调用了sigmoid_train
,sigmoid_predict
以及multiclass_probability
。如果不小心将回归问题输入,则不做处理,直接返回预测值。
svm_cross_validation
将数据打乱,分成指定的份数,依次训练子模型,并将结果存储到target
中。值得注意的是,在分类问题中,如果随机划分就会存在某类数据在某一份数据中不存在的情况。libSVM采用的是将每类数据平均随机分到每一份数据中的方法解决这一问题。
通过将模型以文件的形式进行存取,我们可以节约重复训练的时间。我们先来看模型保存,libSVM通过svm_save_model
函数来保存模型:
int svm_save_model(const char *model_file_name, const svm_model *model) {
...
char *old_locale = setlocale(LC_ALL, NULL);
if (old_locale) {
old_locale = strdup(old_locale);
}
setlocale(LC_ALL, "C");
...
}
在正式写入模型前,上面的代码先对语言环境(字符集)进行统一,这里是取消用户计算机上的设定,选取默认值“C”。然后就是写入模型的参数,比如
fprintf(fp, "svm_type %s\n", svm_type_table[param.svm_type]);
fprintf(fp, "kernel_type %s\n", kernel_type_table[param.kernel_type]);
这些必备参数,还有一些选择性参数,如果对应指针不为空则写入:
if (model->nSV) {
fprintf(fp, "nr_sv");
for (int i = 0; i < nr_class; i++) fprintf(fp, " %d", model->nSV[i]);
fprintf(fp, "\n");
}
最重要的是将模型训练好的支持向量和对应的$\alpha_i$写入:
fprintf(fp, "SV\n");
const double *const *sv_coef = model->sv_coef;
const svm_node *const *SV = model->SV;
for (int i = 0; i < l; i++) {
for (int j = 0; j < nr_class - 1; j++) fprintf(fp, "%.17g ", sv_coef[j][i]);
const svm_node *p = SV[i];
if (param.kernel_type == PRECOMPUTED)
fprintf(fp, "0:%d ", (int)(p->value));
else
while (p->index != -1) {
fprintf(fp, "%d:%.8g ", p->index, p->value);
p++;
}
fprintf(fp, "\n");
}
最后是将还原用户的语言环境:
setlocale(LC_ALL, old_locale);
free(old_locale);
这里我们省略了前后的文件打开和关闭操作。
可以看出,libSVM将文件模型分为两部分:模型的基本参数,也就是超参数(被称作model header),以及模型参数。libSVM会先用read_model_header
去读取模型超参数,然后再读取全部模型。read_model_header
的基本逻辑就是一行一行匹配,如果标签匹配则读取内容,给新模型赋值。这里用到一个宏:
#define FSCANF(_stream, _format, _var) \
do { \
if (fscanf(_stream, _format, _var) != 1) return false; \
} while (0)
是为了“解决scanf失败”等问题。在svm_load_model
中,调用read_model_header
后就是对模型参数的读取,这里略去。
如果模型格式出现问题,libSVM也会去检查:svm_check_parameter
和svm_check_probality_model
都是用来检查模型参数是否正确的函数,保证了存取模型的安全性。
libSVM还剩下用于获取模型参数的接口和删除模型的函数,比较简单,我们在这里略去。
至此,我们已经大致分析完整个libSVM。随着阅读的深入,便越来越觉得自己能力的卑微,而保持这种卑微或许才是不断求知的动力吧。
笔者其实对libSVM并没有完全读懂,对于其中一小部分操作,也无法从说出其作用。后面或许还会继续研究相关内容。