# @package _global_

# Use `act_aloha_solo_real.yaml` to train on real-world datasets collected on Aloha or Aloha-2 robots.
# It contains 2 cameras (i.e. cam_right_wrist/cam_left_wrist, cam_high).
# Also, `training.eval_freq` is set to -1. This config is used to evaluate checkpoints at a certain frequency of training steps.
# When it is set to -1, it deactivates evaluation. This is because real-world evaluation is done through our `control_robot.py` script.
# Look at the documentation in header of `control_robot.py` for more information on how to collect data , train and evaluate a policy.
#
# Example of usage for training and inference with `control_robot.py`:
# ```bash
# python lerobot/scripts/train.py \
#   policy=act_aloha_solo_real \
#   env=aloha_solo_real
# ```
#
# Example of usage for training and inference with [Dora-rs](https://github.com/dora-rs/dora-lerobot):
# ```bash
# python lerobot/scripts/train.py \
#   policy=act_aloha_solo_real \
#   env=dora_aloha_solo_real
# ```

seed: 1000
dataset_repo_id: lerobot/aloha_solo_lego

override_dataset_stats:
  observation.images.cam_right_wrist:
    # stats from imagenet, since we use a pretrained vision model
    mean: [[[0.485]], [[0.456]], [[0.406]]]  # (c,1,1)
    std: [[[0.229]], [[0.224]], [[0.225]]]  # (c,1,1)
  # Uncomment for left wrist and comment the right wrist
  # observation.images.cam_left_wrist:
  #   # stats from imagenet, since we use a pretrained vision model
  #   mean: [[[0.485]], [[0.456]], [[0.406]]]  # (c,1,1)
  #   std: [[[0.229]], [[0.224]], [[0.225]]]  # (c,1,1)
  observation.images.cam_high:
    # stats from imagenet, since we use a pretrained vision model
    mean: [[[0.485]], [[0.456]], [[0.406]]]  # (c,1,1)
    std: [[[0.229]], [[0.224]], [[0.225]]]  # (c,1,1)


training:
  offline_steps: 80000
  online_steps: 0
  eval_freq: -1
  save_freq: 10000
  log_freq: 100
  save_checkpoint: true

  batch_size: 8
  lr: 1e-5
  lr_backbone: 1e-5
  weight_decay: 1e-4
  grad_clip_norm: 10
  online_steps_between_rollouts: 1

  delta_timestamps:
    action: "[i / ${fps} for i in range(${policy.chunk_size})]"

eval:
  n_episodes: 50
  batch_size: 50

# See `configuration_act.py` for more details.
policy:
  name: act

  # Input / output structure.
  n_obs_steps: 1
  chunk_size: 100
  n_action_steps: 100

  input_shapes:
    # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
    observation.images.cam_right_wrist: [3, 480, 640]
    # observation.images.cam_left_wrist: [3, 480, 640] # Uncomment if using left wrist and comment the right.
    observation.images.cam_high: [3, 480, 640]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes:
    observation.images.cam_right_wrist: mean_std
    # observation.images.cam_left_wrist: mean_std # Uncomment if using left wrist and comment the right.
    observation.images.cam_high: mean_std
    observation.state: mean_std
  output_normalization_modes:
    action: mean_std

  # Architecture.
  # Vision backbone.
  vision_backbone: resnet18
  pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
  replace_final_stride_with_dilation: false
  # Transformer layers.
  pre_norm: false
  dim_model: 512
  n_heads: 8
  dim_feedforward: 3200
  feedforward_activation: relu
  n_encoder_layers: 4
  # Note: Although the original ACT implementation has 7 for `n_decoder_layers`, there is a bug in the code
  # that means only the first layer is used. Here we match the original implementation by setting this to 1.
  # See this issue https://github.com/tonyzhaozh/act/issues/25#issue-2258740521.
  n_decoder_layers: 1
  # VAE.
  use_vae: true
  latent_dim: 32
  n_vae_encoder_layers: 4

  # Inference.
  temporal_ensemble_momentum: null

  # Training and loss computation.
  dropout: 0.1
  kl_weight: 10.0
