Skip to content

bindings/python/moraine_conversations/tests/test_smoke.py


     1 import json
     2 import multiprocessing as mp
     3 from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
     4 from urllib.parse import parse_qs, urlparse
     5  
     6 import pytest
     7  
     8 from moraine_conversations import ConversationClient
     9  
    10  
    11 def _session_rows() -> list[dict]:
    12     return [
    13         {
    14             "session_id": "sess_c",
    15             "first_event_time": "2026-01-03 10:00:00",
    16             "first_event_unix_ms": 1767434400000,
    17             "last_event_time": "2026-01-03 10:10:00",
    18             "last_event_unix_ms": 1767435000000,
    19             "total_turns": 3,
    20             "total_events": 30,
    21             "user_messages": 6,
    22             "assistant_messages": 6,
    23             "tool_calls": 3,
    24             "tool_results": 3,
    25             "mode": "web_search",
    26         },
    27         {
    28             "session_id": "sess_b",
    29             "first_event_time": "2026-01-02 10:00:00",
    30             "first_event_unix_ms": 1767348000000,
    31             "last_event_time": "2026-01-02 10:10:00",
    32             "last_event_unix_ms": 1767348600000,
    33             "total_turns": 2,
    34             "total_events": 22,
    35             "user_messages": 4,
    36             "assistant_messages": 4,
    37             "tool_calls": 2,
    38             "tool_results": 2,
    39             "mode": "web_search",
    40         },
    41         {
    42             "session_id": "sess_a",
    43             "first_event_time": "2026-01-01 10:00:00",
    44             "first_event_unix_ms": 1767261600000,
    45             "last_event_time": "2026-01-01 10:10:00",
    46             "last_event_unix_ms": 1767262200000,
    47             "total_turns": 2,
    48             "total_events": 20,
    49             "user_messages": 4,
    50             "assistant_messages": 4,
    51             "tool_calls": 2,
    52             "tool_results": 2,
    53             "mode": "web_search",
    54         },
    55     ]
    56  
    57  
    58 def _rows_for_query(query: str) -> list[dict]:
    59     if "FROM `moraine`.`v_session_summary` AS s" in query:
    60         if "WHERE s.session_id = 'sess_a'" in query:
    61             return [_session_rows()[2]]
    62         return _session_rows()
    63  
    64     if "FROM `moraine`.`search_corpus_stats`" in query:
    65         return [{"docs": 100, "total_doc_len": 5000}]
    66  
    67     if "FROM `moraine`.`search_term_stats`" in query:
    68         return [{"term": "hello", "df": 20}, {"term": "world", "df": 10}]
    69  
    70     if "GROUP BY e.session_id" in query:
    71         return [
    72             {
    73                 "session_id": "sess_c",
    74                 "score": 12.5,
    75                 "matched_terms": 2,
    76                 "event_count_considered": 3,
    77                 "best_event_uid": "evt-c-42",
    78                 "snippet": "best match from session c",
    79             },
    80             {
    81                 "session_id": "sess_a",
    82                 "score": 7.0,
    83                 "matched_terms": 1,
    84                 "event_count_considered": 2,
    85                 "best_event_uid": "evt-a-11",
    86                 "snippet": "weaker match from session a",
    87             },
    88         ]
    89  
    90     if "GROUP BY p.doc_id" in query and "ORDER BY score DESC, event_uid ASC" in query:
    91         return [
    92             {
    93                 "event_uid": "evt-c-42",
    94                 "session_id": "sess_c",
    95                 "source_name": "codex",
    96                 "provider": "codex",
    97                 "event_class": "message",
    98                 "payload_type": "text",
    99                 "actor_role": "assistant",
   100                 "name": "",
   101                 "phase": "",
   102                 "source_ref": "/tmp/sess_c.jsonl:1:42",
   103                 "doc_len": 19,
   104                 "text_preview": "best event in session c",
   105                 "score": 12.5,
   106                 "matched_terms": 2,
   107             },
   108             {
   109                 "event_uid": "evt-a-11",
   110                 "session_id": "sess_a",
   111                 "source_name": "codex",
   112                 "provider": "codex",
   113                 "event_class": "message",
   114                 "payload_type": "text",
   115                 "actor_role": "assistant",
   116                 "name": "",
   117                 "phase": "",
   118                 "source_ref": "/tmp/sess_a.jsonl:1:11",
   119                 "doc_len": 13,
   120                 "text_preview": "weaker event in session a",
   121                 "score": 7.0,
   122                 "matched_terms": 1,
   123             },
   124         ]
   125  
   126     return []
   127  
   128  
   129 def _run_mock_clickhouse(port_queue: mp.Queue) -> None:
   130     class Handler(BaseHTTPRequestHandler):
   131         def _handle(self) -> None:
   132             parsed = urlparse(self.path)
   133             query = parse_qs(parsed.query).get("query", [""])[0]
   134             rows = _rows_for_query(query)
   135             body = json.dumps({"data": rows}).encode("utf-8")
   136  
   137             self.send_response(200)
   138             self.send_header("Content-Type", "application/json")
   139             self.send_header("Content-Length", str(len(body)))
   140             self.end_headers()
   141             self.wfile.write(body)
   142  
   143         def do_GET(self) -> None:  # noqa: N802
   144             self._handle()
   145  
   146         def do_POST(self) -> None:  # noqa: N802
   147             self._handle()
   148  
   149         def log_message(self, format: str, *args) -> None:  # noqa: A003
   150             del format
   151             del args
   152  
   153     server = ThreadingHTTPServer(("127.0.0.1", 0), Handler)
   154     port_queue.put(server.server_address[1])
   155     server.serve_forever()
   156  
   157  
   158 def test_conversation_client_smoke() -> None:
   159     port_queue: mp.Queue[int] = mp.Queue()
   160     proc = mp.Process(target=_run_mock_clickhouse, args=(port_queue,), daemon=True)
   161     proc.start()
   162  
   163     try:
   164         port = port_queue.get(timeout=5)
   165         client = ConversationClient(
   166             url=f"http://127.0.0.1:{port}",
   167             database="moraine",
   168             timeout_seconds=2.0,
   169         )
   170  
   171         conversations = json.loads(
   172             client.list_conversations_json(
   173                 from_unix_ms=1767261600000,
   174                 to_unix_ms=1767500000000,
   175                 mode="web_search",
   176                 limit=2,
   177             )
   178         )
   179         assert len(conversations["items"]) == 2
   180         assert conversations["items"][0]["session_id"] == "sess_c"
   181         assert conversations["next_cursor"] is not None
   182  
   183         conversation = json.loads(client.get_conversation_json("sess_a"))
   184         assert conversation["summary"]["session_id"] == "sess_a"
   185         assert conversation["turns"] == []
   186  
   187         search = json.loads(
   188             client.search_conversations_json(
   189                 query="hello world",
   190                 limit=10,
   191                 min_score=0.0,
   192                 min_should_match=1,
   193                 from_unix_ms=1767261600000,
   194                 to_unix_ms=1767500000000,
   195                 mode="chat",
   196                 include_tool_events=True,
   197                 exclude_codex_mcp=False,
   198             )
   199         )
   200         assert len(search["hits"]) == 2
   201         assert search["hits"][0]["session_id"] == "sess_c"
   202  
   203         event_search = json.loads(
   204             client.search_events_json(
   205                 query="hello world",
   206                 limit=10,
   207                 session_id="sess_c",
   208                 min_score=0.0,
   209                 min_should_match=1,
   210                 include_tool_events=True,
   211                 exclude_codex_mcp=False,
   212                 search_strategy="optimized",
   213             )
   214         )
   215         assert len(event_search["hits"]) == 2
   216         assert event_search["hits"][0]["event_uid"] == "evt-c-42"
   217  
   218         with pytest.raises(ValueError):
   219             client.search_events_json(query="hello world", search_strategy="not_a_strategy")
   220     finally:
   221         proc.terminate()
   222         proc.join(timeout=5)