ModelWrapper

Primary API for zero-copy model sharing with nn.Module models.

Overview

ModelWrapper provides a unified API for wrapping models that supports both task and actor execution modes:

  • from_model() - Create wrapper from model/pipeline

  • for_tasks() - Convenience method for task mode

  • load() - Load model in actor

  • Supports both task and actor execution modes

ModelWrapper Class

class ray_zerocopy.model_wrappers.ModelWrapper(skeleton, model_info, is_standalone_module=False, mode='actor')[source]

Bases: WrapperMixin[T], Generic[T]

A unified serializable wrapper with zero-copy loading for nn.Module and Pipeline objects.

Supports both task-based and actor-based execution modes: - Task mode: Models are executed via Ray tasks with zero-copy loading - Actor mode: Models are prepared for loading in Ray actors with zero-copy

skeleton

The skeleton of the model or pipeline

model_refs

A dict of Ray object references to the model tensors

is_standalone_module

Whether the model or pipeline is a standalone module

mode

Execution mode - “task” or “actor”

model_info

Model info dict with method tracking (task mode) or just refs (actor mode)

Examples

Task Mode (using for_tasks shortcut):

>>> from ray_zerocopy import ModelWrapper
>>> model = YourModel()
>>> rewritten = ModelWrapper.for_tasks(model)
>>> result = rewritten(data)  # Callable pipeline
>>>

Task Mode (using from_model + load):

>>> wrapper = ModelWrapper.from_model(model, mode="task")
>>> rewritten = wrapper.load()  # Get callable pipeline
>>> result = rewritten(data)

Actor Mode:

>>> wrapper = ModelWrapper.from_model(model, mode="actor")
>>>
>>> class InferenceActor:
...     def __init__(self, model_wrapper):
...         self.model = model_wrapper.load()  # Inside an actor's __init__
...     def __call__(self, batch):
...         return self.model(batch["data"])
>>>
>>> ds.map_batches(
...     InferenceActor,
...     fn_constructor_kwargs={"model_wrapper": wrapper},
...     compute=ActorPoolStrategy(size=4)
... )
__init__(skeleton, model_info, is_standalone_module=False, mode='actor')[source]

Initialize ModelWrapper.

Parameters:
  • skeleton (T) – The skeleton of the model or pipeline

  • model_info (dict[str, tuple[ObjectRef, Set[str] | None]] | dict[str, ObjectRef]) – Model info dict. In task mode: dict[str, tuple[ray.ObjectRef, Optional[Set[str]]]] (with method tracking). In actor mode: dict[str, ray.ObjectRef] (just refs).

  • is_standalone_module (bool) – Whether this is a standalone module

  • mode (Literal['task', 'actor']) – Execution mode (“task” or “actor”)

property model_refs: dict[str, ObjectRef]

Get model references, extracted from model_info.

classmethod from_model(model_or_pipeline, mode='actor', model_attr_names=None, method_names=None)[source]

Instantiate a ModelWrapper from a model or pipeline.

A ModelWrapper is serializable and can be put into Ray’s object store by ray.put().

Parameters:
  • model_or_pipeline (T) – The model or pipeline to wrap

  • mode (Literal['task', 'actor']) – Execution mode - “task” for task-based execution, “actor” for actor loading Must be “task” or “actor”. Defaults to “actor”.

  • model_attr_names (list | None) – The attribute names of the models in the pipeline

  • method_names (tuple | None) – Model methods to expose via remote tasks (auto-selected if None)

Returns:

A ModelWrapper instance (not callable - use .load() to get the callable pipeline)

Return type:

ModelWrapper[T]

Example - Task mode:
>>> wrapper = ModelWrapper.from_model(pipeline, mode="task")
>>> rewritten = wrapper.load()  # Get callable pipeline
>>> result = rewritten(data)  # Use the pipeline
Example - Actor mode:
>>> wrapper = ModelWrapper.from_model(pipeline, mode="actor")
>>> # In actor: pipeline = wrapper.load()
classmethod for_tasks(model_or_pipeline, method_names=None)[source]

Convert a model or pipeline into a callable rewritten pipeline with zero-copy model loading.

Note

Under the hood, this is a wrapper around from_model() and load() that immediately prepares and loads the converted pipeline. The returned pipeline will use a remote Ray task for execution.

Parameters:
  • model_or_pipeline (T) – The model or pipeline to wrap

  • method_names (tuple | None) – Model methods to expose via remote tasks (defaults to ("__call__",))

Returns:

A rewritten pipeline ready for immediate use (callable). Each call will spawn a Ray task.

Return type:

Module | T

load(_use_fast_load=False)[source]

Load the model/pipeline from the wrapper.

For task mode: Creates the rewritten pipeline on-demand with remote model shims. For actor mode: Loads the pipeline from Ray’s object store using zero-copy.

Models are loaded on CPU. Users should handle device placement themselves after loading.

Parameters:

_use_fast_load (bool) – Use faster but slightly riskier loading method. Defaults to False. Only applies to actor mode.

Returns:

The deserialized pipeline ready for inference (on CPU)

Return type:

Module | T

Example - Actor mode:
>>> class InferenceActor:
...     def __init__(self, model_wrapper):
...         # Load model (on CPU)
...         self.model = model_wrapper.load()
...
...     def __call__(self, batch):
...         return self.model(batch["data"])
Example - Task mode:
>>> wrapper = ModelWrapper.from_model(pipeline, mode="task")
>>> rewritten = wrapper.load()  # Get callable pipeline
>>> result = rewritten(data)  # Use the pipeline
__getstate__()[source]

Return state for pickling.

__setstate__(state)[source]

Restore state from pickling.

classmethod deserialize(skeleton, model_info, is_standalone_module=False, mode='actor')[source]

Deserialize a ModelWrapper from a skeleton and model info.

Parameters:
  • skeleton (T) – The skeleton of the model or pipeline

  • model_info (dict[str, tuple[ObjectRef, Set[str] | None]] | dict[str, ObjectRef]) – Model info dict. Task mode: TaskModelInfo, Actor mode: ActorModelInfo

  • is_standalone_module (bool) – Whether this is a standalone module

  • mode (Literal['task', 'actor']) – Execution mode (“task” or “actor”)

serialize()[source]

Serialize the ModelWrapper to a dictionary.

Basic Usage

Wrapping a Model

from ray_zerocopy import ModelWrapper

model = MyModel()
wrapper = ModelWrapper.from_model(model)  # "actor" mode by default

Using in Actors

class InferenceActor:
    def __init__(self, model_wrapper):
        # Load pipeline with zero-copy
        self.model = model_wrapper.load()

    def __call__(self, batch):
        return self.model(batch["data"])

# Pass wrapper to actor
ds.map_batches(
    InferenceActor,
    fn_constructor_kwargs={"model_wrapper": wrapper}
)

Using for Tasks

# Create wrapper for task mode
wrapped = ModelWrapper.for_tasks(pipeline)

# Use immediately - each call spawns a Ray task
result = wrapped(data)

Methods

from_model

Class method to create wrapper from model or pipeline.

# Actor mode (default)
wrapper = ModelWrapper.from_model(model, mode="actor")

# Task mode
wrapper = ModelWrapper.from_model(model, mode="task")

Parameters:

  • model - PyTorch model (nn.Module) or pipeline (object with nn.Module attributes)

  • mode - Execution mode: “actor” (default) or “task”

Returns:

  • ModelWrapper instance

for_tasks

Convenience method for task mode.

wrapped = ModelWrapper.for_tasks(pipeline)

Parameters:

  • pipeline - PyTorch model or pipeline

Returns:

  • A converted model or pipeline

load

Load the model/pipeline (actor mode only).

model = wrapper.load()

Parameters:

  • _use_fast_load (optional) - Enable fast loading (experimental)

Returns:

  • Loaded model or pipeline (on CPU). Users should handle device placement themselves.

Supported Model Types

Standalone nn.Module

model = MyModel()
wrapper = ModelWrapper.from_model(model)

Pipeline with Multiple Models

class Pipeline:
    def __init__(self):
        self.encoder = EncoderModel()
        self.decoder = DecoderModel()

    def __call__(self, x):
        return self.decoder(self.encoder(x))

pipeline = Pipeline()
wrapper = ModelWrapper.from_model(pipeline)

The wrapper automatically detects all nn.Module attributes.

Execution Modes

Actor Mode

For Ray Data and long-running actors:

# Create wrapper in actor mode
wrapper = ModelWrapper.from_model(pipeline, mode="actor")

# Use in actor
class Actor:
    def __init__(self, wrapper):
        self.model = wrapper.load()

Task Mode

For ad-hoc inference with Ray tasks:

# Create wrapper in task mode
wrapped = ModelWrapper.for_tasks(pipeline)

# Use immediately
result = wrapped(data)  # Each call spawns a Ray task

Advanced Usage

Conditional Fast Loading

class Actor:
    def __init__(self, wrapper, use_fast):
        self.model = wrapper.load(_use_fast_load=use_fast)

Inspecting Wrapper State

wrapper = ModelWrapper.from_model(pipeline)

# Check if it's a standalone module
print(wrapper._is_standalone_module)

# See detected models
print(wrapper._model_info.keys())

See Also