package foundry.veil.api.client.registry;

import foundry.veil.api.client.render.rendertype.VeilRenderType;
import foundry.veil.ext.CompositeStateExtension;
import org.jetbrains.annotations.ApiStatus;

import java.util.*;
import java.util.function.Predicate;
import net.minecraft.class_1921;
import net.minecraft.class_4668;

/**
 * <p>This allows custom render type shards to be registered. This allows custom code to be run for the setup and clear state of any render type.
 * {@link RenderTypeShardRegistry#addGenericShard(Predicate, class_4668...)} also allows arbitrary injection into any render type created.</p>
 * <p><strong>This should be called during mod construction/init.</strong></p>
 *
 * @author Ocelot
 */
public final class RenderTypeShardRegistry {

    private static final Map<String, List<class_4668>> SHARDS = new HashMap<>();
    private static final Set<GenericShard> GENERIC_SHARDS = new HashSet<>();
    private static final Set<class_1921.class_4687> CREATED_RENDER_TYPES = new HashSet<>();

    private RenderTypeShardRegistry() {
    }

    /**
     * Registers a render stage. The specified shards will only be added to the specific render type.
     *
     * @param renderType The render type to add the stage to
     * @param shards     The shards to add to all matching render types
     */
    public static synchronized void addShard(class_1921 renderType, class_4668... shards) {
        if (shards.length == 0) {
            return;
        }
        if (!(renderType instanceof class_1921.class_4687 compositeRenderType)) {
            throw new IllegalArgumentException("RenderType must be CompositeRenderType");
        }
        ((CompositeStateExtension) (Object) compositeRenderType.method_35784()).veil$addShards(Arrays.asList(shards));
    }

    /**
     * Registers a render stage. The specified shards will be added to the specified render type during construction.
     *
     * @param name   The name of the render type to add the stage to
     * @param shards The shards to add to all matching render types
     */
    public static synchronized void addShard(String name, class_4668... shards) {
        if (shards.length == 0) {
            throw new IllegalArgumentException("No shards provided");
        }
        List<class_4668> newShards = Arrays.asList(shards);
        SHARDS.computeIfAbsent(name, unused -> new ArrayList<>()).addAll(newShards);

        for (class_1921.class_4687 renderType : CREATED_RENDER_TYPES) {
            if (name.equals(VeilRenderType.getName(renderType))) {
                ((CompositeStateExtension) (Object) renderType.method_35784()).veil$addShards(newShards);
            }
        }
    }

    /**
     * Registers a render stage. The specified shards will be added to all render types that match the specified filter during construction.
     *
     * @param filter The filter for what render types to add the stage to
     * @param shards The shards to add to all matching render types
     */
    public static synchronized void addGenericShard(Predicate<class_1921.class_4687> filter, class_4668... shards) {
        if (shards.length == 0) {
            throw new IllegalArgumentException("No shards provided");
        }
        GENERIC_SHARDS.add(new GenericShard(filter, shards));

        for (class_1921.class_4687 renderType : CREATED_RENDER_TYPES) {
            if (filter.test(renderType)) {
                ((CompositeStateExtension) (Object) renderType.method_35784()).veil$addShards(Arrays.asList(shards));
            }
        }
    }

    // Implementation

    @SuppressWarnings({"UnreachableCode", "DataFlowIssue"})
    @ApiStatus.Internal
    public static void inject(class_1921.class_4687 renderType) {
        List<class_4668> shards = SHARDS.get(VeilRenderType.getName(renderType));
        if (shards != null) {
            shards = new ArrayList<>(shards);
        }
        for (GenericShard stage : GENERIC_SHARDS) {
            if (stage.filter.test(renderType)) {
                if (shards == null) {
                    shards = new ArrayList<>(Arrays.asList(stage.shards));
                    continue;
                }

                shards.addAll(Arrays.asList(stage.shards));
            }
        }

        if (shards != null) {
            ((CompositeStateExtension) (Object) renderType.method_35784()).veil$addShards(shards);
        }
        CREATED_RENDER_TYPES.add(renderType);
    }

    private record GenericShard(Predicate<class_1921.class_4687> filter, class_4668[] shards) {
    }
}
