import { TextureSource } from "../../../types";
import { texture } from "../../../utils/webgl";
import { Blur, BLUR_TYPE } from "../programs/blur/blur";
import { JointBilateral } from "../programs/joint-bilateral/joint-bilateral";
import { WebglPipeline } from "../webgl-pipeline";

/**
 * @internal
 */
export interface FlickeringOptions {
    currentThresholdMinimum: number;
    currentThresholdMaximum: number;
    previousThresholdMinimum: number;
    previousThresholdMaximum: number;
    thresholdMinimum: number;
    thresholdMaximum: number;
    thresholdFactor: number;
}

export class ImproveSegmentationMaskPipeline extends WebglPipeline {
    public inputImage?: ImageBitmap;
    public inputMask?: ImageBitmap;
    public previousInputImage?: ImageBitmap;
    public previousInputMask?: ImageBitmap;
    public output: TextureSource | WebGLTexture;

    private jointBilateral: JointBilateral;

    constructor(
        private readonly context: WebGLRenderingContext,
        width: number,
        height: number
    ) {
        super();

        const defaultOptions = {
            context,
            height,
            width,
        };

        const blur = new Blur({
            ...defaultOptions,
            radius: 1,
            type: BLUR_TYPE.GAUSSIAN,
        });
        this.jointBilateral = new JointBilateral(defaultOptions);

        this.addStep({
            program: blur,
            getUniforms: () => {
                return {
                    inputs: texture(context, this.inputMask),
                };
            },
        });

        this.addStep({
            program: this.jointBilateral,
            getUniforms: () => {
                const kSparsityFactor = 0.66; // Higher is more sparse.
                const outputWidth = this.inputImage?.width ?? 1;
                const outputHeight = this.inputImage?.height ?? 1;
                const segmentationWidth = this.inputMask?.width ?? 1;
                const segmentationHeight = this.inputMask?.height ?? 1;
                // From https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html
                // Filter size: Large filters (d > 5) are very slow, so it is recommended to use d=5 for real-time applications, and perhaps d=9 for offline applications that need heavy noise filtering.
                const sigmaSpace = 3;
                const sparsity = Math.max(
                    1,
                    Math.sqrt(sigmaSpace) * kSparsityFactor
                );
                const texelWidth = 1 / outputWidth;
                const texelHeight = 1 / outputHeight;
                return {
                    input_frame: texture(context, this.inputImage),
                    segmentation_mask: texture(context, this.inputMask),
                    texel_size: [texelWidth, texelHeight],
                    step: sparsity,
                    radius: sigmaSpace,
                    offset: sparsity > 1 ? sparsity * 0.5 : 0,
                    sigma_texel: Math.max(texelWidth, texelHeight) * sigmaSpace,
                    sigma_color: sigmaSpace, // Sigma values: For simplicity, you can set the 2 sigma values to be the same - From https://docs.opencv.org/4.x/d4/d86/group__imgproc__filter.html
                };
            },
        });

        this.output = this.jointBilateral.output;
    }

    public setData(image?: ImageBitmap, mask?: ImageBitmap) {
        this.previousInputImage = this.inputImage ? this.inputImage : image;
        this.previousInputMask = this.inputMask ? this.inputMask : mask;
        this.inputImage = image;
        this.inputMask = mask;
    }

    public resizeOutput(width: number, height: number): void {
        super.resizeOutput(width, height);
        this.output = this.jointBilateral.output;
    }
}
