πŸš€ Tatbot VLA Plan 3 β€” Training on stroke_tool datasets and MCP inference toolΒΆ

This guide specifies what’s needed to train Vision-Language-Action (VLA) policies on datasets recorded via src/tatbot/tools/robot/stroke.py, and how to add an MCP tool to run inference from a specific model checkpoint.

Table of ContentsΒΆ

  1. Scope

  2. Dataset recorded by stroke_tool

  3. Training requirements

  4. Preparing datasets for training

  5. Policy training configs

  6. Validating the pipeline

  7. MCP inference tool: design and code skeleton

  8. Registering and running the MCP tool

  9. Operational tips

ScopeΒΆ

  • Train VLA policies (e.g., SmolVLA, Ο€0) using data recorded by stroke_tool.

  • Keep the LeRobot-native dataset format produced by stroke_tool to avoid conversion steps.

  • Provide an MCP tool to load a chosen checkpoint and run inference on the robot.

Dataset recorded by stroke_toolΒΆ

stroke_tool writes LeRobot-compatible episodic datasets into /nfs/tatbot/recordings/ with names like stroke-<scene>-<timestamp>. Within each dataset directory:

  • scene.yaml: Scene definition saved at recording start.

  • strokes.yaml: Stroke list with metadata; large arrays are in arrays/*.npy (see tatbot.data.stroke).

  • strokebatch.safetensors: Packed joint trajectories and EE poses (tatbot.data.stroke.StrokeBatch).

  • logs/episode_*.txt: Per-episode logs.

  • episode_000000/, episode_000001/, …: Episode folders created by LeRobotDataset, including recorded frames and episode_cond with references to stroke_l/stroke_r metadata and any preview images.

Example directory structure (may vary slightly by LeRobot version/settings):

/nfs/tatbot/recordings/
└── stroke-{scene}-{timestamp}/
    β”œβ”€β”€ meta_data/
    β”‚   β”œβ”€β”€ data.parquet
    β”‚   β”œβ”€β”€ stats.json
    β”‚   └── info.json
    β”œβ”€β”€ videos/
    β”‚   β”œβ”€β”€ observation.image_{camera_name}_{episode:06d}.mp4
    β”‚   └── observation.depth_{camera_name}_{episode:06d}.mp4
    β”œβ”€β”€ logs/
    β”‚   └── episode_{episode:06d}.txt
    β”œβ”€β”€ episode_{episode:06d}/
    β”‚   β”œβ”€β”€ stroke_l.png
    β”‚   └── stroke_r.png
    β”œβ”€β”€ scene.yaml
    β”œβ”€β”€ strokes.yaml
    └── strokebatch.safetensors

Key details from src/tatbot/tools/robot/stroke.py:

  • The dataset is created (or resumed) via LeRobotDataset.create(...)/LeRobotDataset(...) with features derived from robot.action_features and robot.observation_features.

  • When RealSense/IP cameras are enabled, images are written through LeRobot’s image writer threads.

  • fps defaults to 10 (configurable via tool input).

  • Each pose step adds a frame composed of observation and the action actually sent to the robot.

Implication: these datasets are immediately usable by LeRobot training scripts (no custom converter required).

Training requirementsΒΆ

  • Python environment with Tatbot installed via uv. Install LeRobot separately (in its own checkout) if you need policy code/extras:

source scripts/setup_env.sh
uv pip install -e .
uv pip install -e .[bot,cam,gen,gpu,img,viz,dev,docs]
# If training with LeRobot policies, do this in your LeRobot repo (not here):
#   cd ~/lerobot && uv pip install -e .[smolvla,pi0]
set -a; source .env; set +a  # optional secrets for WandB, etc.
  • GPU recommended for training; CPU-only is possible for debugging.

  • WandB optional; install/enable explicitly in your training environment.

Preparing datasets for trainingΒΆ

You can train directly from a locally recorded dataset directory. Two common options:

  • Local path training (recommended initially):

    • Use the full path to a recording directory, e.g. /nfs/tatbot/recordings/stroke-tatbotlogo-2025y-08m-09d-17h-02m-10s.

    • Many LeRobot CLIs accept --dataset.root or a repo_id that points locally; prefer --dataset.root where available.

  • Pushing to Hub (optional):

    • If desired, push the dataset to the Hugging Face Hub using LeRobot’s dataset utilities.

Minimum checks before training:

  • Confirm strokes.yaml and strokebatch.safetensors exist.

  • Confirm episodes exist and contain frames and actions.

  • Skim logs/ for anomalies; ensure joint/action ranges and fps look correct.

Aggregating multiple recordings (optional):

from pathlib import Path
from lerobot.datasets.lerobot_dataset import LeRobotDataset

recordings_dir = Path("/nfs/tatbot/recordings")
datasets = []
for dataset_dir in recordings_dir.glob("stroke-*"):
    repo_id = f"tatbot/{dataset_dir.name}"
    datasets.append(LeRobotDataset(repo_id, root=str(dataset_dir)))
# Train across multiple datasets or merge per your pipeline

Policy training configsΒΆ

Guidance uses the same policy families as in docs/models/claude_vla_guide.md.

  • Choose policy: smolvla for faster iteration or pi0 as needed.

  • Observation length vs. action chunking:

    • Typical: n_obs_steps = 1, chunk_size = 50, n_action_steps = 50 to match stroke pose steps.

    • If training on sequences spanning multiple poses, adjust accordingly.

  • Image size and preprocessing: ensure to match camera output (e.g., 512Γ—512 with padding for SmolVLA).

Example commands (adjust flags to your CLI wrapper; standardize outputs under outputs/train/):

SmolVLA finetune from base on a local dataset root:

lerobot-train \
  --policy.type=smolvla \
  --dataset.root="/nfs/tatbot/recordings/stroke-tatbotlogo-..." \
  --batch_size=32 \
  --steps=100000 \
  --wandb.enable=true \
  --wandb.project=tatbot_smolvla_finetune \
  --output_dir=outputs/train/smolvla_tatbot

Pi0 finetune from base:

lerobot-train \
  --policy.type=pi0 \
  --dataset.root="/nfs/tatbot/recordings/stroke-tatbotlogo-..." \
  --batch_size=32 \
  --steps=100000 \
  --wandb.enable=true \
  --wandb.project=tatbot_pi0_finetune \
  --output_dir=outputs/train/pi0_tatbot

Notes:

  • Prefer --dataset.root for local datasets; use --dataset.repo_id only if pushing to Hub.

  • Do not assume fixed chunk_size/n_action_steps; align with actual scene.stroke_length and model config.

  • Keep evaluation split: either reserve episodes for validation or use --dataset.split.* flags where available.

Validating the pipelineΒΆ

  • Dry-run a few training steps and check WandB/logs for:

    • Loss decreasing, stable gradient norms, GPU utilization reasonable.

    • Sampled batches show correct image shapes and action ranges.

  • Save and test a checkpoint by running local evaluation on a held-out episode set.

MCP inference tool: design and code skeletonΒΆ

We will add a new MCP tool vla_infer that:

  • Loads a specified checkpoint directory into the chosen policy class.

  • Builds a Tatbot instance from a scene, connects, streams observations, and sends policy actions in a timed loop.

  • Optionally records an evaluation dataset for later analysis.

Suggested file layout (choose one and keep consistent):

  • Place under src/tatbot/tools/robot/ (recommended, like other robot tools):

    • src/tatbot/tools/robot/models_vla.py

    • src/tatbot/tools/robot/infer_vla.py

  • Or create a dedicated package src/tatbot/tools/vla/ and import it in the registry. This guide shows the robot/ layout.

Input model fields (proposed):

  • policy: one of "smolvla" | "pi0"

  • checkpoint_path: local path to policy checkpoint (e.g., outputs/train/smolvla_tatbot/checkpoints/last/pretrained_model)

  • scene: optional scene to connect the robot (default "default")

  • device: "cuda" | "cpu" (default "cuda")

  • max_steps: integer safety cap on loop iterations (default 500)

  • enable_realsense, fps, debug: same semantics as stroke_tool

  • record_eval: bool to write an eval dataset (default false)

  • dry_run: bool to validate checkpoint/scene without moving robot (default false)

Output model fields (proposed):

  • success: bool

  • message: str

  • num_steps: int

  • Optional: eval_dir: path to saved eval dataset

Code skeleton for the tool:

# src/tatbot/tools/robot/models_vla.py
from typing import Literal, Optional
from tatbot.tools.base import ToolInput, ToolOutput

class VLAInferInput(ToolInput):
    policy: Literal["smolvla", "pi0"]
    checkpoint_path: str
    scene: str = "default"
    device: Literal["cuda", "cpu"] = "cuda"
    max_steps: int = 500
    enable_realsense: bool = False
    fps: int = 10
    debug: bool = False
    record_eval: bool = False
    dry_run: bool = False
    task_prompt: Optional[str] = None

class VLAInferOutput(ToolOutput):
    success: bool = True
    message: str = ""
    num_steps: int = 0
    eval_dir: Optional[str] = None
# src/tatbot/tools/robot/infer_vla.py
import time
import torch
from pathlib import Path
from lerobot.robots import Robot, make_robot_from_config
from lerobot.robots.tatbot.config_tatbot import TatbotConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.robot_utils import busy_wait
from tatbot.main import compose_and_validate_scene
from tatbot.tools.base import ToolContext
from tatbot.tools.registry import tool
from tatbot.tools.robot.models_vla import VLAInferInput, VLAInferOutput
from tatbot.utils.log import get_logger

log = get_logger("tools.vla_infer", "🧠")

@tool(
    name="vla_infer",
    nodes=["hog"],  # add GPU nodes if you want remote-only
    description="Run VLA policy inference on Tatbot from a checkpoint",
    input_model=VLAInferInput,
    output_model=VLAInferOutput,
)
async def vla_infer(input_data: VLAInferInput, ctx: ToolContext):
    try:
        yield {"progress": 0.01, "message": "Loading scene..."}
        scene = compose_and_validate_scene(input_data.scene)

        # Configure cameras if enabled
        rs_cameras = {}
        if input_data.enable_realsense:
            from lerobot.cameras.realsense import RealSenseCameraConfig
            rs_cameras = {
                cam.name: RealSenseCameraConfig(
                    fps=cam.fps, width=cam.width, height=cam.height,
                    serial_number_or_name=cam.serial_number,
                ) for cam in scene.cams.realsenses
            }

        robot: Robot = make_robot_from_config(TatbotConfig(
            ip_address_l=scene.arms.ip_address_l,
            ip_address_r=scene.arms.ip_address_r,
            arm_l_config_filepath=scene.arms.arm_l_config_filepath,
            arm_r_config_filepath=scene.arms.arm_r_config_filepath,
            goal_time=scene.arms.goal_time_slow,
            connection_timeout=scene.arms.connection_timeout,
            home_pos_l=scene.sleep_pos_l.joints,
            home_pos_r=scene.sleep_pos_r.joints,
            rs_cameras=rs_cameras,
            ip_cameras={},
        ))

        yield {"progress": 0.05, "message": "Connecting to robot..."}
        robot.connect()

        # Load policy
        yield {"progress": 0.1, "message": "Loading policy checkpoint..."}
        if input_data.policy == "smolvla":
            from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
            policy = SmolVLAPolicy.from_pretrained(input_data.checkpoint_path)
        else:
            from lerobot.policies.pi0.modeling_pi0 import PI0Policy
            policy = PI0Policy.from_pretrained(input_data.checkpoint_path)
        policy.eval()
        policy.to(input_data.device)

        # Optional dry-run: validate loading without moving hardware
        if input_data.dry_run:
            yield VLAInferOutput(success=True, message="Loaded scene and policy (dry run)", num_steps=0)
            return

        # Optional eval recording
        dataset = None
        if input_data.record_eval:
            output_dir = Path("/nfs/tatbot/recordings")
            eval_dir = output_dir / f"vla-eval-{scene.name}-{int(time.time())}"
            eval_dir.mkdir(parents=True, exist_ok=True)
            action_features = hw_to_dataset_features(robot.action_features, "action", True)
            obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
            # Align writer threads with stroke.py convention
            num_camera_threads = 0
            if hasattr(robot, "rs_cameras") and len(robot.rs_cameras) > 0:
                num_camera_threads += 4 * len(robot.rs_cameras)
            if hasattr(robot, "ip_cameras") and len(robot.ip_cameras) > 0:
                num_camera_threads += 4 * len(robot.ip_cameras)
            dataset = LeRobotDataset.create(
                repo_id=f"tatbot/{eval_dir.name}", fps=input_data.fps,
                root=str(eval_dir), robot_type=robot.name,
                features={**action_features, **obs_features},
                use_videos=True, image_writer_processes=0, image_writer_threads=num_camera_threads,
            )

        # Move to ready
        robot.send_action(robot._urdf_joints_to_action(scene.ready_pos_full.joints), safe=True)

        yield {"progress": 0.2, "message": "Starting inference loop..."}
        num_steps = 0
        dt_target = 1.0 / max(1, input_data.fps)
        
        def preprocess_observation_for_policy(observation):
            """Map Tatbot observation to policy-expected dict. Include task text if needed."""
            # TODO: implement mapping for chosen policy; pass input_data.task_prompt
            return observation

        def prepare_robot_action(policy_output):
            """Convert policy output to robot action format if necessary."""
            try:
                # If policy outputs 14D URDF joints (left7+right7):
                return robot._urdf_joints_to_action(policy_output)
            except Exception:
                return policy_output
        try:
            while num_steps < input_data.max_steps:
                t0 = time.perf_counter()
                observation = robot.get_observation()
                policy_obs = preprocess_observation_for_policy(observation)
                with torch.no_grad():
                    action = policy.select_action(policy_obs) if hasattr(policy, "select_action") else policy(policy_obs)
                # Send action (fast time for continuous control)
                sent_action = robot.send_action(prepare_robot_action(action), scene.arms.goal_time_fast)

                if dataset is not None:
                    obs_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
                    act_frame = build_dataset_frame(dataset.features, sent_action, prefix="action")
                    dataset.add_frame({**obs_frame, **act_frame})

                num_steps += 1
                # FPS pacing (busy-wait for precision)
                dt = time.perf_counter() - t0
                busy_wait(dt_target - dt)

            if dataset is not None:
                dataset.save_episode()

        finally:
            robot.send_action(robot._urdf_joints_to_action(scene.ready_pos_full.joints), safe=True)
            robot.disconnect()

        yield VLAInferOutput(success=True, message="Inference completed", num_steps=num_steps,
                             eval_dir=str(dataset.root) if dataset is not None else None)

    except Exception as e:
        log.error(f"vla_infer failed: {e}")
        yield VLAInferOutput(success=False, message=f"❌ {e}", num_steps=0)

Registering and running the MCP toolΒΆ

  1. Import the new tool so it auto-registers:

  • Update src/tatbot/tools/registry.py register_all_tools() to include:

# Import VLA tools
try:
    from tatbot.tools.vla import infer  # noqa: F401
    log.debug("Imported VLA tools")
except ImportError as e:
    log.debug(f"VLA tools not available: {e}")
  1. Ensure the node config includes this tool (or allow wildcard):

  • Edit src/conf/mcp/hog.yaml and/or GPU nodes to include vla_infer in mcp.tools.

  • If the tool requires GPU, add requires=["gpu"] to the decorator and ensure the node has extras: [gpu].

  1. Restart the MCP server on the node:

./scripts/mcp_run.sh hog
  1. Invoke the tool from your MCP client with input JSON like:

{
  "policy": "smolvla",
  "checkpoint_path": "outputs/train/smolvla_tatbot/checkpoints/last/pretrained_model",
  "scene": "tatbotlogo",
  "device": "cuda",
  "max_steps": 500,
  "enable_realsense": true,
  "fps": 10,
  "record_eval": true
}

Operational tipsΒΆ

  • Always verify the robot is clear to move before inference runs.

  • Use conservative goal_time_* and max_steps initially.

  • Start with device=cpu for dry-runs if GPU memory is tight, then switch to cuda.

  • For reproducibility, snapshot your scene.yaml and checkpoint path in run metadata (WandB or eval dataset name).

References:

  • Existing end-to-end guide: docs/models/claude_vla_guide.md

  • Stroke recording tool: src/tatbot/tools/robot/stroke.py

  • Stroke datatypes: src/tatbot/data/stroke.py