Skip to content

bindings/python/moraine_conversations/src/lib.rs


     1 use moraine_clickhouse::ClickHouseClient;
     2 use moraine_config::ClickHouseConfig;
     3 use moraine_conversations::{
     4     ClickHouseConversationRepository, ConversationDetailOptions, ConversationListFilter,
     5     ConversationMode, ConversationRepository, ConversationSearchQuery, PageRequest, RepoConfig,
     6     SearchEventsQuery, SearchEventsStrategy,
     7 };
     8 use pyo3::exceptions::{PyRuntimeError, PyValueError};
     9 use pyo3::prelude::*;
    10  
    11 #[pyclass]
    12 struct ConversationClient {
    13     repo: ClickHouseConversationRepository,
    14     rt: tokio::runtime::Runtime,
    15 }
    16  
    17 #[pymethods]
    18 impl ConversationClient {
    19     #[new]
    20     #[pyo3(signature = (
    21         url,
    22         database = "moraine".to_string(),
    23         username = "default".to_string(),
    24         password = "".to_string(),
    25         timeout_seconds = 5.0,
    26         max_results = 100,
    27     ))]
    28     fn new(
    29         url: String,
    30         database: String,
    31         username: String,
    32         password: String,
    33         timeout_seconds: f64,
    34         max_results: u16,
    35     ) -> PyResult<Self> {
    36         let clickhouse = ClickHouseClient::new(ClickHouseConfig {
    37             url,
    38             database,
    39             username,
    40             password,
    41             timeout_seconds,
    42             async_insert: true,
    43             wait_for_async_insert: true,
    44         })
    45         .map_err(py_runtime_err)?;
    46  
    47         let repo = ClickHouseConversationRepository::new(
    48             clickhouse,
    49             RepoConfig {
    50                 max_results,
    51                 ..RepoConfig::default()
    52             },
    53         );
    54  
    55         let rt = tokio::runtime::Builder::new_multi_thread()
    56             .enable_all()
    57             .build()
    58             .map_err(py_runtime_err)?;
    59  
    60         Ok(Self { repo, rt })
    61     }
    62  
    63     #[pyo3(signature = (from_unix_ms=None, to_unix_ms=None, mode=None, limit=50, cursor=None))]
    64     fn list_conversations_json(
    65         &self,
    66         from_unix_ms: Option<i64>,
    67         to_unix_ms: Option<i64>,
    68         mode: Option<&str>,
    69         limit: u16,
    70         cursor: Option<String>,
    71     ) -> PyResult<String> {
    72         let parsed_mode = parse_mode(mode)?;
    73         let page = self
    74             .rt
    75             .block_on(self.repo.list_conversations(
    76                 ConversationListFilter {
    77                     from_unix_ms,
    78                     to_unix_ms,
    79                     mode: parsed_mode,
    80                 },
    81                 PageRequest { limit, cursor },
    82             ))
    83             .map_err(py_runtime_err)?;
    84  
    85         serde_json::to_string(&page).map_err(py_runtime_err)
    86     }
    87  
    88     #[pyo3(signature = (session_id, include_turns=false))]
    89     fn get_conversation_json(&self, session_id: String, include_turns: bool) -> PyResult<String> {
    90         let conversation = self
    91             .rt
    92             .block_on(
    93                 self.repo
    94                     .get_conversation(&session_id, ConversationDetailOptions { include_turns }),
    95             )
    96             .map_err(py_runtime_err)?;
    97  
    98         serde_json::to_string(&conversation).map_err(py_runtime_err)
    99     }
   100  
   101     #[pyo3(signature = (
   102         query,
   103         limit=None,
   104         min_score=None,
   105         min_should_match=None,
   106         from_unix_ms=None,
   107         to_unix_ms=None,
   108         mode=None,
   109         include_tool_events=None,
   110         exclude_codex_mcp=None,
   111     ))]
   112     fn search_conversations_json(
   113         &self,
   114         query: String,
   115         limit: Option<u16>,
   116         min_score: Option<f64>,
   117         min_should_match: Option<u16>,
   118         from_unix_ms: Option<i64>,
   119         to_unix_ms: Option<i64>,
   120         mode: Option<&str>,
   121         include_tool_events: Option<bool>,
   122         exclude_codex_mcp: Option<bool>,
   123     ) -> PyResult<String> {
   124         let parsed_mode = parse_mode(mode)?;
   125         let results = self
   126             .rt
   127             .block_on(self.repo.search_conversations(ConversationSearchQuery {
   128                 query,
   129                 limit,
   130                 min_score,
   131                 min_should_match,
   132                 from_unix_ms,
   133                 to_unix_ms,
   134                 mode: parsed_mode,
   135                 include_tool_events,
   136                 exclude_codex_mcp,
   137             }))
   138             .map_err(py_runtime_err)?;
   139  
   140         serde_json::to_string(&results).map_err(py_runtime_err)
   141     }
   142  
   143     #[pyo3(signature = (
   144         query,
   145         limit=None,
   146         session_id=None,
   147         min_score=None,
   148         min_should_match=None,
   149         include_tool_events=None,
   150         exclude_codex_mcp=None,
   151         disable_cache=None,
   152         source=None,
   153         search_strategy=None,
   154     ))]
   155     fn search_events_json(
   156         &self,
   157         query: String,
   158         limit: Option<u16>,
   159         session_id: Option<String>,
   160         min_score: Option<f64>,
   161         min_should_match: Option<u16>,
   162         include_tool_events: Option<bool>,
   163         exclude_codex_mcp: Option<bool>,
   164         disable_cache: Option<bool>,
   165         source: Option<String>,
   166         search_strategy: Option<&str>,
   167     ) -> PyResult<String> {
   168         let parsed_search_strategy = parse_search_strategy(search_strategy)?;
   169         let results = self
   170             .rt
   171             .block_on(self.repo.search_events(SearchEventsQuery {
   172                 query,
   173                 source,
   174                 limit,
   175                 session_id,
   176                 min_score,
   177                 min_should_match,
   178                 include_tool_events,
   179                 event_kinds: None,
   180                 exclude_codex_mcp,
   181                 disable_cache,
   182                 search_strategy: parsed_search_strategy,
   183             }))
   184             .map_err(py_runtime_err)?;
   185  
   186         serde_json::to_string(&results).map_err(py_runtime_err)
   187     }
   188 }
   189  
   190 fn parse_mode(raw: Option<&str>) -> PyResult<Option<ConversationMode>> {
   191     let Some(raw) = raw else {
   192         return Ok(None);
   193     };
   194  
   195     match raw {
   196         "web_search" => Ok(Some(ConversationMode::WebSearch)),
   197         "mcp_internal" => Ok(Some(ConversationMode::McpInternal)),
   198         "tool_calling" => Ok(Some(ConversationMode::ToolCalling)),
   199         "chat" => Ok(Some(ConversationMode::Chat)),
   200         _ => Err(PyValueError::new_err(
   201             "mode must be one of: web_search, mcp_internal, tool_calling, chat",
   202         )),
   203     }
   204 }
   205  
   206 fn parse_search_strategy(raw: Option<&str>) -> PyResult<Option<SearchEventsStrategy>> {
   207     let Some(raw) = raw else {
   208         return Ok(None);
   209     };
   210  
   211     match raw {
   212         "optimized" => Ok(Some(SearchEventsStrategy::Optimized)),
   213         "oracle_exact" => Ok(Some(SearchEventsStrategy::OracleExact)),
   214         _ => Err(PyValueError::new_err(
   215             "search_strategy must be one of: optimized, oracle_exact",
   216         )),
   217     }
   218 }
   219  
   220 fn py_runtime_err(err: impl ToString) -> PyErr {
   221     PyRuntimeError::new_err(err.to_string())
   222 }
   223  
   224 #[pymodule]
   225 fn _core(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
   226     m.add("__version__", env!("CARGO_PKG_VERSION"))?;
   227     m.add_class::<ConversationClient>()?;
   228     Ok(())
   229 }