package foundry.veil.impl.client.render.shader.modifier;

import foundry.veil.impl.client.render.shader.transformer.VeilJobParameters;
import io.github.douira.glsl_transformer.ast.node.TranslationUnit;
import io.github.douira.glsl_transformer.ast.node.declaration.DeclarationMember;
import io.github.douira.glsl_transformer.ast.node.external_declaration.ExternalDeclaration;
import io.github.douira.glsl_transformer.ast.node.type.specifier.BuiltinNumericTypeSpecifier;
import io.github.douira.glsl_transformer.ast.node.type.specifier.TypeSpecifier;
import io.github.douira.glsl_transformer.ast.print.ASTPrinter;
import io.github.douira.glsl_transformer.ast.query.Root;
import io.github.douira.glsl_transformer.ast.query.match.Matcher;
import io.github.douira.glsl_transformer.ast.transform.ASTInjectionPoint;
import io.github.douira.glsl_transformer.ast.transform.ASTParser;
import io.github.douira.glsl_transformer.ast.traversal.ASTListener;
import io.github.douira.glsl_transformer.ast.traversal.ASTWalker;
import io.github.douira.glsl_transformer.parser.ParseShape;
import it.unimi.dsi.fastutil.ints.Int2ObjectArrayMap;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import net.minecraft.class_2960;

@ApiStatus.Internal
public class VertexShaderModification extends SimpleShaderModification {

    public static final Matcher<ExternalDeclaration> INPUT = new Matcher<>("in type name;", ParseShape.EXTERNAL_DECLARATION) {
        {
            Root root = this.pattern.getRoot();
            this.markClassWildcard("type", root.identifierIndex.getUnique("type").getAncestor(TypeSpecifier.class), BuiltinNumericTypeSpecifier.class);
            this.markClassWildcard("name*", root.identifierIndex.getUnique("name").getAncestor(DeclarationMember.class));
        }
    };

    private final Attribute[] attributes;
    private final Map<String, String> mapper;

    public VertexShaderModification(int version, int priority, class_2960[] includes, @Nullable String output, @Nullable String uniform, Function[] functions, Attribute[] attributes) {
        super(version, priority, includes, output, uniform, functions);
        this.attributes = attributes;
        this.mapper = new HashMap<>(this.attributes.length);
    }

    @Override
    public void inject(ASTParser parser, TranslationUnit tree, VeilJobParameters parameters) throws IOException {
        if (this.attributes.length > 0) {
            Map<Integer, Attribute> validInputs = new Int2ObjectArrayMap<>();

            Root root = tree.getRoot();

            root.processMatches(parser, tree.getChildren().stream().filter(dec -> dec.hasAncestor(INPUT.getPatternClass())), INPUT, externalDeclaration -> {
                String[] parts = {null, null};
                ASTWalker.walk(new ASTListener() {
                    @Override
                    public void enterTypeSpecifier(TypeSpecifier node) {
                        parts[0] = ASTPrinter.printSimple(node);
                    }

                    @Override
                    public void enterDeclarationMember(DeclarationMember node) {
                        parts[1] = node.getName().getName();
                    }
                }, externalDeclaration);
                validInputs.put(validInputs.size(), new Attribute(validInputs.size(), parts[0], parts[1]));
            });

            this.mapper.clear();
            for (Attribute attribute : this.attributes) {
                Attribute sourceAttribute = validInputs.get(attribute.index);
                if (sourceAttribute == null) {
                    // TODO this might be messed up on mac. It needs to be tested
                    tree.parseAndInjectNode(parser, ASTInjectionPoint.BEFORE_DECLARATIONS, "layout(location = " + attribute.index + ") in " + attribute.type + " " + attribute.name + ";");
                    this.mapper.put(attribute.name, attribute.name);
                    continue;
                }

                if (!sourceAttribute.type.equals(attribute.type)) {
                    throw new IOException("Expected attribute " + attribute.index + " to be " + attribute.type + " but was " + sourceAttribute.type);
                }

                this.mapper.put(attribute.name, sourceAttribute.name);
            }
        }

        super.inject(parser, tree, parameters);
    }

    @Override
    protected String getPlaceholder(String key) {
        String name = this.mapper.get(key);
        return name != null ? name : super.getPlaceholder(key);
    }

    public record Attribute(int index, String type, String name) {
    }
}
