package me.jellysquid.mods.sodium.client.util.color;

import java.util.function.Function;
import net.minecraft.util.Mth;
import net.minecraft.world.phys.Vec3;

public class FastCubicSampler {
    private static final double[] DENSITY_CURVE = new double[] { 0.0D, 1.0D, 4.0D, 6.0D, 4.0D, 1.0D, 0.0D };
    private static final int DIAMETER = 6;

    public static Vec3 sampleColor(Vec3 pos, ColorFetcher colorFetcher, Function<Vec3, Vec3> transformer) {
        int intX = Mth.m_14107_(pos.m_7096_());
        int intY = Mth.m_14107_(pos.m_7098_());
        int intZ = Mth.m_14107_(pos.m_7094_());

        int[] values = new int[DIAMETER * DIAMETER * DIAMETER];

        for(int x = 0; x < DIAMETER; ++x) {
            int blockX = (intX - 2) + x;

            for(int y = 0; y < DIAMETER; ++y) {
                int blockY = (intY - 2) + y;

                for(int z = 0; z < DIAMETER; ++z) {
                    int blockZ = (intZ - 2) + z;

                    values[index(x, y, z)] = colorFetcher.fetch(blockX, blockY, blockZ);
                }
            }
        }

        // Fast path! Skip blending the colors if all inputs are the same
        if (isHomogenousArray(values)) {
            // Take the first color if it's homogenous (all elements are the same...)
            return transformer.apply(Vec3.m_82501_(values[0]));
        }

        double deltaX = pos.m_7096_() - (double)intX;
        double deltaY = pos.m_7098_() - (double)intY;
        double deltaZ = pos.m_7094_() - (double)intZ;

        Vec3 sum = Vec3.f_82478_;
        double totalFactor = 0.0D;

        for(int x = 0; x < DIAMETER; ++x) {
            double densityX = Mth.m_14139_(deltaX, DENSITY_CURVE[x + 1], DENSITY_CURVE[x]);

            for(int y = 0; y < DIAMETER; ++y) {
                double densityY = Mth.m_14139_(deltaY, DENSITY_CURVE[y + 1], DENSITY_CURVE[y]);

                for(int z = 0; z < DIAMETER; ++z) {
                    double densityZ = Mth.m_14139_(deltaZ, DENSITY_CURVE[z + 1], DENSITY_CURVE[z]);

                    double factor = densityX * densityY * densityZ;
                    totalFactor += factor;

                    Vec3 color = transformer.apply(Vec3.m_82501_(values[index(x, y, z)]));
                    sum = sum.m_82549_(color.m_82490_(factor));
                }
            }
        }

        sum = sum.m_82490_(1.0D / totalFactor);

        return sum;
    }

    private static int index(int x, int y, int z) {
        return (DIAMETER * DIAMETER * z) + (DIAMETER * y) + x;
    }

    public interface ColorFetcher {
        int fetch(int x, int y, int z);
    }

    private static boolean isHomogenousArray(int[] arr) {
        int val = arr[0];

        for (int i = 1; i < arr.length; i++) {
            if (arr[i] != val) {
                return false;
            }
        }

        return true;
    }
}
