Skip to content

rust/codex-mcp/src/main.rs


     1 mod clickhouse;
     2 mod config;
     3  
     4 use crate::clickhouse::ClickHouseClient;
     5 use crate::config::{load_config, resolve_config_path, AppConfig};
     6 use anyhow::{anyhow, Context, Result};
     7 use regex::Regex;
     8 use serde::Deserialize;
     9 use serde_json::{json, Value};
    10 use std::collections::HashMap;
    11 use std::path::PathBuf;
    12 use std::sync::{Arc, OnceLock};
    13 use std::time::Instant;
    14 use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
    15 use tracing::{debug, warn};
    16 use tracing_subscriber::EnvFilter;
    17 use uuid::Uuid;
    18  
    19 #[derive(Debug, Clone, Copy, Deserialize)]
    20 #[serde(rename_all = "lowercase")]
    21 enum Verbosity {
    22     Prose,
    23     Full,
    24 }
    25  
    26 impl Default for Verbosity {
    27     fn default() -> Self {
    28         Self::Prose
    29     }
    30 }
    31  
    32 #[derive(Debug, Deserialize)]
    33 struct RpcRequest {
    34     #[serde(default)]
    35     id: Option<Value>,
    36     method: String,
    37     #[serde(default)]
    38     params: Value,
    39 }
    40  
    41 #[derive(Debug, Deserialize)]
    42 struct ToolCallParams {
    43     name: String,
    44     #[serde(default)]
    45     arguments: Value,
    46 }
    47  
    48 #[derive(Debug, Deserialize)]
    49 struct SearchArgs {
    50     query: String,
    51     #[serde(default)]
    52     limit: Option<u16>,
    53     #[serde(default)]
    54     session_id: Option<String>,
    55     #[serde(default)]
    56     min_score: Option<f64>,
    57     #[serde(default)]
    58     min_should_match: Option<u16>,
    59     #[serde(default)]
    60     include_tool_events: Option<bool>,
    61     #[serde(default)]
    62     exclude_codex_mcp: Option<bool>,
    63     #[serde(default)]
    64     verbosity: Option<Verbosity>,
    65 }
    66  
    67 #[derive(Debug, Deserialize)]
    68 struct OpenArgs {
    69     event_uid: String,
    70     #[serde(default)]
    71     before: Option<u16>,
    72     #[serde(default)]
    73     after: Option<u16>,
    74     #[serde(default)]
    75     verbosity: Option<Verbosity>,
    76 }
    77  
    78 #[derive(Debug, Deserialize)]
    79 struct CorpusStatsRow {
    80     docs: u64,
    81     total_doc_len: u64,
    82 }
    83  
    84 #[derive(Debug, Deserialize)]
    85 struct DfRow {
    86     term: String,
    87     df: u64,
    88 }
    89  
    90 #[derive(Debug, Deserialize)]
    91 struct SearchRow {
    92     event_uid: String,
    93     session_id: String,
    94     source_name: String,
    95     provider: String,
    96     event_class: String,
    97     payload_type: String,
    98     actor_role: String,
    99     name: String,
   100     phase: String,
   101     source_ref: String,
   102     doc_len: u32,
   103     text_preview: String,
   104     score: f64,
   105     matched_terms: u64,
   106 }
   107  
   108 #[derive(Debug, Deserialize)]
   109 struct OpenTargetRow {
   110     session_id: String,
   111     event_order: u64,
   112     turn_seq: u32,
   113 }
   114  
   115 #[derive(Debug, Deserialize)]
   116 struct OpenContextRow {
   117     session_id: String,
   118     event_uid: String,
   119     event_order: u64,
   120     turn_seq: u32,
   121     event_time: String,
   122     actor_role: String,
   123     event_class: String,
   124     payload_type: String,
   125     call_id: String,
   126     name: String,
   127     phase: String,
   128     item_id: String,
   129     text_content: String,
   130     payload_json: String,
   131     token_usage_json: String,
   132     source_ref: String,
   133 }
   134  
   135 #[derive(Debug, Default, Deserialize)]
   136 struct SearchProsePayload {
   137     #[serde(default)]
   138     query_id: String,
   139     #[serde(default)]
   140     query: String,
   141     #[serde(default)]
   142     stats: SearchProseStats,
   143     #[serde(default)]
   144     hits: Vec<SearchProseHit>,
   145 }
   146  
   147 #[derive(Debug, Default, Deserialize)]
   148 struct SearchProseStats {
   149     #[serde(default)]
   150     took_ms: u64,
   151     #[serde(default)]
   152     result_count: u64,
   153 }
   154  
   155 #[derive(Debug, Default, Deserialize)]
   156 struct SearchProseHit {
   157     #[serde(default)]
   158     rank: u64,
   159     #[serde(default)]
   160     event_uid: String,
   161     #[serde(default)]
   162     session_id: String,
   163     #[serde(default)]
   164     score: f64,
   165     #[serde(default)]
   166     event_class: String,
   167     #[serde(default)]
   168     payload_type: String,
   169     #[serde(default)]
   170     actor_role: String,
   171     #[serde(default)]
   172     text_preview: String,
   173 }
   174  
   175 #[derive(Debug, Default, Deserialize)]
   176 struct OpenProsePayload {
   177     #[serde(default)]
   178     found: bool,
   179     #[serde(default)]
   180     event_uid: String,
   181     #[serde(default)]
   182     session_id: String,
   183     #[serde(default)]
   184     turn_seq: u32,
   185     #[serde(default)]
   186     target_event_order: u64,
   187     #[serde(default)]
   188     before: u16,
   189     #[serde(default)]
   190     after: u16,
   191     #[serde(default)]
   192     events: Vec<OpenProseEvent>,
   193 }
   194  
   195 #[derive(Debug, Default, Deserialize)]
   196 struct OpenProseEvent {
   197     #[serde(default)]
   198     is_target: bool,
   199     #[serde(default)]
   200     event_order: u64,
   201     #[serde(default)]
   202     actor_role: String,
   203     #[serde(default)]
   204     event_class: String,
   205     #[serde(default)]
   206     payload_type: String,
   207     #[serde(default)]
   208     text_content: String,
   209 }
   210  
   211 #[derive(Clone)]
   212 struct AppState {
   213     cfg: AppConfig,
   214     ch: ClickHouseClient,
   215 }
   216  
   217 impl AppState {
   218     async fn handle_request(&self, req: RpcRequest) -> Option<Value> {
   219         let id = req.id.clone();
   220  
   221         match req.method.as_str() {
   222             "initialize" => {
   223                 let result = json!({
   224                     "protocolVersion": self.cfg.mcp.protocol_version,
   225                     "capabilities": {
   226                         "tools": {
   227                             "listChanged": false
   228                         }
   229                     },
   230                     "serverInfo": {
   231                         "name": "codex-mcp",
   232                         "version": env!("CARGO_PKG_VERSION")
   233                     }
   234                 });
   235  
   236                 id.map(|msg_id| rpc_ok(msg_id, result))
   237             }
   238             "ping" => id.map(|msg_id| rpc_ok(msg_id, json!({}))),
   239             "notifications/initialized" | "initialized" => None,
   240             "tools/list" => id.map(|msg_id| rpc_ok(msg_id, self.tools_list_result())),
   241             "tools/call" => {
   242                 let Some(msg_id) = id else {
   243                     return None;
   244                 };
   245  
   246                 let parsed: Result<ToolCallParams> =
   247                     serde_json::from_value(req.params).context("invalid tools/call params payload");
   248  
   249                 match parsed {
   250                     Ok(params) => {
   251                         let tool_result = match self.call_tool(params).await {
   252                             Ok(v) => v,
   253                             Err(err) => tool_error_result(err.to_string()),
   254                         };
   255                         Some(rpc_ok(msg_id, tool_result))
   256                     }
   257                     Err(err) => Some(rpc_err(msg_id, -32602, &format!("invalid params: {err}"))),
   258                 }
   259             }
   260             _ => id.map(|msg_id| {
   261                 rpc_err(msg_id, -32601, &format!("method not found: {}", req.method))
   262             }),
   263         }
   264     }
   265  
   266     fn tools_list_result(&self) -> Value {
   267         json!({
   268             "tools": [
   269                 {
   270                     "name": "search",
   271                     "description": "BM25 lexical search over Moraine indexed conversation events.",
   272                     "inputSchema": {
   273                         "type": "object",
   274                         "properties": {
   275                             "query": { "type": "string" },
   276                             "limit": { "type": "integer", "minimum": 1, "maximum": self.cfg.mcp.max_results },
   277                             "session_id": { "type": "string" },
   278                             "min_score": { "type": "number" },
   279                             "min_should_match": { "type": "integer", "minimum": 1 },
   280                             "include_tool_events": { "type": "boolean" },
   281                             "exclude_codex_mcp": { "type": "boolean" },
   282                             "verbosity": {
   283                                 "type": "string",
   284                                 "enum": ["prose", "full"],
   285                                 "default": "prose"
   286                             }
   287                         },
   288                         "required": ["query"]
   289                     }
   290                 },
   291                 {
   292                     "name": "open",
   293                     "description": "Open one event by uid with surrounding conversation context.",
   294                     "inputSchema": {
   295                         "type": "object",
   296                         "properties": {
   297                             "event_uid": { "type": "string" },
   298                             "before": { "type": "integer", "minimum": 0 },
   299                             "after": { "type": "integer", "minimum": 0 },
   300                             "verbosity": {
   301                                 "type": "string",
   302                                 "enum": ["prose", "full"],
   303                                 "default": "prose"
   304                             }
   305                         },
   306                         "required": ["event_uid"]
   307                     }
   308                 }
   309             ]
   310         })
   311     }
   312  
   313     async fn call_tool(&self, params: ToolCallParams) -> Result<Value> {
   314         match params.name.as_str() {
   315             "search" => {
   316                 let args: SearchArgs = serde_json::from_value(params.arguments)
   317                     .context("search expects a JSON object with at least {\"query\": ...}")?;
   318                 let verbosity = args.verbosity.unwrap_or_default();
   319                 let payload = self.search(args).await?;
   320                 match verbosity {
   321                     Verbosity::Full => Ok(tool_ok_full(payload)),
   322                     Verbosity::Prose => Ok(tool_ok_prose(format_search_prose(&payload)?)),
   323                 }
   324             }
   325             "open" => {
   326                 let args: OpenArgs = serde_json::from_value(params.arguments)
   327                     .context("open expects {\"event_uid\": ...}")?;
   328                 let verbosity = args.verbosity.unwrap_or_default();
   329                 let payload = self.open(args).await?;
   330                 match verbosity {
   331                     Verbosity::Full => Ok(tool_ok_full(payload)),
   332                     Verbosity::Prose => Ok(tool_ok_prose(format_open_prose(&payload)?)),
   333                 }
   334             }
   335             other => Err(anyhow!("unknown tool: {other}")),
   336         }
   337     }
   338  
   339     async fn search(&self, args: SearchArgs) -> Result<Value> {
   340         let query = args.query.trim();
   341         if query.is_empty() {
   342             return Err(anyhow!("query cannot be empty"));
   343         }
   344  
   345         let query_id = Uuid::new_v4().to_string();
   346         let started = Instant::now();
   347  
   348         let terms_with_qf = tokenize_query(query, self.cfg.bm25.max_query_terms);
   349         if terms_with_qf.is_empty() {
   350             return Err(anyhow!("query has no searchable terms"));
   351         }
   352  
   353         let terms: Vec<String> = terms_with_qf.iter().map(|(term, _)| term.clone()).collect();
   354  
   355         let limit = args
   356             .limit
   357             .unwrap_or(self.cfg.mcp.max_results)
   358             .max(1)
   359             .min(self.cfg.mcp.max_results);
   360  
   361         let min_should_match = args
   362             .min_should_match
   363             .unwrap_or(self.cfg.bm25.default_min_should_match)
   364             .max(1)
   365             .min(terms.len() as u16);
   366  
   367         let min_score = args.min_score.unwrap_or(self.cfg.bm25.default_min_score);
   368         let include_tool_events = args
   369             .include_tool_events
   370             .unwrap_or(self.cfg.mcp.default_include_tool_events);
   371         let exclude_codex_mcp = args
   372             .exclude_codex_mcp
   373             .unwrap_or(self.cfg.mcp.default_exclude_codex_mcp);
   374  
   375         if let Some(session_id) = args.session_id.as_deref() {
   376             if !is_safe_filter_value(session_id) {
   377                 return Err(anyhow!("session_id contains unsupported characters"));
   378             }
   379         }
   380  
   381         let (docs, total_doc_len) = self.corpus_stats().await?;
   382         if docs == 0 {
   383             return Ok(json!({
   384                 "query_id": query_id,
   385                 "query": query,
   386                 "terms": terms,
   387                 "stats": {
   388                     "docs": 0,
   389                     "avgdl": 0.0,
   390                     "took_ms": started.elapsed().as_millis(),
   391                     "result_count": 0
   392                 },
   393                 "hits": []
   394             }));
   395         }
   396  
   397         let avgdl = (total_doc_len as f64 / docs as f64).max(1.0);
   398         let df_map = self.df_map(&terms).await?;
   399  
   400         let mut idf_by_term = HashMap::<String, f64>::new();
   401         for term in &terms {
   402             let df = *df_map.get(term).unwrap_or(&0);
   403             let idf = if df == 0 {
   404                 (1.0 + ((docs as f64 + 0.5) / 0.5)).ln()
   405             } else {
   406                 let n = docs.max(df) as f64;
   407                 (1.0 + ((n - df as f64 + 0.5) / (df as f64 + 0.5))).ln()
   408             };
   409             idf_by_term.insert(term.clone(), idf.max(0.0));
   410         }
   411  
   412         let query_sql = self.build_search_sql(
   413             &terms,
   414             &idf_by_term,
   415             avgdl,
   416             include_tool_events,
   417             exclude_codex_mcp,
   418             args.session_id.as_deref(),
   419             min_should_match,
   420             min_score,
   421             limit,
   422         )?;
   423  
   424         let mut rows: Vec<SearchRow> = self.ch.query_json_rows(&query_sql).await?;
   425         rows.sort_by(|a, b| b.score.total_cmp(&a.score));
   426  
   427         let took_ms = started.elapsed().as_millis() as u32;
   428  
   429         let hits: Vec<Value> = rows
   430             .iter()
   431             .enumerate()
   432             .map(|(idx, row)| {
   433                 json!({
   434                     "rank": idx + 1,
   435                     "event_uid": row.event_uid,
   436                     "session_id": row.session_id,
   437                     "source_name": row.source_name,
   438                     "provider": row.provider,
   439                     "score": row.score,
   440                     "matched_terms": row.matched_terms,
   441                     "doc_len": row.doc_len,
   442                     "event_class": row.event_class,
   443                     "payload_type": row.payload_type,
   444                     "actor_role": row.actor_role,
   445                     "name": row.name,
   446                     "phase": row.phase,
   447                     "source_ref": row.source_ref,
   448                     "text_preview": row.text_preview
   449                 })
   450             })
   451             .collect();
   452  
   453         let payload = json!({
   454             "query_id": query_id,
   455             "query": query,
   456             "terms": terms,
   457             "stats": {
   458                 "docs": docs,
   459                 "avgdl": avgdl,
   460                 "took_ms": took_ms,
   461                 "result_count": hits.len()
   462             },
   463             "hits": hits
   464         });
   465  
   466         self.log_search(
   467             &query_id,
   468             query,
   469             args.session_id.as_deref().unwrap_or(""),
   470             &terms,
   471             limit,
   472             min_should_match,
   473             min_score,
   474             include_tool_events,
   475             exclude_codex_mcp,
   476             took_ms,
   477             &rows,
   478             docs,
   479             avgdl,
   480         )
   481         .await;
   482  
   483         Ok(payload)
   484     }
   485  
   486     fn build_search_sql(
   487         &self,
   488         terms: &[String],
   489         idf_by_term: &HashMap<String, f64>,
   490         avgdl: f64,
   491         include_tool_events: bool,
   492         exclude_codex_mcp: bool,
   493         session_id: Option<&str>,
   494         min_should_match: u16,
   495         min_score: f64,
   496         limit: u16,
   497     ) -> Result<String> {
   498         if terms.is_empty() {
   499             return Err(anyhow!("cannot build search query with empty terms"));
   500         }
   501  
   502         let terms_array_sql = sql_array_strings(terms);
   503         let idf_vals: Vec<f64> = terms
   504             .iter()
   505             .map(|t| *idf_by_term.get(t).unwrap_or(&0.0))
   506             .collect();
   507         let idf_array_sql = sql_array_f64(&idf_vals);
   508  
   509         let mut where_clauses = vec![format!("p.term IN {}", terms_array_sql)];
   510  
   511         if let Some(sid) = session_id {
   512             where_clauses.push(format!("p.session_id = {}", sql_quote(sid)));
   513         }
   514  
   515         if include_tool_events {
   516             where_clauses.push("p.payload_type != 'token_count'".to_string());
   517         } else {
   518             where_clauses
   519                 .push("p.event_class IN ('message', 'reasoning', 'event_msg')".to_string());
   520             where_clauses.push(
   521                 "p.payload_type NOT IN ('token_count', 'task_started', 'task_complete', 'turn_aborted', 'item_completed')"
   522                     .to_string(),
   523             );
   524         }
   525  
   526         if exclude_codex_mcp {
   527             where_clauses
   528                 .push("positionCaseInsensitiveUTF8(d.payload_json, 'codex-mcp') = 0".to_string());
   529             where_clauses.push("lowerUTF8(d.name) NOT IN ('search', 'open')".to_string());
   530         }
   531  
   532         let where_sql = where_clauses.join("\n  AND ");
   533         let k1 = self.cfg.bm25.k1.max(0.01);
   534         let b = self.cfg.bm25.b.clamp(0.0, 1.0);
   535  
   536         Ok(format!(
   537             "WITH
   538   {k1:.6} AS k1,
   539   {b:.6} AS b,
   540   greatest({avgdl:.6}, 1.0) AS avgdl,
   541   {terms_array_sql} AS q_terms,
   542   {idf_array_sql} AS q_idf
   543 SELECT
   544   p.doc_id AS event_uid,
   545   any(p.session_id) AS session_id,
   546   any(p.source_name) AS source_name,
   547   any(p.provider) AS provider,
   548   any(p.event_class) AS event_class,
   549   any(p.payload_type) AS payload_type,
   550   any(p.actor_role) AS actor_role,
   551   any(p.name) AS name,
   552   any(p.phase) AS phase,
   553   any(p.source_ref) AS source_ref,
   554   any(p.doc_len) AS doc_len,
   555   leftUTF8(any(d.text_content), {preview}) AS text_preview,
   556   sum(
   557     transform(p.term, q_terms, q_idf, 0.0)
   558     *
   559     (
   560       (toFloat64(p.tf) * (k1 + 1.0))
   561       /
   562       (toFloat64(p.tf) + k1 * (1.0 - b + b * (toFloat64(p.doc_len) / avgdl)))
   563     )
   564   ) AS score,
   565   uniqExact(p.term) AS matched_terms
   566 FROM moraine.search_postings AS p
   567 ANY INNER JOIN moraine.search_documents AS d ON d.event_uid = p.doc_id
   568 WHERE {where_sql}
   569 GROUP BY p.doc_id
   570 HAVING matched_terms >= {min_should_match} AND score >= {min_score:.6}
   571 ORDER BY score DESC
   572 LIMIT {limit}
   573 FORMAT JSONEachRow",
   574             preview = self.cfg.mcp.preview_chars,
   575         ))
   576     }
   577  
   578     async fn corpus_stats(&self) -> Result<(u64, u64)> {
   579         let from_stats_query = "SELECT toUInt64(ifNull(sum(docs), 0)) AS docs, toUInt64(ifNull(sum(total_doc_len), 0)) AS total_doc_len FROM moraine.search_corpus_stats FORMAT JSONEachRow";
   580         let from_stats: Vec<CorpusStatsRow> = self.ch.query_json_rows(from_stats_query).await?;
   581  
   582         if let Some(row) = from_stats.first() {
   583             if row.docs > 0 {
   584                 return Ok((row.docs, row.total_doc_len));
   585             }
   586         }
   587  
   588         let fallback_query = "SELECT toUInt64(count()) AS docs, toUInt64(ifNull(sum(doc_len), 0)) AS total_doc_len FROM moraine.search_documents WHERE doc_len > 0 FORMAT JSONEachRow";
   589         let fallback: Vec<CorpusStatsRow> = self.ch.query_json_rows(fallback_query).await?;
   590         if let Some(row) = fallback.first() {
   591             Ok((row.docs, row.total_doc_len))
   592         } else {
   593             Ok((0, 0))
   594         }
   595     }
   596  
   597     async fn df_map(&self, terms: &[String]) -> Result<HashMap<String, u64>> {
   598         let terms_array = sql_array_strings(terms);
   599         let primary_query = format!(
   600             "SELECT term, toUInt64(sum(docs)) AS df FROM moraine.search_term_stats WHERE term IN {} GROUP BY term FORMAT JSONEachRow",
   601             terms_array
   602         );
   603  
   604         let mut map = HashMap::<String, u64>::new();
   605  
   606         let primary_rows: Vec<DfRow> = self.ch.query_json_rows(&primary_query).await?;
   607         for row in primary_rows {
   608             map.insert(row.term, row.df);
   609         }
   610  
   611         if map.len() == terms.len() {
   612             return Ok(map);
   613         }
   614  
   615         let fallback_query = format!(
   616             "SELECT term, count() AS df FROM moraine.search_postings WHERE term IN {} GROUP BY term FORMAT JSONEachRow",
   617             terms_array
   618         );
   619         let fallback_rows: Vec<DfRow> = self.ch.query_json_rows(&fallback_query).await?;
   620         for row in fallback_rows {
   621             map.insert(row.term, row.df);
   622         }
   623  
   624         Ok(map)
   625     }
   626  
   627     async fn log_search(
   628         &self,
   629         query_id: &str,
   630         raw_query: &str,
   631         session_hint: &str,
   632         terms: &[String],
   633         limit: u16,
   634         min_should_match: u16,
   635         min_score: f64,
   636         include_tool_events: bool,
   637         exclude_codex_mcp: bool,
   638         took_ms: u32,
   639         rows: &[SearchRow],
   640         docs: u64,
   641         avgdl: f64,
   642     ) {
   643         let metadata_json = match serde_json::to_string(&json!({
   644             "docs": docs,
   645             "avgdl": avgdl,
   646             "k1": self.cfg.bm25.k1,
   647             "b": self.cfg.bm25.b
   648         })) {
   649             Ok(v) => v,
   650             Err(err) => {
   651                 warn!("failed to encode search metadata: {}", err);
   652                 "{}".to_string()
   653             }
   654         };
   655  
   656         let query_row = json!({
   657             "query_id": query_id,
   658             "source": "codex-mcp",
   659             "session_hint": session_hint,
   660             "raw_query": raw_query,
   661             "normalized_terms": terms,
   662             "term_count": terms.len() as u16,
   663             "result_limit": limit,
   664             "min_should_match": min_should_match,
   665             "min_score": min_score,
   666             "include_tool_events": if include_tool_events { 1 } else { 0 },
   667             "exclude_codex_mcp": if exclude_codex_mcp { 1 } else { 0 },
   668             "response_ms": took_ms,
   669             "result_count": rows.len() as u16,
   670             "metadata_json": metadata_json,
   671         });
   672  
   673         let hit_rows: Vec<Value> = rows
   674             .iter()
   675             .enumerate()
   676             .map(|(idx, row)| {
   677                 json!({
   678                     "query_id": query_id,
   679                     "rank": (idx + 1) as u16,
   680                     "event_uid": row.event_uid,
   681                     "session_id": row.session_id,
   682                     "source_name": row.source_name,
   683                     "provider": row.provider,
   684                     "score": row.score,
   685                     "matched_terms": row.matched_terms as u16,
   686                     "doc_len": row.doc_len,
   687                     "event_class": row.event_class,
   688                     "payload_type": row.payload_type,
   689                     "actor_role": row.actor_role,
   690                     "name": row.name,
   691                     "source_ref": row.source_ref,
   692                 })
   693             })
   694             .collect();
   695  
   696         let ch = self.ch.clone();
   697         if self.cfg.mcp.async_log_writes {
   698             tokio::spawn(async move {
   699                 if let Err(err) = ch.insert_json_rows("search_query_log", &[query_row]).await {
   700                     warn!("failed to write search_query_log: {}", err);
   701                 }
   702                 if !hit_rows.is_empty() {
   703                     if let Err(err) = ch.insert_json_rows("search_hit_log", &hit_rows).await {
   704                         warn!("failed to write search_hit_log: {}", err);
   705                     }
   706                 }
   707             });
   708         } else {
   709             if let Err(err) = self
   710                 .ch
   711                 .insert_json_rows("search_query_log", &[query_row])
   712                 .await
   713             {
   714                 warn!("failed to write search_query_log: {}", err);
   715             }
   716             if !hit_rows.is_empty() {
   717                 if let Err(err) = self.ch.insert_json_rows("search_hit_log", &hit_rows).await {
   718                     warn!("failed to write search_hit_log: {}", err);
   719                 }
   720             }
   721         }
   722     }
   723  
   724     async fn open(&self, args: OpenArgs) -> Result<Value> {
   725         let event_uid = args.event_uid.trim();
   726         if event_uid.is_empty() {
   727             return Err(anyhow!("event_uid cannot be empty"));
   728         }
   729         if !is_safe_filter_value(event_uid) {
   730             return Err(anyhow!("event_uid contains unsupported characters"));
   731         }
   732  
   733         let before = args.before.unwrap_or(self.cfg.mcp.default_context_before);
   734         let after = args.after.unwrap_or(self.cfg.mcp.default_context_after);
   735  
   736         let target_query = format!(
   737             "SELECT session_id, event_order, turn_seq FROM moraine.v_conversation_trace WHERE event_uid = {} ORDER BY event_order DESC LIMIT 1 FORMAT JSONEachRow",
   738             sql_quote(event_uid)
   739         );
   740  
   741         let targets: Vec<OpenTargetRow> = self.ch.query_json_rows(&target_query).await?;
   742         let Some(target) = targets.first() else {
   743             return Ok(json!({
   744                 "found": false,
   745                 "event_uid": event_uid,
   746                 "events": []
   747             }));
   748         };
   749  
   750         let lower = target.event_order.saturating_sub(before as u64).max(1);
   751         let upper = target.event_order + after as u64;
   752  
   753         let context_query = format!(
   754             "SELECT session_id, event_uid, event_order, turn_seq, toString(event_time) AS event_time, actor_role, event_class, payload_type, call_id, name, phase, item_id, text_content, payload_json, token_usage_json, source_ref FROM moraine.v_conversation_trace WHERE session_id = {} AND event_order BETWEEN {} AND {} ORDER BY event_order FORMAT JSONEachRow",
   755             sql_quote(&target.session_id),
   756             lower,
   757             upper
   758         );
   759  
   760         let mut rows: Vec<OpenContextRow> = self.ch.query_json_rows(&context_query).await?;
   761         rows.sort_by_key(|row| row.event_order);
   762  
   763         let events: Vec<Value> = rows
   764             .iter()
   765             .map(|row| {
   766                 json!({
   767                     "is_target": row.event_uid == event_uid,
   768                     "session_id": row.session_id,
   769                     "event_uid": row.event_uid,
   770                     "event_order": row.event_order,
   771                     "turn_seq": row.turn_seq,
   772                     "event_time": row.event_time,
   773                     "actor_role": row.actor_role,
   774                     "event_class": row.event_class,
   775                     "payload_type": row.payload_type,
   776                     "call_id": row.call_id,
   777                     "name": row.name,
   778                     "phase": row.phase,
   779                     "item_id": row.item_id,
   780                     "source_ref": row.source_ref,
   781                     "text_content": row.text_content,
   782                     "payload_json": row.payload_json,
   783                     "token_usage_json": row.token_usage_json,
   784                 })
   785             })
   786             .collect();
   787  
   788         Ok(json!({
   789             "found": true,
   790             "event_uid": event_uid,
   791             "session_id": target.session_id,
   792             "target_event_order": target.event_order,
   793             "turn_seq": target.turn_seq,
   794             "before": before,
   795             "after": after,
   796             "events": events,
   797         }))
   798     }
   799 }
   800  
   801 fn rpc_ok(id: Value, result: Value) -> Value {
   802     json!({
   803         "jsonrpc": "2.0",
   804         "id": id,
   805         "result": result
   806     })
   807 }
   808  
   809 fn rpc_err(id: Value, code: i64, message: &str) -> Value {
   810     json!({
   811         "jsonrpc": "2.0",
   812         "id": id,
   813         "error": {
   814             "code": code,
   815             "message": message
   816         }
   817     })
   818 }
   819  
   820 fn tool_ok_full(payload: Value) -> Value {
   821     let text = serde_json::to_string_pretty(&payload).unwrap_or_else(|_| "{}".to_string());
   822     json!({
   823         "content": [
   824             {
   825                 "type": "text",
   826                 "text": text
   827             }
   828         ],
   829         "structuredContent": payload,
   830         "isError": false
   831     })
   832 }
   833  
   834 fn tool_ok_prose(text: String) -> Value {
   835     json!({
   836         "content": [
   837             {
   838                 "type": "text",
   839                 "text": text
   840             }
   841         ],
   842         "isError": false
   843     })
   844 }
   845  
   846 fn tool_error_result(message: String) -> Value {
   847     json!({
   848         "content": [
   849             {
   850                 "type": "text",
   851                 "text": message
   852             }
   853         ],
   854         "isError": true
   855     })
   856 }
   857  
   858 fn format_search_prose(payload: &Value) -> Result<String> {
   859     let parsed: SearchProsePayload =
   860         serde_json::from_value(payload.clone()).context("failed to parse search payload")?;
   861  
   862     let mut out = String::new();
   863     out.push_str(&format!("Search: \"{}\"\n", parsed.query));
   864     out.push_str(&format!("Query ID: {}\n", parsed.query_id));
   865     out.push_str(&format!(
   866         "Hits: {} ({} ms)\n",
   867         parsed.stats.result_count, parsed.stats.took_ms
   868     ));
   869  
   870     if parsed.hits.is_empty() {
   871         out.push_str("\nNo hits.");
   872         return Ok(out);
   873     }
   874  
   875     for hit in &parsed.hits {
   876         let kind = display_kind(&hit.event_class, &hit.payload_type);
   877         out.push_str(&format!(
   878             "\n{}) session={} score={:.4} kind={} role={}\n",
   879             hit.rank, hit.session_id, hit.score, kind, hit.actor_role
   880         ));
   881  
   882         let snippet = compact_text_line(&hit.text_preview, 220);
   883         if !snippet.is_empty() {
   884             out.push_str(&format!("   snippet: {}\n", snippet));
   885         }
   886  
   887         out.push_str(&format!("   event_uid: {}\n", hit.event_uid));
   888         out.push_str(&format!("   next: open(event_uid=\"{}\")\n", hit.event_uid));
   889     }
   890  
   891     Ok(out.trim_end().to_string())
   892 }
   893  
   894 fn format_open_prose(payload: &Value) -> Result<String> {
   895     let mut parsed: OpenProsePayload =
   896         serde_json::from_value(payload.clone()).context("failed to parse open payload")?;
   897  
   898     let mut out = String::new();
   899     out.push_str(&format!("Open event: {}\n", parsed.event_uid));
   900  
   901     if !parsed.found {
   902         out.push_str("Not found.");
   903         return Ok(out);
   904     }
   905  
   906     out.push_str(&format!("Session: {}\n", parsed.session_id));
   907     out.push_str(&format!("Turn: {}\n", parsed.turn_seq));
   908     out.push_str(&format!(
   909         "Context window: before={} after={}\n",
   910         parsed.before, parsed.after
   911     ));
   912  
   913     parsed.events.sort_by_key(|e| e.event_order);
   914  
   915     let mut before_events = Vec::new();
   916     let mut target_events = Vec::new();
   917     let mut after_events = Vec::new();
   918  
   919     for event in parsed.events {
   920         if event.is_target || event.event_order == parsed.target_event_order {
   921             target_events.push(event);
   922         } else if event.event_order < parsed.target_event_order {
   923             before_events.push(event);
   924         } else {
   925             after_events.push(event);
   926         }
   927     }
   928  
   929     out.push_str("\nBefore:\n");
   930     if before_events.is_empty() {
   931         out.push_str("- (none)\n");
   932     } else {
   933         for event in &before_events {
   934             append_open_event_line(&mut out, event);
   935         }
   936     }
   937  
   938     out.push_str("\nTarget:\n");
   939     if target_events.is_empty() {
   940         out.push_str("- (none)\n");
   941     } else {
   942         for event in &target_events {
   943             append_open_event_line(&mut out, event);
   944         }
   945     }
   946  
   947     out.push_str("\nAfter:\n");
   948     if after_events.is_empty() {
   949         out.push_str("- (none)");
   950     } else {
   951         for event in &after_events {
   952             append_open_event_line(&mut out, event);
   953         }
   954     }
   955  
   956     Ok(out.trim_end().to_string())
   957 }
   958  
   959 fn append_open_event_line(out: &mut String, event: &OpenProseEvent) {
   960     let kind = display_kind(&event.event_class, &event.payload_type);
   961     out.push_str(&format!(
   962         "- [{}] {} {}\n",
   963         event.event_order, event.actor_role, kind
   964     ));
   965  
   966     let text = compact_text_line(&event.text_content, 220);
   967     if !text.is_empty() {
   968         out.push_str(&format!("  {}\n", text));
   969     }
   970 }
   971  
   972 fn display_kind(event_class: &str, payload_type: &str) -> String {
   973     if payload_type.is_empty() || payload_type == event_class || payload_type == "unknown" {
   974         if event_class.is_empty() {
   975             "event".to_string()
   976         } else {
   977             event_class.to_string()
   978         }
   979     } else if event_class.is_empty() {
   980         payload_type.to_string()
   981     } else {
   982         format!("{} ({})", event_class, payload_type)
   983     }
   984 }
   985  
   986 fn compact_text_line(text: &str, max_chars: usize) -> String {
   987     let compact = text.split_whitespace().collect::<Vec<_>>().join(" ");
   988     if compact.chars().count() <= max_chars {
   989         return compact;
   990     }
   991  
   992     let mut trimmed: String = compact.chars().take(max_chars.saturating_sub(3)).collect();
   993     trimmed.push_str("...");
   994     trimmed
   995 }
   996  
   997 fn token_re() -> &'static Regex {
   998     static TOKEN_RE: OnceLock<Regex> = OnceLock::new();
   999     TOKEN_RE.get_or_init(|| Regex::new(r"[A-Za-z0-9_]+").expect("valid token regex"))
  1000 }
  1001  
  1002 fn safe_value_re() -> &'static Regex {
  1003     static SAFE_RE: OnceLock<Regex> = OnceLock::new();
  1004     SAFE_RE
  1005         .get_or_init(|| Regex::new(r"^[A-Za-z0-9._:@/-]{1,256}$").expect("valid safe-value regex"))
  1006 }
  1007  
  1008 fn tokenize_query(text: &str, max_terms: usize) -> Vec<(String, u32)> {
  1009     let mut order = Vec::<String>::new();
  1010     let mut tf = HashMap::<String, u32>::new();
  1011  
  1012     for mat in token_re().find_iter(text) {
  1013         let token = mat.as_str().to_ascii_lowercase();
  1014         if token.len() < 2 || token.len() > 64 {
  1015             continue;
  1016         }
  1017  
  1018         if !tf.contains_key(&token) {
  1019             order.push(token.clone());
  1020         }
  1021         let entry = tf.entry(token).or_insert(0);
  1022         *entry += 1;
  1023  
  1024         if order.len() >= max_terms {
  1025             break;
  1026         }
  1027     }
  1028  
  1029     order
  1030         .into_iter()
  1031         .map(|token| {
  1032             let count = *tf.get(&token).unwrap_or(&1);
  1033             (token, count)
  1034         })
  1035         .collect()
  1036 }
  1037  
  1038 fn is_safe_filter_value(value: &str) -> bool {
  1039     safe_value_re().is_match(value)
  1040 }
  1041  
  1042 fn sql_quote(value: &str) -> String {
  1043     format!("'{}'", value.replace('\\', "\\\\").replace('\'', "''"))
  1044 }
  1045  
  1046 fn sql_array_strings(items: &[String]) -> String {
  1047     let parts = items.iter().map(|item| sql_quote(item)).collect::<Vec<_>>();
  1048     format!("[{}]", parts.join(","))
  1049 }
  1050  
  1051 fn sql_array_f64(items: &[f64]) -> String {
  1052     let parts = items
  1053         .iter()
  1054         .map(|v| format!("{:.12}", v))
  1055         .collect::<Vec<_>>();
  1056     format!("[{}]", parts.join(","))
  1057 }
  1058  
  1059 fn parse_config_flag() -> Option<PathBuf> {
  1060     let mut args = std::env::args().skip(1);
  1061     let mut config_path = None;
  1062  
  1063     while let Some(arg) = args.next() {
  1064         if arg == "--config" {
  1065             if let Some(v) = args.next() {
  1066                 config_path = Some(PathBuf::from(v));
  1067             }
  1068         }
  1069     }
  1070  
  1071     config_path
  1072 }
  1073  
  1074 #[tokio::main(flavor = "multi_thread")]
  1075 async fn main() -> Result<()> {
  1076     tracing_subscriber::fmt()
  1077         .with_env_filter(
  1078             EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
  1079         )
  1080         .with_target(false)
  1081         .init();
  1082  
  1083     let config_path = resolve_config_path(parse_config_flag());
  1084     let cfg = load_config(&config_path)
  1085         .with_context(|| format!("failed to load config {}", config_path.display()))?;
  1086  
  1087     let ch = ClickHouseClient::new(cfg.clickhouse.clone())?;
  1088     ch.ping().await.context("clickhouse ping failed")?;
  1089  
  1090     let state = Arc::new(AppState { cfg, ch });
  1091  
  1092     let stdin = BufReader::new(tokio::io::stdin());
  1093     let mut lines = stdin.lines();
  1094     let mut stdout = tokio::io::stdout();
  1095  
  1096     while let Some(line) = lines.next_line().await? {
  1097         let line = line.trim();
  1098         if line.is_empty() {
  1099             continue;
  1100         }
  1101  
  1102         debug!("incoming rpc line: {}", line);
  1103  
  1104         let parsed = serde_json::from_str::<RpcRequest>(line);
  1105         let req = match parsed {
  1106             Ok(req) => req,
  1107             Err(err) => {
  1108                 warn!("failed to parse rpc request: {}", err);
  1109                 continue;
  1110             }
  1111         };
  1112  
  1113         if let Some(resp) = state.handle_request(req).await {
  1114             let payload = serde_json::to_vec(&resp)?;
  1115             stdout.write_all(&payload).await?;
  1116             stdout.write_all(b"\n").await?;
  1117             stdout.flush().await?;
  1118         }
  1119     }
  1120  
  1121     Ok(())
  1122 }