/*
 * pyspig.cpp: a Python extension which exposes the more sensible end
 * of spigot's internal API.
 */

#include <string>

#include <Python.h>
#include "structmember.h"

#include "spigot.h"
#include "funcs.h"
#include "expr.h"
#include "error.h"
#include "baseout.h"

enum SpigState {
    SS_EMPTY,
    SS_FILLED,
    SS_FORMATTING,
    SS_CFRAC
};

typedef struct {
    PyObject_HEAD
    Spigot *spig;
    CfracGenerator *cfg;
    OutputGenerator *og;
    bool og_finished;
    SpigState state;
} SpigotPy;

static PyObject *Spigot_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
{
    SpigotPy *self;

    self = (SpigotPy *)type->tp_alloc(type, 0);
    if (self != NULL) {
        self->spig = NULL;
        self->cfg = NULL;
        self->og = NULL;
        self->state = SS_EMPTY;
    }

    return (PyObject *)self;
}

struct python_error {};

void spigot_check_exception()
{
    if (PyErr_CheckSignals())
        throw python_error();
}

static void spigot_emplace(SpigotPy *self, Spigot *spig)
{
    if (self) {
        if (self->cfg)
            delete self->cfg;
        if (self->spig)
            delete self->spig;
        if (self->og)
            delete self->og;
        self->spig = spig;
        self->cfg = NULL;
        self->og = NULL;
        self->state = (spig ? SS_FILLED : SS_EMPTY);
    }
}

static void Spigot_dealloc(SpigotPy *self)
{
    spigot_emplace(self, NULL);
    self->ob_type->tp_free((PyObject *)self);
}

static PyObject *Spigot_parse(SpigotPy *self, PyObject *args);
static PyObject *Spigot_clone(SpigotPy *self, PyObject *args);
static PyObject *Spigot_base(SpigotPy *self, PyObject *args, PyObject *kwds);
static PyObject *Spigot_ieee(SpigotPy *self, PyObject *args, PyObject *kwds);
static PyObject *Spigot_readfmt(SpigotPy *self, PyObject *args);
static PyObject *Spigot_cfracterm(SpigotPy *self, PyObject *args);
static PyObject *Spigot_sign(SpigotPy *self, PyObject *args);

static PyMethodDef Spigot_methods[] = {
    {"parse", (PyCFunction)Spigot_parse, METH_VARARGS,
     "Initialise a spigot by parsing an expression."
    },
    {"clone", (PyCFunction)Spigot_clone, METH_VARARGS,
     "Initialise a spigot by cloning from another spigot."
    },
    {"base", (PyCFunction)Spigot_base, METH_VARARGS,
     "Prepare a spigot for formatting data in a positional base."
    },
    {"ieee", (PyCFunction)Spigot_ieee, METH_VARARGS,
     "Prepare a spigot for formatting data as an (optionally-extended)"
     " IEEE hex representation."
    },
    {"readfmt", (PyCFunction)Spigot_readfmt, METH_NOARGS,
     "Read some data from a spigot that is in formatting mode."
    },
    {"cfracterm", (PyCFunction)Spigot_cfracterm, METH_NOARGS,
     "Read a continued fraction term from a spigot."
    },
    {"sign", (PyCFunction)Spigot_sign, METH_NOARGS,
     "Return the sign (-1, +1, or possibly 0) of the number in a spigot."
    },
    {NULL}  /* Sentinel */
};

static PyMemberDef Spigot_members[] = {
    {NULL}  /* Sentinel */
};

static PyTypeObject spigot_SpigotType = {
    PyObject_HEAD_INIT(NULL)
    0,                         /*ob_size*/
    "spigot.internal.Spigot",  /*tp_name*/
    sizeof(SpigotPy),          /*tp_basicsize*/
    0,                         /*tp_itemsize*/
    (destructor)Spigot_dealloc,        /*tp_dealloc*/
    0,                         /*tp_print*/
    0,                         /*tp_getattr*/
    0,                         /*tp_setattr*/
    0,                         /*tp_compare*/
    0,                         /*tp_repr*/
    0,                         /*tp_as_number*/
    0,                         /*tp_as_sequence*/
    0,                         /*tp_as_mapping*/
    0,                         /*tp_hash */
    0,                         /*tp_call*/
    0,                         /*tp_str*/
    0,                         /*tp_getattro*/
    0,                         /*tp_setattro*/
    0,                         /*tp_as_buffer*/
    Py_TPFLAGS_DEFAULT,        /*tp_flags*/
    "An object containing a spigot representation of a real number", /* tp_doc */
    0,		               /* tp_traverse */
    0,		               /* tp_clear */
    0,		               /* tp_richcompare */
    0,		               /* tp_weaklistoffset */
    0,		               /* tp_iter */
    0,		               /* tp_iternext */
    Spigot_methods,            /* tp_methods */
    Spigot_members,            /* tp_members */
    0,                         /* tp_getset */
    0,                         /* tp_base */
    0,                         /* tp_dict */
    0,                         /* tp_descr_get */
    0,                         /* tp_descr_set */
    0,                         /* tp_dictoffset */
    0,                         /* tp_init */
    0,                         /* tp_alloc */
    Spigot_new,                /* tp_new */
};

struct PythonGlobalScope : GlobalScope {
    PyObject *scope;

    PythonGlobalScope(PyObject *ascope) : scope(ascope) { }

    static Spigot *lookup_recursive(PyObject *scope, const char *varname) {
        if (!scope)
            return NULL;
        if (PySequence_Check(scope)) {
            int n = PySequence_Length(scope);
            for (int i = 0; i < n; i++) {
                PyObject *ith = PySequence_GetItem(scope, i);
                Spigot *ret = lookup_recursive(ith, varname);
                if (ret)
                    return ret;
            }
            return NULL;
        }

        PyObject *found;
        if (PyMapping_Check(scope)) {
            if (PyMapping_HasKeyString(scope, (char *)varname))
                found = PyMapping_GetItemString(scope, (char *)varname);
            else
                found = NULL;
        } else if (PyCallable_Check(scope)) {
            found = PyObject_CallFunction(scope, "s", varname);
        } else {
            /* FIXME: it would be nicer to throw a Python TypeError here */
            return NULL;
        }

        if (!found)
            return NULL;

        if (!PyObject_IsInstance(found, (PyObject *)&spigot_SpigotType)) {
            /*
             * If this isn't the internal spigot class, it might be
             * the wrapper on the Python side which has an 'sp'
             * attribute.
             */
            found = PyObject_GetAttrString(found, "sp");

            if (!found ||
                !PyObject_IsInstance(found, (PyObject *)&spigot_SpigotType)) {
                /* FIXME: it would be nicer to throw a Python TypeError here,
                 * at least if found isn't the None object */
                return NULL;
            }
        }

        return ((SpigotPy *)found)->spig->clone();
    }

    virtual Spigot *lookup(const char *varname) {
        return lookup_recursive(scope, varname);
    }
};

static PyObject *Spigot_parse(SpigotPy *self, PyObject *args)
{
    const char *expr = NULL;
    PyObject *pyscope = NULL;

    if (!PyArg_ParseTuple(args, "s|O", &expr, &pyscope))
        return NULL;

    try {
        PythonGlobalScope scope(pyscope);
        spigot_emplace(self, expr_parse(expr, &scope));
    } catch (spigot_error err) {
        PyErr_SetString(PyExc_ValueError, err.errmsg);
        return NULL;
    } catch (python_error) {
        return NULL;
    }

    Py_RETURN_NONE;
}

static PyObject *Spigot_clone(SpigotPy *self, PyObject *args)
{
    SpigotPy *from;

    if (!PyArg_ParseTuple(args, "O!", &spigot_SpigotType, (PyObject **)&from))
        return NULL;

    spigot_emplace(self, from->spig ? from->spig->clone() : NULL);

    Py_RETURN_NONE;
}

static char *base_keywords[] = { "base", "digitlimit", "rmode",
                                 "minintdigits", "uppercase", NULL };

static PyObject *Spigot_base(SpigotPy *self, PyObject *args, PyObject *kwds)
{
    int base = 10, uppercase = 0, digitlimit = 0;
    int rmode = -1, minintdigits = 0;

    if (!PyArg_ParseTupleAndKeywords(args, kwds, "|iiiii", base_keywords,
                                     &base, &digitlimit, &rmode,
                                     &minintdigits, &uppercase))
        return NULL;

    if (self->state == SS_EMPTY) {
        PyErr_SetString(PyExc_RuntimeError, "spigot not initialised");
        return NULL;
    } else if (self->state != SS_FILLED) {
        PyErr_SetString(PyExc_RuntimeError, "spigot already formatting");
        return NULL;
    }

    assert(!self->og);
    self->og = base_format(self->spig->clone(), base, uppercase,
                           rmode != -1, digitlimit,
                           (RoundingMode)rmode, minintdigits);
    self->og_finished = false;
    self->state = SS_FORMATTING;

    Py_RETURN_NONE;
}

static char *ieee_keywords[] = { "bits", "digitlimit", "rmode", NULL };

static PyObject *Spigot_ieee(SpigotPy *self, PyObject *args, PyObject *kwds)
{
    int ieee_bits = 64, digitlimit = 0, rmode = -1;

    if (!PyArg_ParseTupleAndKeywords(args, kwds, "|iii", ieee_keywords,
                                     &ieee_bits, &digitlimit, &rmode))
        return NULL;

    if (self->state == SS_EMPTY) {
        PyErr_SetString(PyExc_RuntimeError, "spigot not initialised");
        return NULL;
    } else if (self->state != SS_FILLED) {
        PyErr_SetString(PyExc_RuntimeError, "spigot already formatting");
        return NULL;
    }

    assert(!self->og);
    self->og = ieee_format(self->spig->clone(), ieee_bits,
                           rmode != -1, digitlimit, (RoundingMode)rmode);
    self->og_finished = false;
    self->state = SS_FORMATTING;

    Py_RETURN_NONE;
}

static PyObject *Spigot_readfmt(SpigotPy *self, PyObject *args)
{
    if (self->state != SS_FORMATTING) {
        PyErr_SetString(PyExc_RuntimeError, "spigot not formatting");
        return NULL;
    }

    std::string out;

    if (self->og_finished)
        return PyString_FromString("");

    try {
        while (true) {
            if (!self->og->get_definite_output(out)) {
                self->og_finished = true;
                return PyString_FromString("");
            }
            if (out.size() > 0)
                return PyString_FromString(out.c_str());
        }
    } catch (python_error) {
        return NULL;
    }
}

static PyObject *bigint_to_pylong(const bigint &n)
{
    /*
     * I didn't find anything in the Python embedding docs that
     * permits the construction of a Python Long object from raw
     * binary data. Hex-formatted ASCII is therefore the most
     * efficient interchange format I can find.
     */
    char *hex = bigint_hexstring(n);
    PyObject *ret = PyLong_FromString(hex, NULL, 16);
    free(hex);
    return ret;
}

static PyObject *Spigot_cfracterm(SpigotPy *self, PyObject *args)
{
    if (self->state == SS_EMPTY) {
        PyErr_SetString(PyExc_RuntimeError, "spigot not initialised");
        return NULL;
    } else if (self->state != SS_FILLED && self->state != SS_CFRAC) {
        PyErr_SetString(PyExc_RuntimeError, "spigot already formatting");
        return NULL;
    }
    if (self->state != SS_CFRAC) {
        assert(self->spig && !self->cfg);
        self->cfg = new CfracGenerator(self->spig->clone());
        self->state = SS_CFRAC;
    }

    try {
        bigint term;
        if (!self->cfg->get_term(&term)) {
            Py_RETURN_NONE;
        }
        return bigint_to_pylong(term);
    } catch (python_error) {
        return NULL;
    }
}

static PyObject *Spigot_sign(SpigotPy *self, PyObject *args)
{
    if (self->state == SS_EMPTY) {
        PyErr_SetString(PyExc_RuntimeError, "spigot not initialised");
        return NULL;
    } else if (self->state != SS_FILLED) {
        PyErr_SetString(PyExc_RuntimeError, "spigot already formatting");
        return NULL;
    }

    try {
        StaticGenerator test(self->spig->clone());
        return PyInt_FromLong((long)test.get_sign());
    } catch (python_error) {
        return NULL;
    }
}

static PyMethodDef spigot_module_methods[] = {
    {NULL}  /* Sentinel */
};

#ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */
#define PyMODINIT_FUNC void
#endif
PyMODINIT_FUNC initinternal(void) 
{
    PyObject* m;

    spigot_SpigotType.tp_new = PyType_GenericNew;
    if (PyType_Ready(&spigot_SpigotType) < 0)
        return;

    m = Py_InitModule3("spigot.internal", spigot_module_methods,
                       "Module providing Python bindings for spigot-based exact real calculation.");

    Py_INCREF(&spigot_SpigotType);
    PyModule_AddObject(m, "Spigot", (PyObject *)&spigot_SpigotType);
}
