import logging
from fastapi import APIRouter

from app.utils.extra import chat_history_to_base_messages
from app.rag import RAG
from app.config.application import settings
from app.exceptions.http import HTTPException
from app.schemas.chat_schema import ChatRequest

from openai import RateLimitError
from langchain_openai.chat_models.base import ChatOpenAI
from langchain.callbacks.manager import collect_runs, atrace_as_chain_group

router = APIRouter()
log = logging.getLogger(__name__)


@router.post(
    "/chat",
    summary="Chat with data",
    status_code=200,
)
async def chat(args:ChatRequest):
            
    log.info("Started GET /chat")
    try: 
        llm = ChatOpenAI(api_key=settings.OPENAI_KEY,model="gpt-4o-2024-08-06")
        chain = RAG(
            llm=llm,
            query=args.query,
            db_user=settings.DB_USERNAME,
            db_password=settings.DB_PASSWORD,
            db_name=settings.DB_NAME,
            db_host=settings.DB_HOST
        ) 
        chat_history = chat_history_to_base_messages(chat_history=args.chat_history)
        
        with collect_runs() as runs_cb:
            async with atrace_as_chain_group("Chat_with_data", inputs={"question": args.query}) as group_manager:
                result = await chain.ainvoke(callbacks=group_manager,chat_history=chat_history)
                await group_manager.on_chain_end({"output": ""})
                run_id = runs_cb.traced_runs[0].id.__str__()
        return result
    
    except RateLimitError as e:
        log.error(e)
        return HTTPException(status_code=e.status_code, content=e.message)
    except Exception as e:
        log.error(e)
        return HTTPException(status_code=500, content="Internal Server Error")