package foundry.veil.mixin.shader.client;

import foundry.veil.Veil;
import foundry.veil.api.client.render.shader.program.ShaderUniformCache;
import foundry.veil.impl.client.render.shader.processor.VanillaShaderProcessor;
import foundry.veil.impl.client.render.shader.program.ShaderProgramImpl;
import it.unimi.dsi.fastutil.objects.Object2ObjectArrayMap;
import org.lwjgl.system.MemoryStack;
import org.lwjgl.system.MemoryUtil;
import org.spongepowered.asm.mixin.Final;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfo;
import org.spongepowered.asm.mixin.injection.callback.CallbackInfoReturnable;

import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.List;
import java.util.Map;
import net.minecraft.class_281;
import net.minecraft.class_284;
import net.minecraft.class_3679;
import net.minecraft.class_5912;
import net.minecraft.class_5944;

import static org.lwjgl.opengl.GL20C.*;

@Mixin(value = class_5944.class, priority = 800)
public abstract class ShaderInstanceMixin implements class_3679 {

    @Shadow
    @Final
    private int programId;
    @Shadow
    @Final
    private List<Integer> uniformLocations;
    @Shadow
    @Final
    public Map<String, class_284> uniformMap;
    @Shadow
    @Final
    private String name;

    @Shadow
    @Final
    private List<String> samplerNames;

    @Unique
    private final Map<String, class_284> veil$uniforms = new Object2ObjectArrayMap<>();

    @Inject(method = "getOrCreate", at = @At("HEAD"), cancellable = true)
    private static void veil$cancelDummyProgram(class_5912 provider, class_281.class_282 type, String name, CallbackInfoReturnable<class_281> cir) {
        if (ShaderProgramImpl.Wrapper.constructingProgram != null) {
            cir.setReturnValue(new ShaderProgramImpl.ShaderWrapper(type, ShaderProgramImpl.Wrapper.constructingProgram));
        }
    }

    @Inject(method = "getOrCreate", at = @At("HEAD"))
    private static void veil$setupFallbackProcessor(class_5912 provider, class_281.class_282 type, String name, CallbackInfoReturnable<class_281> cir) {
        if (Veil.platform().hasErrors()) {
            return;
        }
        VanillaShaderProcessor.setup(provider);
    }

    @Inject(method = "getOrCreate", at = @At("RETURN"))
    private static void veil$clearFallbackProcessor(CallbackInfoReturnable<class_281> cir) {
        if (Veil.platform().hasErrors()) {
            return;
        }
        VanillaShaderProcessor.free();
    }

    @Inject(method = "close", at = @At("HEAD"))
    public void close(CallbackInfo ci) {
        for (class_284 uniform : this.veil$uniforms.values()) {
            uniform.close();
        }
    }

    @Inject(method = "apply", at = @At("TAIL"))
    public void apply(CallbackInfo ci) {
        for (class_284 uniform : this.veil$uniforms.values()) {
            uniform.method_1300();
        }
    }

    @SuppressWarnings("ConstantValue")
    @Inject(method = "updateLocations", at = @At("TAIL"))
    public void updateLocations(CallbackInfo ci) {
        if ((Object) this instanceof ShaderProgramImpl.Wrapper) {
            return;
        }

        for (class_284 uniform : this.veil$uniforms.values()) {
            uniform.method_1297(-1);
        }

        int uniformCount = glGetProgrami(this.programId, GL_ACTIVE_UNIFORMS);
        int maxUniformLength = glGetProgrami(this.programId, GL_ACTIVE_UNIFORM_MAX_LENGTH);

        try (MemoryStack stack = MemoryStack.stackPush()) {
            IntBuffer size = stack.mallocInt(1);
            IntBuffer type = stack.mallocInt(1);
            for (int i = 0; i < uniformCount; i++) {
                String name = glGetActiveUniform(this.programId, i, maxUniformLength, size, type);

                if (this.uniformMap.containsKey(name) || this.samplerNames.contains(name)) {
                    continue;
                }

                int dataType = type.get(0);
                String typeName = ShaderUniformCache.getName(dataType);
                int length = size.get(0);
                if (ShaderUniformCache.isSampler(dataType)) {
                    for (int j = 0; j < length; j++) {
                        if (length > 1) {
                            name = name.substring(0, name.length() - 3) + '[' + j + ']';
                        }
                        Veil.LOGGER.debug("Shader {} detected sampler: {}", this.name, typeName + " " + name);
                        this.samplerNames.add(name);
                    }
                    continue;
                }

                int minecraftType;
                int minecraftCount;
                switch (dataType) {
                    case GL_INT -> {
                        minecraftType = class_284.field_32038;
                        minecraftCount = 1;
                    }
                    case GL_INT_VEC2 -> {
                        minecraftType = class_284.field_32039;
                        minecraftCount = 2;
                    }
                    case GL_INT_VEC3 -> {
                        minecraftType = class_284.field_32040;
                        minecraftCount = 3;
                    }
                    case GL_INT_VEC4 -> {
                        minecraftType = class_284.field_32041;
                        minecraftCount = 4;
                    }
                    case GL_FLOAT -> {
                        minecraftType = class_284.field_32042;
                        minecraftCount = 1;
                    }
                    case GL_FLOAT_VEC2 -> {
                        minecraftType = class_284.field_32043;
                        minecraftCount = 2;
                    }
                    case GL_FLOAT_VEC3 -> {
                        minecraftType = class_284.field_32044;
                        minecraftCount = 3;
                    }
                    case GL_FLOAT_VEC4 -> {
                        minecraftType = class_284.field_32045;
                        minecraftCount = 4;
                    }
                    case GL_FLOAT_MAT2 -> {
                        minecraftType = class_284.field_32046;
                        minecraftCount = 4;
                    }
                    case GL_FLOAT_MAT3 -> {
                        minecraftType = class_284.field_32047;
                        minecraftCount = 9;
                    }
                    case GL_FLOAT_MAT4 -> {
                        minecraftType = class_284.field_32048;
                        minecraftCount = 16;
                    }
                    default -> {
                        Veil.LOGGER.error("Unsupported Uniform Type: {}", typeName);
                        continue;
                    }
                }

                for (int j = 0; j < length; j++) {
                    int location = class_284.method_22096(this.programId, name);
                    if (location == -1) {
                        // If the length is not 1, then it must be another mod adding a uniform block, so ignore
                        if (length == 1) {
                            Veil.LOGGER.warn("Shader {} could not find uniform named {} in the specified shader program.", this.name, name);
                        }

                        // Don't leak resources
                        class_284 old = this.veil$uniforms.remove(name);
                        if (old != null) {
                            old.close();
                        }
                        continue;
                    }

                    if (length > 1) {
                        name = name.substring(0, name.indexOf('[')) + '[' + j + ']';
                    }

                    Veil.LOGGER.debug("Shader {} detected uniform: {}", this.name, typeName + " " + name);
                    class_284 old = this.veil$uniforms.get(name);
                    class_284 uniform;
                    if (old != null) {
                        if (old.method_35662() != minecraftType) {
                            old.close();
                            this.veil$uniforms.put(name, uniform = new class_284(name, minecraftType, minecraftCount, this));
                        } else {
                            uniform = old;
                        }
                    } else {
                        this.veil$uniforms.put(name, uniform = new class_284(name, minecraftType, minecraftCount, this));
                    }

                    IntBuffer intBuffer = uniform.method_35663();
                    if (intBuffer != null) {
                        MemoryUtil.memSet(intBuffer, 0);
                    }

                    FloatBuffer floatBuffer = uniform.method_35664();
                    if (floatBuffer != null) {
                        MemoryUtil.memSet(floatBuffer, Float.floatToIntBits(0.0F));
                    }

                    this.uniformLocations.add(location);
                    uniform.method_1297(location);
                    this.uniformMap.put(name, uniform);
                }
            }
        }

        // Clean up invalid uniforms
        this.veil$uniforms.values().removeIf(uniform -> uniform.method_35660() == -1);
    }
}
