π Tatbot VLA Plan 3 β Training on stroke_tool datasets and MCP inference toolΒΆ
This guide specifies whatβs needed to train Vision-Language-Action (VLA) policies on datasets recorded via src/tatbot/tools/robot/stroke.py
, and how to add an MCP tool to run inference from a specific model checkpoint.
Table of ContentsΒΆ
ScopeΒΆ
Train VLA policies (e.g., SmolVLA, Ο0) using data recorded by
stroke_tool
.Keep the LeRobot-native dataset format produced by
stroke_tool
to avoid conversion steps.Provide an MCP tool to load a chosen checkpoint and run inference on the robot.
Dataset recorded by stroke_toolΒΆ
stroke_tool
writes LeRobot-compatible episodic datasets into /nfs/tatbot/recordings/
with names like stroke-<scene>-<timestamp>
. Within each dataset directory:
scene.yaml
: Scene definition saved at recording start.strokes.yaml
: Stroke list with metadata; large arrays are inarrays/*.npy
(seetatbot.data.stroke
).strokebatch.safetensors
: Packed joint trajectories and EE poses (tatbot.data.stroke.StrokeBatch
).logs/episode_*.txt
: Per-episode logs.episode_000000/
,episode_000001/
, β¦: Episode folders created byLeRobotDataset
, including recorded frames andepisode_cond
with references tostroke_l
/stroke_r
metadata and any preview images.
Example directory structure (may vary slightly by LeRobot version/settings):
/nfs/tatbot/recordings/
βββ stroke-{scene}-{timestamp}/
βββ meta_data/
β βββ data.parquet
β βββ stats.json
β βββ info.json
βββ videos/
β βββ observation.image_{camera_name}_{episode:06d}.mp4
β βββ observation.depth_{camera_name}_{episode:06d}.mp4
βββ logs/
β βββ episode_{episode:06d}.txt
βββ episode_{episode:06d}/
β βββ stroke_l.png
β βββ stroke_r.png
βββ scene.yaml
βββ strokes.yaml
βββ strokebatch.safetensors
Key details from src/tatbot/tools/robot/stroke.py
:
The dataset is created (or resumed) via
LeRobotDataset.create(...)
/LeRobotDataset(...)
withfeatures
derived fromrobot.action_features
androbot.observation_features
.When RealSense/IP cameras are enabled, images are written through LeRobotβs image writer threads.
fps
defaults to 10 (configurable via tool input).Each pose step adds a frame composed of observation and the action actually sent to the robot.
Implication: these datasets are immediately usable by LeRobot training scripts (no custom converter required).
Training requirementsΒΆ
Python environment with Tatbot installed via
uv
. Install LeRobot separately (in its own checkout) if you need policy code/extras:
source scripts/setup_env.sh
uv pip install -e .
uv pip install -e .[bot,cam,gen,gpu,img,viz,dev,docs]
# If training with LeRobot policies, do this in your LeRobot repo (not here):
# cd ~/lerobot && uv pip install -e .[smolvla,pi0]
set -a; source .env; set +a # optional secrets for WandB, etc.
GPU recommended for training; CPU-only is possible for debugging.
WandB optional; install/enable explicitly in your training environment.
Preparing datasets for trainingΒΆ
You can train directly from a locally recorded dataset directory. Two common options:
Local path training (recommended initially):
Use the full path to a recording directory, e.g.
/nfs/tatbot/recordings/stroke-tatbotlogo-2025y-08m-09d-17h-02m-10s
.Many LeRobot CLIs accept
--dataset.root
or arepo_id
that points locally; prefer--dataset.root
where available.
Pushing to Hub (optional):
If desired, push the dataset to the Hugging Face Hub using LeRobotβs dataset utilities.
Minimum checks before training:
Confirm
strokes.yaml
andstrokebatch.safetensors
exist.Confirm episodes exist and contain frames and actions.
Skim
logs/
for anomalies; ensure joint/action ranges and fps look correct.
Aggregating multiple recordings (optional):
from pathlib import Path
from lerobot.datasets.lerobot_dataset import LeRobotDataset
recordings_dir = Path("/nfs/tatbot/recordings")
datasets = []
for dataset_dir in recordings_dir.glob("stroke-*"):
repo_id = f"tatbot/{dataset_dir.name}"
datasets.append(LeRobotDataset(repo_id, root=str(dataset_dir)))
# Train across multiple datasets or merge per your pipeline
Policy training configsΒΆ
Guidance uses the same policy families as in docs/models/claude_vla_guide.md
.
Choose policy:
smolvla
for faster iteration orpi0
as needed.Observation length vs. action chunking:
Typical:
n_obs_steps = 1
,chunk_size = 50
,n_action_steps = 50
to match stroke pose steps.If training on sequences spanning multiple poses, adjust accordingly.
Image size and preprocessing: ensure to match camera output (e.g., 512Γ512 with padding for SmolVLA).
Example commands (adjust flags to your CLI wrapper; standardize outputs under outputs/train/
):
SmolVLA finetune from base on a local dataset root:
lerobot-train \
--policy.type=smolvla \
--dataset.root="/nfs/tatbot/recordings/stroke-tatbotlogo-..." \
--batch_size=32 \
--steps=100000 \
--wandb.enable=true \
--wandb.project=tatbot_smolvla_finetune \
--output_dir=outputs/train/smolvla_tatbot
Pi0 finetune from base:
lerobot-train \
--policy.type=pi0 \
--dataset.root="/nfs/tatbot/recordings/stroke-tatbotlogo-..." \
--batch_size=32 \
--steps=100000 \
--wandb.enable=true \
--wandb.project=tatbot_pi0_finetune \
--output_dir=outputs/train/pi0_tatbot
Notes:
Prefer
--dataset.root
for local datasets; use--dataset.repo_id
only if pushing to Hub.Do not assume fixed
chunk_size
/n_action_steps
; align with actualscene.stroke_length
and model config.Keep evaluation split: either reserve episodes for validation or use
--dataset.split.*
flags where available.
Validating the pipelineΒΆ
Dry-run a few training steps and check WandB/logs for:
Loss decreasing, stable gradient norms, GPU utilization reasonable.
Sampled batches show correct image shapes and action ranges.
Save and test a checkpoint by running local evaluation on a held-out episode set.
MCP inference tool: design and code skeletonΒΆ
We will add a new MCP tool vla_infer
that:
Loads a specified checkpoint directory into the chosen policy class.
Builds a
Tatbot
instance from a scene, connects, streams observations, and sends policy actions in a timed loop.Optionally records an evaluation dataset for later analysis.
Suggested file layout (choose one and keep consistent):
Place under
src/tatbot/tools/robot/
(recommended, like other robot tools):src/tatbot/tools/robot/models_vla.py
src/tatbot/tools/robot/infer_vla.py
Or create a dedicated package
src/tatbot/tools/vla/
and import it in the registry. This guide shows the robot/ layout.
Input model fields (proposed):
policy
: one of"smolvla" | "pi0"
checkpoint_path
: local path to policy checkpoint (e.g.,outputs/train/smolvla_tatbot/checkpoints/last/pretrained_model
)scene
: optional scene to connect the robot (default"default"
)device
:"cuda" | "cpu"
(default"cuda"
)max_steps
: integer safety cap on loop iterations (default500
)enable_realsense
,fps
,debug
: same semantics asstroke_tool
record_eval
: bool to write an eval dataset (defaultfalse
)dry_run
: bool to validate checkpoint/scene without moving robot (defaultfalse
)
Output model fields (proposed):
success
: boolmessage
: strnum_steps
: intOptional:
eval_dir
: path to saved eval dataset
Code skeleton for the tool:
# src/tatbot/tools/robot/models_vla.py
from typing import Literal, Optional
from tatbot.tools.base import ToolInput, ToolOutput
class VLAInferInput(ToolInput):
policy: Literal["smolvla", "pi0"]
checkpoint_path: str
scene: str = "default"
device: Literal["cuda", "cpu"] = "cuda"
max_steps: int = 500
enable_realsense: bool = False
fps: int = 10
debug: bool = False
record_eval: bool = False
dry_run: bool = False
task_prompt: Optional[str] = None
class VLAInferOutput(ToolOutput):
success: bool = True
message: str = ""
num_steps: int = 0
eval_dir: Optional[str] = None
# src/tatbot/tools/robot/infer_vla.py
import time
import torch
from pathlib import Path
from lerobot.robots import Robot, make_robot_from_config
from lerobot.robots.tatbot.config_tatbot import TatbotConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import busy_wait
from tatbot.main import compose_and_validate_scene
from tatbot.tools.base import ToolContext
from tatbot.tools.registry import tool
from tatbot.tools.robot.models_vla import VLAInferInput, VLAInferOutput
from tatbot.utils.log import get_logger
log = get_logger("tools.vla_infer", "π§ ")
@tool(
name="vla_infer",
nodes=["hog"], # add GPU nodes if you want remote-only
description="Run VLA policy inference on Tatbot from a checkpoint",
input_model=VLAInferInput,
output_model=VLAInferOutput,
)
async def vla_infer(input_data: VLAInferInput, ctx: ToolContext):
try:
yield {"progress": 0.01, "message": "Loading scene..."}
scene = compose_and_validate_scene(input_data.scene)
# Configure cameras if enabled
rs_cameras = {}
if input_data.enable_realsense:
from lerobot.cameras.realsense import RealSenseCameraConfig
rs_cameras = {
cam.name: RealSenseCameraConfig(
fps=cam.fps, width=cam.width, height=cam.height,
serial_number_or_name=cam.serial_number,
) for cam in scene.cams.realsenses
}
robot: Robot = make_robot_from_config(TatbotConfig(
ip_address_l=scene.arms.ip_address_l,
ip_address_r=scene.arms.ip_address_r,
arm_l_config_filepath=scene.arms.arm_l_config_filepath,
arm_r_config_filepath=scene.arms.arm_r_config_filepath,
goal_time=scene.arms.goal_time_slow,
connection_timeout=scene.arms.connection_timeout,
home_pos_l=scene.sleep_pos_l.joints,
home_pos_r=scene.sleep_pos_r.joints,
rs_cameras=rs_cameras,
ip_cameras={},
))
yield {"progress": 0.05, "message": "Connecting to robot..."}
robot.connect()
# Load policy
yield {"progress": 0.1, "message": "Loading policy checkpoint..."}
if input_data.policy == "smolvla":
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
policy = SmolVLAPolicy.from_pretrained(input_data.checkpoint_path)
else:
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
policy = PI0Policy.from_pretrained(input_data.checkpoint_path)
policy.eval()
policy.to(input_data.device)
# Optional dry-run: validate loading without moving hardware
if input_data.dry_run:
yield VLAInferOutput(success=True, message="Loaded scene and policy (dry run)", num_steps=0)
return
# Optional eval recording
dataset = None
if input_data.record_eval:
output_dir = Path("/nfs/tatbot/recordings")
eval_dir = output_dir / f"vla-eval-{scene.name}-{int(time.time())}"
eval_dir.mkdir(parents=True, exist_ok=True)
action_features = hw_to_dataset_features(robot.action_features, "action", True)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
# Align writer threads with stroke.py convention
num_camera_threads = 0
if hasattr(robot, "rs_cameras") and len(robot.rs_cameras) > 0:
num_camera_threads += 4 * len(robot.rs_cameras)
if hasattr(robot, "ip_cameras") and len(robot.ip_cameras) > 0:
num_camera_threads += 4 * len(robot.ip_cameras)
dataset = LeRobotDataset.create(
repo_id=f"tatbot/{eval_dir.name}", fps=input_data.fps,
root=str(eval_dir), robot_type=robot.name,
features={**action_features, **obs_features},
use_videos=True, image_writer_processes=0, image_writer_threads=num_camera_threads,
)
# Move to ready
robot.send_action(robot._urdf_joints_to_action(scene.ready_pos_full.joints), safe=True)
yield {"progress": 0.2, "message": "Starting inference loop..."}
num_steps = 0
dt_target = 1.0 / max(1, input_data.fps)
def preprocess_observation_for_policy(observation):
"""Map Tatbot observation to policy-expected dict. Include task text if needed."""
# TODO: implement mapping for chosen policy; pass input_data.task_prompt
return observation
def prepare_robot_action(policy_output):
"""Convert policy output to robot action format if necessary."""
try:
# If policy outputs 14D URDF joints (left7+right7):
return robot._urdf_joints_to_action(policy_output)
except Exception:
return policy_output
try:
while num_steps < input_data.max_steps:
t0 = time.perf_counter()
observation = robot.get_observation()
policy_obs = preprocess_observation_for_policy(observation)
with torch.no_grad():
action = policy.select_action(policy_obs) if hasattr(policy, "select_action") else policy(policy_obs)
# Send action (fast time for continuous control)
sent_action = robot.send_action(prepare_robot_action(action), scene.arms.goal_time_fast)
if dataset is not None:
obs_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
act_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
dataset.add_frame({**obs_frame, **act_frame})
num_steps += 1
# FPS pacing (busy-wait for precision)
dt = time.perf_counter() - t0
busy_wait(dt_target - dt)
if dataset is not None:
dataset.save_episode()
finally:
robot.send_action(robot._urdf_joints_to_action(scene.ready_pos_full.joints), safe=True)
robot.disconnect()
yield VLAInferOutput(success=True, message="Inference completed", num_steps=num_steps,
eval_dir=str(dataset.root) if dataset is not None else None)
except Exception as e:
log.error(f"vla_infer failed: {e}")
yield VLAInferOutput(success=False, message=f"β {e}", num_steps=0)
Registering and running the MCP toolΒΆ
Import the new tool so it auto-registers:
Update
src/tatbot/tools/registry.py
register_all_tools()
to include:
# Import VLA tools
try:
from tatbot.tools.vla import infer # noqa: F401
log.debug("Imported VLA tools")
except ImportError as e:
log.debug(f"VLA tools not available: {e}")
Ensure the node config includes this tool (or allow wildcard):
Edit
src/conf/mcp/hog.yaml
and/or GPU nodes to includevla_infer
inmcp.tools
.If the tool requires GPU, add
requires=["gpu"]
to the decorator and ensure the node hasextras: [gpu]
.
Restart the MCP server on the node:
./scripts/mcp_run.sh hog
Invoke the tool from your MCP client with input JSON like:
{
"policy": "smolvla",
"checkpoint_path": "outputs/train/smolvla_tatbot/checkpoints/last/pretrained_model",
"scene": "tatbotlogo",
"device": "cuda",
"max_steps": 500,
"enable_realsense": true,
"fps": 10,
"record_eval": true
}
Operational tipsΒΆ
Always verify the robot is clear to move before inference runs.
Use conservative
goal_time_*
andmax_steps
initially.Start with
device=cpu
for dry-runs if GPU memory is tight, then switch tocuda
.For reproducibility, snapshot your
scene.yaml
and checkpoint path in run metadata (WandB or eval dataset name).
References:
Existing end-to-end guide:
docs/models/claude_vla_guide.md
Stroke recording tool:
src/tatbot/tools/robot/stroke.py
Stroke datatypes:
src/tatbot/data/stroke.py