Mitigating The Distribution Shift of Diffusion-based Dataset Distillation
Abstract
Dataset Distillation (DD) seeks to create small, synthetic datasets for efficient model training. While diffusion models are powerful generators, their use in DD is hampered by distribution shifts between synthetic and ideal distilled data, leading to suboptimal performance. We identify two critical shifts. First, considering the small capacity of the synthetic data, an optimal synthetic distribution for DD should be a simplification of the real data distribution, rather than replicating the original data's complexity. Second, there is a hazardous empirical deviation in the synthetic dataset from this learned distribution due to the data sampling process. To address these, we introduce a two-stage approach. During diffusion training time, we mitigate the distribution shift by employing an L1 sparsity regularizer, compelling the diffusion model to learn a compact and semantically sparse manifold. Then, during sampling time, we abandon the flawed sequential sampling paradigm and instead synchronously denoises the entire synthetic dataset with distribution regularizers. This framework systematically mitigates both identified distribution shifts. Experiments show our method achieves state-of-the-art performance with superior computational efficiency.