import warnings

import numpy as np

from .utils import norm, zeros


def lanczos(A, q_1, k=None, reorthogonalize=False, beta_tol=0):
    """See Stability of Lanczos Method page 5"""
    # assert np.allclose(A, A.T)  # can't do this with sparse matrix
    n = len(q_1)
    assert A.shape[0] == A.shape[1] == n

    if k is None:
        k = n

    result_type = np.result_type(A, q_1)
    if np.issubdtype(result_type, np.integer):
        result_type = np.float64
    Q = np.empty((n, k), dtype=result_type)
    alpha = np.empty(k, dtype=result_type)
    beta = np.empty(k - 1, dtype=result_type)  # this is really beta_2, beta_3, ...

    Q[:, 0] = q_1 / norm(q_1)
    next_q = A @ Q[:, 0]
    alpha[0] = next_q @ Q[:, 0]
    next_q -= alpha[0] * Q[:, 0]

    for i in range(1, k):
        beta[i - 1] = norm(next_q)
        if beta[i - 1] <= beta_tol:
            # TODO: in some cases we may just want to continue from a new vector
            return (
                np.atleast_2d(Q[:, :i]),
                (alpha[:i], beta[: (i - 1)]),
                zeros(n, dtype=result_type),
            )
        if i > n:
            warnings.warn(f"Lanczos iteration greater than dimension ({n}) due to numerical error. Try increasing beta_tol ({beta_tol}) or precision.")

        Q[:, i] = next_q / beta[i - 1]

        next_q = A @ Q[:, i]
        alpha[i] = next_q @ Q[:, i]
        next_q -= alpha[i] * Q[:, i]
        next_q -= beta[i - 1] * Q[:, i - 1]

        assert not reorthogonalize, "We are doing reorthogonalization all wrong. Don't use it for now"
        if reorthogonalize:
            for _ in range(2):
                next_q -= Q[:, : (i + 1)] @ (Q[:, : (i + 1)].T @ next_q)

    if k == n:
        next_q = zeros(n, dtype=result_type)
    # AQ = QT + next_q e_k.T where T = SymmetricTridiagonal(a, b)
    return Q, (alpha, beta), next_q
