Search Posts

Linear Regression in java

package hk.quantr.linearregression;

public class LinearRegression {

	private final double intercept, slope;
	private final double r2;
	private final double svar0, svar1;

	/**
	 * Performs a linear regression on the data points {@code (y[i], x[i])}.
	 *
	 * @param x the values of the predictor variable
	 * @param y the corresponding values of the response variable
	 */
	public LinearRegression(double[] x, double[] y) {
		if (x.length != y.length) {
			throw new IllegalArgumentException("array lengths are not equal");
		}
		int n = x.length;

		// first pass
		double sumx = 0.0, sumy = 0.0, sumx2 = 0.0;
		for (int i = 0; i < n; i++) {
			sumx += x[i];
			sumx2 += x[i] * x[i];
			sumy += y[i];
		}
		double xbar = sumx / n;
		double ybar = sumy / n;

		// second pass: compute summary statistics
		double xxbar = 0.0, yybar = 0.0, xybar = 0.0;
		for (int i = 0; i < n; i++) {
			xxbar += (x[i] - xbar) * (x[i] - xbar);
			yybar += (y[i] - ybar) * (y[i] - ybar);
			xybar += (x[i] - xbar) * (y[i] - ybar);
		}
		slope = xybar / xxbar;
		intercept = ybar - slope * xbar;

		// more statistical analysis
		double rss = 0.0;      // residual sum of squares
		double ssr = 0.0;      // regression sum of squares
		for (int i = 0; i < n; i++) {
			double fit = slope * x[i] + intercept;
			rss += (fit - y[i]) * (fit - y[i]);
			ssr += (fit - ybar) * (fit - ybar);
		}

		int degreesOfFreedom = n - 2;
		r2 = ssr / yybar;
		double svar = rss / degreesOfFreedom;
		svar1 = svar / xxbar;
		svar0 = svar / n + xbar * xbar * svar1;
	}

	/**
	 * Returns the <em>y</em>-intercept &alpha; of the best of the best-fit line <em>y</em> = &alpha; + &beta; <em>x</em>.
	 */
	public double intercept() {
		return intercept;
	}

	/**
	 * Returns the slope &beta; of the best of the best-fit line <em>y</em> = &alpha; + &beta; <em>x</em>.
	 */
	public double slope() {
		return slope;
	}

	/**
	 * Returns the coefficient of determination <em>R</em><sup>2</sup>.
	 */
	public double R2() {
		return r2;
	}

	/**
	 * Returns the standard error of the estimate for the intercept.
	 */
	public double interceptStdErr() {
		return Math.sqrt(svar0);
	}

	/**
	 * Returns the standard error of the estimate for the slope.
	 */
	public double slopeStdErr() {
		return Math.sqrt(svar1);
	}

	/**
	 * Returns the expected response {@code y} given the value of the predictor variable {@code x}.
	 */
	public double predict(double x) {
		return slope * x + intercept;
	}

	/**
	 * Returns a string representation of the simple linear regression model.
	 */
	public String toString() {
		StringBuilder s = new StringBuilder();
		s.append(String.format("%.2f n + %.2f", slope(), intercept()));
		s.append("  (R^2 = " + String.format("%.3f", R2()) + ")");
		return s.toString();
	}

	public static void main(String[] args) {
		double x[] = {1, 2, 3};
		double y[] = {4, 5, 7};
		LinearRegression lr = new LinearRegression(x, y);
		System.out.println(lr.predict(4));
	}
}

Leave a Reply

Your email address will not be published. Required fields are marked *