bindings/python/moraine_conversations/src/lib.rs
1 use moraine_clickhouse::ClickHouseClient;
2 use moraine_config::ClickHouseConfig;
3 use moraine_conversations::{
4 ClickHouseConversationRepository, ConversationDetailOptions, ConversationListFilter,
5 ConversationMode, ConversationRepository, ConversationSearchQuery, PageRequest, RepoConfig,
6 SearchEventsQuery, SearchEventsStrategy,
7 };
8 use pyo3::exceptions::{PyRuntimeError, PyValueError};
9 use pyo3::prelude::*;
10
11 #[pyclass]
12 struct ConversationClient {
13 repo: ClickHouseConversationRepository,
14 rt: tokio::runtime::Runtime,
15 }
16
17 #[pymethods]
18 impl ConversationClient {
19 #[new]
20 #[pyo3(signature = (
21 url,
22 database = "moraine".to_string(),
23 username = "default".to_string(),
24 password = "".to_string(),
25 timeout_seconds = 5.0,
26 max_results = 100,
27 ))]
28 fn new(
29 url: String,
30 database: String,
31 username: String,
32 password: String,
33 timeout_seconds: f64,
34 max_results: u16,
35 ) -> PyResult<Self> {
36 let clickhouse = ClickHouseClient::new(ClickHouseConfig {
37 url,
38 database,
39 username,
40 password,
41 timeout_seconds,
42 async_insert: true,
43 wait_for_async_insert: true,
44 })
45 .map_err(py_runtime_err)?;
46
47 let repo = ClickHouseConversationRepository::new(
48 clickhouse,
49 RepoConfig {
50 max_results,
51 ..RepoConfig::default()
52 },
53 );
54
55 let rt = tokio::runtime::Builder::new_multi_thread()
56 .enable_all()
57 .build()
58 .map_err(py_runtime_err)?;
59
60 Ok(Self { repo, rt })
61 }
62
63 #[pyo3(signature = (from_unix_ms=None, to_unix_ms=None, mode=None, limit=50, cursor=None))]
64 fn list_conversations_json(
65 &self,
66 from_unix_ms: Option<i64>,
67 to_unix_ms: Option<i64>,
68 mode: Option<&str>,
69 limit: u16,
70 cursor: Option<String>,
71 ) -> PyResult<String> {
72 let parsed_mode = parse_mode(mode)?;
73 let page = self
74 .rt
75 .block_on(self.repo.list_conversations(
76 ConversationListFilter {
77 from_unix_ms,
78 to_unix_ms,
79 mode: parsed_mode,
80 },
81 PageRequest { limit, cursor },
82 ))
83 .map_err(py_runtime_err)?;
84
85 serde_json::to_string(&page).map_err(py_runtime_err)
86 }
87
88 #[pyo3(signature = (session_id, include_turns=false))]
89 fn get_conversation_json(&self, session_id: String, include_turns: bool) -> PyResult<String> {
90 let conversation = self
91 .rt
92 .block_on(
93 self.repo
94 .get_conversation(&session_id, ConversationDetailOptions { include_turns }),
95 )
96 .map_err(py_runtime_err)?;
97
98 serde_json::to_string(&conversation).map_err(py_runtime_err)
99 }
100
101 #[pyo3(signature = (
102 query,
103 limit=None,
104 min_score=None,
105 min_should_match=None,
106 from_unix_ms=None,
107 to_unix_ms=None,
108 mode=None,
109 include_tool_events=None,
110 exclude_codex_mcp=None,
111 ))]
112 fn search_conversations_json(
113 &self,
114 query: String,
115 limit: Option<u16>,
116 min_score: Option<f64>,
117 min_should_match: Option<u16>,
118 from_unix_ms: Option<i64>,
119 to_unix_ms: Option<i64>,
120 mode: Option<&str>,
121 include_tool_events: Option<bool>,
122 exclude_codex_mcp: Option<bool>,
123 ) -> PyResult<String> {
124 let parsed_mode = parse_mode(mode)?;
125 let results = self
126 .rt
127 .block_on(self.repo.search_conversations(ConversationSearchQuery {
128 query,
129 limit,
130 min_score,
131 min_should_match,
132 from_unix_ms,
133 to_unix_ms,
134 mode: parsed_mode,
135 include_tool_events,
136 exclude_codex_mcp,
137 }))
138 .map_err(py_runtime_err)?;
139
140 serde_json::to_string(&results).map_err(py_runtime_err)
141 }
142
143 #[pyo3(signature = (
144 query,
145 limit=None,
146 session_id=None,
147 min_score=None,
148 min_should_match=None,
149 include_tool_events=None,
150 exclude_codex_mcp=None,
151 disable_cache=None,
152 source=None,
153 search_strategy=None,
154 ))]
155 fn search_events_json(
156 &self,
157 query: String,
158 limit: Option<u16>,
159 session_id: Option<String>,
160 min_score: Option<f64>,
161 min_should_match: Option<u16>,
162 include_tool_events: Option<bool>,
163 exclude_codex_mcp: Option<bool>,
164 disable_cache: Option<bool>,
165 source: Option<String>,
166 search_strategy: Option<&str>,
167 ) -> PyResult<String> {
168 let parsed_search_strategy = parse_search_strategy(search_strategy)?;
169 let results = self
170 .rt
171 .block_on(self.repo.search_events(SearchEventsQuery {
172 query,
173 source,
174 limit,
175 session_id,
176 min_score,
177 min_should_match,
178 include_tool_events,
179 event_kinds: None,
180 exclude_codex_mcp,
181 disable_cache,
182 search_strategy: parsed_search_strategy,
183 }))
184 .map_err(py_runtime_err)?;
185
186 serde_json::to_string(&results).map_err(py_runtime_err)
187 }
188 }
189
190 fn parse_mode(raw: Option<&str>) -> PyResult<Option<ConversationMode>> {
191 let Some(raw) = raw else {
192 return Ok(None);
193 };
194
195 match raw {
196 "web_search" => Ok(Some(ConversationMode::WebSearch)),
197 "mcp_internal" => Ok(Some(ConversationMode::McpInternal)),
198 "tool_calling" => Ok(Some(ConversationMode::ToolCalling)),
199 "chat" => Ok(Some(ConversationMode::Chat)),
200 _ => Err(PyValueError::new_err(
201 "mode must be one of: web_search, mcp_internal, tool_calling, chat",
202 )),
203 }
204 }
205
206 fn parse_search_strategy(raw: Option<&str>) -> PyResult<Option<SearchEventsStrategy>> {
207 let Some(raw) = raw else {
208 return Ok(None);
209 };
210
211 match raw {
212 "optimized" => Ok(Some(SearchEventsStrategy::Optimized)),
213 "oracle_exact" => Ok(Some(SearchEventsStrategy::OracleExact)),
214 _ => Err(PyValueError::new_err(
215 "search_strategy must be one of: optimized, oracle_exact",
216 )),
217 }
218 }
219
220 fn py_runtime_err(err: impl ToString) -> PyErr {
221 PyRuntimeError::new_err(err.to_string())
222 }
223
224 #[pymodule]
225 fn _core(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
226 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
227 m.add_class::<ConversationClient>()?;
228 Ok(())
229 }