a NeurIPS 2021 paper

Konpat Preechakul, Chawan Piansaddhayanon, Burin Naowarat, Tirasan Khandhawit, Sira Sriswasdi, Ekapol Chuangsuwanich

Abstract

Set prediction tasks require the matching between predicted set and ground truth set in order to propagate the gradient signal. Recent works have performed this matching in the original feature space thus requiring predefined distance functions. We propose a method for learning the distance function by performing the matching in the latent space learned from encoding networks. This method enables the use of teacher forcing which was not possible previously since matching in the feature space must be computed after the entire output sequence is generated. Nonetheless, a naive implementation of latent set prediction might not converge due to permutation instability. To address this problem, we provide sufficient conditions for permutation stability which begets an algorithm to improve the overall model convergence. Experiments on several set prediction tasks, including image captioning and object detection, demonstrate the effectiveness of our method.

Paper

Untitled

Video Presentation

Untitled

Code

https://cdn-icons-png.flaticon.com/512/25/25231.png?w=360

A short tour

Set prediction with neural networks can be used to solve many real-world problems such as object detection.

Traditionally, object detection has been approached by ROI pooling + non-maximum suppression and their derivations.

lsp for web v1.svg

Set prediction is arguably a more straightforward to deal with object detection. However, set prediction with neural nets is not straightforward.

Due to “minimum matching” between the neural net’s prediction and the ground truths, one needs to define a proper distance function, which is not at all clear how to do. After that, the loss is calculated between matched pairs, and the optimization may ensue.

lsp for web 2.svg

Applying set prediction in text domain is another key limitation. Autoregressive + Teacher forcing + Set prediction is an inefficient combination. One needs to do as much as $O(n^2)$ teacher-forced predictions before matching, expensive in both computation and memory.

lsp for web 3.svg

These problems are lifted if we do set prediction in the latent space!

Here, we propose Latent Set Prediction (LSP) technique that enables set prediction in the latent space that seamlessly deals with the distance function and teacher forcing problems!

lsp for web 4.svg

Anatomically, we propose another model called “Encoder” that transforms the target outputs back into the latent space where the minimum assignment is done.

In the latent space, we may use simple Euclidean distance while allowing the Encoder learn end-to-end to come up with a natural latent space without any further specification.

Doing this naively may lead to a non-convergent optimization.

Powered by Fruition