from typing import Any

from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.schema.output import LLMResult
from langchain.output_parsers.json import parse_partial_json 

class MultipleLLMAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
    tokens: str = ""
    answer: str = ""
    previous_run: str = ""
    is_answer: bool = False

    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        if token is not None and token != "":
            self.queue.put_nowait(token)
        
        chunk = kwargs['chunk'].dict()
        
        run_id = kwargs['run_id']
        if run_id != self.previous_run:
            self.tokens = ""
        self.previous_run = run_id
        if 'message' in chunk and 'additional_kwargs' in chunk['message'] and 'function_call' in chunk['message']['additional_kwargs'] and 'arguments' in chunk['message']['additional_kwargs']['function_call']:
            token = chunk['message']['additional_kwargs']['function_call']['arguments']
            self.tokens += token
            escaped_text = None
            if self.tokens:
                escaped_text = parse_partial_json(self.tokens)
                print(escaped_text)
            if escaped_text is not None and 'output' in escaped_text and 'answer' in escaped_text['output'] and escaped_text['output']['answer']:
                new_token = escaped_text['output']['answer'][len(self.answer):]
                if new_token == "": 
                    return
                self.queue.put_nowait(new_token)
                self.answer = escaped_text['output']['answer']

    async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """
        Custom on_llm_end which set event done only when all LLM chain response generation is completed
        """
        if (response.generations[0][0].text is not None and response.generations[0][0].text != ""):
            self.done.set()