import db from "../../Firestore"
import * as tf from "@tensorflow/tfjs"
import * as palette from "../../components/symbols/palette"
import _ from "lodash"

// Given a Node, what types of Edges go out of the Node?
// const ML_MODEL_NODE_TO_EDGE_OUT = "NodeToEdgeOut"
// const ML_MODEL_NODE_TO_EDGE_OUT_ENCODER = "NodeToEdgeOutEncoder"

// Given a Node, what types of Edges go into the Node?
//const ML_MODEL_NODE_TO_EDGE_IN = "NodeToEdgeIn"

// Given a Node and Edge, what Node is at the end of the Edge?
//const ML_MODEL_NODE_AND_EDGE_TO_NODE_OUT = "NodeAndEdgeToNodeOut"

// Given 2 Nodes, what edge(s) should join them?
const ML_MODEL_EDGE_PREDICTION = "EdgePrediction"
const ML_MODEL_EDGE_PREDICTION_ENCODER = "EdgePredictionEncoder"

const ML_MODEL_NODE_PREDICTION = "NodePrediction"
const ML_MODEL_NODE_PREDICTION_ENCODER = "NodePredictionEncoder"

const loadTfLayersModel = async (type, accountId) => {
    console.log("%cloadTfLayersModel", "color:lightgreen", { type, accountId })
    return loadMLModel(type, accountId).then(async ({ modelInfo, meta }) => {
        if (!modelInfo) {
            return undefined
        }

        return {
            modelInfo,
            tfModel: await tf.loadLayersModel({
                load: () => {
                    //console.log("loading model", { modelInfo })
                    return modelInfo
                },
            }),
            meta,
        }
    })
}

/**
 * Get the data for reinstantiating a TensorFlow tf.Model
 *
 * @param {*} type
 * @param {*} accountId
 * @returns
 */
const loadMLModel = async (type, accountId) => {
    const loadedModelInfo = await loadMLModelData(type, accountId)

    // console.log("%cfound models", "color:lightgreen", {
    //     count: loadedModelInfo.length,
    //     loadedModelInfo,
    // })

    if (loadedModelInfo.length === 0) {
        return undefined
    }

    const firstRec = loadedModelInfo[0]

    const tfModelInfo = {
        ...firstRec.tf_model_info,
        weightData: base64ToArrayBuffer(firstRec.tf_model_info.weightData),
    }

    const result = { modelInfo: tfModelInfo, meta: firstRec.train_meta }
    //console.log("%cloadMLModel result", "color:lightgreen", result)
    return result
}

const loadMLModelData = async (type, accountId) => {
    //console.log("%cload ML data", "color:lightgreen", { type, accountId })
    const mlModelData = await db
        .collection("ml_models")
        .where("account_id", "==", accountId)
        .where("type", "==", type)
        .get()
        .then((querySnapshot) => {
            return querySnapshot.docs.map((doc) => ({
                id: doc.id,
                ...doc.data(),
            }))
        })

    return mlModelData
}

const arrayBufferToBase64 = (buffer) => {
    var str = ""
    var bytes = new Uint8Array(buffer)
    var len = bytes.byteLength
    for (var i = 0; i < len; i++) {
        str += String.fromCharCode(bytes[i])
    }
    return window.btoa(str)
}

const base64ToArrayBuffer = (base64) => {
    var binary_string = window.atob(base64)
    var len = binary_string.length
    var bytes = new Uint8Array(len)
    for (var i = 0; i < len; i++) {
        bytes[i] = binary_string.charCodeAt(i)
    }
    return bytes.buffer
}

const getElementVecTfBuf = (elementTypeName, elementTypes) => {
    const elementBuf = tf.buffer([1, elementTypes.length])

    const elementType = palette.getElementType(elementTypeName)
    if (elementTypes.includes(elementType.name)) {
        const index = elementTypes.indexOf(elementTypeName)
        elementBuf.set(1, 0, index)
    } else {
        // console.log("%celement type not found", "color:orange", {
        //     elementTypes,
        //     elementType,
        //     elementTypeName,
        // })
    }
    //console.log("%cgetElementVec", "color:lightblue", { elementTypeName, elementTypes, elementBuf })

    return elementBuf
}

const getNodeEmbedding = (element, types) => {
    return tf.tidy(() => {
        const elementTypeVec = getElementVecTfBuf(element.type, types)

        //console.log("%c** ELEMENT", "color:lightgreen", element.type)

        // const sourceElements = element.edges.in
        //     .map((edge) => edge.sourceNode())
        //     .filter((item) => item !== undefined)

        // const targetElements = element.edges.out
        //     .map((edge) => edge.targetNode())
        //     .filter((item) => item !== undefined)

        // const sourceVecs = sourceElements.map((element) => getElementVecTfBuf(element.type, types))
        // const targetVecs = targetElements.map((element) => getElementVecTfBuf(element.type, types))

        // const tensors = [
        //     elementTypeVec.toTensor(),
        //     elementTypeVec.toTensor(),
        //     elementTypeVec.toTensor(),
        //     elementTypeVec.toTensor(),
        //     ...sourceVecs.map((vec) => vec.toTensor()),
        //     ...targetVecs.map((vec) => vec.toTensor()),
        // ]

        //const avg = tensors.length === 1 ? tensors[0] : tf.layers.average().apply(tensors)

        //const convResult = avg.dataSync()

        //console.log("%creturn", "color:chartreuse", { values: elementTypeVec.values, convResult })

        return elementTypeVec.values
        //return convResult
    })
}

// We don't want to train on the 'Association' connector since it's a bit weak and not really something
// we can make recommendations on

const getConnectorTypes = () => {
    //console.log("%cgetConnectorTypes", "color:lightblue", { connectorTypes })
    return palette.getConnectorTypes()
}

// https://towardsdatascience.com/using-null-samples-to-shape-decision-spaces-and-defend-against-adversarial-attacks-3ecd16b6596c
const getConnectorTypesWithNone = () => {
    return getConnectorTypes().concat(["None"])
}

const recommendEdge = (tfModel, tfEncoderModel, sourceType, modelElementTypes) => {
    return tf.tidy(() => {
        const n = getElementVecTfBuf(sourceType, modelElementTypes)

        const t_n = tf.tensor(n.values, [1, modelElementTypes.length])

        const xs = tf.concat([t_n], 1)

        const encodedXs = tfEncoderModel.predict(xs)

        const preds = tfModel.predict(encodedXs)

        const { values, indices } = tf.topk(preds, 5)

        const topVals = values.dataSync()
        const topIndices = indices.dataSync()

        const connectorTypesWithNone = getConnectorTypesWithNone()

        // Get values above 10%
        const topValues = [...topVals]
            .map((val, index) => {
                return {
                    match: val,
                    index: topIndices[index],
                    type: connectorTypesWithNone[topIndices[index]],
                }
            })
            .filter((item) => item.match > 0.1)

        //const maxIndex = preds.as1D().argMax().dataSync()[0]

        // console.log("prediction(s)", {
        //     maxIndex,
        //     sourceType,
        //     target: topValues.map((item) => `${item.type} ${(item.match * 100).toFixed(0)}%`),
        // })

        //tf.dispose(t_n)

        return topValues.map((item) => ({ type: item.type, match: item.match }))
    })
}

/**
 * @param {*} tfModel
 * @param {*} sourceType
 * @param {*} targetType
 * @param {*} modelElementTypes
 * @param {*} tfEncoderModel Optional encoder to use on embedding before making prediction
 * @returns
 */
const recommendEdgeBetweenNodes = (
    tfModel,
    sourceType,
    targetType,
    modelElementTypes,
    tfEncoderModel
) => {
    return tf.tidy(() => {
        const sourceBuf = getElementVecTfBuf(sourceType, modelElementTypes)
        const targetBuf = getElementVecTfBuf(targetType, modelElementTypes)

        const t_source = tf.tensor(sourceBuf.values, [1, modelElementTypes.length])
        const t_target = tf.tensor(targetBuf.values, [1, modelElementTypes.length])

        const xs = tf.concat([t_source, t_target], 1)

        const inputXs = tfEncoderModel ? tfEncoderModel.predict(xs) : xs

        const preds = tfModel.predict(inputXs)

        const { values, indices } = tf.topk(preds, 3)

        const topVals = values.dataSync()
        const topIndices = indices.dataSync()

        const connectorTypesWithNone = getConnectorTypesWithNone()
        // console.log("%crecommendEdgeBetweenNodes", "color:lightblue", {
        //     topVals,
        //     topIndices,
        //     connectorTypesWithNone,
        // })

        const MIN_RECOMMENDATION_RATE = 0.4 // must be 40% or higher for recommendation

        const topValues = [...topVals]
            .map((val, index) => {
                return {
                    match: val,
                    index: topIndices[index],
                    type: connectorTypesWithNone[topIndices[index]],
                }
            })
            .filter((item) => item.match > MIN_RECOMMENDATION_RATE)

        return topValues
    })
}

const createViewEmbeddings = async ({ modelCache }) => {}

const findClosestMatch = ({ searchItem, items }) => {
    const threshold = 0.7
    const cosineDistanceCalc = tf.layers.dot({ axes: -1, normalize: true })

    let results = []

    console.log("%cfinding closest matches", "color:yellow", { searchItem, items })

    tf.tidy(() => {
        const searchTensors = tf.tensor2d(items.map((item) => item.embedding))
        const tSearchItemTiled = tf
            .tensor(searchItem.embedding)
            .reshape([1, -1])
            .tile([searchTensors.shape[0], 1])

        const cosineSimilarities = cosineDistanceCalc
            .apply([searchTensors, tSearchItemTiled])
            .reshape([searchTensors.shape[0]])

        //const maxTopK = _.clamp(cosineSimilarities.shape[0], 1, 10)
        const maxTopK = _.clamp(searchTensors.shape[0], 1, 10);

        const { values, indices } = tf.topk(cosineSimilarities, maxTopK)

        const valuesArray = values.arraySync()
        const indicesArray = indices.arraySync()

        console.log("findClosestMatch", { valuesArray, indicesArray })

        for (let i = 0; i < valuesArray.length; i++) {

            console.log("findClosestMatch", { i, valuesArray, indicesArray })
            // If the cosine similarity score is below the threshold, skip this item
            if (valuesArray[i] < threshold) continue

            results.push({
                ...items[indicesArray[i]],
                score: valuesArray[i],
            })
        }
    })

    return results
}


// Euclidean distance version
// const findClosestMatch = ({ searchItem, items, threshold = 0.7 }) => {
//     let results = [];

//     tf.tidy(() => {
//         const searchTensors = tf.tensor2d(items.map(item => item.embedding));
//         const tSearchItemTiled = tf.tensor(searchItem.embedding)
//                                  .reshape([1, -1])
//                                  .tile([searchTensors.shape[0], 1]);

//         // Calculate Euclidean distances
//         const distances = tf.sqrt(tf.sum(tf.squaredDifference(searchTensors, tSearchItemTiled), 1));

//         // The smaller the Euclidean distance, the more similar the items are. 
//         // We use topk to get the smallest distances and their indices.
//         const { values, indices } = tf.topk(distances.neg(), 10); // negating the distances to get smallest values

//         const valuesArray = values.arraySync();
//         const indicesArray = indices.arraySync();

//         for (let i = 0; i < valuesArray.length; i++) {
//             // If the Euclidean distance is above the threshold, skip this item
//             if (-valuesArray[i] > threshold) continue; // negate the values back to get the original distances

//             results.push({
//                 ...items[indicesArray[i]],
//                 score: -valuesArray[i] // score is now a distance, not similarity
//             });
//         }
//     });

//     return results;
// };


// const findClosestMatch = ({ searchItem, items }) => {
//     const cosineDistanceCalc = tf.layers.dot({ axes: -1, normalize: true })

//     // Items to search
//     const results = []

//     //console.log("%cfinding closest matches", "color:yellow", { searchItem, items })

//     tf.tidy(() => {
//         const searchTensors = items.map((item) => item.embedding)

//         const tSearchData = tf.tensor2d(searchTensors)

//         const tSearchItem = tf.tensor(searchItem.embedding)
//         const tSearchItem_reshaped = tSearchItem.reshape([1, -1])
//         const tSearchItemTiled = tSearchItem_reshaped.tile([searchTensors.length, 1])

//         const cosineDistance = cosineDistanceCalc.apply([tSearchData, tSearchItemTiled])

//         const reshaped = cosineDistance.reshape([cosineDistance.shape[1], cosineDistance.shape[0]])

//         const maxTopK = _.clamp(reshaped.shape[1], 1, 10)

//         const { values, indices } = tf.topk(reshaped, maxTopK)

//         const indiceVals = indices.dataSync()

//         //console.log("indiceVals", { values, indiceVals })

//         indiceVals.forEach((indice, index) => {
//             const item = items[indice]
//             const result = {
//                 ...item,
//                 score: values.dataSync()[index],
//             }
//             results.push(result)
//         })
//     })

//     return results
// }

/**
 *
 * @param {*} sourceView
 * @param {*} model
 * @param {*} cachedModels
 * @returns
 */
const findSimilarViews = (sourceView, model, cachedModels) => {
    const models = Object.values(cachedModels)
    const elementTypes = palette.ELEMENT_INDEX.map((item) => item.name)

    const sourceEmbedding = getViewEmbedding(sourceView, model, elementTypes)

    const viewEmbeddings = getViewEmbeddings(models, elementTypes)
    // console.log("%cfindSimilarViews:%cviewEmbeddings", "color:lightblue", "color:lightgreen", {
    //     viewEmbeddings,
    // })
    const tView2D = tf.tensor2d(viewEmbeddings)

    const tViewToFindSimilar = tf.tensor([sourceEmbedding])

    const viewsAndModel = _.flatten(
        models.map((model) => model.model.views.map((view) => ({ model, view })))
    ).filter((item) => item.view.elements.length > 0)
    // console.log("%cfindSimilarViews:%cviewsAndModel", "color:lightblue", "color:lightgreen", {
    //     viewsAndModel,
    // })

    const repeated = tViewToFindSimilar.tile([viewsAndModel.length, 1])

    // normalize: true => gives the cosine distance
    const cosineDistanceCalc = tf.layers.dot({ axes: -1, normalize: true })
    const cosineDistance = cosineDistanceCalc.apply([tView2D, repeated])

    const reshaped = cosineDistance.reshape([cosineDistance.shape[1], cosineDistance.shape[0]])

    const maxTopK = _.clamp(reshaped.shape[1], 1, 10)

    const { values, indices } = tf.topk(reshaped, maxTopK)

    const indiceVals = indices.dataSync()
    const results = []

    indiceVals.forEach((indice, index) => {
        const item = viewsAndModel[indice]
        const result = {
            parent_id: item.model.parent_id,
            name: item.model.name,
            type: item.model.type,
            view: item.view,
            file: item.model.model.file,
            score: values.dataSync()[index],
        }

        if (result.view.id !== sourceView.id) {
            results.push(result)
        }
    })

    return results
}

const getViewElementEmbeddings = (view, model, elementTypes) => {
    const viewElementEmbeddings = view.elements
        .filter((viewElement) => viewElement.diagramObject.archimateElement)
        .map((viewElement) => {
            const element = model.model.elements.find(
                (element) => element.id === viewElement.diagramObject.archimateElement
            )

            const elementEmbedding = getElementVecTfBuf(
                element.type.replace("Relationship", ""),
                elementTypes
            )

            return elementEmbedding
        })
        // filter out connectors that we detected above (by returning undefined - above)
        .filter((item) => item !== undefined)

    return viewElementEmbeddings
}

const getViewEmbedding = (view, model, elementTypes) => {
    const viewElementEmbeddings = getViewElementEmbeddings(view, model, elementTypes)

    const viewEmbedding = tf.tidy(() => {
        const viewTensors = viewElementEmbeddings.map((tbuf) => tbuf.toTensor())

        //console.log("%cviewTensors", "color:pink", viewTensors)

        if (viewTensors.length === 0) {
            return tf.zeros([1, elementTypes.length]).dataSync()
        }
        const avg =
            viewTensors.length === 1 ? viewTensors[0] : tf.layers.average().apply(viewTensors)
        return avg.dataSync()
    })

    return viewEmbedding
}
const getViewEmbeddings = (cachedModels, elementTypes) => {
    // console.log("%cgetViewEmbeddings", "color:lightblue", {
    //     cachedModels,
    //     elementTypes,
    // })

    const viewsAndModel = _.flatten(
        cachedModels.map((model) => {
            // console.log("%cgetViewEmbeddings:%cmodel", "color:lightblue", "color:lightgreen", {
            //     model,
            // })
            return model.model.views.map((view) => ({ model, view }))
        })
    ).filter((item) => item.view.elements.length > 0)

    const viewEmbeddings = viewsAndModel.map((item) => {
        return getViewEmbedding(item.view, item.model, elementTypes)
    })
    return viewEmbeddings
}

const getEdgeIndexFromType = (cnxTypes, relType) => {
    const result = cnxTypes.findIndex((item) => item === relType)

    if (result === -1) {
        console.error(`Cannot find ${relType} in cnxTypes`, cnxTypes)
    }
    return result
}

const getEdgeVecTfBufFromType = (relType, cnxTypes) => {
    const baseType = relType.replace("Relationship", "")
    const edgeIndex = getEdgeIndexFromType(cnxTypes, baseType)

    //console.log('getEdgeVecTfBufFromType', { relType, baseType, edgeIndex, cnxTypes })

    // Add 1 at the end to allow for a 'None' connector type, i.e. there is no connector.
    // This is to help train that there should not be a connection type between 2 element types
    const buf = tf.buffer([1, cnxTypes.length])
    buf.set(1, 0, edgeIndex)
    return buf
}

const recommendTargetNode = (tfModel, tfEncoderModel, sourceType, edgeType, modelElementTypes) => {
    const modelElementTypesWithNone = ["None", ...modelElementTypes]

    return tf.tidy(() => {
        const n = getElementVecTfBuf(sourceType, modelElementTypes)
        const e = getEdgeVecTfBufFromType(edgeType, getConnectorTypes())

        const t_n = tf.tensor(n.values, [1, modelElementTypes.length])
        const t_e = tf.tensor(e.values, [1, getConnectorTypes().length])

        const xs = tf.concat([t_n, t_e], 1)

        const encodedXs = tfEncoderModel.predict(xs)

        const preds = tfModel.predict(encodedXs)

        const { values, indices } = tf.topk(preds, 3)

        const topVals = values.dataSync()
        const topIndices = indices.dataSync()

        // Get values above 10%
        const topValues = [...topVals]
            .map((val, index) => {
                return {
                    match: val,
                    index: topIndices[index],
                    type: modelElementTypesWithNone[topIndices[index]],
                }
            })
            .filter((item) => item.match > 0.1)

        return topValues.map((item) => ({ type: item.type, match: item.match }))
    })
}

export {
    arrayBufferToBase64,
    loadMLModel,
    loadMLModelData,
    loadTfLayersModel,
    getNodeEmbedding,
    getEdgeVecTfBufFromType,
    getEdgeIndexFromType,
    getElementVecTfBuf,
    recommendEdge,
    recommendTargetNode,
    recommendEdgeBetweenNodes,
    getConnectorTypes,
    getConnectorTypesWithNone,
    getViewElementEmbeddings as getViewEmbedding,
    createViewEmbeddings,
    getViewEmbeddings,
    findClosestMatch,
    findSimilarViews,
    //ML_MODEL_NODE_TO_EDGE_OUT,
    //ML_MODEL_NODE_TO_EDGE_OUT_ENCODER,
    //ML_MODEL_NODE_TO_EDGE_IN,
    //ML_MODEL_NODE_AND_EDGE_TO_NODE_OUT,
    ML_MODEL_EDGE_PREDICTION,
    ML_MODEL_EDGE_PREDICTION_ENCODER,
    ML_MODEL_NODE_PREDICTION,
    ML_MODEL_NODE_PREDICTION_ENCODER,
}
