Masking and Precheck
SMDPfier provides sophisticated action masking and precheck validation to handle invalid options gracefully and ensure robust option execution.
Action Masking
Action masking restricts which options are available based on the current environment state. This is essential for handling state-dependent action validity in complex environments.
Key Concepts
- Availability Function: Determines which option indices are valid
- Action Mask: Binary array indicating available options
- Index Interface Only: Masking only works with
action_interface="index" - Dynamic Evaluation: Mask is recomputed every step
Basic Action Masking
import gymnasium as gym
from smdpfier import SMDPfier, Option
from smdpfier.defaults import ConstantOptionDuration
# Define options for CartPole
options = [
Option([0, 0], "strong-left"), # Index 0
Option([1, 1], "strong-right"), # Index 1
Option([0, 1], "left-right"), # Index 2
Option([1, 0], "right-left"), # Index 3
]
def cart_availability(obs):
"""Restrict options based on cart position and velocity."""
position, velocity = obs[0], obs[1]
available = []
# Always allow balanced options
available.extend([2, 3]) # left-right, right-left
# Restrict strong movements based on position
if position < 0.3: # Cart not too far right
available.append(1) # Allow strong-right
if position > -0.3: # Cart not too far left
available.append(0) # Allow strong-left
return available
env = SMDPfier(
gym.make("CartPole-v1"),
options_provider=options,
duration_fn=ConstantOptionDuration(5),
action_interface="index",
max_options=4,
availability_fn=cart_availability
)
obs, info = env.reset()
print(f"Available options: {info['smdp']['action_mask']}")
# Might show: [1, 1, 1, 1] (all available) or [0, 1, 1, 1] (strong-left masked)
Action Mask Structure
The action mask is a binary list where 1 means available and 0 means masked:
action_mask = [1, 0, 1, 0] # Options 0 and 2 available, 1 and 3 masked
Usage in RL algorithms:
obs, info = env.reset()
action_mask = info["smdp"]["action_mask"]
# In your RL algorithm:
if action_mask is not None:
# Mask invalid actions (set their Q-values to -inf)
masked_q_values = q_values.copy()
masked_q_values[action_mask == 0] = -float('inf')
action = np.argmax(masked_q_values)
else:
action = np.argmax(q_values)
Complex Masking Examples
Environment-Specific Masking
def taxi_availability(obs):
"""Taxi-v3 environment masking."""
# Decode Taxi state
taxi_row, taxi_col, passenger_loc, destination = env.unwrapped.decode(obs)
available = []
# Movement actions (always available)
available.extend([0, 1, 2, 3]) # south, north, east, west
# Pickup action (only if passenger at taxi location)
if passenger_loc < 4: # Passenger not in taxi
passenger_coords = env.unwrapped.locs[passenger_loc]
if (taxi_row, taxi_col) == passenger_coords:
available.append(4) # Allow pickup
# Dropoff action (only if passenger in taxi at destination)
if passenger_loc == 4: # Passenger in taxi
destination_coords = env.unwrapped.locs[destination]
if (taxi_row, taxi_col) == destination_coords:
available.append(5) # Allow dropoff
return available
# Taxi options
taxi_options = [
Option([0], "south"), # Index 0
Option([1], "north"), # Index 1
Option([2], "east"), # Index 2
Option([3], "west"), # Index 3
Option([4], "pickup"), # Index 4
Option([5], "dropoff"), # Index 5
]
taxi_env = SMDPfier(
gym.make("Taxi-v3"),
options_provider=taxi_options,
duration_fn=ConstantOptionDuration(1),
action_interface="index",
max_options=6,
availability_fn=taxi_availability
)
State-Dependent Option Length
def adaptive_availability(obs):
"""Allow different option lengths based on state."""
velocity = abs(obs[1]) # Cart velocity
if velocity > 0.5: # High velocity - need quick corrections
return [0, 1] # Only single-action options
else: # Low velocity - can use longer sequences
return [0, 1, 2, 3, 4] # All options available
options = [
Option([0], "quick-left"), # Index 0 - quick
Option([1], "quick-right"), # Index 1 - quick
Option([0, 0], "double-left"), # Index 2 - longer
Option([1, 1], "double-right"), # Index 3 - longer
Option([0, 1, 0], "zigzag"), # Index 4 - longest
]
env = SMDPfier(
gym.make("CartPole-v1"),
options_provider=options,
duration_fn=ConstantOptionDuration(3),
action_interface="index",
max_options=5,
availability_fn=adaptive_availability
)
Dynamic Options with Masking
When using dynamic option generators, the availability_fn is automatically passed to restrict generated options:
from smdpfier.defaults.options import RandomStaticLen
def state_aware_generator(obs, info):
"""Generate options based on state, using availability info."""
# Get action mask from info (passed automatically)
action_mask = info.get("action_mask")
if action_mask is not None:
# Generate options only with available actions
available_actions = [i for i, avail in enumerate(action_mask) if avail]
else:
# No masking - use all actions
available_actions = list(range(info["action_space"].n))
# Generate random options with available actions only
options = []
for i in range(5):
if available_actions:
actions = random.choices(available_actions, k=3)
options.append(Option(actions, f"dynamic_{i}"))
return options
def base_availability(obs):
"""Base availability function."""
if obs[0] > 0: # Cart right
return [0] # Only left action
else:
return [0, 1] # Both actions
env = SMDPfier(
gym.make("CartPole-v1"),
options_provider=state_aware_generator,
duration_fn=ConstantOptionDuration(2),
action_interface="index",
max_options=5,
availability_fn=base_availability # Passed to generator automatically
)
Precheck Validation
Precheck validation attempts to validate options before execution by testing their actions in the current environment state.
Enabling Precheck
env = SMDPfier(
base_env,
options_provider=options,
duration_fn=ConstantOptionDuration(5),
action_interface="index",
precheck=True # Enable precheck validation
)
How Precheck Works
- Before executing an option, SMDPfier saves the environment state
- Tests each action in the option sequence
- Restores the environment state after testing
- Raises SMDPOptionValidationError if any action fails
- Proceeds with execution if all actions are valid
Precheck Example
from smdpfier.errors import SMDPOptionValidationError
# Option that might be invalid in some states
risky_option = Option([0, 1, 0, 1, 0], "risky-sequence")
try:
obs, reward, term, trunc, info = env.step(risky_option)
except SMDPOptionValidationError as e:
print(f"Option '{e.option_name}' failed precheck!")
print(f"Failed at step {e.failing_step_index}")
print(f"Action {e.action_repr} is invalid")
print(f"Environment state: {e.short_obs_summary}")
# Handle the error (e.g., try a different option)
fallback_option = Option([0], "safe-fallback")
obs, reward, term, trunc, info = env.step(fallback_option)
Precheck Limitations
⚠️ Important Limitations:
- Environment must support state save/restore (not all environments do)
- Performance overhead from testing each option
- May not catch all edge cases (e.g., stochastic environments)
- False positives possible in complex environments
Recommendation: Use precheck during development and debugging, consider disabling in production for performance.
Precheck vs Masking
| Approach | When to Use | Pros | Cons |
|---|---|---|---|
| Action Masking | Known invalid patterns | Fast, reliable | Requires domain knowledge |
| Precheck | Unknown failure modes | Automatic detection | Slower, may have false positives |
| Both | Maximum safety | Comprehensive validation | Higher complexity |
Best Practices
Masking Strategy
def robust_availability(obs):
"""Comprehensive availability function."""
available = []
# Conservative base set (always safe)
available.extend([0, 1]) # Basic actions
# Add options based on state confidence
confidence = compute_state_confidence(obs)
if confidence > 0.8:
available.extend([2, 3]) # Medium complexity options
if confidence > 0.9:
available.extend([4, 5, 6]) # High complexity options
return available
Error-Resilient Option Design
# Design options to minimize failure probability
safe_options = [
Option([0], "single-left"), # Minimal option - rarely fails
Option([1], "single-right"), # Minimal option - rarely fails
Option([0, 1], "balanced"), # Balanced - self-correcting
]
# Avoid overly long or extreme options
risky_options = [
Option([0]*10, "extreme-left"), # Long sequence - high failure risk
Option([1]*10, "extreme-right"), # Long sequence - high failure risk
]
Integration with RL Algorithms
class MaskedDQNAgent:
def select_action(self, obs, info):
q_values = self.q_network(obs)
# Apply action mask if available
action_mask = info.get("smdp", {}).get("action_mask")
if action_mask is not None:
masked_q_values = q_values.copy()
masked_q_values[np.array(action_mask) == 0] = -float('inf')
return np.argmax(masked_q_values)
else:
return np.argmax(q_values)
def train(self, env):
obs, info = env.reset()
while True:
action = self.select_action(obs, info)
try:
obs, reward, term, trunc, info = env.step(action)
# ... training logic ...
except SMDPOptionValidationError:
# Handle validation failure
continue
if term or trunc:
break
Debugging Masking Issues
Inspect Masking Behavior
def debug_masking(env, num_steps=10):
"""Debug action masking behavior."""
obs, info = env.reset()
for step in range(num_steps):
mask = info.get("smdp", {}).get("action_mask", [])
available_actions = [i for i, avail in enumerate(mask) if avail == 1]
print(f"Step {step}:")
print(f" Observation: {obs[:3]}...") # First 3 elements
print(f" Action mask: {mask}")
print(f" Available actions: {available_actions}")
if available_actions:
action = random.choice(available_actions)
obs, reward, term, trunc, info = env.step(action)
if term or trunc:
break
else:
print(" No actions available!")
break
debug_masking(env)
Common Masking Issues
Issue: All actions masked
def overly_restrictive_availability(obs):
if obs[0] > 2.0: # Impossible condition
return [] # No actions available!
return [0, 1]
# Fix: Ensure at least one action is always available
def better_availability(obs):
available = [0] # Always allow basic action
if obs[0] < 2.0:
available.append(1)
return available
Issue: Inconsistent masking
def inconsistent_availability(obs):
# Problem: Random masking
return random.choices([0, 1, 2], k=random.randint(1, 3))
# Fix: Deterministic masking based on state
def consistent_availability(obs):
position = obs[0]
if position > 0.5:
return [0, 2] # Deterministic based on position
else:
return [1, 2]
Summary
| Feature | Purpose | Best For |
|---|---|---|
| Action Masking | Restrict invalid options | State-dependent validity |
| Precheck | Test options before execution | Unknown failure modes |
| Combined | Maximum robustness | Critical applications |
Key Takeaways: - Use masking for known patterns of invalid actions - Use precheck for unknown failure modes during development - Design options to minimize failure probability - Always ensure at least one action remains available
Next: Error Handling | See Also: API Reference