TSTM: Temporal Segmentation for Task-relevant Mask in Visual Reinforcement Learning Generalization
Abstract
Achieving strong policy generalization to unseen environments remains a core challenge in visual reinforcement learning, and segmenting task-relevant regions to mitigate the influence of irrelevant visual cues has emerged as a promising direction. However, existing methods rely solely on the current observation, lack temporal information, and fail to exploit preceding observations, leaving learned policies susceptible to task-irrelevant background variations and ultimately degrading policy performance. In this paper, we propose temporal segmentation for task-relevant mask in visual reinforcement learning, named TSTM, which extracts task-relevant regions from sequential observations by exploiting temporal information, thereby producing more reliable masks and improving policy generalization. TSTM introduces a temporal segmentation network with an encoder-temporal-decoder architecture, where a convolutional LSTM module captures temporal dependencies across observations. To reduce inference overhead, we further develop a lightweight student network as an efficient substitute for the teacher network. The resulting task-relevant masks are encoded by a CNN-based encoder, and invariant representation learning is employed to improve robustness by enforcing consistency between representations of the original and augmented observation sequences. With these task-relevant representations, we train an actor-critic agent to learn a policy with strong generalization capability. Experimental results demonstrate that TSTM achieves superior generalization performance over existing state-of-the-art methods on most visual RL tasks.