import unittest

from aiel_sdk.integrations.openai.facade import ChatOpenAI, OpenAIEmbeddings


class _FakeIntegrationsService:
    def __init__(self, embeddings_shape: str = "openai"):
        self.calls = []
        self.embeddings_shape = embeddings_shape

    def invoke_action(self, workspace_id, project_id, connection_id, action, payload=None, request_id=None):
        self.calls.append(
            {
                "workspace_id": workspace_id,
                "project_id": project_id,
                "connection_id": connection_id,
                "action": action,
                "payload": payload,
                "request_id": request_id,
            }
        )
        if action == "openai.chat.completions.create":
            return {"id": "chatcmpl_1", "choices": [{"message": {"role": "assistant", "content": "Bonjour."}}]}
        if action == "openai.embeddings.create":
            inp = (payload or {}).get("params", {}).get("input")
            if self.embeddings_shape == "openai":
                if isinstance(inp, list):
                    return {"data": [{"embedding": [0.1, 0.2]}, {"embedding": [0.3, 0.4]}]}
                return {"data": [{"embedding": [0.9, 0.8]}]}
            if self.embeddings_shape == "single_key":
                return {"embedding": [0.7, 0.6]}
            if self.embeddings_shape == "embeddings_key":
                return {"embeddings": [[0.5, 0.4], [0.3, 0.2]]}
            if self.embeddings_shape == "nested":
                return {"data": {"data": [{"embedding": [0.11, 0.22]}]}}
        return {}


class _FakeClient:
    def __init__(self, embeddings_shape: str = "openai"):
        self.integrations = _FakeIntegrationsService(embeddings_shape=embeddings_shape)


class OpenAILangChainStyleTests(unittest.TestCase):
    def test_chat_openai_invoke(self):
        client = _FakeClient()
        llm = ChatOpenAI(
            client=client,
            workspace_id="w1",
            project_id="p1",
            connection_id="c1",
            model="gpt-4.1",
            temperature=0,
        )

        out = llm.invoke(
            [
                {"role": "system", "content": "Translate to French."},
                {"role": "user", "content": "I love programming."},
            ],
            request_id="req_1",
        )

        self.assertEqual(out.get("id"), "chatcmpl_1")
        call = client.integrations.calls[-1]
        self.assertEqual(call["action"], "openai.chat.completions.create")
        self.assertEqual(call["request_id"], "req_1")
        self.assertEqual(call["payload"]["params"]["model"], "gpt-4.1")
        self.assertEqual(call["payload"]["params"]["temperature"], 0)

    def test_openai_embeddings_helpers(self):
        client = _FakeClient()
        emb = OpenAIEmbeddings(
            client=client,
            workspace_id="w1",
            project_id="p1",
            connection_id="c1",
            model="text-embedding-3-small",
        )

        q = emb.embed_query("hello")
        d = emb.embed_documents(["a", "b"])

        self.assertEqual(q, [0.9, 0.8])
        self.assertEqual(d, [[0.1, 0.2], [0.3, 0.4]])

    def test_openai_embeddings_single_embedding_shape(self):
        client = _FakeClient(embeddings_shape="single_key")
        emb = OpenAIEmbeddings(
            client=client,
            workspace_id="w1",
            project_id="p1",
            connection_id="c1",
            model="text-embedding-3-small",
        )
        q = emb.embed_query("hello")
        self.assertEqual(q, [0.7, 0.6])

    def test_openai_embeddings_embeddings_key_shape(self):
        client = _FakeClient(embeddings_shape="embeddings_key")
        emb = OpenAIEmbeddings(
            client=client,
            workspace_id="w1",
            project_id="p1",
            connection_id="c1",
            model="text-embedding-3-small",
        )
        d = emb.embed_documents(["a", "b"])
        self.assertEqual(d, [[0.5, 0.4], [0.3, 0.2]])

    def test_openai_embeddings_nested_shape(self):
        client = _FakeClient(embeddings_shape="nested")
        emb = OpenAIEmbeddings(
            client=client,
            workspace_id="w1",
            project_id="p1",
            connection_id="c1",
            model="text-embedding-3-small",
        )
        q = emb.embed_query("hello")
        self.assertEqual(q, [0.11, 0.22])


if __name__ == "__main__":
    unittest.main(verbosity=2)
