libSVM源码解读(6)

接口函数

Posted by Welt Xing on July 15, 2021

引言

我们在这里结束对libSVM代码级的讨论(数学上还有很多值得深挖的地方),我们这里讨论libSVM的接口函数,也是用户直接调用的方法。不少函数在libSVM源码解读(1)中提及,如果代码实现不难就将其略去。

模型训练

给定数据集进行训练,在libSVM中由svm_train完成。其逻辑用伪代码书写如下:

def svm_train(problem):
	if problem is distribution estimate or regression:
    	Initialization
    	if problem is regression and need estimating:
    		call svm_svr_probability /* 噪声估计 */
    	call svm_train_one           /* 单次训练 */
    	process the parameters
    else:                            /* 回归问题 */
    	call svm_group_classes       /* 整理多类别数据 */
    	calculate parameter 'C'
    	for i <- 1 to nr_class:
    		for j <- i+1 to nr_class:
    		/* 用一对一策略处理多分类问题 */
    		if need estimating:
    			call svm_binary_svc_probability
    		call svm_train_one
    	build output
    free the sapce we allocated before
    return model we trained

svm_train工作量最大的地方就是多分类问题,包括训练好$\dfrac{n(n-1)}{2}$个模型后如何处理输出的问题:用矩阵来存储所有分类模型中的$\alpha$向量。

模型预测

svm_predict_values函数中,我们用训练好的模型对测试数据进行预测:

  • 如果任务是回归,我们直接计算拟合函数值$\sum_j\alpha_j K(x_i,x)-\rho$并返回;
  • 如果任务是单类SVM,如果$\sum_j\alpha_j K(x_j,x)-\rho>0$,则认为测试数据可以被归为训练数据集中,返回1,否则返回-1;
  • 如果任务是回归,使用投票法选取票数最多的类作为输出。

函数svm_predict是对svm_predict_values的封装。svm_predict_probablity也是面向分类问题的预测函数,但基于的是概率估计最大原则。注意到函数中调用了sigmoid_trainsigmoid_predict以及multiclass_probability。如果不小心将回归问题输入,则不做处理,直接返回预测值。

交叉验证

svm_cross_validation将数据打乱,分成指定的份数,依次训练子模型,并将结果存储到target中。值得注意的是,在分类问题中,如果随机划分就会存在某类数据在某一份数据中不存在的情况。libSVM采用的是将每类数据平均随机分到每一份数据中的方法解决这一问题。

模型存取

通过将模型以文件的形式进行存取,我们可以节约重复训练的时间。我们先来看模型保存,libSVM通过svm_save_model函数来保存模型:

int svm_save_model(const char *model_file_name, const svm_model *model) {
	...
    char *old_locale = setlocale(LC_ALL, NULL);
    if (old_locale) {
        old_locale = strdup(old_locale);
    }
    setlocale(LC_ALL, "C");
    ...
}

在正式写入模型前,上面的代码先对语言环境(字符集)进行统一,这里是取消用户计算机上的设定,选取默认值“C”。然后就是写入模型的参数,比如

fprintf(fp, "svm_type %s\n", svm_type_table[param.svm_type]);
fprintf(fp, "kernel_type %s\n", kernel_type_table[param.kernel_type]);

这些必备参数,还有一些选择性参数,如果对应指针不为空则写入:

if (model->nSV) {
	fprintf(fp, "nr_sv");
	for (int i = 0; i < nr_class; i++) fprintf(fp, " %d", model->nSV[i]);
	fprintf(fp, "\n");
}

最重要的是将模型训练好的支持向量和对应的$\alpha_i$写入:

fprintf(fp, "SV\n");
const double *const *sv_coef = model->sv_coef;
const svm_node *const *SV = model->SV;

for (int i = 0; i < l; i++) {
    for (int j = 0; j < nr_class - 1; j++) fprintf(fp, "%.17g ", sv_coef[j][i]);

    const svm_node *p = SV[i];

    if (param.kernel_type == PRECOMPUTED)
        fprintf(fp, "0:%d ", (int)(p->value));
    else
        while (p->index != -1) {
            fprintf(fp, "%d:%.8g ", p->index, p->value);
            p++;
        }
    fprintf(fp, "\n");
}

最后是将还原用户的语言环境:

setlocale(LC_ALL, old_locale);
free(old_locale);

这里我们省略了前后的文件打开和关闭操作。

可以看出,libSVM将文件模型分为两部分:模型的基本参数,也就是超参数(被称作model header),以及模型参数。libSVM会先用read_model_header去读取模型超参数,然后再读取全部模型。read_model_header的基本逻辑就是一行一行匹配,如果标签匹配则读取内容,给新模型赋值。这里用到一个宏:

#define FSCANF(_stream, _format, _var)                         \
    do {                                                       \
        if (fscanf(_stream, _format, _var) != 1) return false; \
    } while (0)

是为了“解决scanf失败”等问题。在svm_load_model中,调用read_model_header后就是对模型参数的读取,这里略去。

如果模型格式出现问题,libSVM也会去检查:svm_check_parametersvm_check_probality_model都是用来检查模型参数是否正确的函数,保证了存取模型的安全性。

其他函数

libSVM还剩下用于获取模型参数的接口和删除模型的函数,比较简单,我们在这里略去。

结语

至此,我们已经大致分析完整个libSVM。随着阅读的深入,便越来越觉得自己能力的卑微,而保持这种卑微或许才是不断求知的动力吧。

笔者其实对libSVM并没有完全读懂,对于其中一小部分操作,也无法从说出其作用。后面或许还会继续研究相关内容。