/*
 * Decompiled with CFR 0.152.
 */
package jdplus.toolkit.base.core.math.matrices.lapack;

import jdplus.toolkit.base.core.math.matrices.CPointer;
import jdplus.toolkit.base.core.math.matrices.DataPointer;
import jdplus.toolkit.base.core.math.matrices.FastMatrix;
import jdplus.toolkit.base.core.math.matrices.MatrixException;
import jdplus.toolkit.base.core.math.matrices.MatrixTransformation;
import jdplus.toolkit.base.core.math.matrices.RPointer;
import lombok.Generated;

public final class GEMM {
    public static void apply(double alpha, FastMatrix A, FastMatrix B, double beta, FastMatrix C) {
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int cstart = C.getStartPosition();
        int cinc = C.getColumnIncrement();
        CPointer pc = new CPointer(C.getStorage(), cstart);
        if (beta == 0.0) {
            for (c = 0; c < n; ++c) {
                pc.set(m, 0.0);
                pc.move(cinc);
            }
            pc.pos(cstart);
        } else if (beta == -1.0) {
            for (c = 0; c < n; ++c) {
                pc.chs(m);
                pc.move(cinc);
            }
            pc.pos(cstart);
        } else if (beta != 1.0) {
            for (c = 0; c < n; ++c) {
                pc.mul(m, beta);
                pc.move(cinc);
            }
            pc.pos(cstart);
        }
        if (alpha != 0.0) {
            int k = A.getColumnsCount();
            int astart = A.getStartPosition();
            int ainc = A.getColumnIncrement();
            int bstart = B.getStartPosition();
            int binc = B.getColumnIncrement();
            CPointer pa = new CPointer(A.getStorage(), astart);
            CPointer pb = new CPointer(B.getStorage(), bstart);
            for (int j = 0; j < n; ++j) {
                for (int i = 0; i < k; ++i) {
                    double w = alpha * pb.value(i);
                    if (w != 0.0) {
                        pc.addAX(m, w, pa);
                    }
                    pa.move(ainc);
                }
                pa.pos(astart);
                pc.move(cinc);
                pb.move(binc);
            }
        }
    }

    public static void apply(double alpha, FastMatrix A, FastMatrix B, double beta, FastMatrix C, MatrixTransformation ta, MatrixTransformation tb) {
        int k;
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int ma = A.getRowsCount();
        int na = A.getColumnsCount();
        int mb = B.getRowsCount();
        int nb = B.getColumnsCount();
        if (ta == MatrixTransformation.Transpose && m != na || ta == MatrixTransformation.None && m != ma || tb == MatrixTransformation.Transpose && n != mb || tb == MatrixTransformation.None && n != nb) {
            throw new MatrixException("m_err_dim");
        }
        int n2 = k = ta == MatrixTransformation.Transpose ? ma : na;
        if (tb == MatrixTransformation.Transpose && k != nb || tb == MatrixTransformation.None && k != mb) {
            throw new MatrixException("m_err_dim");
        }
        if (m == 0 || n == 0 || (alpha == 0.0 || k == 0) && beta == 1.0) {
            return;
        }
        int cstart = C.getStartPosition();
        int clda = C.getColumnIncrement();
        double[] pc = C.getStorage();
        if (alpha == 0.0) {
            GEMM.mul(beta, pc, m, n, cstart, clda);
            return;
        }
        block0 : switch (tb) {
            case Transpose: {
                switch (ta) {
                    case Transpose: {
                        GEMM.addAtBt(C, beta, A, B, alpha);
                        break;
                    }
                    case None: {
                        GEMM.addABt(C, beta, A, B, alpha);
                    }
                }
                break;
            }
            case None: {
                switch (ta) {
                    case Transpose: {
                        GEMM.addAtB(C, beta, A, B, alpha);
                        break block0;
                    }
                    case None: {
                        GEMM.addAB(C, beta, A, B, alpha);
                    }
                }
            }
        }
    }

    private static void mul(double beta, double[] pc, int m, int n, int cstart, int clda) {
        block5: {
            block4: {
                if (beta != 0.0) break block4;
                int icmax = cstart + n * clda;
                for (int ic = cstart; ic < icmax; ic += clda) {
                    int jcmax = ic + m;
                    for (int jc = ic; jc < jcmax; ++jc) {
                        pc[jc] = 0.0;
                    }
                }
                break block5;
            }
            if (beta == 1.0) break block5;
            int icmax = cstart + n * clda;
            for (int ic = cstart; ic < icmax; ic += clda) {
                int jcmax = ic + m;
                int jc = ic;
                while (jc < jcmax) {
                    int n2 = jc++;
                    pc[n2] = pc[n2] * beta;
                }
            }
        }
    }

    private static void addAB(FastMatrix C, double beta, FastMatrix A, FastMatrix B, double alpha) {
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int k = A.getColumnsCount();
        int astart = A.getStartPosition();
        int alda = A.getColumnIncrement();
        int bstart = B.getStartPosition();
        int blda = B.getColumnIncrement();
        int cstart = C.getStartPosition();
        int clda = C.getColumnIncrement();
        CPointer lc = new CPointer(C.getStorage(), cstart);
        CPointer la = new CPointer(A.getStorage(), astart);
        CPointer lb = new CPointer(B.getStorage(), bstart);
        for (int j = 0; j < n; ++j) {
            if (beta == 0.0) {
                lc.set(m, 0.0);
            } else if (beta != 1.0) {
                lc.mul(m, beta);
            }
            la.pos(astart);
            for (int l = 0; l < k; ++l) {
                double tmp = alpha * lb.value(l);
                if (tmp != 0.0) {
                    lc.addAX(m, tmp, la);
                }
                la.move(alda);
            }
            lc.move(clda);
            lb.move(blda);
        }
    }

    private static void addAtBt(FastMatrix C, double beta, FastMatrix A, FastMatrix B, double alpha) {
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int k = A.getRowsCount();
        int astart = A.getStartPosition();
        int alda = A.getColumnIncrement();
        int bstart = B.getStartPosition();
        int blda = B.getColumnIncrement();
        int cstart = C.getStartPosition();
        int clda = C.getColumnIncrement();
        CPointer la = new CPointer(A.getStorage(), astart);
        RPointer lb = new RPointer(B.getStorage(), bstart, blda);
        double[] pc = C.getStorage();
        int j = 0;
        int ic = cstart;
        while (j < n) {
            la.pos(astart);
            int i = 0;
            int ijc = ic;
            while (i < m) {
                double tmp = ((DataPointer)la).dot(k, lb);
                pc[ijc] = beta == 0.0 ? alpha * tmp : alpha * tmp + beta * pc[ijc];
                ++i;
                ++ijc;
                la.move(alda);
            }
            ++j;
            ic += clda;
            lb.next();
        }
    }

    private static void addAtB(FastMatrix C, double beta, FastMatrix A, FastMatrix B, double alpha) {
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int k = A.getRowsCount();
        int astart = A.getStartPosition();
        int alda = A.getColumnIncrement();
        int bstart = B.getStartPosition();
        int blda = B.getColumnIncrement();
        int cstart = C.getStartPosition();
        int clda = C.getColumnIncrement();
        CPointer la = new CPointer(A.getStorage(), astart);
        CPointer lb = new CPointer(B.getStorage(), bstart);
        double[] pc = C.getStorage();
        int j = 0;
        int ic = cstart;
        while (j < n) {
            la.pos(astart);
            int i = 0;
            int ijc = ic;
            while (i < m) {
                double tmp = ((DataPointer)la).dot(k, lb);
                pc[ijc] = beta == 0.0 ? alpha * tmp : alpha * tmp + beta * pc[ijc];
                ++i;
                ++ijc;
                la.move(alda);
            }
            ++j;
            ic += clda;
            lb.move(blda);
        }
    }

    private static void addABt(FastMatrix C, double beta, FastMatrix A, FastMatrix B, double alpha) {
        int m = C.getRowsCount();
        int n = C.getColumnsCount();
        int k = A.getColumnsCount();
        int astart = A.getStartPosition();
        int alda = A.getColumnIncrement();
        int bstart = B.getStartPosition();
        int blda = B.getColumnIncrement();
        int cstart = C.getStartPosition();
        int clda = C.getColumnIncrement();
        CPointer lc = new CPointer(C.getStorage(), cstart);
        CPointer la = new CPointer(A.getStorage(), astart);
        RPointer lb = new RPointer(B.getStorage(), bstart, blda);
        for (int j = 0; j < n; ++j) {
            if (beta == 0.0) {
                ((DataPointer)lc).set(m, 0.0);
            } else if (beta != 1.0) {
                ((DataPointer)lc).mul(n, beta);
            }
            la.pos(astart);
            for (int l = 0; l < k; ++l) {
                double tmp = alpha * ((DataPointer)lb).value(l);
                if (tmp != 0.0) {
                    ((DataPointer)lc).addAX(m, tmp, la);
                }
                la.move(alda);
            }
            lc.move(clda);
            lb.next();
        }
    }

    @Generated
    private GEMM() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

