API Reference
The API reference is generated directly from Google-style docstrings.
plait
plait: A PyTorch-inspired framework for LLM inference pipelines.
This package provides tools for building, executing, and optimizing complex LLM inference pipelines with automatic DAG capture and maximum throughput through async execution.
ExecutionSettings
dataclass
Context manager for shared execution configuration.
Provides default settings for all module executions within the context. Bound module settings take precedence over context settings. Settings can be nested, with inner contexts overriding specific fields from outer contexts.
Attributes:
| Name | Type | Description |
|---|---|---|
resources |
ResourceConfig | ResourceManager | None
|
Optional ResourceConfig or ResourceManager for LLM endpoints. Modules will use these resources for LLM execution. |
checkpoint_dir |
Path | str | None
|
Optional directory for saving execution checkpoints. When set, a CheckpointManager is created automatically. |
max_concurrent |
int
|
Maximum number of concurrent tasks. Defaults to 100. |
scheduler |
Scheduler | None
|
Optional custom Scheduler instance. When None, a new Scheduler is created for each execution. |
on_task_complete |
Callable[[str, TaskResult], None] | None
|
Optional callback invoked when a task completes. Receives the node_id and TaskResult. |
on_task_failed |
Callable[[str, Exception], None] | None
|
Optional callback invoked when a task fails. Receives the node_id and the exception. |
task_timeout |
float | None
|
Maximum seconds for a single task (LLM call) before timeout. When set, tasks exceeding this duration raise TimeoutError and dependent nodes are cancelled. None means no timeout (default). Recommended: 60-300 seconds depending on model and prompt length. |
max_task_retries |
int
|
Maximum retry attempts for transient failures. Retries apply to TransientError, not to permanent errors (4xx) or rate limits (handled separately). Default 0 means no retries. |
task_retry_delay |
float
|
Base delay in seconds between retry attempts. Uses exponential backoff: delay doubles each retry. E.g., with delay=1.0, retries occur at 1s, 2s, 4s, etc. Defaults to 1.0. |
streaming |
bool
|
When True, batch calls return an async iterator yielding BatchResult objects as they complete. When False (default), batch calls return a list of all results. |
preserve_order |
bool
|
When True, streaming results are yielded in input order (may wait on slower items). When False (default), results yield as soon as they complete for maximum throughput. Only applies when streaming=True. |
on_progress |
Callable[[int, int], None] | None
|
Optional callback for batch progress updates. Called with (completed_count, total_count) after each input completes. Works with both streaming and non-streaming batch execution. |
profile |
bool
|
Whether to enable profiling. Defaults to False. When enabled, a TraceProfiler is created and task execution is recorded. |
profile_path |
Path | str | None
|
Path for saving profile traces. If None with profile=True, uses './traces/trace_{timestamp}.json'. Defaults to None. |
profile_counters |
bool
|
Whether to include counter events in trace. Defaults to True. |
profile_include_args |
bool
|
Whether to include task args in trace. Defaults to True. |
profiler |
TraceProfiler | None
|
The TraceProfiler instance when profiling is enabled. Access via get_profiler() after entering the context. |
Example
with ExecutionSettings( ... resources=resource_config, ... checkpoint_dir="/data/checkpoints", ... max_concurrent=50, ... ): ... # All module executions in this block share these settings ... result1 = await pipeline1(input1) ... result2 = await pipeline2(input2)
Example with callbacks
def on_complete(node_id: str, result): ... print(f"Completed: {node_id}")
with ExecutionSettings(on_task_complete=on_complete): ... result = await pipeline(input)
Example with streaming
async with ExecutionSettings(resources=config, streaming=True): ... async for result in pipeline(["doc1", "doc2"]): ... if result.ok: ... await send_to_client(result.output) ... else: ... logger.error(f"Input {result.index} failed")
Example with progress tracking
def on_progress(done: int, total: int) -> None: ... print(f"Progress: {done}/{total}")
with ExecutionSettings(resources=config, on_progress=on_progress): ... results = pipeline.run_sync(documents)
Note
ExecutionSettings supports both sync and async context managers.
Use with for synchronous code and async with for async code.
The async version properly awaits checkpoint flushing on exit.
profiler
property
Get the profiler for this context.
Shorthand for get_profiler(). The profiler is created automatically when entering a context with profile=True.
Returns:
| Type | Description |
|---|---|
TraceProfiler | None
|
The TraceProfiler instance, or None if profiling is not enabled. |
Example
async with ExecutionSettings(profile=True) as settings: ... if settings.profiler: ... settings.profiler.add_instant_event("custom_marker")
__aenter__()
async
Enter the execution settings context (asynchronous).
Activates this settings context, making it available via get_execution_settings(). Creates a CheckpointManager if checkpoint_dir is set.
Returns:
| Type | Description |
|---|---|
Self
|
This ExecutionSettings instance. |
Example
async with ExecutionSettings(checkpoint_dir="/ckpt") as settings: ... result = await pipeline(input)
Checkpoints are flushed on exit
__aexit__(exc_type, exc_val, exc_tb)
async
Exit the execution settings context (asynchronous).
Flushes any pending checkpoints, exports profiler trace, and resets the context variable to the previous state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exc_type
|
type[BaseException] | None
|
Exception type if an exception was raised. |
required |
exc_val
|
BaseException | None
|
Exception value if an exception was raised. |
required |
exc_tb
|
object
|
Exception traceback if an exception was raised. |
required |
__enter__()
Enter the execution settings context (synchronous).
Activates this settings context, making it available via get_execution_settings(). Creates a CheckpointManager if checkpoint_dir is set.
Returns:
| Type | Description |
|---|---|
Self
|
This ExecutionSettings instance. |
Example
with ExecutionSettings(max_concurrent=10) as settings: ... assert get_execution_settings() is settings
__exit__(exc_type, exc_val, exc_tb)
Exit the execution settings context (synchronous).
Resets the context variable to the previous state. If a CheckpointManager was created, it is NOT flushed in sync mode (use async context manager for proper cleanup). Profiler traces are exported if profiling was enabled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
exc_type
|
type[BaseException] | None
|
Exception type if an exception was raised. |
required |
exc_val
|
BaseException | None
|
Exception value if an exception was raised. |
required |
exc_tb
|
object
|
Exception traceback if an exception was raised. |
required |
Note
For proper checkpoint flushing, use the async context manager. The sync version is primarily for simple configurations without active checkpointing.
get_checkpoint_dir()
Get the effective checkpoint directory.
Checks this context and parent contexts for checkpoint_dir.
Returns:
| Type | Description |
|---|---|
Path | None
|
The checkpoint directory as a Path, or None if not set. |
get_checkpoint_manager()
Get the checkpoint manager for this context.
The checkpoint manager is created automatically when entering a context with checkpoint_dir set. For nested contexts without their own checkpoint_dir, returns the parent's manager.
Returns:
| Type | Description |
|---|---|
CheckpointManager | None
|
The CheckpointManager instance, or None if no checkpointing. |
get_max_concurrent()
Get the effective max_concurrent setting.
Returns:
| Type | Description |
|---|---|
int
|
The max_concurrent value, defaulting to 100. |
get_max_task_retries()
Get the effective max_task_retries setting.
Returns:
| Type | Description |
|---|---|
int
|
The maximum number of retry attempts for transient failures. |
get_on_progress()
Get the effective on_progress callback.
Checks this context and parent contexts for the callback.
Returns:
| Type | Description |
|---|---|
Callable[[int, int], None] | None
|
The on_progress callback, or None if not set. |
get_preserve_order()
Get the effective preserve_order setting.
Returns:
| Type | Description |
|---|---|
bool
|
True if results should be yielded in input order, False otherwise. |
get_profiler()
Get the profiler for this context.
The profiler is created automatically when entering a context with profile=True. For nested contexts without their own profile setting, returns the parent's profiler.
Returns:
| Type | Description |
|---|---|
TraceProfiler | None
|
The TraceProfiler instance, or None if profiling is not enabled. |
Example
async with ExecutionSettings(profile=True) as settings: ... profiler = settings.get_profiler() ... if profiler: ... profiler.add_instant_event("custom_marker")
get_resources()
Get the effective resources configuration.
Checks this context and parent contexts for resources.
Returns:
| Type | Description |
|---|---|
ResourceConfig | ResourceManager | None
|
The ResourceConfig or ResourceManager, or None if not set. |
get_scheduler()
Get the effective scheduler.
Checks this context and parent contexts for a custom scheduler.
Returns:
| Type | Description |
|---|---|
Scheduler | None
|
The Scheduler instance, or None if not set. |
get_streaming()
Get the effective streaming setting.
Returns:
| Type | Description |
|---|---|
bool
|
True if streaming mode is enabled, False otherwise. |
get_task_retry_delay()
Get the effective task_retry_delay setting.
Returns:
| Type | Description |
|---|---|
float
|
The base delay in seconds between retry attempts. |
get_task_timeout()
Get the effective task_timeout setting.
Checks this context and parent contexts for the timeout value.
Returns:
| Type | Description |
|---|---|
float | None
|
The task timeout in seconds, or None if no timeout. |
LLMInference
Bases: Module
Atomic module for LLM API calls.
This is the fundamental building block for LLM operations. All other modules ultimately compose LLMInference instances to build complex inference pipelines.
The alias parameter decouples the module from specific endpoints, allowing the same module to run against different models/endpoints based on resource configuration at runtime.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alias
|
str
|
Resource binding key that maps to an endpoint configuration. This allows the same module to use different LLM providers depending on the ResourceConfig passed to run(). |
required |
system_prompt
|
str | Parameter
|
System prompt for the LLM. Can be a string (converted to a non-learnable Parameter) or a Parameter instance (for learnable prompts). Empty string results in no system prompt. |
''
|
temperature
|
float
|
Sampling temperature for the LLM. Higher values produce more random outputs. Defaults to 1.0. |
1.0
|
max_tokens
|
int | None
|
Maximum number of tokens to generate. None means no limit (use model default). |
None
|
response_format
|
type | None
|
Expected response format type for structured output. None means plain text response. |
None
|
Example
llm = LLMInference(alias="fast_llm", temperature=0.7) llm.alias 'fast_llm' llm.temperature 0.7
Example with system prompt
llm = LLMInference( ... alias="assistant", ... system_prompt="You are a helpful assistant.", ... temperature=0.5, ... ) llm.system_prompt.value 'You are a helpful assistant.' llm.system_prompt.requires_grad False
Note
LLMInference.forward() should not be called directly. Use the run() function to execute modules, which handles tracing and resource management.
__init__(alias, system_prompt='', temperature=1.0, max_tokens=None, response_format=None)
Initialize the LLMInference module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
alias
|
str
|
Resource binding key for endpoint resolution. |
required |
system_prompt
|
str | Parameter
|
System prompt string or Parameter. |
''
|
temperature
|
float
|
Sampling temperature (0.0 to 2.0 typical). |
1.0
|
max_tokens
|
int | None
|
Maximum tokens to generate. |
None
|
response_format
|
type | None
|
Type for structured output parsing. |
None
|
backward(feedback, ctx)
async
Backward pass for LLM inference.
Generates feedback for both the input prompt and any learnable parameters (like system_prompt). The parameter feedback includes context about what the LLM received and produced to help the optimizer understand how to improve the parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feedback
|
Any
|
Combined feedback from downstream nodes. |
required |
ctx
|
Any
|
BackwardContext with inputs, output, and graph structure. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
BackwardResult with: |
Any
|
|
Any
|
|
Example
LLMInference backward is called automatically during
Value.backward() when the module is in the graph
output, record = await run(llm_module, "Hello", record=True) loss_val = await loss_fn(output, target) await loss_val.backward() # Calls llm_module.backward()
Note
The parameter feedback includes: - The current system prompt value - The parameter description - A sample of the input and output - The feedback received This context helps the optimizer generate targeted improvements.
forward(prompt)
Execute the LLM call.
This method should not be called directly. During tracing, the tracer intercepts calls and records them in the graph. During execution, the runtime handles the actual API call through the ResourceManager.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prompt
|
str
|
The user prompt to send to the LLM. |
required |
Returns:
| Type | Description |
|---|---|
str
|
The LLM's response text. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
Always raised because direct execution is not supported. Use run() to execute modules. |
Note
The runtime replaces this with actual LLM calls. This placeholder exists to define the expected signature and to catch accidental direct invocations.
Module
Base class for all inference operations.
Analogous to torch.nn.Module. Subclass this to define custom inference logic by implementing the forward() method.
Child modules and parameters assigned as attributes are automatically registered, enabling recursive traversal and parameter collection.
Example
from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("You are helpful.") ... module = MyModule() "prompt" in module._parameters True
Note
Always call super().init() in subclass init methods to ensure proper registration of children and parameters.
training
property
Whether the module is in training mode.
In training mode, forward passes return Value objects with tape ids attached, enabling automatic backward propagation.
Returns:
| Type | Description |
|---|---|
bool
|
True if the module is in training mode, False otherwise. |
Example
module = MyModule() module.training False module.train() module.training True
__call__(*args, **kwargs)
Execute the module.
Behavior depends on context: 1. If a trace context is active: records the call and returns a Value with ref pointing to the generated node ID (Value-driven tracing) 2. If resources are bound OR ExecutionSettings is active: traces and executes 3. Otherwise: executes forward() directly (for non-LLM modules)
When bound or in an ExecutionSettings context, this method is async and should be awaited. Supports batch execution when the first argument is a list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments passed to forward(). |
()
|
**kwargs
|
Any
|
Keyword arguments passed to forward(). |
{}
|
Returns:
| Name | Type | Description |
|---|---|---|
Any
|
If tracing: A Value with ref set to the node ID, representing the eventual output of this call. Dependencies are collected from Value.ref attributes in the arguments. |
|
Any
|
If bound/context: A coroutine that yields the execution result. |
|
Otherwise |
Any
|
The result from forward(). |
Example
class Doubler(Module): ... def forward(self, x: int) -> int: ... return x * 2 ... doubler = Doubler() doubler(5) # Without trace context, calls forward() directly 10
Example with bound resources
pipeline = MyPipeline().bind(resources=config) result = await pipeline("input") # Async execution
Example with ExecutionSettings
async with ExecutionSettings(resources=config): ... result = await pipeline("input")
Note
During tracing, the tracer records this call as a node in the execution graph. The forward() method is not called; instead, dependencies are tracked based on Value refs.
__init__()
Initialize the module with empty registries.
Sets up internal dictionaries for tracking child modules and parameters. Uses object.setattr to avoid triggering the custom setattr during initialization.
__setattr__(name, value)
Set an attribute with automatic registration of modules and parameters.
When a value is assigned to an attribute: - If it's an Module, it's registered as a child module - If it's a Parameter, it's registered in the parameters dict - If it's a ParameterList or ParameterDict, it's registered for iteration - The value's _name is set to the attribute name for introspection
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The attribute name. |
required |
value
|
Any
|
The value to assign. |
required |
Note
This method is called for all attribute assignments, including those in init. Internal attributes (starting with '_') that are not modules or parameters are set directly.
backward(feedback, ctx)
async
Propagate feedback backward through this module.
Default implementation passes feedback unchanged to all inputs. Override for custom backward logic that generates more targeted feedback for specific inputs or parameters.
This method is called during the backward pass initiated by
Value.backward(). The ctx parameter provides access to the
forward pass context including inputs, outputs, and the
computation graph.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feedback
|
Any
|
Combined feedback from downstream nodes as a Value. |
required |
ctx
|
Any
|
BackwardContext with inputs, output, graph structure, and optional reasoning LLM for generating feedback. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
BackwardResult with input_feedback and parameter_feedback |
Any
|
dictionaries specifying how feedback should be distributed. |
Example
async def backward(self, feedback, ctx): ... from plait.optimization.backward import BackwardResult ... result = BackwardResult() ... ... # Pass feedback to all inputs unchanged ... for input_name in ctx.inputs: ... result.input_feedback[input_name] = feedback ... ... return result
Note
The default implementation passes feedback unchanged to all inputs. Override this method to implement custom feedback propagation logic, such as: - Generating parameter-specific feedback - Filtering feedback based on input relevance - Using ctx.reason() for LLM-powered feedback generation
bind(resources, max_concurrent=100, **kwargs)
Bind resources to this module for direct execution.
After binding, the module can be called directly with await: pipeline = MyPipeline().bind(resources=config) result = await pipeline("input")
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
resources
|
ResourceConfig | ResourceManager
|
Resource configuration or manager for LLM endpoints. |
required |
max_concurrent
|
int
|
Maximum concurrent tasks during execution. |
100
|
**kwargs
|
Any
|
Additional execution options (checkpoint_dir, etc.). |
{}
|
Returns:
| Type | Description |
|---|---|
Self
|
Self, for method chaining. |
Example
from plait.resources.config import ResourceConfig, EndpointConfig config = ResourceConfig(endpoints={ ... "fast": EndpointConfig(provider_api="openai", model="gpt-4o-mini") ... }) pipeline = MyPipeline().bind(resources=config) result = await pipeline("Hello!")
Note
Bound resources and config can be overridden per-call by passing keyword arguments to call, or by using ExecutionSettings context.
children()
Iterate over immediate child modules.
Yields child modules in the order they were registered. Does not recurse into nested modules.
Yields:
| Type | Description |
|---|---|
Module
|
Each immediate child Module. |
Example
class Parent(Module): ... def init(self): ... super().init() ... self.child1 = Module() ... self.child2 = Module() ... parent = Parent() list(parent.children()) # doctest: +ELLIPSIS [<...Module...>, <...Module...>]
direct_parameters()
Iterate over parameters directly owned by this module.
Yields parameters assigned on this module itself, plus parameters in ParameterList and ParameterDict containers attached directly to it. Does not recurse into child modules.
Returns:
| Type | Description |
|---|---|
Iterator[Parameter]
|
An iterator of Parameters directly owned by this module. |
eval()
Set the module to evaluation mode.
In evaluation mode, forward passes return raw values without tape wrapping. This is the default mode and is used during inference when backward passes are not needed.
This method recursively sets all child modules to evaluation mode.
Returns:
| Type | Description |
|---|---|
Self
|
Self, for method chaining. |
Example
module.train() # Enable training mode module.eval() # Disable training mode output = await module("Hello") isinstance(output, str) # Raw value, not Value True
Note
Equivalent to calling .train(False).
forward(*args, **kwargs)
Define the inference computation.
Override this method to implement your module's logic. During tracing, this receives Value objects representing symbolic values. During execution, this receives actual values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments for the computation. |
()
|
**kwargs
|
Any
|
Keyword arguments for the computation. |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
The result of the inference computation. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
If not overridden in a subclass. |
Example
class Greeter(Module): ... def forward(self, name: str) -> str: ... return f"Hello, {name}!" ... greeter = Greeter() greeter("World") 'Hello, World!'
load_state_dict(state_dict)
Load parameter values from a dictionary.
Used for restoring learned prompts/instructions from a saved state. The keys in state_dict must match the hierarchical parameter names from this module's named_parameters().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state_dict
|
dict[str, str]
|
Dictionary mapping parameter names to their values. |
required |
Raises:
| Type | Description |
|---|---|
KeyError
|
If a key in state_dict does not match any parameter in this module. Missing keys in state_dict are silently ignored (partial loads are allowed). |
Example
from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("original") ... module = MyModule() module.load_state_dict({"prompt": "updated"}) module.prompt.value 'updated'
Example with unknown key
from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("test") ... module = MyModule() module.load_state_dict({"unknown": "value"}) Traceback (most recent call last): ... KeyError: 'Unknown parameter: unknown'
Note
This method modifies the parameter values in-place. If you need to preserve the original values, use state_dict() first to save them.
modules()
Iterate over all modules in the tree, including self.
Performs a depth-first traversal starting from this module. Includes this module as the first item yielded.
Yields:
| Type | Description |
|---|---|
Module
|
All Modules in the subtree rooted at this module. |
Example
class Nested(Module): ... def init(self): ... super().init() ... self.inner = Module() ... class Outer(Module): ... def init(self): ... super().init() ... self.nested = Nested() ... outer = Outer() len(list(outer.modules())) 3
named_children()
Iterate over immediate child modules with their names.
Yields (name, module) pairs for each immediate child. Does not recurse into nested modules.
Yields:
| Type | Description |
|---|---|
tuple[str, Module]
|
Tuples of (attribute_name, child_module). |
Example
class Parent(Module): ... def init(self): ... super().init() ... self.child1 = Module() ... parent = Parent() [(name, type(m).name) for name, m in parent.named_children()][('child1', 'Module')]
named_modules(prefix='')
Iterate over all modules with hierarchical dot-separated names.
Performs a depth-first traversal, yielding (name, module) pairs. Names are hierarchical, e.g., "layer1.sublayer.module".
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prefix
|
str
|
Prefix to prepend to all names. Used internally for recursive calls to build hierarchical names. |
''
|
Yields:
| Type | Description |
|---|---|
str
|
Tuples of (hierarchical_name, module). The root module |
Module
|
has an empty string name (or the prefix if provided). |
Example
class Inner(Module): ... def init(self): ... super().init() ... class Outer(Module): ... def init(self): ... super().init() ... self.inner = Inner() ... outer = Outer() [(name, type(m).name) for name, m in outer.named_modules()][('', 'Outer'), ('inner', 'Inner')]
named_parameters(prefix='', remove_duplicate=True)
Iterate over all parameters with hierarchical dot-separated names.
Recursively yields (name, parameter) pairs from this module and all descendants. Names reflect the module hierarchy. Also yields parameters from ParameterList and ParameterDict containers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prefix
|
str
|
Prefix to prepend to parameter names. Used internally for recursive calls to build hierarchical names. |
''
|
remove_duplicate
|
bool
|
If True, yield each Parameter instance only once. |
True
|
Yields:
| Type | Description |
|---|---|
tuple[str, Parameter]
|
Tuples of (hierarchical_name, parameter). |
Example
from plait.parameter import Parameter class Inner(Module): ... def init(self): ... super().init() ... self.weight = Parameter("w") ... class Outer(Module): ... def init(self): ... super().init() ... self.bias = Parameter("b") ... self.inner = Inner() ... outer = Outer() [(name, p.value) for name, p in outer.named_parameters()][('bias', 'b'), ('inner.weight', 'w')]
parameters(remove_duplicate=True)
Iterate over all parameters in the module tree.
Recursively yields parameters from this module and all descendant modules in depth-first order. Also yields parameters from ParameterList and ParameterDict containers.
Shared Parameter instances are de-duplicated by default, matching PyTorch's behavior for shared parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
remove_duplicate
|
bool
|
If True, yield each Parameter instance only once. |
True
|
Yields:
| Type | Description |
|---|---|
Parameter
|
All Parameter objects in the subtree. |
Example
from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("test") ... module = MyModule() list(module.parameters()) # doctest: +ELLIPSIS [Parameter(value='test', ...)]
run_sync(*args, **kwargs)
Execute synchronously (blocking).
Convenience method for scripts and notebooks where async isn't needed. Blocks until execution completes and returns the result.
This method requires that resources are available either through:
- Prior bind() call on this module, or
- An active ExecutionSettings context
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Any
|
Positional arguments passed to forward(). |
()
|
**kwargs
|
Any
|
Keyword arguments passed to forward(). |
{}
|
Returns:
| Type | Description |
|---|---|
Any
|
Single result for single input, list for batch input. |
Raises:
| Type | Description |
|---|---|
RuntimeError
|
If called from within an async context (would block the event loop), or if no resources are available. |
Example
pipeline = MyPipeline().bind(resources=config) result = pipeline.run_sync("Hello") results = pipeline.run_sync(["a", "b", "c"])
Example with ExecutionSettings
with ExecutionSettings(resources=config): ... result = pipeline.run_sync("Hello")
Note
Use await module(...) in async code instead. This method is
intended for synchronous scripts and REPL environments only.
state_dict()
Return a dictionary of all parameter values.
Used for saving learned prompts/instructions after optimization. Keys are hierarchical parameter names (e.g., "summarizer.system_prompt"), matching the output of named_parameters().
Returns:
| Type | Description |
|---|---|
dict[str, str]
|
A dictionary mapping parameter names to their string values. |
Example
from plait.parameter import Parameter class Inner(Module): ... def init(self): ... super().init() ... self.weight = Parameter("w") ... class Outer(Module): ... def init(self): ... super().init() ... self.bias = Parameter("b") ... self.inner = Inner() ... outer = Outer() outer.state_dict() {'bias': 'b', 'inner.weight': 'w'}
Note
The returned dict can be serialized to JSON/pickle and later restored with load_state_dict().
train(mode=True)
Set the module to training mode.
In training mode, forward passes return Value objects with tape ids attached, enabling implicit record flow through the training pipeline. This eliminates manual record management.
This method recursively sets all child modules to the same mode.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mode
|
bool
|
If True, enable training mode. If False, disable it. Defaults to True. |
True
|
Returns:
| Type | Description |
|---|---|
Self
|
Self, for method chaining. |
Example
module = MyModule().bind(resources) module.train() # Enable training mode output = await module("Hello") hasattr(output, "meta") and "_tape_ids" in output.meta True
Example with chaining
module.train().bind(resources) result = await module("input")
Note
Use .eval() to switch back to evaluation mode where
raw values are returned without tape wrapping.
ModuleDict
Bases: Module, Mapping[str, Module]
A dict-like container for named module access.
Provides dict-like operations (keys, values, items, etc.) while properly registering modules for parameter collection. This mirrors PyTorch's nn.ModuleDict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Mapping[str, Module] | Iterable[tuple[str, Module]] | None
|
Optional dict or iterable of (key, module) pairs. |
None
|
Example
modules = ModuleDict({ ... 'encoder': Encoder(), ... 'decoder': Decoder() ... }) output = modules'encoder' modules.keys() dict_keys(['encoder', 'decoder'])
Note
ModuleDict does not define a forward() method since access patterns depend on use case. Access modules by key and call them directly.
__contains__(key)
Check if a key is in the dict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
object
|
The key to check for. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if key is in the dict, False otherwise. |
__delitem__(key)
Delete a module by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The key to delete. |
required |
Raises:
| Type | Description |
|---|---|
KeyError
|
If key is not in the dict. |
__getattr__(name)
Get a module by name as attribute access.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of the module to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
Module
|
The module with the given name. |
Raises:
| Type | Description |
|---|---|
AttributeError
|
If no module with that name exists. |
__getitem__(key)
Get a module by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The key of the module to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
Module
|
The module with the given key. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If key is not in the dict. |
__init__(modules=None)
Initialize the ModuleDict container.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Mapping[str, Module] | Iterable[tuple[str, Module]] | None
|
Optional mapping (dict, ModuleDict, etc.) or iterable of (key, module) pairs. |
None
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If any value is not a Module instance. |
__iter__()
Iterate over module keys.
Yields:
| Type | Description |
|---|---|
str
|
Each key in the dict. |
__len__()
Return the number of modules in the dict.
Returns:
| Type | Description |
|---|---|
int
|
The number of modules. |
__setitem__(key, module)
Set a module at the given key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The key to set. |
required |
module
|
Module
|
The module to set. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If module is not a Module instance. |
clear()
Remove all modules from the dict.
forward(x)
Forward is not implemented for ModuleDict.
ModuleDict does not define a forward method since access patterns depend on use case.
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Always raised. |
items()
Return a view of the (key, module) pairs.
Returns:
| Type | Description |
|---|---|
Any
|
A dict_items view of the items. |
keys()
Return a view of the module keys.
Returns:
| Type | Description |
|---|---|
Any
|
A dict_keys view of the keys. |
pop(key, default=None)
update(modules)
Update the dict with modules from another mapping or iterable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Mapping[str, Module] | Iterable[tuple[str, Module]]
|
Mapping (dict, ModuleDict, etc.) or iterable of (key, module) pairs. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any value is not a Module instance. |
values()
Return a view of the modules.
Returns:
| Type | Description |
|---|---|
Any
|
A dict_values view of the modules. |
ModuleList
Bases: Module
A list-like container for modules.
Provides list-like operations (append, extend, insert, etc.) while properly registering modules for parameter collection. This mirrors PyTorch's nn.ModuleList.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Iterable[Module] | None
|
Optional iterable of modules to initialize with. |
None
|
Example
layers = ModuleList([Layer() for _ in range(3)]) for layer in layers: ... x = layer(x) layers.append(AnotherLayer()) len(layers) 4
Note
ModuleList does not define a forward() method since the iteration pattern depends on use case. Use iteration or indexing to access and execute modules.
__contains__(module)
Check if a module is in the list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to check for. |
required |
Returns:
| Type | Description |
|---|---|
bool
|
True if module is in the list, False otherwise. |
__delitem__(idx)
Delete a module at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
The index to delete. |
required |
Raises:
| Type | Description |
|---|---|
IndexError
|
If index is out of range. |
__getitem__(idx)
Get module(s) by index or slice.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int | slice
|
Integer index or slice. |
required |
Returns:
| Type | Description |
|---|---|
Module | ModuleList
|
Single module for integer index, new ModuleList for slice. |
Raises:
| Type | Description |
|---|---|
IndexError
|
If index is out of range. |
Note
Slicing returns a new ModuleList containing views of the same module objects, but does NOT reparent them. The modules remain children of the original container.
__init__(modules=None)
Initialize the ModuleList container.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Iterable[Module] | None
|
Optional iterable of modules to add. |
None
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If any item in modules is not a Module. |
__iter__()
__len__()
Return the number of modules in the list.
Returns:
| Type | Description |
|---|---|
int
|
The number of modules. |
__setitem__(idx, module)
Set a module at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
The index to set. |
required |
module
|
Module
|
The module to set. |
required |
Raises:
| Type | Description |
|---|---|
IndexError
|
If index is out of range. |
TypeError
|
If module is not a Module instance. |
append(module)
Append a module to the end of the list.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to append. |
required |
Returns:
| Type | Description |
|---|---|
ModuleList
|
Self, for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If module is not a Module instance. |
extend(modules)
Extend the list with modules from an iterable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
modules
|
Iterable[Module]
|
Iterable of modules to add. |
required |
Returns:
| Type | Description |
|---|---|
ModuleList
|
Self, for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If any item is not a Module instance. |
forward(x)
Forward is not implemented for ModuleList.
ModuleList does not define a forward method since the iteration pattern depends on use case.
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Always raised. |
insert(idx, module)
Insert a module at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
The index to insert at. |
required |
module
|
Module
|
The module to insert. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If module is not a Module instance. |
pop(idx=-1)
Remove and return the module at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int
|
The index to pop from. Defaults to -1 (last). |
-1
|
Returns:
| Type | Description |
|---|---|
Module
|
The removed module. |
Raises:
| Type | Description |
|---|---|
IndexError
|
If index is out of range or list is empty. |
Parameter
dataclass
A learnable value that can be optimized via backward passes.
Similar to torch.nn.Parameter, but for values (prompts, instructions, structured configs, etc.) that are optimized via LLM feedback rather than gradient descent.
The description field is required when requires_grad=True to enable the optimizer to understand how to improve the parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
Any
|
The current value of the parameter (string, dict, list, etc.). |
required |
description
|
str | None
|
A description of what this parameter does/represents. Required when requires_grad=True to enable self-documenting optimization. Can be None when requires_grad=False. |
None
|
requires_grad
|
bool
|
If True, feedback will be accumulated during backward passes and description is required. If False, the parameter is treated as a constant. |
True
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If requires_grad=True but description is None. |
Example
param = Parameter( ... "You are a helpful assistant.", ... description="Defines the agent's identity and baseline behavior." ... ) str(param) 'You are a helpful assistant.' param.description "Defines the agent's identity and baseline behavior." param.accumulate_feedback("Be more concise") param.get_accumulated_feedback() ['Be more concise'] param.apply_update("You are a concise, helpful assistant.") str(param) 'You are a concise, helpful assistant.'
Constant parameter (no description required)
const = Parameter({"model": "gpt-4"}, requires_grad=False) const.requires_grad False
__post_init__()
Validate that description is provided when requires_grad=True.
Raises:
| Type | Description |
|---|---|
ValueError
|
If requires_grad=True but description is None. |
__str__()
Return the current value as a string.
Returns:
| Type | Description |
|---|---|
str
|
The string representation of the parameter value. |
accumulate_feedback(feedback)
Collect feedback from backward passes.
Feedback is only accumulated if requires_grad is True.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feedback
|
str | Value
|
The feedback string or Value to accumulate. |
required |
apply_update(new_value)
Apply an optimizer-computed update.
Updates the parameter value and clears the feedback buffer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
new_value
|
Any
|
The new value to set. |
required |
get_accumulated_feedback()
Get all accumulated feedback.
Returns:
| Type | Description |
|---|---|
list[str]
|
A copy of the list of accumulated feedback strings. |
zero_feedback()
Clear accumulated feedback without updating the value.
Similar to zero_grad() in PyTorch, this clears the feedback buffer to prepare for a new backward pass.
ParameterDict
Bases: MutableMapping[str, 'Parameter']
A dict-like container for Parameters.
Holds a dictionary of Parameters that will be properly collected by Module.parameters() and Module.named_parameters(). This is analogous to torch.nn.ParameterDict.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameters
|
dict[str, Parameter] | Iterable[tuple[str, Parameter]] | None
|
Optional mapping or iterable of (key, Parameter) pairs. |
None
|
Example
class MultiTask(Module): ... def init(self): ... super().init() ... self.prompts = ParameterDict({ ... "summarize": Parameter("Summarize:", description="Summary prompt"), ... "translate": Parameter("Translate:", description="Translation prompt"), ... }) ... m = MultiTask() list(m.named_parameters()) [('prompts.summarize', Parameter(...)), ('prompts.translate', Parameter(...))]
Note
The container itself is not a Parameter, but it provides iteration methods that Module uses to collect the contained Parameters.
__delitem__(key)
Delete a parameter by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The parameter's key. |
required |
__getitem__(key)
Get a parameter by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The parameter's key. |
required |
Returns:
| Type | Description |
|---|---|
Parameter
|
The Parameter at the given key. |
__init__(parameters=None)
__iter__()
Iterate over parameter keys.
Yields:
| Type | Description |
|---|---|
str
|
Each key in the dict. |
__len__()
Return the number of parameters.
Returns:
| Type | Description |
|---|---|
int
|
Number of parameters in the dict. |
__repr__()
Return string representation.
Returns:
| Type | Description |
|---|---|
str
|
String representation of the ParameterDict. |
__setitem__(key, value)
Set a parameter by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
The parameter's key. |
required |
value
|
Parameter
|
The Parameter to store. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If value is not a Parameter. |
named_parameters(prefix='')
Iterate over parameters with their names.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prefix
|
str
|
Prefix to prepend to parameter names. |
''
|
Yields:
| Type | Description |
|---|---|
tuple[str, Parameter]
|
Tuples of (name, parameter). |
parameters()
ParameterList
Bases: MutableSequence['Parameter']
A list-like container for Parameters.
Holds a list of Parameters that will be properly collected by Module.parameters() and Module.named_parameters(). This is analogous to torch.nn.ParameterList.
Parameters are named by their index in the list (e.g., "0", "1", "2").
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameters
|
Iterable[Parameter] | None
|
Optional iterable of Parameter objects to initialize with. |
None
|
Example
class MultiPrompt(Module): ... def init(self): ... super().init() ... self.prompts = ParameterList([ ... Parameter("Be concise", description="Style prompt"), ... Parameter("Be helpful", description="Tone prompt"), ... ]) ... m = MultiPrompt() list(m.named_parameters()) [('prompts.0', Parameter(...)), ('prompts.1', Parameter(...))]
Note
The container itself is not a Parameter, but it provides iteration methods that Module uses to collect the contained Parameters.
__delitem__(index)
Delete parameter(s) by index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int | slice
|
Integer index or slice. |
required |
__getitem__(index)
__init__(parameters=None)
Initialize the ParameterList.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
parameters
|
Iterable[Parameter] | None
|
Optional iterable of Parameter objects. |
None
|
__len__()
Return the number of parameters.
Returns:
| Type | Description |
|---|---|
int
|
Number of parameters in the list. |
__repr__()
Return string representation.
Returns:
| Type | Description |
|---|---|
str
|
String representation of the ParameterList. |
__setitem__(index, value)
insert(index, value)
Insert a parameter at the given index.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
Index to insert at. Negative indices work like Python list.insert(): insert(-1, x) inserts before the last element, and very negative indices (beyond the start) insert at the beginning. |
required |
value
|
Parameter
|
Parameter to insert. |
required |
Raises:
| Type | Description |
|---|---|
TypeError
|
If value is not a Parameter. |
named_parameters(prefix='')
Iterate over parameters with their names.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prefix
|
str
|
Prefix to prepend to parameter names. |
''
|
Yields:
| Type | Description |
|---|---|
tuple[str, Parameter]
|
Tuples of (name, parameter). |
parameters()
Sequential
Bases: Module
A sequential container that chains modules together.
Modules are executed in order, with each module's output passed as input to the next module. This mirrors PyTorch's nn.Sequential.
Supports two initialization styles: 1. Positional arguments: Sequential(mod1, mod2, mod3) 2. OrderedDict for named access: Sequential(OrderedDict([('name', mod)]))
When using OrderedDict, modules can be accessed by name as attributes.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Module | OrderedDict[str, Module]
|
Either positional Module instances, or a single OrderedDict mapping names to modules. |
()
|
Example
Positional args
pipeline = Sequential( ... Preprocessor(), ... Analyzer(), ... Formatter() ... ) len(pipeline) 3 pipeline[0] # First module
Example with OrderedDict
from collections import OrderedDict pipeline = Sequential(OrderedDict([ ... ('preprocess', Preprocessor()), ... ('analyze', Analyzer()), ... ])) pipeline.preprocess # Named access
Note
The forward() method passes the input through each module in sequence, so each module must accept a single argument matching the previous module's output type.
__getattr__(name)
Get a module by name for named Sequential containers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
The name of the module to retrieve. |
required |
Returns:
| Type | Description |
|---|---|
Module
|
The module with the given name. |
Raises:
| Type | Description |
|---|---|
AttributeError
|
If no module with that name exists. |
__getitem__(idx)
Get module(s) by index or slice.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int | slice
|
Integer index or slice. |
required |
Returns:
| Type | Description |
|---|---|
Module | Sequential
|
Single module for integer index, new Sequential for slice. |
Raises:
| Type | Description |
|---|---|
IndexError
|
If index is out of range. |
TypeError
|
If idx is not int or slice. |
Note
Slicing returns a new Sequential containing views of the same module objects, but does NOT reparent them. The modules remain children of the original container.
__init__(*args)
Initialize the Sequential container.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
*args
|
Module | OrderedDict[str, Module]
|
Either positional Module instances, or a single OrderedDict mapping names to modules. |
()
|
Raises:
| Type | Description |
|---|---|
TypeError
|
If a non-Module argument is provided (except OrderedDict). |
ValueError
|
If OrderedDict is provided with positional args. |
__iter__()
__len__()
Return the number of modules in the container.
Returns:
| Type | Description |
|---|---|
int
|
The number of modules. |
append(module)
Append a module to the end of the sequence.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
module
|
Module
|
The module to append. |
required |
Returns:
| Type | Description |
|---|---|
Sequential
|
Self, for method chaining. |
Raises:
| Type | Description |
|---|---|
TypeError
|
If module is not a Module instance. |
forward(x)
Execute modules sequentially, chaining outputs to inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Any
|
Input to the first module. |
required |
Returns:
| Type | Description |
|---|---|
Any
|
Output of the last module. |
Note
Each module's output becomes the next module's input.