import { Sprite, Vector2 } from "@babylonjs/core";
import { SpriteManager } from "@babylonjs/core/Sprites/spriteManager";
import { Vector3 } from "@babylonjs/core/Maths/math.vector";
import { Conn, Layer, Neuron, Similarity } from "./domain";
import { neuronColor } from "../color";

export const createLayer =
    (dim: Vector2, padding: number, margin: number, spriteManager: SpriteManager, norm: number = 1) => {

        const arr = new Array<Neuron>(dim.x * dim.y)

        let index = 0;

        let weightsStruct

        const sigmoid = (x) => 1 / (1 + Math.exp(-x));

        const centerCompensate = new Vector3(-margin * (dim.y - 1) / 2, -margin * (dim.x - 1) / 2, padding);

        //debugger;

        for (let i = 0; i < dim.x; i++) {
            for (let j = 0; j < dim.y; j++) {

                const position = new Vector3(margin * dim.y - j * margin, margin * dim.x - i * margin, 0).add(centerCompensate)
                const sprite = new Sprite("neuron", spriteManager)
                sprite.position = position

                sprite.size = 0.07;

                arr[index] = {
                    index,
                    position: position,
                    sprite,
                    weights: null,
                    currentValue: 0,
                    force: new Vector3(0, 0, 0),
                    velocity: new Vector3(Math.random() - 0.5, Math.random() - 0.5, 0).scaleInPlace(4.4),
                    //velocity: new Vector3(0, 0, 0),

                }

                index++

            }

        }

        const M = dim.x * dim.y;
        const similarityMatrix = new Array<Similarity[]>(M)
            .fill(null)
            .map(() => new Array<Similarity>(M)
                .fill(null)
                .map(() => 0)
            )

        const calculateSimilarityEukl = (wLength) => {



            weightsStruct = arr.map(n => n.weights);

            for (let i = 0; i < M; i++) {
                const indices = new Array(wLength).fill(0).map((e, i) => i);
                indices.sort(function (a, b) {
                    return weightsStruct[i][b] - weightsStruct[i][a]
                })
            }

            for (let i = 0; i < M; i++) {
                for (let j = 0; j < M; j++) {
                    let d = 0
                    const ta = new Array(wLength).fill(0);

                    for (let w = 0; w < wLength; w++) {
                        ta[w] = (weightsStruct[i][w] - weightsStruct[j][w]) ** 2;
                    }
                    //
                    d = ta.reduce((a, b) => a + b)
                    similarityMatrix[i][j] = 1 / (d + 0.001)

                }

            }



        }
        const calculateSimilarityDom = (wLength) => {



            weightsStruct = arr.map(n => n.weights);

            const preferredOutput1 = new Array(wLength).fill(0);
            const preferredOutput2 = new Array(wLength).fill(0);

            for (let i = 0; i < M; i++) {

                var indices = new Array(wLength).fill(0).map((e, i) => i);
                indices.sort(function (a, b) {
                    return weightsStruct[i][b] - weightsStruct[i][a]
                })

                preferredOutput1[i] = indices[0]
                preferredOutput1[i] = indices[1]
            }

            for (let i = 0; i < M; i++) {
                for (let j = 0; j < M; j++) {

                    const ta = new Array(wLength).fill(0);

                   //

                    //similarityMatrix[i][j] = 1 / (d + 0.001)
                    similarityMatrix[i][j] = 0;
                    similarityMatrix[i][j] += preferredOutput1[i] == preferredOutput1[j] ? 0.6 : 0;


                }

            }

        }

        const setWeights = (w: Float32Array[]) => {

            for (let i = 0; i < dim.x * dim.y; i++) {
                arr[i].weights = w[i]
            }

        }

        const setValues = (v: ArrayLike<number>) => {

            for (let i = 0; i < dim.x * dim.y; i++) {
                arr[i].currentValue = v[i];

                arr[i].sprite.color = neuronColor(v[i] / norm)

                // arr[i].sprite.color.a = 1;
                // arr[i].sprite.color.r = v[i] / norm;
                // arr[i].sprite.color.g = v[i] / norm;
                // arr[i].sprite.color.b = v[i] / norm;
            }

        }

        const calculateForce = (dt: number, k: number, d: number, c: number) => {

            //console.time("Calculate Force");
            for (let i = 0; i < M; i++) {

                arr[i].force.x = 0;
                arr[i].force.y = 0;
                arr[i].force.z = 0;

                for (let j = 0; j < M; j++) {
                    if (i != j) {

                        const diff = arr[j].position.subtract(arr[i].position)

                        const p = diff.scale(1.71 * similarityMatrix[i][j])
                        //
                        arr[i].force.addInPlace(p)

                        const q = diff.length() < 7.5 ? -1 : 1;

                        arr[i].force.addInPlace(diff.clone().normalize().scaleInPlace(0.1 * d * q / (diff.lengthSquared() + 0.1)))

                    }

                }

                const cF = arr[i].position.subtract(centerCompensate)

                arr[i].force.addInPlace(cF.clone().normalize().scale(-2 * c * cF.length() ** 0.76));

                arr[i].velocity.addInPlace(arr[i].force.scale(dt))
                arr[i].velocity.scaleInPlace(0.94);
                arr[i].sprite.position.addInPlace(arr[i].velocity.scale(dt))

            }

            for (let i = 0; i < M; i++) {
                arr[i].position = arr[i].sprite.position;
            }

            //console.timeEnd("Calculate Force");

        }

        const getNeuronPositionByIndex: (i: number) => Vector3 = (i: number) => {

            return arr[i].sprite.position;
        }

        const forwardValues = (outputLayer: Layer) => {

            const M = outputLayer.v.length
            const o = new Array(M).fill(0);

            const cons = new Array<Conn>(M * dim.x * dim.x);

            for (let i = 0; i < dim.x * dim.x; i++) {

                for (let w = 0; w < M; w++) {

                    if (!arr[i].weights) {
                        debugger

                    }

                    const conn: Conn = {
                        value: arr[i].currentValue * arr[i].weights[w],
                        normalizedValue: null,
                        positions: [arr[i].sprite.position, outputLayer.getNeuronPositionByIndex(w)]
                    };

                    cons[i * M + w] = conn

                    o[w] += conn.value;
                }

            }

            const output = o.map(x => sigmoid(x))

            //o.sort((a, b) => a - b)

            return {
                connections: cons,
                //values: o.map(e => (e - bounds[0])/d)
                values: output
            }
        }

        return {
            v: arr,
            setValues,
            forwardValues,
            setWeights,
            getNeuronPositionByIndex,
            calculateSimilarityDom,
            calculateSimilarityEukl,

            calculateForce

        }

    }


