01
02
03
04
05
06
07
08
09
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
package algs91; // section 9.9
import stdlib.*;
/* ***********************************************************************
 *  Compilation:  javac GaussianElimination.java
 *  Execution:    java GaussianElimination
 *
 *  Gaussian elimination with partial pivoting.
 *
 *  % java GaussianElimination
 *  -1.0
 *  2.0
 *  2.0
 *
 *************************************************************************/

public class GaussianElimination {
  private static final double EPSILON = 1e-10;

  // Gaussian elimination with partial pivoting
  public static double[] lsolve(double[][] A, double[] b) {
    int N  = b.length;

    for (int p = 0; p < N; p++) {

      // find pivot row and swap
      int max = p;
      for (int i = p + 1; i < N; i++) {
        if (Math.abs(A[i][p]) > Math.abs(A[max][p])) {
          max = i;
        }
      }
      double[] temp = A[p]; A[p] = A[max]; A[max] = temp;
      double   t    = b[p]; b[p] = b[max]; b[max] = t;

      // singular or nearly singular
      if (Math.abs(A[p][p]) <= EPSILON) {
        throw new Error("Matrix is singular or nearly singular");
      }

      // pivot within A and b
      for (int i = p + 1; i < N; i++) {
        double alpha = A[i][p] / A[p][p];
        b[i] -= alpha * b[p];
        for (int j = p; j < N; j++) {
          A[i][j] -= alpha * A[p][j];
        }
      }
    }

    // back substitution
    double[] x = new double[N];
    for (int i = N - 1; i >= 0; i--) {
      double sum = 0.0;
      for (int j = i + 1; j < N; j++) {
        sum += A[i][j] * x[j];
      }
      x[i] = (b[i] - sum) / A[i][i];
    }
    return x;
  }


  // sample client
  public static void main(String[] args) {
    int N = 3;
    double[][] A = {
        { 0, 1,  1 },
        { 2, 4, -2 },
        { 0, 3, 15 }
    };
    double[] b = { 4, 2, 36 };
    double[] x = lsolve(A, b);


    // print results
    for (int i = 0; i < N; i++) {
      StdOut.println(x[i]);
    }

  }

}