package team.lodestar.lodestone.systems.textureloader;

import net.minecraftforge.client.event.RegisterTextureAtlasSpriteLoadersEvent;
import team.lodestar.lodestone.helpers.render.ColorHelper;
import com.mojang.blaze3d.platform.NativeImage;
import com.mojang.datafixers.util.Pair;
import team.lodestar.lodestone.systems.easing.Easing;
import net.minecraft.client.renderer.texture.TextureAtlasSprite;
import net.minecraft.client.resources.metadata.animation.AnimationMetadataSection;
import net.minecraft.resources.ResourceLocation;
import net.minecraft.server.packs.resources.Resource;
import net.minecraft.util.Mth;
import net.minecraftforge.client.event.TextureStitchEvent;
import net.minecraftforge.eventbus.api.IEventBus;
import net.minecraftforge.fml.javafmlmod.FMLJavaModLoadingContext;

import javax.annotation.Nonnull;
import java.awt.*;
import java.io.IOException;
import java.util.function.Consumer;


public class LodestoneTextureLoader {

    protected static final ColorLerp GRADIENT = (image, x, y, luminosity, s) -> ((y % 16) / 16f);
    protected static final ColorLerp LUMINOUS_GRADIENT = (image, x, y, luminosity, s) -> (((y % 16) / 16f) + luminosity / s) / 2f;
    protected static final ColorLerp LUMINOUS = (image, x, y, luminosity, s) -> luminosity / s;

    public static void registerTextureLoader(String loaderName, ResourceLocation targetPath, ResourceLocation inputImage, RegisterTextureAtlasSpriteLoadersEvent event) {
        IEventBus busMod = FMLJavaModLoadingContext.get().getModEventBus();
        event.register(loaderName, (atlas, resourceManager, textureInfo, resource, atlasWidth, atlasHeight, spriteX, spriteY, mipmapLevel, image) -> {
            try {
                Resource r = resourceManager.getResourceOrThrow(inputImage);
                return new TextureAtlasSprite(atlas, getAnimatedInfo(r, image, textureInfo), mipmapLevel, atlasWidth, atlasHeight, spriteX, spriteY, image) {
                };
            } catch (Throwable throwable1) {
                throwable1.printStackTrace();
            }

            //Return the default sprite
            return new TextureAtlasSprite(atlas, textureInfo, mipmapLevel, atlasWidth, atlasHeight, spriteX, spriteY, image) {
            };
        });
        busMod.addListener((Consumer<TextureStitchEvent.Pre>) event1 -> event1.addSprite(targetPath));
    }

    public static void registerTextureLoader(String loaderName, ResourceLocation targetPath, ResourceLocation inputImage, TextureModifier textureModifier, RegisterTextureAtlasSpriteLoadersEvent event) {
        IEventBus busMod = FMLJavaModLoadingContext.get().getModEventBus();
        event.register(loaderName, (atlas, resourceManager, textureInfo, resource, atlasWidth, atlasHeight, spriteX, spriteY, mipmapLevel, image) -> {
            try {
                //Get the resource we want. If it's null, just get out of here
                Resource r = resourceManager.getResourceOrThrow(inputImage);
                image = textureModifier.modifyTexture(NativeImage.read(r.open()));
                return new TextureAtlasSprite(atlas, getAnimatedInfo(r, image, textureInfo), mipmapLevel, atlasWidth, atlasHeight, spriteX, spriteY, image) {
                };
            } catch (Throwable throwable1) {
                throwable1.printStackTrace();
            }

            //Return the default sprite
            return new TextureAtlasSprite(atlas, textureInfo, mipmapLevel, atlasWidth, atlasHeight, spriteX, spriteY, image) {
            };
        });
        busMod.addListener((Consumer<TextureStitchEvent.Pre>) event1 -> event1.addSprite(targetPath));
    }

    private static TextureAtlasSprite.Info getAnimatedInfo(@Nonnull Resource r, @Nonnull NativeImage image, @Nonnull TextureAtlasSprite.Info baseInfo) throws IOException {
        AnimationMetadataSection section = r.metadata().getSection(AnimationMetadataSection.SERIALIZER).orElseThrow();
        Pair<Integer, Integer> pair = section.getFrameSize(image.getWidth(), image.getHeight());
        return new TextureAtlasSprite.Info(baseInfo.name(), pair.getFirst(), pair.getSecond(), section);
    }

    public static NativeImage applyGrayscale(NativeImage nativeimage) {
        for (int x = 0; x < nativeimage.getWidth(); x++) {
            for (int y = 0; y < nativeimage.getHeight(); y++) {
                int pixel = nativeimage.getPixelRGBA(x, y);
                int L = (int) (0.299D * ((pixel) & 0xFF) + 0.587D * ((pixel >> 8) & 0xFF) + 0.114D * ((pixel >> 16) & 0xFF));
                nativeimage.setPixelRGBA(x, y, NativeImage.combine((pixel >> 24) & 0xFF, L, L, L));
            }
        }
        return nativeimage;
    }

    public static NativeImage applyMultiColorGradient(Easing easing, NativeImage nativeimage, ColorLerp colorLerp, Color... colors) {
        int colorCount = colors.length - 1;
        int lowestLuminosity = 255;
        int highestLuminosity = 0;
        for (int x = 0; x < nativeimage.getWidth(); x++) {
            for (int y = 0; y < nativeimage.getHeight(); y++) {
                int pixel = nativeimage.getPixelRGBA(x, y);
                int alpha = (pixel >> 24) & 0xFF;
                if (alpha == 0) {
                    continue;
                }
                int luminosity = (int) (0.299D * ((pixel) & 0xFF) + 0.587D * ((pixel >> 8) & 0xFF) + 0.114D * ((pixel >> 16) & 0xFF));
                if (luminosity < lowestLuminosity) {
                    lowestLuminosity = luminosity;
                }
                if (luminosity > highestLuminosity) {
                    highestLuminosity = luminosity;
                }
            }
        }
        for (int x = 0; x < nativeimage.getWidth(); x++) {
            for (int y = 0; y < nativeimage.getHeight(); y++) {
                int pixel = nativeimage.getPixelRGBA(x, y);
                int alpha = (pixel >> 24) & 0xFF;
                if (alpha == 0) {
                    continue;
                }
                int luminosity = (int) (0.299D * ((pixel) & 0xFF) + 0.587D * ((pixel >> 8) & 0xFF) + 0.114D * ((pixel >> 16) & 0xFF));
                float pct = luminosity / 255f; //this should probably be smth else
                float newLuminosity = Mth.lerp(pct, lowestLuminosity, highestLuminosity);
                float lerp = 1 - colorLerp.lerp(pixel, x, y, newLuminosity, highestLuminosity);
                float colorIndex = 2 * colorCount * lerp; //TODO: figure out why this * 2 is here

                int index = (int) Mth.clamp(colorIndex, 0, colorCount);
                Color color = colors[index];
                Color nextColor = index == colorCount ? color : colors[index + 1];
                Color transition = ColorHelper.colorLerp(easing, colorIndex - (int) (colorIndex), color, nextColor);
                nativeimage.setPixelRGBA(x, y, NativeImage.combine(alpha, transition.getBlue(), transition.getGreen(), transition.getRed()));
            }
        }
        return nativeimage;
    }

    public interface ColorLerp {
        float lerp(int pixel, int x, int y, float luminosity, float luminosityScale);
    }

    public interface TextureModifier {
        NativeImage modifyTexture(NativeImage nativeImage);
    }
}