import React, { useState, useEffect } from "react"
import Header from "../components/Header"
import {
  Alert,
  Box,
  Divider,
  List,
  ListItemButton,
  Paper,
  Skeleton,
  Stack,
  Typography,
} from "@mui/material"
import { selectModelState } from "../redux/selectors"
import { useSelector, useDispatch } from "react-redux"
import * as colors from "@mui/material/colors"
import Controls from "../components/controls/Controls"
import * as modelServices from "./services/modelServices"
import { setModelState } from "../redux/actions"
import db from "../Firestore"
import _ from "lodash"
import { useSnackbar } from "notistack"
import * as tf from "@tensorflow/tfjs"
import * as tfvis from "@tensorflow/tfjs-vis"
import firebase from "firebase/compat/app"
import * as palette from "../components/symbols/palette"
import { serverTimestamp } from "./services/dataServices"
import * as mlServices from "./services/mlServices"
import useAccountStatus from "../components/useAccountStatus"
import { spacing } from "./services/styleServices"

// See https://distill.pub/2021/understanding-gnns/ for details on GNN types

const styles = {
  pageContent: {
    margin: spacing(1),
    padding: spacing(1),
    display: "flex",
    flexDirection: "row",
  },
  parentName: {
    backgroundColor: colors.grey[300],
  },
  selectedFiles: {
    display: "flex",
    flexDirection: "column",
    gap: spacing(1),
  },
  listItemRoot: {
    "&.Mui-selected": {
      backgroundColor: colors.pink[100],
    },
  },
  buttons: {
    marginTop: spacing(3),
  },
  row: {
    display: "flex",
    flexDirection: "row",
    gap: spacing(1),
    marginBottom: spacing(1),
  },
  elements: {
    marginTop: spacing(2),
  },
  elementInfo: {
    display: "flex",
    flexDirection: "column",
  },
  rels: {
    display: "flex",
    flexDirection: "column",
    marginLeft: spacing(1),
  },
  vec: {
    fontFamily: "monospace",
  },
  elementTypeCount: {
    display: "flex",
    flexDirection: "row",
    marginLeft: spacing(1),
  },
  elementTypeLabel: {
    minWidth: 160,
  },
  edgePrediction: {
    display: "flex",
    flexDirection: "column",
    marginLeft: spacing(1),
    gap: spacing(1),
    maxWidth: 250,
    marginTop: spacing(2),
  },
  files: {
    minWidth: 300,
  },
  trainingLog: {
    padding: spacing(1),
  },
  logLines: {
    display: "flex",
    flexDirection: "column",
  },
}

const ML = (props) => {
  const [title, setTitle] = useState("ML")

  const { enqueueSnackbar } = useSnackbar()

  const [selectedFiles, setSelectedFiles] = useState([])

  // Element types that we will train on, because there are enough of them
  const [trainTypes, setTrainTypes] = useState([])

  const [elementTypeOptions, setElementTypeOptions] = useState([])

  const [parents, setParents] = useState()

  const [edgePredictionStats, setEdgePredictionStats] = useState({})

  const [nodePredictionStats, setNodePredictionStats] = useState({})

  const modelCache = useSelector(selectModelState)

  const [trainLog, setTrainLog] = useState([])

  const { isActive } = useAccountStatus()

  const [values, setValues] = useState({
    edge_prediction: { from: "", to: "" },
  })

  const [recommendations, setRecommendations] = useState([])

  const dispatch = useDispatch()

  const [accountId, setAccountId] = useState()

  const [useConv, setUseConv] = useState(false)

  const [elementCounts, setElementCounts] = useState()

  const [graph, setGraph] = useState()

  useEffect(() => {
    const unsub = firebase.auth().onAuthStateChanged((user) => {
      if (user) {
        user.getIdTokenResult(false).then((token) => {
          setAccountId(token.claims.account_id)
        })
      }
    })

    return unsub
  }, [])

  useEffect(() => {
    if (accountId) {
      loadProjectsAndComponents(accountId)
    }
  }, [accountId])

  // Load any current neural network models to show the user training score and coverage of elements
  useEffect(() => {
    if (accountId) {
      db.collection("ml_models")
        .where("account_id", "==", accountId)
        .get()
        .then((snapshot) => {
          const models = snapshot.docs.map((doc) => doc.data())

          const nodeStats = models.find(
            (model) => model.type === "NodePrediction"
          )
          const edgeStats = models.find(
            (model) => model.type === "EdgePrediction"
          )

          console.log("%cexisting models", "color:yellow", {
            models,
            nodeStats,
            edgeStats,
          })

          setNodePredictionStats(nodeStats)
          setEdgePredictionStats(edgeStats)
        })
    }
  }, [accountId])

  useEffect(() => {
    if (trainTypes) {
      const options = trainTypes
        .map((type) => ({ title: type, id: type }))
        .sort((a, b) => a.title.localeCompare(b.title))
      setElementTypeOptions(options)
    }
  }, [trainTypes])

  useEffect(() => {
    return () => {
      tfvis.visor().close()
    }
  }, [])

  const loadProjectsAndComponents = async (accountId) => {
    const projects = await db
      .collection("projects")
      .where("account_id", "==", accountId)
      .where("ml_training", "==", true)
      .get()
      .then((querySnapshot) =>
        querySnapshot.docs.map((doc) => ({
          id: doc.id,
          parent_type: "project",
          ...doc.data(),
        }))
      )

    const components = await db
      .collection("components")
      .where("account_id", "==", accountId)
      .where("ml_training", "==", true)
      .get()
      .then((querySnapshot) =>
        querySnapshot.docs.map((doc) => ({
          id: doc.id,
          parent_type: "component",
          ...doc.data(),
        }))
      )

    console.log("parents", { projects, components })

    const parents = [...projects, ...components]
      .map((parent) => ({
        id: parent.id,
        account_id: parent.account_id,
        parent_type: parent.parent_type,
        name: parent.name,
        files: parent.files,
      }))
      .sort((a, b) => a.name.localeCompare(b.name))
    setParents(parents)
  }

  /**
   * Get element count in the form
   *
   * ```
   * [{
   * type: "ApplicationComponent",
   * count: 5
   * }]
   * ```
   *
   * @param {*} graph
   * @returns
   */
  const getElementSummary = (graph) => {
    const result = _.countBy(
      graph.map((element) => ({
        type: element.type,
      })),
      "type"
    )

    const asArray = Object.keys(result)
      .map((type) => ({ type, count: result[type] }))
      .sort((a, b) => b.count - a.count)

    console.log("%celement summary", "color: pink", { result, asArray })

    return asArray
  }

  const createModelCacheEntry = (
    model,
    fileName,
    name,
    parentId,
    parentType
  ) => {
    const modelState = modelServices.createModelCacheItem(
      model,
      fileName,
      name,
      parentId,
      parentType
    )

    return modelState
  }

  function getCachedModel(selected) {
    const key = modelServices.createModelCacheKey(
      selected.fileName,
      selected.parentId,
      "project"
    )
    console.log("key", { key, selected, modelCache })

    // const cachedModel = Object.values(modelCache).find(
    //     (cacheEntry) =>
    //         cacheEntry.parent_id === key.parentId && cacheEntry.model.file === key.fileName
    // )

    const cachedModel = modelServices.searchModelCache({
      modelCacheKey: key,
      modelCache: modelCache,
    })

    return cachedModel
  }

  const loadModelIntoCache = (model, fileName, rawText, props) => {
    console.log("loadModelIntoCache", { model, fileName, props })

    const { parent } = props
    const modelState = createModelCacheEntry(
      model,
      fileName,
      parent.name,
      parent.id,
      parent.parent_type
    )

    dispatch(setModelState(modelState))
  }

  const handleSelectFile = (parent, fileName) => {
    console.log("parent and fileName", { parent, fileName })

    const isSelected =
      selectedFiles.find(
        (entry) => entry.parentId === parent.id && entry.fileName === fileName
      ) !== undefined
    console.log("isSelected", isSelected)

    if (!isSelected) {
      const newSelected = [...selectedFiles, { parentId: parent.id, fileName }]
      setSelectedFiles(newSelected)

      // Load selected file

      // parent.parent_type will be 'project' or 'component', so we need to add an 's' into the filePath
      const filePath = `accounts/${parent.account_id}/${parent.parent_type}s/${parent.id}/`

      modelServices.loadFile(filePath, fileName, loadModelIntoCache, {
        parent,
      })
    } else {
      const newSelected = selectedFiles.filter(
        (entry) => entry.parentId !== parent.id || entry.fileName !== fileName
      )
      setSelectedFiles(newSelected)
    }
  }

  const getEdgeVecs = (cnxTypes, edges) => {
    const buf = tf.buffer([cnxTypes.length, 1])
    edges.forEach((edge) => {
      const relType = edge.type.replace("Relationship", "")
      const edgeType = cnxTypes.findIndex((item) => item === relType)
      buf.set(1, edgeType, 0)
    })
    return buf
  }

  const getEdgeIndex = (cnxTypes, edge) => {
    const relType = edge.type.replace("Relationship", "")
    return mlServices.getEdgeIndexFromType(cnxTypes, relType)
  }

  const getEdgeVec = (cnxTypes, edge) => {
    const buf = tf.buffer([cnxTypes.length, 1])
    const edgeType = getEdgeIndex(cnxTypes, edge)
    buf.set(1, edgeType, 0)
    return buf
  }

  /**
   * Get the neighbours of a node incl the node itself. Take the average of the neighbours
   *
   * @param {*} element
   */
  const getConvolutionalElement = (element, elementTypes) => {
    let convResult

    tf.tidy(() => {
      //console.log("%c[getConv] BEGIN", "color:lightgreen", { element })

      //const nodeVec = mlServices.getElementVecTfBuf(element.type, elementTypes)

      const inNeighbours = _.flatten(
        element.edges.in.map((edge) => edge.sourceNode())
      )
      const outNeighbours = _.flatten(
        element.edges.out.map((edge) => edge.targetNode())
      )
        // This filtering for !undefined removes where the connection is between
        // a relationship and a node, not node-to-node
        .filter((e) => e !== undefined)

      const inNeighbourVecs = inNeighbours.map((element) =>
        mlServices.getElementVecTfBuf(element.type, elementTypes)
      )
      // const outNeighbourVecs = outNeighbours.map((element) =>
      //     mlServices.getElementVecTfBuf(element.type, elementTypes)
      // )

      const tensors = [
        // nodeVec.toTensor(),
        // nodeVec.toTensor(),
        // nodeVec.toTensor(),
        ...inNeighbourVecs.map((inNode) => inNode.toTensor()),
        //...outNeighbourVecs.map((outNode) => outNode.toTensor()),
      ]

      let result
      switch (tensors.length) {
        case 0:
          result = tf.zeros([1, elementTypes.length])
          break
        case 1:
          result = tensors[0]
          break

        default:
          result = tf.layers.average().apply(tensors)
          break
      }

      //const sum = tf.addN(tensors)
      //console.log("avg", { element, count: tensors.length, avg: avg.dataSync() })

      convResult = result.dataSync()
      //convResult = sum.dataSync()
    })

    return convResult
  }

  const getVec = (element, trainTypes, cnxTypes) => {
    const elementBuf = mlServices.getElementVecTfBuf(element.type, trainTypes)

    const incomingEdgeBuf = getEdgeVecs(cnxTypes, element.edges.in)
    const outgoingEdgeBuf = getEdgeVecs(cnxTypes, element.edges.out)

    return {
      element: elementBuf.toTensor().dataSync(),
      edges_in: incomingEdgeBuf.toTensor().dataSync(),
      edges_out: outgoingEdgeBuf.toTensor().dataSync(),
      element_type: palette.getElementType(element.type),
    }
  }

  const getEdgePredictionXsAndLabels = (graph) => {
    const edgesOut = _.flatten(graph.map((element) => element.edges.out))

    const xsAndLabels = edgesOut
      .map((edge) => {
        const targetNode = edge.targetNode()
        if (targetNode === undefined) {
          return undefined
        }

        return {
          xs: {
            source: edge.sourceNode().type,
            target: edge.targetNode().type,
          },
          label: edge.type.replace("Relationship", ""),
        }
      })
      .filter((item) => item !== undefined)

    const uniqueXsAndLabels = _.uniqWith(
      xsAndLabels,
      (a, b) =>
        a.xs.source === b.xs.source &&
        a.xs.target === b.xs.target &&
        a.label === b.label
    )

    const xsData = uniqueXsAndLabels.map((input) => input.xs)
    const labelsData = uniqueXsAndLabels.map((input) => input.label)

    return { xsTypes: xsData, labelTypes: labelsData }
  }

  const ALL_ELEMENT_TYPES = palette.ELEMENT_INDEX.map((item) => item.name)

  // Given a source and target node, recommend the edge type that should connect the,
  const handleTrainEdgePrediction = async (allowedLabels) => {
    if (!graph || graph.length === 0) {
      console.log("%cno graph", "color:orange")
      return
    }

    tfvis.visor().open()

    const filteredGraph = graph.filter((element) =>
      allowedLabels.includes(element.type)
    )

    console.log("%cgraphs", "color:yellow", {
      graph,
      filteredGraph,
      allowedLabels,
    })

    const { xsTypes, labelTypes } = getEdgePredictionXsAndLabels(filteredGraph)

    const elementTypes = allowedLabels

    console.log("%cedge prediction", "color:orange", {
      xsTypes,
      labelTypes,
      elementTypes,
    })

    const missingXs = _.flatten(
      elementTypes.map((source) =>
        elementTypes.map((target) => ({
          source,
          target,
        }))
      )
    ).filter((combo) => {
      const found = xsTypes.find((xsType) => {
        return xsType.source === combo.source && xsType.target === combo.target
      })
      //console.log("%citem combo", "color:pink", { combo, exists: found })
      return found === undefined
    })

    const missingItemsPortion = 0.7

    const missingSampleXs = _.shuffle(missingXs).slice(
      0,
      missingXs.length * missingItemsPortion
    )

    const missingLabels = _.fill(Array(missingSampleXs.length), "None")
    console.log("missingEdgeTypes", {
      missingXs,
      missingLabels,
      xsTypes,
      labelTypes,
    })

    const combinedXs = [...xsTypes, ...missingSampleXs]
    const combinedLabels = [...labelTypes, ...missingLabels]

    console.log("combinedXs", { combinedXs, combinedLabels })
    // Convert xs and labels to vectors

    const xsBufs = combinedXs.map((input) => {
      const source = mlServices.getElementVecTfBuf(input.source, elementTypes)
      const target = mlServices.getElementVecTfBuf(input.target, elementTypes)

      //console.log("%cCreate xs", "color:yellow", [...source.values, ...target.values])

      return tf.buffer(
        [1, source.values.length + target.values.length],
        "float32",
        [...source.values, ...target.values]
      )
    })

    const toIntArray = (arr) => arr.map((item) => [...item.values])

    console.log("xsBufs", { xsBufs })

    const { encoder, decoder } = await getEncoderAndDecoder(toIntArray(xsBufs))

    const encoderModel = tf.sequential()
    encoderModel.add(encoder)
    const autoencodeInputTensor = tf.tensor2d(toIntArray(xsBufs))
    console.log("autoencodeInputTensor", autoencodeInputTensor)
    const autoencodedXs = encoderModel.predict(autoencodeInputTensor)

    console.log("autoencodedXs", { autoencodedXs })

    // const csvXs = autoencodeInputTensor.dataSync()
    // console.log("%ccsvXs", "color:yellow", { csvXs })

    const labelBufs = combinedLabels.map((label) => {
      const labelBuf = mlServices.getEdgeVecTfBufFromType(
        label,
        // Use a 'None' entry in the connector list, since if the NN can't predict an answer
        // it needs a 'None' to fall back on, i.e. since all the softmax outputs need to add up to 1
        // so a failed prediction needs None = 1 - all others = 0
        mlServices.getConnectorTypesWithNone()
      )
      if (labelBuf.values.find((slot) => slot === 1) === undefined) {
        console.log("%cno label", "color:orange", label, labelBuf)
      }
      return labelBuf
    })

    const labelVals = labelBufs.map((buffer) => buffer.values)

    const labelsTensor = tf.tensor(labelVals, [
      labelVals.length,
      labelVals[0].length,
    ])

    const inputCount = autoencodedXs.shape[1]
    const outputCount = labelBufs[0].values.length

    const model = createNodePredictionModel(inputCount, outputCount)

    const batchSize = 15
    const epochs = 200

    model
      //.fit(xsTensor, labelsTensor, {
      .fit(autoencodedXs, labelsTensor, {
        batchSize,
        epochs,
        shuffle: true,
        validationSplit: 0.3,
        callbacks: tfvis.show.fitCallbacks(
          { name: "Train Edge Prediction" },
          ["loss", "mse", "acc"],
          { height: 200, callbacks: ["onEpochEnd"] }
        ),
      })
      .then(async (history) => {
        let savedModelInfo

        await model.save({
          save: (modelInfo) => {
            console.log("%cmodelInfo", "color:lightgreen", {
              modelInfo,
              weightData: modelInfo.weightData,
              history,
            })
            savedModelInfo = modelInfo
          },
        })

        await saveModel(
          savedModelInfo,
          mlServices.ML_MODEL_EDGE_PREDICTION,
          elementTypes,
          history
        )

        let savedEncoderModelInfo

        await encoderModel.save({
          save: (modelInfo) => {
            console.log("%cencoder modelInfo", "color:lightgreen", {
              modelInfo,
              weightData: modelInfo.weightData,
            })
            savedEncoderModelInfo = modelInfo
          },
        })

        await saveModel(
          savedEncoderModelInfo,
          mlServices.ML_MODEL_EDGE_PREDICTION_ENCODER,
          elementTypes,
          history
        )
      })
      .then(async () => {
        //TODO: This shouldn't be required if tf.tidy() worked, but seems like it too eagerly disposes of tensors
        // See

        //tf.dispose(xsTensor)
        tf.dispose(labelsTensor)

        // const loadedTfModel = await mlServices.loadTfLayersModel(
        //     mlServices.ML_MODEL_EDGE_PREDICTION,
        //     accountId
        // )

        // const tfModel = loadedTfModel.tfModel

        console.log("elementTypes", { elementTypes })

        const doRec = (id, sourceType, targetType) => {
          const rec = mlServices.recommendEdgeBetweenNodes(
            model,
            sourceType,
            targetType,
            elementTypes,
            encoderModel
          )
          console.log(`%crecommendation ${id}`, "color:lightgreen", {
            rec,
            sourceType,
            targetType,
          })
        }

        console.log("%cedge predictions for loaded model", "color:lightgreen")

        doRec("1", palette.BUSINESS_FUNCTION, palette.BUSINESS_OBJECT)
        doRec("2", palette.CAPABILITY, palette.CAPABILITY)
        doRec("3", palette.TECHNOLOGY_SERVICE, palette.STAKEHOLDER)
        doRec("4", palette.APPLICATION_INTERFACE, palette.APPLICATION_SERVICE)
        doRec("5", palette.BUSINESS_ACTOR, palette.BUSINESS_FUNCTION)
        doRec("5", palette.BUSINESS_OBJECT, palette.BUSINESS_OBJECT)

        tf.dispose(model)
        // console.log("%cpredictions for normal model", "color:lightgreen")
        // mlServices.recommendEdge(model, palette.BUSINESS_FUNCTION, trainTypes)
      })
  }

  const handlePredictEdge = async () => {
    const tfModelType = mlServices.ML_MODEL_EDGE_PREDICTION

    const { modelInfo, tfModel, meta } = await mlServices.loadTfLayersModel(
      tfModelType,
      accountId
    )

    const {
      modelInfo: encoderModelInfo,
      tfModel: tfEncoderModel,
      meta: encoderMeta,
    } = await mlServices.loadTfLayersModel(
      mlServices.ML_MODEL_EDGE_PREDICTION_ENCODER,
      accountId
    )
    const recommendations = mlServices.recommendEdgeBetweenNodes(
      tfModel,
      values.edge_prediction.from,
      values.edge_prediction.to,
      meta.element_types,
      tfEncoderModel
    )

    console.log("recommendations", recommendations)

    setRecommendations(recommendations)
  }

  const printTensors = (label) => {
    console.log(
      `%cTensors ${label ? `(${label})` : ""}`,
      "color:lightgreen",
      tf.memory().numTensors,
      tf.memory().numDataBuffers
    )
  }

  const saveModel = async (model, type, modelElementTypes, history) => {
    console.log("%csaving model", "color:lightgreen", { model, type })

    const loss = history.history.loss[history.history.loss.length - 1]
    const accuracy = history.history.acc[history.history.acc.length - 1]

    const modelRec = {
      account_id: accountId,
      type: type,
      created: serverTimestamp(),
      modified: serverTimestamp(),
      tf_model_info: {
        ...model,

        // convert weightData to base64 string
        weightData: mlServices.arrayBufferToBase64(model.weightData),
      },

      // Meta-data relevant for how this model was trained.
      // Is typically useful when we go to use the model in a different context
      // e.g. in this case the element types enables us to build the input vector
      // to correspond to how the model was trained since element types can vary
      // depending on the type of ArchiMate models used.
      train_meta: {
        element_types: modelElementTypes,
        loss,
        accuracy,
      },
    }

    if (type === mlServices.ML_MODEL_EDGE_PREDICTION) {
      setEdgePredictionStats(modelRec)
    } else if (type === mlServices.ML_MODEL_NODE_PREDICTION) {
      setNodePredictionStats(modelRec)
    }

    // See if any models already exist

    const existingMLModels = await mlServices.loadMLModelData(type, accountId)

    console.log("%cany existing models?", "color:lightgreen", existingMLModels)

    if (existingMLModels.length > 0) {
      console.log("%cfound existing model(s)", "color:lightgreen", {
        existingMLModels,
      })

      await Promise.all(
        existingMLModels.map((model) => {
          return db.collection("ml_models").doc(model.id).delete()
        })
      )
    }
    const result = await db.collection("ml_models").add(modelRec)

    console.log("%csaved new model", "color:lightgreen", { result })
  }

  const createNodePredictionModel = (inputSize, outputSize) => {
    const model = tf.sequential()

    model.add(
      tf.layers.dense({
        inputShape: [inputSize],
        units: 20,
        activation: "sigmoid",
      })
    )

    // Dropout rates help prevent overfitting
    model.add(tf.layers.dropout({ rate: 0.05 }))

    // Add an output layer
    model.add(tf.layers.dense({ units: outputSize, activation: "softmax" }))

    model.compile({
      optimizer: tf.train.adam(0.1),
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    })

    // model.add(tf.layers.dense({ units: outputSize, activation: "sigmoid" }))

    // model.compile({
    //     optimizer: tf.train.adam(0.1),
    //     loss: "categoricalCrossentropy",
    //     metrics: ["accuracy"],
    // })

    return model
  }

  // Store the vector for all views.
  // Create a new vector for the view, for which you want to find similar views.
  // Run the code below, and the results with the higher value are the closest matches.

  // const dotLayer = tf.layers.dot({axes: -1, normalize: true});

  // const view = [0, 0, 1, 1]
  // const x1 = tf.tensor([[0, 0, 1, 0], [0, 1, 0, 0]]);
  // const x2 = tf.tensor([view, view]);

  // // Invoke the layer's apply() method in eager (imperative) mode.
  // const y = dotLayer.apply([x1, x2]);
  // y.print();

  /**
   *
   * @param {tValues} tensor of decimal match values from topK
   * @returns
   */
  const getTopKAsObjects = ({ tValues, tIndices }) => {
    console.log("%cvalues", "color:lightgreen", {
      values: tValues.dataSync(),
      indices: tIndices.dataSync(),
    })

    // We need to do this operation to convert the tensor data into a simple array
    const valsArr = [...tValues.dataSync()] // decimal
    const indicesArr = [...tIndices.dataSync()] // integer

    const combined = indicesArr.map((indice, index) => {
      const match = valsArr[index]
      const result = { indice, match }
      console.log("%cresult", "color:lightgreen", result)
      return result
    })

    return combined
  }

  const handleCreateViewEmbeddings = () => {
    const cachedModels = selectedFiles.map((file) => getCachedModel(file))

    console.log("%ccachedModels", "color:lightgreen", cachedModels)

    const viewsAndModel = _.flatten(
      cachedModels.map((model) =>
        model.model.views.map((view) => ({ model, view }))
      )
    ).filter((item) => item.view.elements.length > 0)

    console.log("%cviews", "color:lightgreen", viewsAndModel)

    const viewEmbeddings = mlServices.getViewEmbeddings(
      cachedModels,
      ALL_ELEMENT_TYPES
    )

    console.log("%cview embeddings", "color:lightgreen", {
      viewEmbeddings,
      viewsAndModel,
      noOfViews: viewsAndModel.length,
    })

    const sourceViewIndex = 3
    const sourceView = viewsAndModel[sourceViewIndex]
    const sourceEmbedding = viewEmbeddings[sourceViewIndex]
    console.log("%cembedding1", "color:lightgreen", {
      view1: sourceView,
      embedding1: sourceEmbedding,
    })

    const tView2D = tf.tensor2d(viewEmbeddings)
    console.log("%ctView2D", "color:lightgreen", tView2D)

    // const a = tf.tensor2d([[1, 2, 3, 4]]);
    const tViewToFindSimilar = tf.tensor([sourceEmbedding])
    console.log("%cview to find similar", "color:lightgreen", {
      tViewToFindSimilar,
      shape: tViewToFindSimilar.shape,
    })

    // const repeated = a.tile([4, 1])
    const repeated = tViewToFindSimilar.tile([viewsAndModel.length, 1])
    console.log("%crepeated", "color:lightgreen", repeated)
    // repeated.print();  // or a.tile([1, 2])

    // normalize: true => gives the cosine distance
    const dotLayer = tf.layers.dot({ axes: -1, normalize: true })
    const cosineDistance = dotLayer.apply([tView2D, repeated])
    console.log(
      "%cy",
      "color:lightgreen",
      cosineDistance,
      cosineDistance.dataSync()
    )

    const reshaped = cosineDistance.reshape([
      cosineDistance.shape[1],
      cosineDistance.shape[0],
    ])
    console.log("%creshaped", "color:lightgreen", reshaped)

    const { values, indices } = tf.topk(reshaped, 5)
    console.log("%cvalues", "color:lightgreen", {
      values: values.dataSync(),
      indices: indices.dataSync(),
    })

    const combined = getTopKAsObjects({ tValues: values, tIndices: indices })

    console.log("%csimilarityData", "color:lightgreen", { combined })

    const matches = combined.filter((item) => item.match > 0.6)

    console.log("%csimilarity matches", "color:lightgreen", { matches })

    const sv = mlServices.findSimilarViews(
      sourceView.view,
      sourceView.model,
      cachedModels
    )
    console.log("%csimilar views", "color:yellow", {
      sv,
      names: sv.map((v) => `${v.file} > ${v.view.name}`),
    })
  }

  // https://observablehq.com/@jerdak/tensorflow-js-dimensionality-reduction-autoencoder
  const getEncoderAndDecoder = async (trainingInputs) => {
    const model = tf.sequential()

    const embeddingWidth = trainingInputs[0].length
    console.log("%cgetEncoderAndDecoder", "color:lightgreen", {
      trainingInputs,
      embeddingWidth,
    })

    const activation = "sigmoid"

    // Using an autoencoder width of 6 gets good results with the node prediction
    const encoderUnits = Math.floor(Math.min(embeddingWidth / 4, 6))

    console.log("%cusing encoder units", "color:lightgreen", {
      encoderUnits,
      embeddingWidth,
    })

    // To simulate PCA we use 1 hidden layer with a linear (relu) activation
    const encoder = tf.layers.dense({
      units: encoderUnits,
      inputShape: [embeddingWidth], //We will input N samples X 4 columns
      activation: activation,
    })

    const decoder = tf.layers.dense({
      units: embeddingWidth,
      activation: activation,
    })

    model.add(encoder)

    // model.add(tf.layers.dense({ units: 10, activation: "sigmoid" }))
    // model.add(tf.layers.dropout({ rate: 0.1 }))

    model.add(decoder)
    model.compile({ optimizer: tf.train.adam(0.05), loss: "meanSquaredError" })
    //model.compile({ optimizer: tf.train.sgd(0.1), loss: "meanSquaredError" })

    console.log("%cmodel", "color:lightgreen", model)

    const xs = tf.tensor2d(trainingInputs)
    let history = await model.fit(xs, xs, {
      epochs: 200,
      batchSize: 15,
      shuffle: true,
      validationSpit: 0.1,
      callbacks: tfvis.show.fitCallbacks(
        { name: "Encode xs" },
        ["loss", "mse", "acc"],
        {
          height: 200,
          callbacks: ["onEpochEnd"],
        }
      ),
    })
    xs.dispose()

    return {
      model: model,
      encoder: encoder,
      decoder: decoder,
    }
  }

  // const handleCreateEncoders = async (allowedElementTypes) => {
  //     const baseElementTypes = allowedElementTypes
  //     const baseElementTypesWithNone = ["None", ...allowedElementTypes]

  //     // Missing combinations of node + edge out, for which a 'None' answer is suitable

  //     const usedTypes = _.flatten(
  //         graph.map((element) => {
  //             return element.edges.out.map((edgeOut) => ({
  //                 source_type: element.type,
  //                 edge_type: edgeOut.type.replace("Relationship", ""),
  //                 target_type: edgeOut.targetNode().type,
  //             }))
  //         })
  //     )

  //     const uniqueUsedTypes = _.uniqWith(
  //         usedTypes,
  //         (a, b) =>
  //             a.source_type === b.source_type &&
  //             a.edge_type === b.edge_type &&
  //             a.target_type === b.target_type
  //     )

  //     const noneIndex = baseElementTypesWithNone.indexOf("None")
  //     console.log("allowedElementTypesWithNone", { baseElementTypesWithNone, noneIndex })

  //     const xsAndLabels = uniqueUsedTypes.map((input) => {
  //         const sourceVec = mlServices.getElementVecTfBuf(input.source_type, baseElementTypes)
  //         const edgeVec = mlServices.getEdgeVecTfBufFromType(
  //             input.edge_type,
  //             palette.getConnectorTypes()
  //         )
  //         const targetVec = mlServices.getElementVecTfBuf(
  //             input.target_type,
  //             baseElementTypesWithNone
  //         )

  //         const xs = tf.buffer([1, sourceVec.values.length + edgeVec.values.length], "float32", [
  //             ...sourceVec.values,
  //             ...edgeVec.values,
  //         ])

  //         const label = tf.buffer([1, targetVec.values.length], "float32", targetVec.values)

  //         return { xs, label }
  //     })

  //     const xsData = xsAndLabels.map((item) => item.xs)

  //     const toIntArray = (arr) => arr.map((item) => [...item.values])

  //     const { model, encoder, decoder } = await getEncoderAndDecoder(toIntArray(xsData))

  //     return { model, encoder, decoder }
  // }

  const handleTrainNodePrediction = async (allowedElementTypes) => {
    const baseElementTypes = allowedElementTypes
    const baseElementTypesWithNone = ["None", ...allowedElementTypes]

    tfvis.visor().open()

    const logInfo = []

    logInfo.push(
      `Element types: ${baseElementTypesWithNone.join(", ") || "None"}`
    )
    setTrainLog(logInfo)

    // Missing combinations of node + edge out, for which a 'None' answer is suitable

    const usedTypes = _.flatten(
      graph.map((element) => {
        return element.edges.out.map((edgeOut) => ({
          source_type: element.type,
          edge_type: edgeOut.type.replace("Relationship", ""),
          target_type: edgeOut.targetNode().type,
        }))
      })
    )

    const uniqueUsedTypes = _.uniqWith(
      usedTypes,
      (a, b) =>
        a.source_type === b.source_type &&
        a.edge_type === b.edge_type &&
        a.target_type === b.target_type
    )

    logInfo.push(`Used types: ${uniqueUsedTypes.length}`)

    uniqueUsedTypes.forEach((item) => {
      logInfo.push(
        `${item.source_type} - ${item.edge_type} - ${item.target_type}`
      )
    })
    setTrainLog(logInfo)

    const missingSourceElementAndEdgeCombinations = _.flatten(
      palette.ELEMENT_INDEX.map((source) =>
        palette.getConnectorTypes().map((rel) => ({
          source_type: source.name,
          edge_type: rel,
        }))
      )
    ).filter((combo) => {
      const found = uniqueUsedTypes.find((xsType) => {
        return (
          xsType.source_type === combo.source_type &&
          xsType.edge_type === combo.edge_type
        )
      })
      return found === undefined
    })

    const missingItemsPortion = 0.2

    const missingSample = _.shuffle(
      missingSourceElementAndEdgeCombinations
    ).slice(
      0,
      missingSourceElementAndEdgeCombinations.length * missingItemsPortion
    )

    const missingXsVecs = missingSample.map((combo) => {
      const nodeVec = mlServices.getElementVecTfBuf(
        combo.source_type,
        baseElementTypes
      )
      const edgeVec = mlServices.getEdgeVecTfBufFromType(
        combo.edge_type,
        palette.getConnectorTypes()
      )

      return tf.buffer(
        [1, nodeVec.values.length + edgeVec.values.length],
        "float32",
        [...nodeVec.values, ...edgeVec.values]
      )
    })

    logInfo.push(
      `Possible prediction node types: ${
        baseElementTypesWithNone.join(", ") || "None"
      }`
    )
    setTrainLog(logInfo)

    const missingLabels = _.fill(Array(missingSample.length), "None")
    const noneIndex = baseElementTypesWithNone.indexOf("None")
    const missingLabelVecs = missingLabels.map((label) => {
      const elementBuf = tf.buffer([1, baseElementTypesWithNone.length])
      elementBuf.set(1, 0, noneIndex)
      return elementBuf
    })

    const model = tf.sequential()

    const xsAndLabels = uniqueUsedTypes.map((input) => {
      const sourceVec = mlServices.getElementVecTfBuf(
        input.source_type,
        baseElementTypes
      )
      const edgeVec = mlServices.getEdgeVecTfBufFromType(
        input.edge_type,
        palette.getConnectorTypes()
      )
      const targetVec = mlServices.getElementVecTfBuf(
        input.target_type,
        baseElementTypesWithNone
      )

      const xs = tf.buffer(
        [1, sourceVec.values.length + edgeVec.values.length],
        "float32",
        [...sourceVec.values, ...edgeVec.values]
      )

      const label = tf.buffer(
        [1, targetVec.values.length],
        "float32",
        targetVec.values
      )

      return { xs, label }
    })

    const xsData = xsAndLabels.map((item) => item.xs)
    const labelsData = xsAndLabels.map((item) => item.label)

    const toIntArray = (arr) => arr.map((item) => [...item.values])

    const repeatCount = 1
    const repeatedXs = _.flatten(_.times(repeatCount, _.constant(xsData)))
    const repeatedLabels = _.flatten(
      _.times(repeatCount, _.constant(labelsData))
    )

    const combinedXs = [...repeatedXs, ...missingXsVecs]
    const combinedLabels = [...repeatedLabels, ...missingLabelVecs]

    const { encoder, decoder } = await getEncoderAndDecoder(
      toIntArray(combinedXs)
    )

    const encoderModel = tf.sequential()
    encoderModel.add(encoder)
    const autoencodeInputTensor = tf.tensor2d(toIntArray(combinedXs))
    const autoencodedXs = encoderModel.predict(autoencodeInputTensor)

    if (xsData.length === 0) {
      console.log("No inputs, returning")
      return
    }

    const xsRowSize = autoencodedXs.shape[1]

    const labelTensor = tf.stack(
      combinedLabels.map((buf) => buf.values),
      0
    )

    model.add(
      tf.layers.dense({
        inputShape: [xsRowSize],
        units: xsRowSize + 5,
        activation: "sigmoid",
      })
    )

    // Dropout rates help prevent overfitting
    model.add(tf.layers.dropout({ rate: 0.05 }))

    //model.add(tf.layers.dense({ units: 40, activation: "sigmoid" }))

    // Add an output layer
    model.add(
      tf.layers.dense({
        units: baseElementTypesWithNone.length,
        activation: "softmax",
      })
    )

    model.compile({
      optimizer: tf.train.adam(0.1),
      loss: "categoricalCrossentropy",
      metrics: ["accuracy"],
    })

    // model.add(
    //     tf.layers.dense({ units: baseElementTypesWithNone.length, activation: "sigmoid" })
    // )

    // model.compile({
    //     optimizer: tf.train.adam(0.1),
    //     loss: "meanSquaredError",
    //     metrics: ["accuracy"],
    // })

    const batchSize = 32
    const epochs = 250

    model
      .fit(autoencodedXs, labelTensor, {
        batchSize,
        epochs,
        shuffle: true,
        validationSplit: 0.2,
        callbacks: tfvis.show.fitCallbacks(
          { name: "Train Node Prediction" },
          ["loss", "mse", "acc"],
          { height: 200, callbacks: ["onEpochEnd"] }
        ),
      })
      .then(async (history) => {
        //TODO: This shouldn't be required if tf.tidy() worked, but seems like it too eagerly disposes of tensors
        tf.dispose(autoencodedXs)
        tf.dispose(labelTensor)

        logInfo.push(
          `Predict ${palette.BUSINESS_FUNCTION} - ${palette.ACCESS_RELATIONSHIP} - ?`
        )

        const doRec = (id, sourceNodeType, edgeType) => {
          if (baseElementTypes.includes(sourceNodeType)) {
            const rec = mlServices.recommendTargetNode(
              model,
              encoderModel,
              sourceNodeType,
              edgeType,
              baseElementTypes
            )
            console.log(
              `rec ${id}`,
              `${sourceNodeType} - ${edgeType} - ${rec
                .map((r) => `${r.type} ${(r.match * 100).toFixed(0)}%`)
                .join(", ")}`
            )
          } else {
            console.log(`rec ${id} - skipping ${sourceNodeType} - ${edgeType}`)
          }
        }

        doRec("1", palette.BUSINESS_FUNCTION, palette.ACCESS_RELATIONSHIP)
        doRec("2", palette.BUSINESS_ACTOR, palette.ASSIGNMENT_RELATIONSHIP)
        doRec("3", palette.BUSINESS_PROCESS, palette.COMPOSITION_RELATIONSHIP)
        doRec("4", palette.DATA_OBJECT, palette.REALIZATION_RELATIONSHIP)
        doRec(
          "5",
          palette.APPLICATION_COMPONENT,
          palette.ASSIGNMENT_RELATIONSHIP
        )
        doRec("6", palette.APPLICATION_SERVICE, palette.SERVING_RELATIONSHIP)
        doRec("7", palette.ASSESSMENT, palette.INFLUENCE_RELATIONSHIP)
        doRec("8", palette.WORK_PACKAGE, palette.REALIZATION_RELATIONSHIP)

        let savedModelInfo
        let savedEncoderModelInfo

        printTensors("NP 90")

        await model.save({
          save: (modelInfo) => {
            console.log("%csave modelInfo", "color:lightgreen", {
              modelInfo,
              weightData: modelInfo.weightData,
              history,
            })
            savedModelInfo = modelInfo
          },
        })

        tf.dispose(model)

        await encoderModel.save({
          save: (modelInfo) => {
            console.log("%csave encoder modelInfo", "color:lightgreen", {
              modelInfo,
              weightData: modelInfo.weightData,
            })
            savedEncoderModelInfo = modelInfo
          },
        })

        tf.dispose(encoderModel)

        printTensors("NP 100")

        await saveModel(
          savedModelInfo,
          mlServices.ML_MODEL_NODE_PREDICTION,
          trainTypes,
          history
        )

        await saveModel(
          savedEncoderModelInfo,
          mlServices.ML_MODEL_NODE_PREDICTION_ENCODER,
          trainTypes,
          history
        )
      })
  }

  // const autoencodeXs = async (xs) => {
  //     const beforeTensorCount = tf.memory().numTensors

  //     const featureSize = xs[0].length
  //     const middleUnits = Math.floor(xs[0].length / 2)
  //     console.log("%cautoEncode xs", "color:yellow", {
  //         xs,
  //         len: xs.length,
  //         middleUnits,
  //     })
  //     const model = tf.sequential()

  //     model.add(
  //         tf.layers.dense({
  //             inputShape: [featureSize],
  //             units: middleUnits,
  //             activation: "sigmoid",
  //         })
  //     )

  //     // Dropout rates help prevent overfitting
  //     model.add(tf.layers.dropout({ rate: 0.2 }))

  //     model.add(tf.layers.dense({ units: featureSize, activation: "sigmoid" }))

  //     model.compile({
  //         optimizer: tf.train.adam(0.05),
  //         loss: "meanSquaredError",
  //         metrics: ["accuracy"],
  //     })

  //     const batchSize = 32
  //     const epochs = 120

  //     const t_xs = tf.tensor(xs)

  //     await model.fit(t_xs, t_xs, {
  //         batchSize,
  //         epochs,
  //         shuffle: true,
  //         validationSplit: 0.2,
  //         callbacks: tfvis.show.fitCallbacks(
  //             { name: "Training Performance" },
  //             ["loss", "mse", "acc"],
  //             {
  //                 height: 200,
  //                 callbacks: ["onEpochEnd"],
  //             }
  //         ),
  //     })

  //     const preds = model.predict(t_xs)

  //     tf.dispose(t_xs)
  //     tf.dispose(model)

  //     const afterTensorCount = tf.memory().numTensors

  //     console.log("[autoencodeXs] tensor count change", {
  //         beforeTensorCount,
  //         afterTensorCount,
  //         diff: afterTensorCount - beforeTensorCount,
  //     })

  //     return preds
  // }

  const getElementTypesWithCount = (graph, min) => {
    const summary = getElementSummary(graph)

    const elementTypes = summary
      .filter((item) => item.count >= min)
      .map((item) => item.type)

    console.log("%cgetElementWithCount", "color:yellow", {
      summary,
      elementTypes,
    })
    return { elementTypes, summary }
  }

  const MIN_ELEMENT_COUNT = 0

  // const addEmbeddings = (graph, elementTypes) => {
  //     const conv1 = graph.map((element) => {
  //         //console.log("%caddEmbeddings", "color:yellow", { element, elementTypes })
  //         const baseEmbedding = mlServices.getElementVecTfBuf(element.type, elementTypes)
  //         const neighbourhood = [
  //             ...element.edges.in.map((inEdge) => inEdge.sourceNode()),
  //             //...element.edges.out.map((outEdge) => outEdge.targetNode()),
  //         ]
  //         const neighbourhoodEmbeddings = neighbourhood.map((neighbour) => {
  //             const embedding = mlServices.getElementVecTfBuf(neighbour.type, elementTypes)
  //             return embedding
  //         })

  //         const bufs = [baseEmbedding, ...neighbourhoodEmbeddings]
  //         //const bufs = [...neighbourhoodEmbeddings]
  //         const neighbourhoodElementVec = tf.tidy(() => {
  //             if (bufs.length === 0) {
  //                 const result = tf.buffer([1, elementTypes.length]).values
  //                 console.log("no elements. result", result)
  //                 return result
  //             } else {
  //                 //console.log("bufs", bufs)
  //                 const tensors = bufs.map((buf) => buf.toTensor())
  //                 const avg =
  //                     tensors.length === 1 ? tensors[0] : tf.layers.average().apply(tensors)
  //                 return avg.dataSync()
  //             }
  //         })

  //         //console.log("%caddEmbeddings result", "color:yellow", neighbourhoodElementVec)

  //         const result = {
  //             ...element,
  //             baseEmbedding: baseEmbedding.values,
  //             convResult: neighbourhoodElementVec,
  //             neighbourhoodEmbeddings: neighbourhoodEmbeddings,
  //         }
  //         return result
  //     })

  //     const conv2 = conv1.map((element) => {
  //         const conv1El = conv1.find((el) => el.id === element.id)

  //         const sourceIds = conv1El.edges.in
  //             .map((inEdge) => inEdge.sourceNode())
  //             .map((node) => node.id)
  //         const sourceEls = conv1.filter((el) => sourceIds.includes(el.id))

  //         // const targetIds = conv1El.edges.out
  //         //     .map((outEdge) => outEdge.targetNode())
  //         //     .map((node) => node.id)
  //         //const targetEls = conv1.filter((el) => targetIds.includes(el.id))
  //         // console.log("%caddEmbeddings result", "color:yellow", {
  //         //     conv1El,
  //         //     sourceIds,
  //         //     targetIds,
  //         //     sourceEls,
  //         //     targetEls,
  //         // })

  //         const conv2Neighbourhood = [
  //             element.convResult,
  //             ...sourceEls.map((el) => el.convResult),
  //             // ...targetEls.map((el) => el.convResult),
  //         ]

  //         //console.log("conv2Neighbourhood", conv2Neighbourhood)

  //         const conv2Result = tf.tidy(() => {
  //             if (conv2Neighbourhood.length === 0) {
  //                 const result = tf.buffer([1, elementTypes.length]).values
  //                 console.log("no elements. result", result)
  //                 return result
  //             } else {
  //                 console.log("conv2Neighbourhood", conv2Neighbourhood)
  //                 const tensors = conv2Neighbourhood.map((arr) => tf.tensor(arr))
  //                 const avg =
  //                     tensors.length === 1 ? tensors[0] : tf.layers.average().apply(tensors)
  //                 return avg.dataSync()
  //             }
  //         })

  //         return {
  //             ...element,
  //             conv2Result,
  //         }
  //     })

  //     return conv2
  // }

  const handleGetGraph = ({ useConv = false }) => {
    if (selectedFiles.length > 0) {
      const cachedModels = selectedFiles.map((file) => getCachedModel(file))

      console.log("cachedModels", { cachedModels, trainTypes })

      if (cachedModels) {
        console.log("tf backend", tf.getBackend())

        tf.tidy(() => {
          const graph = modelServices.buildGraph(cachedModels)

          // const A = tf.buffer([graph.length, graph.length])
          // for (var i = 0; i < 3; i++) {
          //     A.set(1, i, i)
          // }
          // graph.forEach((element, index) => {
          //     element.edges.out.forEach((edge) => {
          //         const targetIndex = graph.findIndex(
          //             (el) => el.id === edge.targetNode().id
          //         )
          //         A.set(1, index, targetIndex)
          //     })
          // })
          // console.log("A", A)

          console.log("%cgraph", "color:pink", graph)

          const { elementTypes, summary } = getElementTypesWithCount(
            graph,
            MIN_ELEMENT_COUNT
          )

          //const convGraph = addEmbeddings(graph, elementTypes)
          //console.log("convGraph", convGraph)

          const { xsTypes, labelTypes } = getEdgePredictionXsAndLabels(graph)

          const expandedElementTypes = _.uniq(
            _.flatten(
              xsTypes
                .filter((item) => elementTypes.includes(item.source))
                .map((item) => [item.source, item.target])
            )
          )
          console.log("%cedge prediction", "color:orange", {
            xsTypes,
            labelTypes,
            ALL_ELEMENT_TYPES,
            expandedElementTypes,
          })

          setTrainTypes(expandedElementTypes)
          //setTrainTypes(elementTypes)
          setElementCounts(summary)
          console.log("%ctraining element types", "color:lightgreen", {
            elementTypes,
            expandedElementTypes,
          })

          if (expandedElementTypes.length === 0) {
            enqueueSnackbar("Not enough elements to train", { variant: "info" })
            return
          }

          const filteredGraph = graph.filter((element) =>
            expandedElementTypes.includes(element.type)
          )
          // console.log("%cfilteredGraph", "color:lightgreen", {
          //     filteredGraph,
          //     expandedElementTypes,
          // })

          const graphWithVecs = filteredGraph
            .filter((element) => element.edges.out.length > 0)
            .map((element) => {
              const baseEmbeddings = getVec(
                element,
                expandedElementTypes,
                mlServices.getConnectorTypes()
              )

              const convResult = getConvolutionalElement(
                element,
                expandedElementTypes
              )
              //console.log("%cconvResult", "color:lightgreen", { element, convResult })

              const nodeAndEdgeOutVecs = element.edges.out
                .map((edge) => {
                  const ev = getEdgeVec(mlServices.getConnectorTypes(), edge)

                  const targetNode = edge.targetNode()
                  if (targetNode === undefined) {
                    return
                  }

                  const targetVec = mlServices.getElementVecTfBuf(
                    targetNode.type,
                    expandedElementTypes
                  )

                  // console.log(
                  //     "%cbaseEmbeddings",
                  //     "color:lightgreen",
                  //     baseEmbeddings.element
                  // )

                  return {
                    info: `${element.name} (${element.type}) -> [${
                      edge.type
                    }] -> ${edge.targetNode().name} (${
                      edge.targetNode().type
                    })`,

                    target_type: edge.targetNode().type,
                    source: useConv
                      ? [...baseEmbeddings.element, ...convResult]
                      : baseEmbeddings.element,
                    edge_out: ev.values,
                    node_and_edge_out: useConv
                      ? [...baseEmbeddings.element, ...convResult, ...ev.values]
                      : [...baseEmbeddings.element, ...ev.values],
                    target: [...targetVec.values],
                  }
                })
                // Remove where targetNode was pointing at a relationship
                // and not an element
                .filter((item) => item !== undefined)

              const nodeAndEdgeInVecs = element.edges.in.map((edge) => {
                const ev = getEdgeVec(mlServices.getConnectorTypes(), edge)
                const targetVec = mlServices.getElementVecTfBuf(
                  edge.targetNode().type,
                  expandedElementTypes
                )

                return {
                  info: `${element.name} (${element.type}) -> ${
                    edge.targetNode().name
                  } (${edge.targetNode().type})`,
                  target_type: element.type,
                  source: useConv
                    ? [...baseEmbeddings.element, ...convResult]
                    : baseEmbeddings.element,
                  edge_in: ev.values,
                  node_and_edge_in: useConv
                    ? [...baseEmbeddings.element, ...convResult, ...ev.values]
                    : [...baseEmbeddings.element, ...ev.values],
                  target: [...targetVec.values],
                }
              })

              const fullEmbeddings = {
                ...baseEmbeddings,
                element_conv: convResult,
                node_and_edges_out: nodeAndEdgeOutVecs,
                node_and_edges_in: nodeAndEdgeInVecs,
              }

              // console.log("%ccreating embedding", "color:lightgreen", {
              //     element,
              //     fullEmbeddings,
              //     baseEmbeddings,
              // })

              return {
                ...element,
                embedding: fullEmbeddings,
              }
            })

          //console.log("setting", { graphWithVecs })
          setGraph(graphWithVecs)
        })
      }
    }
  }

  const handleInputChange = (event) => {
    console.log("%chandleInputChange", "color:lightgreen", { event })

    const newValues = {
      ...values,
    }
    _.set(newValues, event.target.name, event.target.value)
    console.log("%cnewValues", "color:lightgreen", newValues)
    setValues(newValues)
  }

  return (
    <Header title={title}>
      <Box sx={styles.pageContent}>
        <Stack sx={styles.files} gap={1}>
          {!parents && (
            <Box>
              <ParentSkeleton />
              <ParentSkeleton />
              <ParentSkeleton />
            </Box>
          )}

          {parents &&
            parents.map((parent) => (
              <Box key={`${parent.id}-${parent.type}`}>
                <Box>
                  <Typography variant="h6">{parent.name}</Typography>
                </Box>
                <List dense>
                  {parent.files.map((file) => (
                    <ListItemButton
                      key={file}
                      onClick={() => handleSelectFile(parent, file)}
                      sx={styles.listItemRoot}
                      selected={
                        selectedFiles.find(
                          (entry) =>
                            entry.parentId === parent.id &&
                            entry.fileName === file
                        ) !== undefined
                      }
                    >
                      <Typography>{file}</Typography>
                    </ListItemButton>
                  ))}
                </List>
              </Box>
            ))}
          <Divider />
          <Box>
            <Controls.Button
              text="Clear"
              onClick={() => {
                setSelectedFiles([])
                setGraph(undefined)
              }}
            />
          </Box>
        </Stack>
        <Paper
          sx={{
            display: "flex",
            flexDirection: "column",
            flex: 1,
            maxWidth: 2000,
            padding: "10px",
          }}
        >
          <ModelStats modelStats={nodePredictionStats} />
          <ModelStats modelStats={edgePredictionStats} />

          <Stack sx={styles.buttons} direction="row" gap={1}>
            <Controls.Button
              text="Create Embeddings"
              onClick={() => handleGetGraph({ useConv })}
              tooltip="Create embeddings for selected files"
              disabled={selectedFiles.length === 0 || !isActive()}
            />

            {/* <Controls.Button
                            text='Tensors'
                            onClick={() => console.log(tf.memory().numTensors)}
                        /> */}
            <Controls.Button
              text="Train Node Prediction"
              onClick={() => {
                handleTrainNodePrediction(trainTypes)
              }}
              tooltip="Train AI to predict target node from source node and edges"
              disabled={selectedFiles.length === 0 || !graph || !isActive()}
            />
            {/* <Controls.Button
                            text='Create Encoders'
                            onClick={() => {
                                handleCreateEncoders(trainTypes)
                            }}
                            tooltip="Create encoders for node + edge combinations"
                        /> */}
            <Controls.Button
              text="Train Edge Prediction"
              onClick={() => handleTrainEdgePrediction(trainTypes)}
              tooltip="Train AI to predict edge types from source and target nodes"
              disabled={selectedFiles.length === 0 || !graph || !isActive()}
            />
            <Controls.Button
              text="Create View Embeddings"
              onClick={() => handleCreateViewEmbeddings()}
              disabled={selectedFiles.length === 0 || !graph || !isActive()}
              tooltip="Create embeddings to support view similarity detection"
            />
            {/* <Controls.Button
                            text='Chart'
                            onClick={() => {
                                const series1 = Array(100)
                                    .fill(0)
                                    .map((y) => Math.random() * 100 - Math.random() * 50)
                                    .map((y, x) => ({ x, y }))

                                const series2 = Array(100)
                                    .fill(0)
                                    .map((y) => Math.random() * 100 - Math.random() * 150)
                                    .map((y, x) => ({ x, y }))

                                const series = ["First", "Second"]
                                const data = { values: [series1, series2], series }

                                const surface = { name: "Scatterplot", tab: "Charts" }
                                tfvis.render.scatterplot(surface, data)
                            }}
                        /> */}
          </Stack>
          <Box sx={{ margin: "5px", mt: "20px" }}>
            <Typography fontWeight={"bold"}>Element Types</Typography>
            {selectedFiles.length === 0 && (
              <Box sx={{ mt: "10px" }}>
                <Alert severity="info">
                  Select 1 or more files on the left to use for training
                </Alert>
              </Box>
            )}
            {selectedFiles.length > 0 && !graph && (
              <Box sx={{ mt: "10px" }}>
                <Alert severity="info">
                  Select <b>Create Embeddings</b>
                </Alert>
              </Box>
            )}
            {selectedFiles.length > 0 && graph && (
              <Alert severity="info">
                Click <b>Train Node Prediction</b> or{" "}
                <b>Train Edge Prediction</b> to train ArchiBot based on the
                selected models
              </Alert>
            )}
          </Box>
          <Box sx={{ marginLeft: "10px" }}>
            {graph &&
              elementCounts &&
              elementCounts.map((item) => (
                <Box key={item.type} sx={styles.elementTypeCount}>
                  <Box sx={styles.elementTypeLabel}>
                    <Typography variant="caption">{item.type}</Typography>
                  </Box>
                  <Box>
                    <Typography
                      variant="caption"
                      sx={{
                        color:
                          item.count < MIN_ELEMENT_COUNT && colors.grey[400],
                      }}
                    >
                      {item.count}
                    </Typography>
                  </Box>
                </Box>
              ))}
          </Box>
          <Box sx={styles.edgePrediction}>
            <Typography fontWeight="bold">Edge Prediction</Typography>
            <Controls.Select
              name="edge_prediction.from"
              label={
                <Typography variant="body2" component={"span"}>
                  From
                </Typography>
              }
              value={values.edge_prediction.from}
              options={elementTypeOptions}
              onChange={handleInputChange}
            />

            <Controls.Select
              name="edge_prediction.to"
              label={
                <Typography variant="body2" component={"span"}>
                  To
                </Typography>
              }
              value={values.edge_prediction.to}
              options={elementTypeOptions}
              onChange={handleInputChange}
            />
            <Controls.Button text="Predict Edge" onClick={handlePredictEdge} />
            {recommendations &&
              recommendations.map((r) => (
                <Box key={r.type}>
                  <Typography variant="body2">{`${r.type} ${(
                    r.match * 100
                  ).toFixed(0)}%`}</Typography>
                </Box>
              ))}
          </Box>

          {trainLog && trainLog.length > 0 && (
            <Paper sx={styles.trainingLog}>
              <Box sx={{ mt: "20px" }}>
                <Typography fontWeight={"bold"}>Training Log</Typography>
              </Box>

              <Box sx={styles.logLines}>
                {trainLog.map((line, index) => (
                  <Typography key={`${line}-${index}`} variant="caption">
                    {line}
                  </Typography>
                ))}
              </Box>
            </Paper>
          )}

          {/* <Box sx={styles.elements}>
                        {graph &&
                            graph.map((element) => (
                                <Box sx={styles.row} key={`${element.id}-ELEMENT`}>
                                    <Box sx={{ minWidth: 75 }}>
                                        <Typography variant='caption'>{element.layer}</Typography>
                                    </Box>
                                    <Box sx={styles.elementInfo}>
                                        <Typography variant='caption'>{element.name}</Typography>
                                        <Box sx={styles.rels}>
                                            {element.edges.in.map((edge) => (
                                                <Typography
                                                    variant='caption'
                                                    key={`${element.id}-${edge.id}-IN`}
                                                >
                                                    IN {edge.type}
                                                </Typography>
                                            ))}
                                        </Box>
                                        <Box sx={styles.rels}>
                                            {element.edges.out.map((edge) => (
                                                <Typography
                                                    variant='caption'
                                                    key={`${element.id}-${edge.id}-OUT`}
                                                >
                                                    OUT {edge.type}
                                                </Typography>
                                            ))}
                                        </Box>
                                        <Box>
                                            <Typography sx={styles.vec} variant='caption'>
                                                element: {element.embedding.element}
                                            </Typography>
                                        </Box>
                                        <Box>
                                            <Typography sx={styles.vec} variant='caption'>
                                                edges_in: {element.embedding.edges_in}
                                            </Typography>
                                        </Box>
                                        <Box>
                                            <Typography sx={styles.vec} variant='caption'>
                                                edges_out: {element.embedding.edges_out}
                                            </Typography>
                                        </Box>

                                        <Box>
                                            <Typography variant='caption'>
                                                {Object.keys(element.embedding).join(", ")}
                                            </Typography>
                                        </Box>
                                    </Box>
                                </Box>
                            ))}
                    </Box> */}
        </Paper>
      </Box>
    </Header>
  )
}

const ModelStats = (props) => {
  const { modelStats } = props

  const asPercent = (num) => (num * 100).toFixed(2)

  const splitWordBasedOnCamelCase = (word) => {
    return word
      ?.replace(/([a-z0-9])([A-Z])/g, "$1 $2")
      .replace(/([A-Z]+)([A-Z][a-z0-9])/g, "$1 $2")
  }

  return (
    <>
      {modelStats && (
        <Stack sx={{ marginTop: "15px" }}>
          <Box>
            <Typography variant="body2" sx={{ fontWeight: "bold", mb: "5px" }}>
              {splitWordBasedOnCamelCase(modelStats.type)}
            </Typography>
          </Box>
          <Typography>
            Loss: {modelStats?.train_meta?.loss.toFixed(5)}
          </Typography>
          <Typography>
            Accuracy: {asPercent(modelStats?.train_meta?.accuracy)}%
          </Typography>
          <Typography variant="caption">
            Elements: {modelStats?.train_meta?.element_types?.join(", ")}
          </Typography>
        </Stack>
      )}
    </>
  )
}

const ParentSkeleton = (props) => {
  return (
    <Box>
      <Skeleton
        variant="rect"
        width="290px"
        height={20}
        sx={{ marginTop: "15px", marginBottom: "5px" }}
      />
      <FileSkeleton />
      <FileSkeleton />
      <FileSkeleton />
    </Box>
  )
}

const FileSkeleton = (props) => {
  return (
    <Skeleton
      sx={{ marginLeft: "20px", marginTop: "10px" }}
      variant="rect"
      width="270px"
      height={20}
    />
  )
}

export default ML
