package com.bailuntec.domain.pojo;

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;

/**
 * 回归实现
 */
public class LinerRegression {
    /**
     * 最小二乘法
     *
     * @param dataNodeList
     */
    public static BigDecimal[] getAB(List<DataNode> dataNodeList) {
        BigDecimal alpha = BigDecimal.ZERO;
        BigDecimal beta = BigDecimal.ZERO;

        int n = dataNodeList.size();
        BigDecimal sumX = BigDecimal.ZERO;
        BigDecimal sumY = BigDecimal.ZERO;
        BigDecimal sumXY = BigDecimal.ZERO;
        BigDecimal sumX2 = BigDecimal.ZERO;

        for (DataNode dataNode : dataNodeList) {
            sumX = sumX.add(dataNode.getX());
            sumY = sumY.add(dataNode.getY());
            sumXY = sumXY.add(dataNode.getXY());
            sumX2 = sumX2.add(dataNode.getX().pow(2));
        }
        BigDecimal var1 = sumY.multiply(sumX).divide(BigDecimal.valueOf(n), 5, RoundingMode.HALF_EVEN).subtract(sumXY);
        BigDecimal var2 = (sumX.multiply(sumX).divide(BigDecimal.valueOf(n), 5, RoundingMode.HALF_EVEN)).subtract(sumX2);
        if (var2.compareTo(BigDecimal.ZERO) == 0) {
            alpha = BigDecimal.ZERO;
        } else {
            alpha = var1.divide(var2, 5, RoundingMode.HALF_EVEN);
        }
        beta = (sumY.subtract((alpha.multiply(sumX)))).divide(BigDecimal.valueOf(n), 5, RoundingMode.HALF_EVEN);
        BigDecimal r2 = getR2(dataNodeList, alpha, beta);
        return new BigDecimal[]{alpha, beta, r2};
    }

    /**
     * 拟合优度
     *
     * @param dataNodeList
     */
    public static BigDecimal getR2(List<DataNode> dataNodeList, BigDecimal alpha, BigDecimal beta) {
        BigDecimal r = BigDecimal.ZERO;
        BigDecimal num = BigDecimal.ZERO;
        BigDecimal den = BigDecimal.ZERO;
        BigDecimal sumY = BigDecimal.ZERO;
        for (DataNode dataNode : dataNodeList) {
            sumY = sumY.add(dataNode.getY());
        }
        BigDecimal avgY = sumY.divide(BigDecimal.valueOf(dataNodeList.size()), 5, RoundingMode.HALF_EVEN);
        for (DataNode dataNode : dataNodeList) {
            num = num.add((dataNode.getY().subtract((dataNode.getX().multiply(alpha).add(beta)))).pow(2));
            den = den.add((dataNode.getY().subtract(avgY)).pow(2));
        }
        if (den.compareTo(BigDecimal.ZERO) == 0) {
            r = BigDecimal.ZERO;
        } else {
            r = BigDecimal.ONE.subtract((num.divide(den, 5, RoundingMode.HALF_EVEN)));
        }
        return r;
    }

}
