import _ from "lodash"
import * as rules from "./rules"
import * as mlServices from "./mlServices"
import * as modelServices from "./modelServices"
import { setMlModels } from "../..//redux/actions"
import * as palette from "../../components/symbols/palette"

// Load and run rules
const runRules = async (model, selectedRules, mlModels, mlRules, accountId, dispatch) => {
    const allMsgs = runAllRules(model, selectedRules)
    //const allMlMsgs = runMlRules(model, mlModels, accountId, dispatch)

    const promises = [allMsgs]

    if (mlRules) {
        const edgeMsgs = runEdgePredictionRules(model, mlModels, accountId, dispatch)
        promises.push(edgeMsgs)

        const nodeMsgs = runNodePredictionRules(model, mlModels, accountId, dispatch)
        promises.push(nodeMsgs)
    }

    const results = await Promise.all(promises)

    console.log("%callMsgs", "color:lightgreen", { results, model })

    return _.flatten(results)
}

const runAllRules = async (model, selectedRules) => {
    const allMsgs = _.flatten(
        selectedRules.map((rule) => {
            const msgs = []

            if (rule.rule_data.source_to_target) {
                const r = new rules.SourceTargetRule2(
                    rule.id,
                    rule.name,
                    "No category",
                    rule.rule_data.sources,
                    rule.rule_data.targets,
                    rule.rule_data.connectors.map(
                        (connectorType) => `${connectorType}Relationship`
                    ),
                    rule.rule_data.source_to_target_msg,
                    rule.rule_data.required || "Y"
                )

                msgs.push(r.execute(model))
            }

            if (rule.rule_data.target_to_source) {
                const r = new rules.TargetSourceRule2(
                    rule.id,
                    rule.name,
                    "No category",
                    rule.rule_data.targets,
                    rule.rule_data.sources,
                    rule.rule_data.connectors.map(
                        (connectorType) => `${connectorType}Relationship`
                    ),
                    rule.rule_data.target_to_source_msg,
                    rule.rule_data.required || "Y"
                )

                msgs.push(r.execute(model))
            }

            return _.flatten(msgs)
        })
    )

    return allMsgs
}

const loadOrRetrieveTfModel = async (modelType, mlModels, accountId, dispatch) => {
    let tfModelOut // Edge out of a given Node type
    let metaOut

    if (mlModels.value[modelType] === undefined) {
        console.log("%cml model not loaded. loading...", "color:orange", modelType)
        const { tfModel, meta } = await getTfModel(
            accountId,
            mlModels,
            modelType,
            dispatch
        )

        tfModelOut = tfModel
        metaOut = meta
    } else {
        tfModelOut = mlModels.value[modelType].tfModel
        metaOut = mlModels.value[modelType].meta
    }

    return { tfModel: tfModelOut, elementTypes: metaOut.element_types }
}

const runNodePredictionRules = async (model, mlModels, accountId, dispatch) => {
    //const graph = modelServices.buildGraph([model])
    const allMsgs = []
    //console.group("%crunning node prediction rules", "color:yellow", { graph, mlModels })

    const { tfModel: tfModelNodePrediction, elementTypes: tfNodePredictionElementTypes } =
        await loadOrRetrieveTfModel(
            mlServices.ML_MODEL_NODE_PREDICTION,
            mlModels,
            accountId,
            dispatch
        )

    // console.log("%cperforming node predictions", "color:yellow", {
    //     tfNodePredictionElementTypes,
    // })

    const {
        tfModel: tfModelNodePredictionEncoder,
        elementTypes: tfNodePredictionEncoderElementTypes,
    } = await loadOrRetrieveTfModel(
        mlServices.ML_MODEL_NODE_PREDICTION_ENCODER,
        mlModels,
        accountId,
        dispatch
    )

    const triplesByView = getTriplesByView(tfNodePredictionElementTypes, model)

    //console.log("%ctriplesByView", "color:yellow", { triplesByView })

    const embeddings = createEmbeddings(triplesByView, tfNodePredictionElementTypes)

    //console.log("%cembeddings", "color:yellow", { embeddings })

    const uniqueEmbeddings = _.uniqWith(
        embeddings,
        (a, b) => a.source === b.source && a.target === b.target
    )

    uniqueEmbeddings.forEach((em) => {
        const recommendations = mlServices.recommendTargetNode(
            tfModelNodePrediction,
            tfModelNodePredictionEncoder,
            em.source.type,
            em.edge.type,
            tfNodePredictionElementTypes
        )

        //console.log("%crecommendations", "color:yellow", { em, recommendations })

        if (recommendations.length === 0) {
            // No recommendations
            //} else if (edgeOutRecommendations.find((r) => r.type === em.edge.type)) {
            // All good - edge out is one of the expected types
        } else if (!recommendations.find((r) => r.type === em.target.type)) {
            // If there's a 'None' recommendation then ignore it, since this means the tf model
            // could explicitly not find a recommendation for this edge.
            const recommendationsMinusNone = recommendations.filter((r) => r.type !== "None")

            // console.log("%crecommendations for target node", "color:orange", {
            //     em,
            //     recommendationsMinusNone,
            // })

            if (recommendationsMinusNone.length > 0) {
                const element = model.model.elements.find((el) => el.id === em.source.id)

                allMsgs.push({
                    id: `${em.source.id}-${em.target.id}-${em.edge.type}-${em.view.id}-node-prediction`,
                    // ML = Machine Learning
                    ruleId: "ML",
                    element,
                    // Optional callback for a message to determine if this message is shown for a given view
                    // Example. check if a given view has both the source + target elements, then show message, else do not show message
                    isShowForView: (view) => {
                        if (view.id !== em.view.id) {
                            //console.log("wrong view", view, em)
                            return false
                        }

                        const source = view.elements.find(
                            (el) => el.diagramObject.archimateElement === em.source.id
                        )

                        const target = view.elements.find(
                            (el) => el.diagramObject.archimateElement === em.target.id
                        )

                        //console.log("targetFound", source, target)

                        return source && target
                    },
                    // ML = Machine Learning
                    rule: "ML",
                    isMlRule: true,
                    msg: `Expecting ${recommendationsMinusNone
                        .map((rec) => `${rec.type} (${(rec.match * 100).toFixed(0)}%)`)
                        .join(", ")} from '${em.source.name}' to '${em.edge.type}' - not ${
                        em.target.type
                    }`,
                })
            }
        } else {
            // console.log("%citem matches recommendation", "color:chartreuse", {
            //     source: em.source.type,
            //     target: em.target.type,
            //     actual: em.type,
            //     recommended: recommendations,
            // })
        }
    })

    console.groupEnd()
    return allMsgs
}

const getTriplesByView = (allowedElementTypes, model) => {
    const triplesByView = _.flatten(
        model.model.views.map((view) => {
            const edges = _.flatten(
                view.elements.map((element) => {
                    //console.log("%cview edge", "color:lightblue", { element })

                    return _.flatten(
                        element.sourceConnections
                            .map((sc) => {
                                const rel = model.model.elements.find(
                                    (el) => el.id === sc.connection.archimateRelationship
                                )

                                // This happens if the connection goes into Note. Ignore these.
                                if (rel === undefined) {
                                    return undefined
                                }

                                //console.log("%crel", "color:lightblue", { sc, rel })

                                // Check if target is a relationship or element
                                const target = model.model.elements.find(
                                    (element) => element.id === rel.target
                                )

                                if (!allowedElementTypes.includes(target.type)) {
                                    return undefined
                                }
                                //console.log("%ctarget", "color:lightblue", { target })

                                // Target is a relationship, not an element. We ignore these. so return undefined, and
                                // we'll filter these out at the end
                                if (target.source) {
                                    return undefined
                                }

                                const source = model.model.elements.find(
                                    (element) => element.id === rel.source
                                )

                                if (!allowedElementTypes.includes(source.type)) {
                                    return undefined
                                }

                                return {
                                    view: { id: view.id },
                                    source: {
                                        id: source.id,
                                        type: source.type,
                                        name: source.name,
                                    },
                                    target: {
                                        id: target.id,
                                        type: target.type,
                                        name: target.name,
                                    },
                                    edge: {
                                        id: rel.id,
                                        type: rel.type.replace("Relationship", ""),
                                    },
                                    //objs: { source, target, rel },
                                }
                            })
                            .filter((item) => item !== undefined)
                    )
                })
            )
            return edges
        })
    ).filter((item) => item !== undefined)

    return triplesByView
}

const createEmbeddings = (triplesByView, allowedElementTypes) => {
    const embeddings = triplesByView.map((item) => {
        //console.log("%citem", "color:lightblue", { item })

        return {
            view: item.view,
            edge: { type: item.edge.type },
            source: item.source,
            target: item.target,
            source_vec: mlServices.getElementVecTfBuf(item.source.type, allowedElementTypes),
            target_vec: mlServices.getElementVecTfBuf(item.target.type, allowedElementTypes),
            edge_vec: mlServices.getEdgeVecTfBufFromType(
                item.edge.type,
                palette.getConnectorTypes()
            ),
        }
    })

    return embeddings
}

const runEdgePredictionRules = async (model, mlModels, accountId, dispatch) => {
    const graph = modelServices.buildGraph([model])
    const allMsgs = []
    //console.log("%crunning edge prediction rules", "color:lightblue", { graph, mlModels })

    const { tfModel: tfModelEdgePrediction, elementTypes: tfEdgePredictionElementTypes } =
        await loadOrRetrieveTfModel(
            mlServices.ML_MODEL_EDGE_PREDICTION,
            mlModels,
            accountId,
            dispatch
        )

    // console.log("%cperforming edge predictions", "color:lightgreen", {
    //     tfEdgePredictionElementTypes,
    // })

    const {
        tfModel: tfModelEdgePredictionEncoder,
        elementTypes: tfEdgePredictionEncoderElementTypes,
    } = await loadOrRetrieveTfModel(
        mlServices.ML_MODEL_EDGE_PREDICTION_ENCODER,
        mlModels,
        accountId,
        dispatch
    )

    // Get edges by view

    const triplesByView = getTriplesByView(tfEdgePredictionElementTypes, model)

    //console.log("%cgot view edges", "color:lightblue", { triplesByView })

    const embeddings = createEmbeddings(triplesByView, tfEdgePredictionElementTypes)

    const uniqueEmbeddings = _.uniqWith(
        embeddings,
        (a, b) => a.source === b.source && a.target === b.target
    )

    uniqueEmbeddings.forEach((em) => {
        const recommendations = mlServices.recommendEdgeBetweenNodes(
            tfModelEdgePrediction,
            em.source.type,
            em.target.type,
            tfEdgePredictionElementTypes,
            tfModelEdgePredictionEncoder
        )

        if (recommendations.length === 0) {
            // No recommendations
            //} else if (edgeOutRecommendations.find((r) => r.type === em.edge.type)) {
            // All good - edge out is one of the expected types
        } else if (!recommendations.find((r) => r.type === em.edge.type)) {
            // If there's a 'None' recommendation then ignore it, since this means the tf model
            // could explicitly not find a recommendation for this edge.
            const recommendationsMinusNone = recommendations.filter((r) => r.type !== "None")

            if (recommendationsMinusNone.length > 0) {
                const element = model.model.elements.find((el) => el.id === em.source.id)

                allMsgs.push({
                    id: `${em.source.id}-${em.target.id}-${em.edge.type}-${em.view.id}-edge-prediction`,
                    // ML = Machine Learning
                    ruleId: "ML",
                    element,
                    // Optional callback for a message to determine if this message is shown for a given view
                    // Example. check if a given view has both the source + target elements, then show message, else do not show message
                    isShowForView: (view) => {
                        if (view.id !== em.view.id) {
                            //console.log("wrong view", view, em)
                            return false
                        }

                        const source = view.elements.find(
                            (el) => el.diagramObject.archimateElement === em.source.id
                        )

                        const target = view.elements.find(
                            (el) => el.diagramObject.archimateElement === em.target.id
                        )

                        //console.log("targetFound", source, target)

                        return source && target
                    },
                    // ML = Machine Learning
                    rule: "ML",
                    isMlRule: true,
                    msg: `Expecting ${recommendationsMinusNone
                        .map((rec) => `${rec.type} (${(rec.match * 100).toFixed(0)}%)`)
                        .join(", ")} from '${em.source.name}' to '${em.target.name}' - not ${
                        em.edge.type
                    }`,
                })
            }
        } else {
            // console.log("%citem matches recommendation", "color:chartreuse", {
            //     source: em.source.type,
            //     target: em.target.type,
            //     actual: em.type,
            //     recommended: recommendations,
            // })
        }
    })

    return allMsgs
}

const runMlRules = async (model, mlModels, accountId, dispatch) => {
    const allMsgs = []

    const graph = modelServices.buildGraph([model])
    //console.log("%crunning ML rules", "color:lightgreen", { graph, mlModels })

    const tfModelType = mlServices.ML_MODEL_NODE_PREDICTION

    let tfModelToUse

    let metaToUse

    if (mlModels.value[tfModelType] === undefined) {
        //console.log("%cml model not loaded. loading...", "color:orange", tfModelType)
        const { modelInfo, tfModel, meta } = await getTfModel(
            accountId,
            mlModels,
            tfModelType,
            dispatch
        )

        tfModelToUse = tfModel
        metaToUse = meta
    } else {
        tfModelToUse = mlModels.value[tfModelType].tfModel
        metaToUse = mlModels.value[tfModelType].meta
    }

    let tfEncoderModelToUse
    let encoderMetaToUse

    const encoderModelType = mlServices.ML_MODEL_NODE_PREDICTION_ENCODER

    if (mlModels.value[encoderModelType] === undefined) {
        //console.log("%cml encoder model not loaded. loading...", "color:orange", encoderModelType)
        const { modelInfo, tfModel, meta } = await getTfModel(
            accountId,
            mlModels,
            encoderModelType,
            dispatch
        )

        tfEncoderModelToUse = tfModel
        encoderMetaToUse = meta
    } else {
        tfEncoderModelToUse = mlModels.value[encoderModelType].tfModel
        encoderMetaToUse = mlModels.value[encoderModelType].meta
    }

    const tfModelElementTypes = metaToUse.element_types

    // console.log("%crunning rules witb tfModel", "color:lightgreen", {
    //     model,
    //     tfModelToUse,
    //     metaToUse,
    //     tfModelElementTypes,
    // })

    const elementsWithEmbeddings = graph
        .map((element) => {
            const embedding = mlServices.getNodeEmbedding(element, tfModelElementTypes)
            return { ...element, embedding }
        })
        .filter((element) => element.embedding.find((arrSlot) => arrSlot !== 0))

    // console.log("%cview node embeddings", "color:pink", {
    //     elementsWithEmbeddings,
    // })

    elementsWithEmbeddings
        .filter((element) => tfModelElementTypes.includes(element.type))
        .forEach((element) => {
            // Get relationships shown on this view for this element

            const elementRels =
                model.model.elements.filter((rel) => rel.source && rel.source === element.id) || []

            //console.log("%celement relationships", "color:pink", { element, elementRels })

            const edgeAndTargets = elementRels
                .map((rel) => {
                    return {
                        relType: rel.type.replace("Relationship", ""),
                        relSource: rel.source,
                        source: model.model.elements.find((el) => el.id === rel.source),
                        target: model.model.elements.find((el) => el.id === rel.target),
                    }
                })
                // Too difficult to train a model to work out when an Association should be used or not
                .filter((edgeAndTarget) => edgeAndTarget.relType !== "Association")

            const dedupedEdgesAndTargets = _.uniqBy(edgeAndTargets, ["relType", "relSource"])
            console.log("%cdeduped", "color:pink", { dedupedEdgesAndTargets })

            console.log("%cedge and target", "color:pink", { edgeAndTargets })

            dedupedEdgesAndTargets.forEach((et) => {
                //console.log("item source type", item.source.type)
                const recommendations = mlServices.recommendTargetNode(
                    tfModelToUse,
                    et.source.type,
                    et.relType,
                    tfModelElementTypes
                )
                // console.log("%crecommendations", "color:yellow", {
                //     source: et.source.type,
                //     type: et.source.type,
                //     relType: et.relType,
                //     recommendations,
                //     target: et.target.type,
                // })

                if (
                    !recommendations.find(
                        (r) => r.type === et.target.type.replace("Relationship", "")
                    )
                ) {
                    // console.log("%cbad element", "color:red", {
                    //     source: item.source.name,
                    //     source_type: item.source.type,
                    //     edgeType: item.relType,
                    //     actual: item.target.type,
                    //     expected: recommendations,
                    // })

                    allMsgs.push({
                        id: `${element.id}-${et.target.id}-${et.relType}`,
                        ruleId: tfModelType,
                        element,
                        rule: tfModelType,
                        isMlRule: true,
                        msg: `Expecting ${et.relType} to ${recommendations
                            .map((rec) => `${rec.type} (${(rec.match * 100).toFixed(0)}%)`)
                            .join(", ")} - not ${et.target.type}`,
                    })
                } else {
                    // console.log("%citem matches recommendation", "color:chartreuse", {
                    //     source: item.source.type,
                    //     target: item.target.type,
                    //     actual: item.relType,
                    //     recommended: recommendations,
                    // })
                }
            })
        })
    return allMsgs
}

const getTfModel = async (accountId, mlModels, tfModelType, dispatch) => {
    const { modelInfo, tfModel, meta } = await mlServices.loadTfLayersModel(tfModelType, accountId)

    if (tfModel) {
        const newMlModels = {
            ...mlModels.value,
            [tfModelType]: { tfModel: tfModel, meta: meta },
        }
        console.log("%cnew MlModels", "color: lightgreen", {
            newMlModels,
            modelInfo,
            meta,
        })

        dispatch(setMlModels(newMlModels))
        return { modelInfo, tfModel, meta }
    } else {
        console.log("%cno TF model", "color: orange")
        return undefined
    }
}

export { runRules }
