多项式回归(Polynomial Regression)(附代码)

浏览: 3216

Image多项式回归有个很重要的因素就是指数(degree)。如果我们发现数据的分布大致是一条曲线,那么很可能符合多项式回归,但是我们不知道degree是多少。所以我们只能一个个去试,直到找到最拟合分布的degree。这个过程我们可以交给数据科学软件完成。需要注意的是,如果degree选择过大的话可能会导致函数过于拟合, 意味着对数据或者函数未来的发展很难预测,也许指向不同的方向。

这个回归的计算需要用到矩阵数据结构。有的编程语言可能需要导入外库。

Image

多项式回归有个很重要的因素就是指数(degree)。如果我们发现数据的分布大致是一条曲线,那么很可能符合多项式回归,但是我们不知道degree是多少。所以我们只能一个个去试,直到找到最拟合分布的degree。这个过程我们可以交给数据科学软件完成。需要注意的是,如果degree选择过大的话可能会导致函数过于拟合, 意味着对数据或者函数未来的发展很难预测,也许指向不同的方向。

这个回归的计算需要用到矩阵数据结构。有的编程语言可能需要导入外库。

Image

我们对所有拟合这个公式的点,用矩阵表示他们的关系

Image

如果用矩阵符号表示:

Image

多项式回归向量的系数(使用最小二乘法):

Image

Java 和 Python 代码如下:


package regression;
import Jama.Matrix;
import Jama.QRDecomposition;

public class PR {

private final int N;
private final int degree;
private final Matrix beta;
private double SSE;
private double SST;

public PR(double[] x, double[] y, int degree) {
this.degree = degree;
N = x.length;

// build Vandermonde matrix
double[][] vandermonde = new double[N][degree+1];
for (int i = 0; i < N; i++) {
for (int j = 0; j <= degree; j++) {
vandermonde[i][j] = Math.pow(x[i], j);
}
}
Matrix X = new Matrix(vandermonde);

// 从向量中增加一个矩阵
Matrix Y = new Matrix(y, N);

// 找到最小的平方值
QRDecomposition qr = new QRDecomposition(X);
beta = qr.solve(Y);


// 得到y的平均值
double sum = 0.0;
for (int i = 0; i < N; i++)
sum += y[i];
double mean = sum / N;

// total variation to be accounted for
for (int i = 0; i < N; i++) {
double dev = y[i] - mean;
SST += dev*dev;
}

// variation not accounted for
Matrix residuals = X.times(beta).minus(Y);
SSE = residuals.norm2() * residuals.norm2();
}

public double beta(int j) {
return beta.get(j, 0);
}

public int degreee() {
return degree;
}

public double R2() {
return 1.0 - SSE/SST;
}

public double predict(double x) {

double y = 0.0;
for (int j = degree; j>=0; j--) {
y = beta(j) + (x*y);
}
return y;
}

public String toString() {
String s = "";
int j = degree;

// 忽略系数为0.
while (Math.abs(beta(j)) < 1E-5)
j--;

// create remaining terms
for (j = j; j >= 0; j--) {
if (j == 0) s += String.format("%.2f ", beta(j));
else if (j == 1) s += String.format("%.2f N + ", beta(j));
else s += String.format("%.2f N^%d + ", beta(j), j);
}
return s + " (R^2 = " + String.format("%.3f", R2()) + ")";
}

}

ref:

Java代码使用了《算法》中的代码,可以在普林斯顿的算法课上下载:[Polynomial Regression](http://introcs.cs.princeton.edu/java/97data/PolynomialRegression.java.html)

推荐 1
本文由 dykin 创作,采用 知识共享署名-相同方式共享 3.0 中国大陆许可协议 进行许可。
转载、引用前需联系作者,并署名作者且注明文章出处。
本站文章版权归原作者及原出处所有 。内容为作者个人观点, 并不代表本站赞同其观点和对其真实性负责。本站是一个个人学习交流的平台,并不用于任何商业目的,如果有任何问题,请及时联系我们,我们将根据著作权人的要求,立即更正或者删除有关内容。本站拥有对此声明的最终解释权。

0 个评论

要回复文章请先登录注册