JIT Wrappers
Overview
JITModelWrapper provides zero-copy model sharing for TorchScript (compiled) models. It supports both task-based and actor-based execution modes, similar to ModelWrapper but works with torch.jit.ScriptModule objects instead of nn.Module objects.
API Reference
JITModelWrapper
A unified wrapper that supports both task and actor execution modes for TorchScript models.
Task Mode
For task-based inference with TorchScript models:
from ray_zerocopy import JITModelWrapper
# Compile your model
jit_model = torch.jit.trace(model, example_input)
# Wrap for task-based execution (immediate use)
wrapped = JITModelWrapper.for_tasks(jit_model)
result = wrapped(data) # Each call spawns a Ray task
Or using from_model():
wrapper = JITModelWrapper.from_model(jit_model, mode="task")
wrapped = wrapper.load()
result = wrapped(data)
Actor Mode
For actor-based inference with TorchScript models:
from ray_zerocopy import JITModelWrapper
# Compile your model
jit_model = torch.jit.trace(model, example_input)
# Wrap for actor-based execution
wrapper = JITModelWrapper.from_model(jit_model, mode="actor")
# Use in actor
class InferenceActor:
def __init__(self, wrapper):
self.model = wrapper.load()
Usage Examples
Task Mode with Pipeline
class Pipeline:
def __init__(self):
self.encoder = torch.jit.trace(EncoderModel(), example_input)
self.decoder = torch.jit.trace(DecoderModel(), example_encoded)
def __call__(self, x):
return self.decoder(self.encoder(x))
pipeline = Pipeline()
wrapped = JITModelWrapper.for_tasks(pipeline)
result = wrapped(data)
Actor Mode with Ray Data
wrapper = JITModelWrapper.from_model(jit_pipeline, mode="actor")
class InferenceActor:
def __init__(self, wrapper):
self.pipeline = wrapper.load()
def __call__(self, batch):
return self.pipeline(batch["data"])
ds.map_batches(
InferenceActor,
fn_constructor_kwargs={"wrapper": wrapper},
compute=ActorPoolStrategy(size=4)
)
Comparison with ModelWrapper
JITModelWrapper provides the same unified API as ModelWrapper:
from_model()- Create wrapper from model/pipelinefor_tasks()- Convenience method for task modeload()- Load model in actor (actor mode) or get callable pipeline (task mode)Supports both task and actor execution modes
The main difference is that JITModelWrapper works with torch.jit.ScriptModule objects instead of nn.Module objects.