Skip to content

API Reference

The API reference is generated directly from Google-style docstrings.

plait

plait: A PyTorch-inspired framework for LLM inference pipelines.

This package provides tools for building, executing, and optimizing complex LLM inference pipelines with automatic DAG capture and maximum throughput through async execution.

ExecutionSettings dataclass

Context manager for shared execution configuration.

Provides default settings for all module executions within the context. Bound module settings take precedence over context settings. Settings can be nested, with inner contexts overriding specific fields from outer contexts.

Attributes:

Name Type Description
resources ResourceConfig | ResourceManager | None

Optional ResourceConfig or ResourceManager for LLM endpoints. Modules will use these resources for LLM execution.

checkpoint_dir Path | str | None

Optional directory for saving execution checkpoints. When set, a CheckpointManager is created automatically.

max_concurrent int

Maximum number of concurrent tasks. Defaults to 100.

scheduler Scheduler | None

Optional custom Scheduler instance. When None, a new Scheduler is created for each execution.

on_task_complete Callable[[str, TaskResult], None] | None

Optional callback invoked when a task completes. Receives the node_id and TaskResult.

on_task_failed Callable[[str, Exception], None] | None

Optional callback invoked when a task fails. Receives the node_id and the exception.

task_timeout float | None

Maximum seconds for a single task (LLM call) before timeout. When set, tasks exceeding this duration raise TimeoutError and dependent nodes are cancelled. None means no timeout (default). Recommended: 60-300 seconds depending on model and prompt length.

max_task_retries int

Maximum retry attempts for transient failures. Retries apply to TransientError, not to permanent errors (4xx) or rate limits (handled separately). Default 0 means no retries.

task_retry_delay float

Base delay in seconds between retry attempts. Uses exponential backoff: delay doubles each retry. E.g., with delay=1.0, retries occur at 1s, 2s, 4s, etc. Defaults to 1.0.

streaming bool

When True, batch calls return an async iterator yielding BatchResult objects as they complete. When False (default), batch calls return a list of all results.

preserve_order bool

When True, streaming results are yielded in input order (may wait on slower items). When False (default), results yield as soon as they complete for maximum throughput. Only applies when streaming=True.

on_progress Callable[[int, int], None] | None

Optional callback for batch progress updates. Called with (completed_count, total_count) after each input completes. Works with both streaming and non-streaming batch execution.

profile bool

Whether to enable profiling. Defaults to False. When enabled, a TraceProfiler is created and task execution is recorded.

profile_path Path | str | None

Path for saving profile traces. If None with profile=True, uses './traces/trace_{timestamp}.json'. Defaults to None.

profile_counters bool

Whether to include counter events in trace. Defaults to True.

profile_include_args bool

Whether to include task args in trace. Defaults to True.

profiler TraceProfiler | None

The TraceProfiler instance when profiling is enabled. Access via get_profiler() after entering the context.

Example

with ExecutionSettings( ... resources=resource_config, ... checkpoint_dir="/data/checkpoints", ... max_concurrent=50, ... ): ... # All module executions in this block share these settings ... result1 = await pipeline1(input1) ... result2 = await pipeline2(input2)

Example with callbacks

def on_complete(node_id: str, result): ... print(f"Completed: {node_id}")

with ExecutionSettings(on_task_complete=on_complete): ... result = await pipeline(input)

Example with streaming

async with ExecutionSettings(resources=config, streaming=True): ... async for result in pipeline(["doc1", "doc2"]): ... if result.ok: ... await send_to_client(result.output) ... else: ... logger.error(f"Input {result.index} failed")

Example with progress tracking

def on_progress(done: int, total: int) -> None: ... print(f"Progress: {done}/{total}")

with ExecutionSettings(resources=config, on_progress=on_progress): ... results = pipeline.run_sync(documents)

Note

ExecutionSettings supports both sync and async context managers. Use with for synchronous code and async with for async code. The async version properly awaits checkpoint flushing on exit.

profiler property

Get the profiler for this context.

Shorthand for get_profiler(). The profiler is created automatically when entering a context with profile=True.

Returns:

Type Description
TraceProfiler | None

The TraceProfiler instance, or None if profiling is not enabled.

Example

async with ExecutionSettings(profile=True) as settings: ... if settings.profiler: ... settings.profiler.add_instant_event("custom_marker")

__aenter__() async

Enter the execution settings context (asynchronous).

Activates this settings context, making it available via get_execution_settings(). Creates a CheckpointManager if checkpoint_dir is set.

Returns:

Type Description
Self

This ExecutionSettings instance.

Example

async with ExecutionSettings(checkpoint_dir="/ckpt") as settings: ... result = await pipeline(input)

Checkpoints are flushed on exit

__aexit__(exc_type, exc_val, exc_tb) async

Exit the execution settings context (asynchronous).

Flushes any pending checkpoints, exports profiler trace, and resets the context variable to the previous state.

Parameters:

Name Type Description Default
exc_type type[BaseException] | None

Exception type if an exception was raised.

required
exc_val BaseException | None

Exception value if an exception was raised.

required
exc_tb object

Exception traceback if an exception was raised.

required

__enter__()

Enter the execution settings context (synchronous).

Activates this settings context, making it available via get_execution_settings(). Creates a CheckpointManager if checkpoint_dir is set.

Returns:

Type Description
Self

This ExecutionSettings instance.

Example

with ExecutionSettings(max_concurrent=10) as settings: ... assert get_execution_settings() is settings

__exit__(exc_type, exc_val, exc_tb)

Exit the execution settings context (synchronous).

Resets the context variable to the previous state. If a CheckpointManager was created, it is NOT flushed in sync mode (use async context manager for proper cleanup). Profiler traces are exported if profiling was enabled.

Parameters:

Name Type Description Default
exc_type type[BaseException] | None

Exception type if an exception was raised.

required
exc_val BaseException | None

Exception value if an exception was raised.

required
exc_tb object

Exception traceback if an exception was raised.

required
Note

For proper checkpoint flushing, use the async context manager. The sync version is primarily for simple configurations without active checkpointing.

get_checkpoint_dir()

Get the effective checkpoint directory.

Checks this context and parent contexts for checkpoint_dir.

Returns:

Type Description
Path | None

The checkpoint directory as a Path, or None if not set.

get_checkpoint_manager()

Get the checkpoint manager for this context.

The checkpoint manager is created automatically when entering a context with checkpoint_dir set. For nested contexts without their own checkpoint_dir, returns the parent's manager.

Returns:

Type Description
CheckpointManager | None

The CheckpointManager instance, or None if no checkpointing.

get_max_concurrent()

Get the effective max_concurrent setting.

Returns:

Type Description
int

The max_concurrent value, defaulting to 100.

get_max_task_retries()

Get the effective max_task_retries setting.

Returns:

Type Description
int

The maximum number of retry attempts for transient failures.

get_on_progress()

Get the effective on_progress callback.

Checks this context and parent contexts for the callback.

Returns:

Type Description
Callable[[int, int], None] | None

The on_progress callback, or None if not set.

get_preserve_order()

Get the effective preserve_order setting.

Returns:

Type Description
bool

True if results should be yielded in input order, False otherwise.

get_profiler()

Get the profiler for this context.

The profiler is created automatically when entering a context with profile=True. For nested contexts without their own profile setting, returns the parent's profiler.

Returns:

Type Description
TraceProfiler | None

The TraceProfiler instance, or None if profiling is not enabled.

Example

async with ExecutionSettings(profile=True) as settings: ... profiler = settings.get_profiler() ... if profiler: ... profiler.add_instant_event("custom_marker")

get_resources()

Get the effective resources configuration.

Checks this context and parent contexts for resources.

Returns:

Type Description
ResourceConfig | ResourceManager | None

The ResourceConfig or ResourceManager, or None if not set.

get_scheduler()

Get the effective scheduler.

Checks this context and parent contexts for a custom scheduler.

Returns:

Type Description
Scheduler | None

The Scheduler instance, or None if not set.

get_streaming()

Get the effective streaming setting.

Returns:

Type Description
bool

True if streaming mode is enabled, False otherwise.

get_task_retry_delay()

Get the effective task_retry_delay setting.

Returns:

Type Description
float

The base delay in seconds between retry attempts.

get_task_timeout()

Get the effective task_timeout setting.

Checks this context and parent contexts for the timeout value.

Returns:

Type Description
float | None

The task timeout in seconds, or None if no timeout.

LLMInference

Bases: Module

Atomic module for LLM API calls.

This is the fundamental building block for LLM operations. All other modules ultimately compose LLMInference instances to build complex inference pipelines.

The alias parameter decouples the module from specific endpoints, allowing the same module to run against different models/endpoints based on resource configuration at runtime.

Parameters:

Name Type Description Default
alias str

Resource binding key that maps to an endpoint configuration. This allows the same module to use different LLM providers depending on the ResourceConfig passed to run().

required
system_prompt str | Parameter

System prompt for the LLM. Can be a string (converted to a non-learnable Parameter) or a Parameter instance (for learnable prompts). Empty string results in no system prompt.

''
temperature float

Sampling temperature for the LLM. Higher values produce more random outputs. Defaults to 1.0.

1.0
max_tokens int | None

Maximum number of tokens to generate. None means no limit (use model default).

None
response_format type | None

Expected response format type for structured output. None means plain text response.

None
Example

llm = LLMInference(alias="fast_llm", temperature=0.7) llm.alias 'fast_llm' llm.temperature 0.7

Example with system prompt

llm = LLMInference( ... alias="assistant", ... system_prompt="You are a helpful assistant.", ... temperature=0.5, ... ) llm.system_prompt.value 'You are a helpful assistant.' llm.system_prompt.requires_grad False

Note

LLMInference.forward() should not be called directly. Use the run() function to execute modules, which handles tracing and resource management.

__init__(alias, system_prompt='', temperature=1.0, max_tokens=None, response_format=None)

Initialize the LLMInference module.

Parameters:

Name Type Description Default
alias str

Resource binding key for endpoint resolution.

required
system_prompt str | Parameter

System prompt string or Parameter.

''
temperature float

Sampling temperature (0.0 to 2.0 typical).

1.0
max_tokens int | None

Maximum tokens to generate.

None
response_format type | None

Type for structured output parsing.

None

backward(feedback, ctx) async

Backward pass for LLM inference.

Generates feedback for both the input prompt and any learnable parameters (like system_prompt). The parameter feedback includes context about what the LLM received and produced to help the optimizer understand how to improve the parameter.

Parameters:

Name Type Description Default
feedback Any

Combined feedback from downstream nodes.

required
ctx Any

BackwardContext with inputs, output, and graph structure.

required

Returns:

Type Description
Any

BackwardResult with:

Any
  • input_feedback["prompt"]: Value feedback for the input prompt
Any
  • parameter_feedback["system_prompt"]: Feedback for the system prompt if it's a learnable Parameter
Example
LLMInference backward is called automatically during
Value.backward() when the module is in the graph

output, record = await run(llm_module, "Hello", record=True) loss_val = await loss_fn(output, target) await loss_val.backward() # Calls llm_module.backward()

Note

The parameter feedback includes: - The current system prompt value - The parameter description - A sample of the input and output - The feedback received This context helps the optimizer generate targeted improvements.

forward(prompt)

Execute the LLM call.

This method should not be called directly. During tracing, the tracer intercepts calls and records them in the graph. During execution, the runtime handles the actual API call through the ResourceManager.

Parameters:

Name Type Description Default
prompt str

The user prompt to send to the LLM.

required

Returns:

Type Description
str

The LLM's response text.

Raises:

Type Description
RuntimeError

Always raised because direct execution is not supported. Use run() to execute modules.

Note

The runtime replaces this with actual LLM calls. This placeholder exists to define the expected signature and to catch accidental direct invocations.

Module

Base class for all inference operations.

Analogous to torch.nn.Module. Subclass this to define custom inference logic by implementing the forward() method.

Child modules and parameters assigned as attributes are automatically registered, enabling recursive traversal and parameter collection.

Example

from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("You are helpful.") ... module = MyModule() "prompt" in module._parameters True

Note

Always call super().init() in subclass init methods to ensure proper registration of children and parameters.

training property

Whether the module is in training mode.

In training mode, forward passes return Value objects with tape ids attached, enabling automatic backward propagation.

Returns:

Type Description
bool

True if the module is in training mode, False otherwise.

Example

module = MyModule() module.training False module.train() module.training True

__call__(*args, **kwargs)

Execute the module.

Behavior depends on context: 1. If a trace context is active: records the call and returns a Value with ref pointing to the generated node ID (Value-driven tracing) 2. If resources are bound OR ExecutionSettings is active: traces and executes 3. Otherwise: executes forward() directly (for non-LLM modules)

When bound or in an ExecutionSettings context, this method is async and should be awaited. Supports batch execution when the first argument is a list.

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to forward().

()
**kwargs Any

Keyword arguments passed to forward().

{}

Returns:

Name Type Description
Any

If tracing: A Value with ref set to the node ID, representing the eventual output of this call. Dependencies are collected from Value.ref attributes in the arguments.

Any

If bound/context: A coroutine that yields the execution result.

Otherwise Any

The result from forward().

Example

class Doubler(Module): ... def forward(self, x: int) -> int: ... return x * 2 ... doubler = Doubler() doubler(5) # Without trace context, calls forward() directly 10

Example with bound resources

pipeline = MyPipeline().bind(resources=config) result = await pipeline("input") # Async execution

Example with ExecutionSettings

async with ExecutionSettings(resources=config): ... result = await pipeline("input")

Note

During tracing, the tracer records this call as a node in the execution graph. The forward() method is not called; instead, dependencies are tracked based on Value refs.

__init__()

Initialize the module with empty registries.

Sets up internal dictionaries for tracking child modules and parameters. Uses object.setattr to avoid triggering the custom setattr during initialization.

__setattr__(name, value)

Set an attribute with automatic registration of modules and parameters.

When a value is assigned to an attribute: - If it's an Module, it's registered as a child module - If it's a Parameter, it's registered in the parameters dict - If it's a ParameterList or ParameterDict, it's registered for iteration - The value's _name is set to the attribute name for introspection

Parameters:

Name Type Description Default
name str

The attribute name.

required
value Any

The value to assign.

required
Note

This method is called for all attribute assignments, including those in init. Internal attributes (starting with '_') that are not modules or parameters are set directly.

backward(feedback, ctx) async

Propagate feedback backward through this module.

Default implementation passes feedback unchanged to all inputs. Override for custom backward logic that generates more targeted feedback for specific inputs or parameters.

This method is called during the backward pass initiated by Value.backward(). The ctx parameter provides access to the forward pass context including inputs, outputs, and the computation graph.

Parameters:

Name Type Description Default
feedback Any

Combined feedback from downstream nodes as a Value.

required
ctx Any

BackwardContext with inputs, output, graph structure, and optional reasoning LLM for generating feedback.

required

Returns:

Type Description
Any

BackwardResult with input_feedback and parameter_feedback

Any

dictionaries specifying how feedback should be distributed.

Example

async def backward(self, feedback, ctx): ... from plait.optimization.backward import BackwardResult ... result = BackwardResult() ... ... # Pass feedback to all inputs unchanged ... for input_name in ctx.inputs: ... result.input_feedback[input_name] = feedback ... ... return result

Note

The default implementation passes feedback unchanged to all inputs. Override this method to implement custom feedback propagation logic, such as: - Generating parameter-specific feedback - Filtering feedback based on input relevance - Using ctx.reason() for LLM-powered feedback generation

bind(resources, max_concurrent=100, **kwargs)

Bind resources to this module for direct execution.

After binding, the module can be called directly with await: pipeline = MyPipeline().bind(resources=config) result = await pipeline("input")

Parameters:

Name Type Description Default
resources ResourceConfig | ResourceManager

Resource configuration or manager for LLM endpoints.

required
max_concurrent int

Maximum concurrent tasks during execution.

100
**kwargs Any

Additional execution options (checkpoint_dir, etc.).

{}

Returns:

Type Description
Self

Self, for method chaining.

Example

from plait.resources.config import ResourceConfig, EndpointConfig config = ResourceConfig(endpoints={ ... "fast": EndpointConfig(provider_api="openai", model="gpt-4o-mini") ... }) pipeline = MyPipeline().bind(resources=config) result = await pipeline("Hello!")

Example with additional options

pipeline = MyPipeline().bind( ... resources=config, ... max_concurrent=50, ... checkpoint_dir="/data/checkpoints", ... )

Note

Bound resources and config can be overridden per-call by passing keyword arguments to call, or by using ExecutionSettings context.

children()

Iterate over immediate child modules.

Yields child modules in the order they were registered. Does not recurse into nested modules.

Yields:

Type Description
Module

Each immediate child Module.

Example

class Parent(Module): ... def init(self): ... super().init() ... self.child1 = Module() ... self.child2 = Module() ... parent = Parent() list(parent.children()) # doctest: +ELLIPSIS [<...Module...>, <...Module...>]

direct_parameters()

Iterate over parameters directly owned by this module.

Yields parameters assigned on this module itself, plus parameters in ParameterList and ParameterDict containers attached directly to it. Does not recurse into child modules.

Returns:

Type Description
Iterator[Parameter]

An iterator of Parameters directly owned by this module.

eval()

Set the module to evaluation mode.

In evaluation mode, forward passes return raw values without tape wrapping. This is the default mode and is used during inference when backward passes are not needed.

This method recursively sets all child modules to evaluation mode.

Returns:

Type Description
Self

Self, for method chaining.

Example

module.train() # Enable training mode module.eval() # Disable training mode output = await module("Hello") isinstance(output, str) # Raw value, not Value True

Note

Equivalent to calling .train(False).

forward(*args, **kwargs)

Define the inference computation.

Override this method to implement your module's logic. During tracing, this receives Value objects representing symbolic values. During execution, this receives actual values.

Parameters:

Name Type Description Default
*args Any

Positional arguments for the computation.

()
**kwargs Any

Keyword arguments for the computation.

{}

Returns:

Type Description
Any

The result of the inference computation.

Raises:

Type Description
NotImplementedError

If not overridden in a subclass.

Example

class Greeter(Module): ... def forward(self, name: str) -> str: ... return f"Hello, {name}!" ... greeter = Greeter() greeter("World") 'Hello, World!'

load_state_dict(state_dict)

Load parameter values from a dictionary.

Used for restoring learned prompts/instructions from a saved state. The keys in state_dict must match the hierarchical parameter names from this module's named_parameters().

Parameters:

Name Type Description Default
state_dict dict[str, str]

Dictionary mapping parameter names to their values.

required

Raises:

Type Description
KeyError

If a key in state_dict does not match any parameter in this module. Missing keys in state_dict are silently ignored (partial loads are allowed).

Example

from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("original") ... module = MyModule() module.load_state_dict({"prompt": "updated"}) module.prompt.value 'updated'

Example with unknown key

from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("test") ... module = MyModule() module.load_state_dict({"unknown": "value"}) Traceback (most recent call last): ... KeyError: 'Unknown parameter: unknown'

Note

This method modifies the parameter values in-place. If you need to preserve the original values, use state_dict() first to save them.

modules()

Iterate over all modules in the tree, including self.

Performs a depth-first traversal starting from this module. Includes this module as the first item yielded.

Yields:

Type Description
Module

All Modules in the subtree rooted at this module.

Example

class Nested(Module): ... def init(self): ... super().init() ... self.inner = Module() ... class Outer(Module): ... def init(self): ... super().init() ... self.nested = Nested() ... outer = Outer() len(list(outer.modules())) 3

named_children()

Iterate over immediate child modules with their names.

Yields (name, module) pairs for each immediate child. Does not recurse into nested modules.

Yields:

Type Description
tuple[str, Module]

Tuples of (attribute_name, child_module).

Example

class Parent(Module): ... def init(self): ... super().init() ... self.child1 = Module() ... parent = Parent() [(name, type(m).name) for name, m in parent.named_children()][('child1', 'Module')]

named_modules(prefix='')

Iterate over all modules with hierarchical dot-separated names.

Performs a depth-first traversal, yielding (name, module) pairs. Names are hierarchical, e.g., "layer1.sublayer.module".

Parameters:

Name Type Description Default
prefix str

Prefix to prepend to all names. Used internally for recursive calls to build hierarchical names.

''

Yields:

Type Description
str

Tuples of (hierarchical_name, module). The root module

Module

has an empty string name (or the prefix if provided).

Example

class Inner(Module): ... def init(self): ... super().init() ... class Outer(Module): ... def init(self): ... super().init() ... self.inner = Inner() ... outer = Outer() [(name, type(m).name) for name, m in outer.named_modules()][('', 'Outer'), ('inner', 'Inner')]

named_parameters(prefix='', remove_duplicate=True)

Iterate over all parameters with hierarchical dot-separated names.

Recursively yields (name, parameter) pairs from this module and all descendants. Names reflect the module hierarchy. Also yields parameters from ParameterList and ParameterDict containers.

Parameters:

Name Type Description Default
prefix str

Prefix to prepend to parameter names. Used internally for recursive calls to build hierarchical names.

''
remove_duplicate bool

If True, yield each Parameter instance only once.

True

Yields:

Type Description
tuple[str, Parameter]

Tuples of (hierarchical_name, parameter).

Example

from plait.parameter import Parameter class Inner(Module): ... def init(self): ... super().init() ... self.weight = Parameter("w") ... class Outer(Module): ... def init(self): ... super().init() ... self.bias = Parameter("b") ... self.inner = Inner() ... outer = Outer() [(name, p.value) for name, p in outer.named_parameters()][('bias', 'b'), ('inner.weight', 'w')]

parameters(remove_duplicate=True)

Iterate over all parameters in the module tree.

Recursively yields parameters from this module and all descendant modules in depth-first order. Also yields parameters from ParameterList and ParameterDict containers.

Shared Parameter instances are de-duplicated by default, matching PyTorch's behavior for shared parameters.

Parameters:

Name Type Description Default
remove_duplicate bool

If True, yield each Parameter instance only once.

True

Yields:

Type Description
Parameter

All Parameter objects in the subtree.

Example

from plait.parameter import Parameter class MyModule(Module): ... def init(self): ... super().init() ... self.prompt = Parameter("test") ... module = MyModule() list(module.parameters()) # doctest: +ELLIPSIS [Parameter(value='test', ...)]

run_sync(*args, **kwargs)

Execute synchronously (blocking).

Convenience method for scripts and notebooks where async isn't needed. Blocks until execution completes and returns the result.

This method requires that resources are available either through: - Prior bind() call on this module, or - An active ExecutionSettings context

Parameters:

Name Type Description Default
*args Any

Positional arguments passed to forward().

()
**kwargs Any

Keyword arguments passed to forward().

{}

Returns:

Type Description
Any

Single result for single input, list for batch input.

Raises:

Type Description
RuntimeError

If called from within an async context (would block the event loop), or if no resources are available.

Example

pipeline = MyPipeline().bind(resources=config) result = pipeline.run_sync("Hello") results = pipeline.run_sync(["a", "b", "c"])

Example with ExecutionSettings

with ExecutionSettings(resources=config): ... result = pipeline.run_sync("Hello")

Note

Use await module(...) in async code instead. This method is intended for synchronous scripts and REPL environments only.

state_dict()

Return a dictionary of all parameter values.

Used for saving learned prompts/instructions after optimization. Keys are hierarchical parameter names (e.g., "summarizer.system_prompt"), matching the output of named_parameters().

Returns:

Type Description
dict[str, str]

A dictionary mapping parameter names to their string values.

Example

from plait.parameter import Parameter class Inner(Module): ... def init(self): ... super().init() ... self.weight = Parameter("w") ... class Outer(Module): ... def init(self): ... super().init() ... self.bias = Parameter("b") ... self.inner = Inner() ... outer = Outer() outer.state_dict() {'bias': 'b', 'inner.weight': 'w'}

Note

The returned dict can be serialized to JSON/pickle and later restored with load_state_dict().

train(mode=True)

Set the module to training mode.

In training mode, forward passes return Value objects with tape ids attached, enabling implicit record flow through the training pipeline. This eliminates manual record management.

This method recursively sets all child modules to the same mode.

Parameters:

Name Type Description Default
mode bool

If True, enable training mode. If False, disable it. Defaults to True.

True

Returns:

Type Description
Self

Self, for method chaining.

Example

module = MyModule().bind(resources) module.train() # Enable training mode output = await module("Hello") hasattr(output, "meta") and "_tape_ids" in output.meta True

Example with chaining

module.train().bind(resources) result = await module("input")

Note

Use .eval() to switch back to evaluation mode where raw values are returned without tape wrapping.

ModuleDict

Bases: Module, Mapping[str, Module]

A dict-like container for named module access.

Provides dict-like operations (keys, values, items, etc.) while properly registering modules for parameter collection. This mirrors PyTorch's nn.ModuleDict.

Parameters:

Name Type Description Default
modules Mapping[str, Module] | Iterable[tuple[str, Module]] | None

Optional dict or iterable of (key, module) pairs.

None
Example

modules = ModuleDict({ ... 'encoder': Encoder(), ... 'decoder': Decoder() ... }) output = modules'encoder' modules.keys() dict_keys(['encoder', 'decoder'])

Note

ModuleDict does not define a forward() method since access patterns depend on use case. Access modules by key and call them directly.

__contains__(key)

Check if a key is in the dict.

Parameters:

Name Type Description Default
key object

The key to check for.

required

Returns:

Type Description
bool

True if key is in the dict, False otherwise.

__delitem__(key)

Delete a module by key.

Parameters:

Name Type Description Default
key str

The key to delete.

required

Raises:

Type Description
KeyError

If key is not in the dict.

__getattr__(name)

Get a module by name as attribute access.

Parameters:

Name Type Description Default
name str

The name of the module to retrieve.

required

Returns:

Type Description
Module

The module with the given name.

Raises:

Type Description
AttributeError

If no module with that name exists.

__getitem__(key)

Get a module by key.

Parameters:

Name Type Description Default
key str

The key of the module to retrieve.

required

Returns:

Type Description
Module

The module with the given key.

Raises:

Type Description
KeyError

If key is not in the dict.

__init__(modules=None)

Initialize the ModuleDict container.

Parameters:

Name Type Description Default
modules Mapping[str, Module] | Iterable[tuple[str, Module]] | None

Optional mapping (dict, ModuleDict, etc.) or iterable of (key, module) pairs.

None

Raises:

Type Description
TypeError

If any value is not a Module instance.

__iter__()

Iterate over module keys.

Yields:

Type Description
str

Each key in the dict.

__len__()

Return the number of modules in the dict.

Returns:

Type Description
int

The number of modules.

__setitem__(key, module)

Set a module at the given key.

Parameters:

Name Type Description Default
key str

The key to set.

required
module Module

The module to set.

required

Raises:

Type Description
TypeError

If module is not a Module instance.

clear()

Remove all modules from the dict.

forward(x)

Forward is not implemented for ModuleDict.

ModuleDict does not define a forward method since access patterns depend on use case.

Raises:

Type Description
NotImplementedError

Always raised.

items()

Return a view of the (key, module) pairs.

Returns:

Type Description
Any

A dict_items view of the items.

keys()

Return a view of the module keys.

Returns:

Type Description
Any

A dict_keys view of the keys.

pop(key, default=None)

Remove and return a module by key.

Parameters:

Name Type Description Default
key str

The key to pop.

required
default Module | None

Value to return if key is not found.

None

Returns:

Type Description
Module | None

The removed module, or default if not found.

update(modules)

Update the dict with modules from another mapping or iterable.

Parameters:

Name Type Description Default
modules Mapping[str, Module] | Iterable[tuple[str, Module]]

Mapping (dict, ModuleDict, etc.) or iterable of (key, module) pairs.

required

Raises:

Type Description
TypeError

If any value is not a Module instance.

values()

Return a view of the modules.

Returns:

Type Description
Any

A dict_values view of the modules.

ModuleList

Bases: Module

A list-like container for modules.

Provides list-like operations (append, extend, insert, etc.) while properly registering modules for parameter collection. This mirrors PyTorch's nn.ModuleList.

Parameters:

Name Type Description Default
modules Iterable[Module] | None

Optional iterable of modules to initialize with.

None
Example

layers = ModuleList([Layer() for _ in range(3)]) for layer in layers: ... x = layer(x) layers.append(AnotherLayer()) len(layers) 4

Note

ModuleList does not define a forward() method since the iteration pattern depends on use case. Use iteration or indexing to access and execute modules.

__contains__(module)

Check if a module is in the list.

Parameters:

Name Type Description Default
module Module

The module to check for.

required

Returns:

Type Description
bool

True if module is in the list, False otherwise.

__delitem__(idx)

Delete a module at the given index.

Parameters:

Name Type Description Default
idx int

The index to delete.

required

Raises:

Type Description
IndexError

If index is out of range.

__getitem__(idx)

__getitem__(idx: int) -> Module
__getitem__(idx: slice) -> ModuleList

Get module(s) by index or slice.

Parameters:

Name Type Description Default
idx int | slice

Integer index or slice.

required

Returns:

Type Description
Module | ModuleList

Single module for integer index, new ModuleList for slice.

Raises:

Type Description
IndexError

If index is out of range.

Note

Slicing returns a new ModuleList containing views of the same module objects, but does NOT reparent them. The modules remain children of the original container.

__init__(modules=None)

Initialize the ModuleList container.

Parameters:

Name Type Description Default
modules Iterable[Module] | None

Optional iterable of modules to add.

None

Raises:

Type Description
TypeError

If any item in modules is not a Module.

__iter__()

Iterate over modules in order.

Yields:

Type Description
Module

Each module in the list.

__len__()

Return the number of modules in the list.

Returns:

Type Description
int

The number of modules.

__setitem__(idx, module)

Set a module at the given index.

Parameters:

Name Type Description Default
idx int

The index to set.

required
module Module

The module to set.

required

Raises:

Type Description
IndexError

If index is out of range.

TypeError

If module is not a Module instance.

append(module)

Append a module to the end of the list.

Parameters:

Name Type Description Default
module Module

The module to append.

required

Returns:

Type Description
ModuleList

Self, for method chaining.

Raises:

Type Description
TypeError

If module is not a Module instance.

extend(modules)

Extend the list with modules from an iterable.

Parameters:

Name Type Description Default
modules Iterable[Module]

Iterable of modules to add.

required

Returns:

Type Description
ModuleList

Self, for method chaining.

Raises:

Type Description
TypeError

If any item is not a Module instance.

forward(x)

Forward is not implemented for ModuleList.

ModuleList does not define a forward method since the iteration pattern depends on use case.

Raises:

Type Description
NotImplementedError

Always raised.

insert(idx, module)

Insert a module at the given index.

Parameters:

Name Type Description Default
idx int

The index to insert at.

required
module Module

The module to insert.

required

Raises:

Type Description
TypeError

If module is not a Module instance.

pop(idx=-1)

Remove and return the module at the given index.

Parameters:

Name Type Description Default
idx int

The index to pop from. Defaults to -1 (last).

-1

Returns:

Type Description
Module

The removed module.

Raises:

Type Description
IndexError

If index is out of range or list is empty.

Parameter dataclass

A learnable value that can be optimized via backward passes.

Similar to torch.nn.Parameter, but for values (prompts, instructions, structured configs, etc.) that are optimized via LLM feedback rather than gradient descent.

The description field is required when requires_grad=True to enable the optimizer to understand how to improve the parameter.

Parameters:

Name Type Description Default
value Any

The current value of the parameter (string, dict, list, etc.).

required
description str | None

A description of what this parameter does/represents. Required when requires_grad=True to enable self-documenting optimization. Can be None when requires_grad=False.

None
requires_grad bool

If True, feedback will be accumulated during backward passes and description is required. If False, the parameter is treated as a constant.

True

Raises:

Type Description
ValueError

If requires_grad=True but description is None.

Example

param = Parameter( ... "You are a helpful assistant.", ... description="Defines the agent's identity and baseline behavior." ... ) str(param) 'You are a helpful assistant.' param.description "Defines the agent's identity and baseline behavior." param.accumulate_feedback("Be more concise") param.get_accumulated_feedback() ['Be more concise'] param.apply_update("You are a concise, helpful assistant.") str(param) 'You are a concise, helpful assistant.'

Constant parameter (no description required)

const = Parameter({"model": "gpt-4"}, requires_grad=False) const.requires_grad False

__post_init__()

Validate that description is provided when requires_grad=True.

Raises:

Type Description
ValueError

If requires_grad=True but description is None.

__str__()

Return the current value as a string.

Returns:

Type Description
str

The string representation of the parameter value.

accumulate_feedback(feedback)

Collect feedback from backward passes.

Feedback is only accumulated if requires_grad is True.

Parameters:

Name Type Description Default
feedback str | Value

The feedback string or Value to accumulate.

required

apply_update(new_value)

Apply an optimizer-computed update.

Updates the parameter value and clears the feedback buffer.

Parameters:

Name Type Description Default
new_value Any

The new value to set.

required

get_accumulated_feedback()

Get all accumulated feedback.

Returns:

Type Description
list[str]

A copy of the list of accumulated feedback strings.

zero_feedback()

Clear accumulated feedback without updating the value.

Similar to zero_grad() in PyTorch, this clears the feedback buffer to prepare for a new backward pass.

ParameterDict

Bases: MutableMapping[str, 'Parameter']

A dict-like container for Parameters.

Holds a dictionary of Parameters that will be properly collected by Module.parameters() and Module.named_parameters(). This is analogous to torch.nn.ParameterDict.

Parameters:

Name Type Description Default
parameters dict[str, Parameter] | Iterable[tuple[str, Parameter]] | None

Optional mapping or iterable of (key, Parameter) pairs.

None
Example

class MultiTask(Module): ... def init(self): ... super().init() ... self.prompts = ParameterDict({ ... "summarize": Parameter("Summarize:", description="Summary prompt"), ... "translate": Parameter("Translate:", description="Translation prompt"), ... }) ... m = MultiTask() list(m.named_parameters()) [('prompts.summarize', Parameter(...)), ('prompts.translate', Parameter(...))]

Note

The container itself is not a Parameter, but it provides iteration methods that Module uses to collect the contained Parameters.

__delitem__(key)

Delete a parameter by key.

Parameters:

Name Type Description Default
key str

The parameter's key.

required

__getitem__(key)

Get a parameter by key.

Parameters:

Name Type Description Default
key str

The parameter's key.

required

Returns:

Type Description
Parameter

The Parameter at the given key.

__init__(parameters=None)

Initialize the ParameterDict.

Parameters:

Name Type Description Default
parameters dict[str, Parameter] | Iterable[tuple[str, Parameter]] | None

Optional dict or iterable of (key, Parameter) pairs.

None

__iter__()

Iterate over parameter keys.

Yields:

Type Description
str

Each key in the dict.

__len__()

Return the number of parameters.

Returns:

Type Description
int

Number of parameters in the dict.

__repr__()

Return string representation.

Returns:

Type Description
str

String representation of the ParameterDict.

__setitem__(key, value)

Set a parameter by key.

Parameters:

Name Type Description Default
key str

The parameter's key.

required
value Parameter

The Parameter to store.

required

Raises:

Type Description
TypeError

If value is not a Parameter.

named_parameters(prefix='')

Iterate over parameters with their names.

Parameters:

Name Type Description Default
prefix str

Prefix to prepend to parameter names.

''

Yields:

Type Description
tuple[str, Parameter]

Tuples of (name, parameter).

parameters()

Iterate over all parameters.

Yields:

Type Description
Parameter

Each Parameter in the dict.

ParameterList

Bases: MutableSequence['Parameter']

A list-like container for Parameters.

Holds a list of Parameters that will be properly collected by Module.parameters() and Module.named_parameters(). This is analogous to torch.nn.ParameterList.

Parameters are named by their index in the list (e.g., "0", "1", "2").

Parameters:

Name Type Description Default
parameters Iterable[Parameter] | None

Optional iterable of Parameter objects to initialize with.

None
Example

class MultiPrompt(Module): ... def init(self): ... super().init() ... self.prompts = ParameterList([ ... Parameter("Be concise", description="Style prompt"), ... Parameter("Be helpful", description="Tone prompt"), ... ]) ... m = MultiPrompt() list(m.named_parameters()) [('prompts.0', Parameter(...)), ('prompts.1', Parameter(...))]

Note

The container itself is not a Parameter, but it provides iteration methods that Module uses to collect the contained Parameters.

__delitem__(index)

__delitem__(index: int) -> None
__delitem__(index: slice) -> None

Delete parameter(s) by index.

Parameters:

Name Type Description Default
index int | slice

Integer index or slice.

required

__getitem__(index)

__getitem__(index: int) -> Parameter
__getitem__(index: slice) -> list[Parameter]

Get parameter(s) by index.

Parameters:

Name Type Description Default
index int | slice

Integer index or slice.

required

Returns:

Type Description
Parameter | list[Parameter]

Single Parameter for int index, list for slice.

__init__(parameters=None)

Initialize the ParameterList.

Parameters:

Name Type Description Default
parameters Iterable[Parameter] | None

Optional iterable of Parameter objects.

None

__len__()

Return the number of parameters.

Returns:

Type Description
int

Number of parameters in the list.

__repr__()

Return string representation.

Returns:

Type Description
str

String representation of the ParameterList.

__setitem__(index, value)

__setitem__(index: int, value: Parameter) -> None
__setitem__(
    index: slice, value: Iterable[Parameter]
) -> None

Set parameter(s) by index.

Parameters:

Name Type Description Default
index int | slice

Integer index or slice.

required
value Parameter | Iterable[Parameter]

Parameter or iterable of Parameters.

required

insert(index, value)

Insert a parameter at the given index.

Parameters:

Name Type Description Default
index int

Index to insert at. Negative indices work like Python list.insert(): insert(-1, x) inserts before the last element, and very negative indices (beyond the start) insert at the beginning.

required
value Parameter

Parameter to insert.

required

Raises:

Type Description
TypeError

If value is not a Parameter.

named_parameters(prefix='')

Iterate over parameters with their names.

Parameters:

Name Type Description Default
prefix str

Prefix to prepend to parameter names.

''

Yields:

Type Description
tuple[str, Parameter]

Tuples of (name, parameter).

parameters()

Iterate over all parameters.

Yields:

Type Description
Parameter

Each Parameter in the list.

Sequential

Bases: Module

A sequential container that chains modules together.

Modules are executed in order, with each module's output passed as input to the next module. This mirrors PyTorch's nn.Sequential.

Supports two initialization styles: 1. Positional arguments: Sequential(mod1, mod2, mod3) 2. OrderedDict for named access: Sequential(OrderedDict([('name', mod)]))

When using OrderedDict, modules can be accessed by name as attributes.

Parameters:

Name Type Description Default
*args Module | OrderedDict[str, Module]

Either positional Module instances, or a single OrderedDict mapping names to modules.

()
Example

Positional args

pipeline = Sequential( ... Preprocessor(), ... Analyzer(), ... Formatter() ... ) len(pipeline) 3 pipeline[0] # First module

Example with OrderedDict

from collections import OrderedDict pipeline = Sequential(OrderedDict([ ... ('preprocess', Preprocessor()), ... ('analyze', Analyzer()), ... ])) pipeline.preprocess # Named access

Note

The forward() method passes the input through each module in sequence, so each module must accept a single argument matching the previous module's output type.

__getattr__(name)

Get a module by name for named Sequential containers.

Parameters:

Name Type Description Default
name str

The name of the module to retrieve.

required

Returns:

Type Description
Module

The module with the given name.

Raises:

Type Description
AttributeError

If no module with that name exists.

__getitem__(idx)

__getitem__(idx: int) -> Module
__getitem__(idx: slice) -> Sequential

Get module(s) by index or slice.

Parameters:

Name Type Description Default
idx int | slice

Integer index or slice.

required

Returns:

Type Description
Module | Sequential

Single module for integer index, new Sequential for slice.

Raises:

Type Description
IndexError

If index is out of range.

TypeError

If idx is not int or slice.

Note

Slicing returns a new Sequential containing views of the same module objects, but does NOT reparent them. The modules remain children of the original container.

__init__(*args)

Initialize the Sequential container.

Parameters:

Name Type Description Default
*args Module | OrderedDict[str, Module]

Either positional Module instances, or a single OrderedDict mapping names to modules.

()

Raises:

Type Description
TypeError

If a non-Module argument is provided (except OrderedDict).

ValueError

If OrderedDict is provided with positional args.

__iter__()

Iterate over modules in order.

Yields:

Type Description
Module

Each module in the sequence.

__len__()

Return the number of modules in the container.

Returns:

Type Description
int

The number of modules.

append(module)

Append a module to the end of the sequence.

Parameters:

Name Type Description Default
module Module

The module to append.

required

Returns:

Type Description
Sequential

Self, for method chaining.

Raises:

Type Description
TypeError

If module is not a Module instance.

forward(x)

Execute modules sequentially, chaining outputs to inputs.

Parameters:

Name Type Description Default
x Any

Input to the first module.

required

Returns:

Type Description
Any

Output of the last module.

Note

Each module's output becomes the next module's input.