import { Terms } from '@blockscholes/ql'
import { syntaxTree } from '@codemirror/language'
import { EditorState } from '@codemirror/state'
import { SyntaxNode, Tree } from '@lezer/common'
import { Diagnostic } from '@codemirror/lint'
import {
  ParseTreeNode,
  ParseTreeRoot,
  IHistoricalAnalyzerDataSource,
  ValueType,
} from './types'
import { getFunction } from './language'

function walkThrough(
  node: SyntaxNode,
  ...path: (number | string)[]
): SyntaxNode | null {
  const { cursor } = node
  let i = 0
  let cursorIsMoving = true
  path.unshift(cursor.type.id)
  while (i < path.length && cursorIsMoving) {
    if (cursor.type.id === path[i] || cursor.type.name === path[i]) {
      i += 1
      if (i < path.length) {
        cursorIsMoving = cursor.next()
      }
    } else {
      cursorIsMoving = cursor.nextSibling()
    }
  }
  if (i >= path.length) {
    return cursor.node
  }
  return null
}

function retrieveAllRecursiveNodes(
  parentNode: SyntaxNode | null,
  recursiveNode: number,
  leaf: number,
): SyntaxNode[] {
  const nodes: SyntaxNode[] = []

  function recursiveRetrieveNode(node: SyntaxNode | null) {
    const subNode = node?.getChild(recursiveNode)
    const le = node?.lastChild
    if (subNode && subNode.type.id === recursiveNode) {
      recursiveRetrieveNode(subNode)
    }
    if (le && le.type.id === leaf) {
      nodes.push(le)
    }
  }

  recursiveRetrieveNode(parentNode)
  return nodes
}

// TODO: make into method to avoid passing through sources and state
export function getType(
  node: SyntaxNode | null,
  sources: Array<IHistoricalAnalyzerDataSource>,
  state: EditorState,
): ValueType {
  if (!node) {
    return ValueType.none
  }
  switch (node.type.id) {
    case Terms.Expr:
      return getType(node.firstChild, sources, state)
    case Terms.FunctionCall: {
      const funcNode = node.firstChild?.firstChild
      if (!funcNode) {
        return ValueType.none
      }
      return getFunction(funcNode.type.id).returnType
    }
    case Terms.Identifier: {
      const nodeText = state.sliceDoc(node.from, node.to)
      // FIXME: we can assume that the vast majority, if not all, identifiers are time series vectors
      // but we need a more robust solution
      const sourceType = sources.find((source) => source.label === nodeText)
      return sourceType?.type === ValueType.scalar
        ? ValueType.scalar
        : ValueType.vector
    }
    case Terms.NumberLiteral:
      return ValueType.scalar
    case Terms.UnaryExpression:
      return getType(walkThrough(node, Terms.Expr), sources, state)
    case Terms.BinaryExpression: {
      const lt = getType(node.firstChild, sources, state)
      const rt = getType(node.lastChild, sources, state)
      if (lt === ValueType.scalar && rt === ValueType.scalar) {
        return ValueType.scalar
      }
      return ValueType.vector
    }
    default:
      return ValueType.none
  }
}

function makeParseTreeNode(type: string, isError: boolean): ParseTreeNode {
  return {
    type,
    isError,
    children: [],
  }
}

export class Parser {
  public readonly parseTree: ParseTreeRoot
  public readonly rootType: ValueType

  private readonly tree: Tree
  private readonly state: EditorState
  private readonly diagnostics: Diagnostic[]
  private readonly sources: Array<IHistoricalAnalyzerDataSource>

  constructor(
    state: EditorState,
    sources: Array<IHistoricalAnalyzerDataSource>,
  ) {
    this.tree = syntaxTree(state)
    this.state = state
    this.sources = sources
    this.diagnostics = []
    this.rootType = this.checkParseTree(this.tree.topNode.firstChild)
    this.parseTree = this.parse()
  }
  getDiagnostics(): Diagnostic[] {
    return this.diagnostics.sort((a, b) => {
      return a.from - b.from
    })
  }
  isValid(): boolean {
    return this.diagnostics.length === 0 && !this.parseTree.treeInError
  }
  inputString(): string {
    return this.state.doc.toString()
  }
  checkParseTree(node: SyntaxNode | null): ValueType {
    if (!node) {
      return ValueType.none
    }
    switch (node.type.id) {
      case Terms.Expr:
        return this.checkParseTree(node.firstChild)
      case Terms.FunctionCall:
        this.checkCallFunction(node)
        break
      case Terms.BinaryExpression:
        this.checkBinaryExpr(node)
        break
      case Terms.UnaryExpression: {
        const unaryExprType = this.checkParseTree(walkThrough(node, Terms.Expr))
        if (
          unaryExprType !== ValueType.scalar &&
          unaryExprType !== ValueType.vector
        ) {
          this.addDiagnostic(
            node,
            `unary expression only allowed on expressions of type scalar or vector, got ${unaryExprType}`,
          )
        }
        break
      }
      default: {
        break
      }
    }
    return getType(node, this.sources, this.state)
  }
  private checkCallFunction(node: SyntaxNode): void {
    const funcID = node.firstChild?.firstChild
    if (!funcID) {
      this.addDiagnostic(node, 'function not defined')
      return
    }

    const args = retrieveAllRecursiveNodes(
      walkThrough(node, Terms.FunctionCallBody),
      Terms.FunctionCallArgs,
      Terms.Expr,
    )
    const funcSignature = getFunction(funcID.type.id)
    const nargs = funcSignature.minArgs

    if (!funcSignature.variadic) {
      if (args.length !== nargs) {
        this.addDiagnostic(
          node,
          `expected ${nargs} argument(s) in call to "${funcSignature.name}", got ${args.length}`,
        )
      }
    } else {
      const na = nargs // - 1
      if (na > args.length) {
        this.addDiagnostic(
          node,
          `expected at least ${na} argument(s) in call to "${funcSignature.name}", got ${args.length}`,
        )
      }
    }

    // TODO: this code needs updated slightly to match the eventual BSLang type system
    // let j = 0
    // for (let i = 0; i < args.length; i += 1) {
    //   j = i
    //   if (j >= funcSignature.supportedArgTypes.length) {
    //     if (!funcSignature.variadic) {
    //       // This is not a vararg function so we should not check the
    //       // type of the extra arguments.
    //       break
    //     }
    //     j = funcSignature.supportedArgTypes.length - 1
    //   }
    //   this.expectType(
    //     args[i],
    //     funcSignature.supportedArgTypes,
    //     `call to function "${funcSignature.name}"`,
    //   )
    // }
  }

  private checkBinaryExpr(node: SyntaxNode): void {
    const lExpr = node.firstChild
    const rExpr = node.lastChild
    if (!lExpr || !rExpr) {
      this.addDiagnostic(
        node,
        'left or right expression is missing in binary expression',
      )
      return
    }
    const lt = this.checkParseTree(lExpr)
    const rt = this.checkParseTree(rExpr)

    if (lt !== ValueType.scalar && lt !== ValueType.vector) {
      this.addDiagnostic(
        lExpr,
        'binary expression must contain only scalar and vector types',
      )
    }
    if (rt !== ValueType.scalar && rt !== ValueType.vector) {
      this.addDiagnostic(
        rExpr,
        'binary expression must contain only scalar and vector types',
      )
    }
  }

  private expectType(
    node: SyntaxNode,
    want: ValueType[],
    context: string,
  ): void {
    const t = this.checkParseTree(node)
    if (!want.includes(t)) {
      this.addDiagnostic(
        node,
        `expected type ${Object.values(want).join(
          ', ',
        )} in ${context}, got ${t}`,
      )
    }
  }
  private addDiagnostic(node: SyntaxNode, msg: string): void {
    this.diagnostics.push({
      severity: 'error',
      message: msg,
      from: node.from,
      to: node.to,
    })
  }
  private parse(): ParseTreeRoot {
    const head = this.tree.topNode
    const programNode = makeParseTreeNode(head.name, head.type.isError)
    const root: ParseTreeRoot = {
      treeInError: programNode.isError,
      program: programNode,
    }

    const helper = (parent: ParseTreeNode, node: SyntaxNode | null) => {
      if (!node?.name) {
        return
      }

      const newNode = makeParseTreeNode(node.name, node.type.isError)
      if (newNode.isError) {
        root.treeInError = true
      }

      // TODO: use node type here rather than strings
      if (
        node.name === 'Identifier' ||
        node.name === 'NumberLiteral' ||
        node.name === 'UnaryOp'
      ) {
        newNode.content = this.state.sliceDoc(node.from, node.to)
      }

      parent.children.push(newNode)

      helper(newNode, node.firstChild)
      helper(parent, node.nextSibling)
    }

    helper(root.program, head.firstChild)

    return root
  }
}
