Arxiv:2103.15954V1 [Cs.CV] 29 Mar 2021 Better Performing Architectures
Total Page:16
File Type:pdf, Size:1020Kb
DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation Yufan He1 Dong Yang2 Holger Roth2 Can Zhao2 Daguang Xu2 1Johns Hopkins University 2NVIDIA Abstract 0.1 0.9 휷 0.9 0.1 ퟏ Search Discretization 0.8 0.8 휷ퟐ Recently, neural architecture search (NAS) has been ap- 0.1 0.1 plied to automatically search high-performance networks 0.1 0.1 Gap 휷ퟑ 0.9 0.9 for medical image segmentation. The NAS search space 0.1 0.1 usually contains a network topology level (controlling con- 휷ퟏ+휷ퟐ+휷ퟑ=1 Continuous Model Discrete Model nections among cells with different spatial scales) and a cell level (operations within each cell). Existing meth- Figure 1. Limitations of existing differentiable topology search ods either require long searching time for large-scale 3D formulation. E.g. in Auto-DeepLab [21], each edge in the topol- image datasets, or are limited to pre-defined topologies ogy search space is given a probability β. The probabilities of in- (such as U-shaped or single-path) . In this work, we fo- put edges to a node sum to one, which means only one input edge cus on three important aspects of NAS in 3D medical image for each node would be selected. A single-path discrete model (red segmentation: flexible multi-path network topology, high path) is extracted from the continuous searched model. This can result in a large “discretization gap” between the feature flow of search efficiency, and budgeted GPU memory usage. A the searched continuous model and the final discrete model. novel differentiable search framework is proposed to sup- port fast gradient-based search within a highly flexible network topology search space. The discretization of the can vary considerably. This makes the direct application of searched optimal continuous model in differentiable scheme even a successful network like U-Net [34] to a new task less may produce a sub-optimal final discrete model (discretiza- likely to be optimal. tion gap). Therefore, we propose a topology loss to alleviate The neural architecture search (NAS) algorithms [49] this problem. In addition, the GPU memory usage for the have been proposed to automatically discover the opti- searched 3D model is limited with budget constraints dur- mal architectures within a search space. The NAS search ing search. Our Differentiable Network Topology Search space for segmentation usually contains two levels: net- scheme (DiNTS) is evaluated on the Medical Segmentation work topology level and cell level. The network topology Decathlon (MSD) challenge, which contains ten challeng- controls the connections among cells and decides the flow ing segmentation tasks. Our method achieves the state-of- of the feature maps across different spatial scales. The cell the-art performance and the top ranking on the MSD chal- level decides the specific operations on the feature maps. lenge leaderboard. A more flexible search space has more potential to contain arXiv:2103.15954v1 [cs.CV] 29 Mar 2021 better performing architectures. 1. Introduction In terms of the search methods in finding the optimal ar- chitecture from the search space, evolutionary or reinforce- Automated medical image segmentation is essential for ment learning-based [49, 33] algorithms are usually time many clinical applications like finding new biomarkers and consuming. C2FNAS [45] takes 333 GPU days to search monitoring disease progression. The recent developments one 3D segmentation network using the evolutionary-based in deep neural network architectures have achieved great methods, which is too computationally expensive for com- performance improvements in image segmentation. Man- mon use cases. Differentiable architecture search [23] is ually designed networks, like U-Net [34], have been widely much more efficient and Auto-DeepLab [21] is the first used in different tasks. However, the diversity of medical work to apply differentiable search for segmentation net- image segmentation tasks could be extremely high since the work topology. However, Auto-DeepLab’s differentiable image characteristics & appearances can be completely dis- formulation limits the searched network topology. As tinct for different modalities and the presentation of diseases shown in Fig.1, this formulation assumes that only one in- put edge would be kept for each node. Its final searched • We achieve the new state-of-the-art results and top model only has a single path from input to output which ranking in the MSD challenge leaderboard while only limits its complexity. Our first goal is to propose a new dif- taking 1.7% of the search time compared to the NAS- ferentiable scheme to support more complex topologies in based C2FNAS [45]. order to find novel architectures with better performance. Meanwhile, the differentiable architecture search suffers 2. Related Work from the “discretization gap” problem [4, 38]. The dis- cretization of the searched optimal continuous model may 2.1. Medical Image Segmentation produce a sub-optimal discrete final architecture and cause Medical image segmentation faces some unique chal- a large performance gap. As shown in Fig.1, the gap lenges like lacking manual labels and vast memory us- comes from two sides: 1) the searched continuous model age for processing 3D high resolution images. Compared is not binary, thus some operations/edges with small but to networks used in natural images like DeepLab [2] and non-zero probabilities are discarded during the discretiza- PSPNet [46], 2D/3D UNet [34,6] is better at preserving tion step; 2) the discretization algorithm has topology con- fine details and memory friendly when applied to 3D im- straints (e.g. single-path), thus edges causing infeasible ages. VNet [26] improves 3D UNet with residual blocks. topology are not allowed even if they have large probabili- UNet++ [47] uses dense blocks [13] to redesign skip con- ties in the continuous model. Alleviating the first problem nections. H-DenseUNet [17] combines 2D and 3D UNet by encouraging a binarized model during search has been to save memory. nnUNet [14] ensembles 2D, 3D, and cas- explored [5, 38, 27]. However, alleviating the second prob- caded 3D UNet and achieves state-of-the-art results on a lem requires the search to be aware of the discretization al- variety of medical image segmentation benchmarks. gorithm and topology constraints. In this paper, we propose a topology loss in search stage and a topology guaranteed 2.2. Neural Architecture Search discretization algorithm to mitigate this problem. Neural architecture search (NAS) focuses on designing In medical image analysis, especially for some longitu- network automatically. The work in NAS can be catego- dinal analysis tasks, high input image resolution and large rized into three dimensions: search space, search method patch size are usually desired to capture miniscule longitu- and performance estimation [8]. The search space defines dinal changes. Thus, large GPU memory usage is a major what architecture can be searched, which can be further di- challenge for training with large high resolution 3D images. vided into network topology level and cell level. For image Most NAS algorithms with computational constraints focus classification, [23, 50, 22, 33, 31, 11] focus on searching op- on latency [1,3, 18, 36] for real-time applications. How- timal cells and apply a pre-defined network topology while ever, real-time inference often is not a major concern com- [9, 42] perform search on the network topology. In seg- pared to the problem caused by huge GPU memory usage in mentation, Auto-DeepLab [21] uses a highly flexible search 3D medical image analysis. In this paper, we propose addi- space while FasterSeg [3] proposes a low latency two level tional GPU memory constraints in the search stage to limit search space. Both perform a joint two-level search. In the GPU usage needed for retraining the searched model. medical image segmentation, NAS-UNet [40], V-NAS [48] We validate our method on the Medical Segmentation and Kim et al [15] search cells and apply it to a U-Net- Decathlon (MSD) dataset [37] which contains 10 repre- like topology. C2FNAS [45] searches 3D network topol- sentative 3D medical segmentation tasks covering differ- ogy in a U-shaped space and then searches the operation for ent anatomies and imaging modalities. We achieve state- each cell. MS-NAS [44] applies PC-Darts [43] and Auto- of-the-art results while only takes 5.8 GPU days (recent DeepLab’s formulation to 2D medical images. C2FNAS [45] takes 333 GPU days on the same dataset). Search method and performance estimation focus on Our contributions can be summarized as: finding the optimal architecture from the search space. Evo- • We propose a novel Differentiable Network Topology lutionary and reinforcement learning has been used in [49, Search scheme DiNTS, which supports more flexible 33] but those methods require extremely long search time. topologies and joint two-level search. Differentiable methods [23, 21] relax the discrete architec- ture into continuous representations and allow direct gradi- • We propose a topology guaranteed discretization algo- ent based search. This is magnitudes faster and has been ap- rithm and a discretization aware topology loss for the plied in various NAS works [23, 21, 43, 48, 44]. However, search stage to minimize the discretization gap. converting the continuous representation back to the dis- crete architecture causes the “discretization gap”. To solve • We develop a memory usage aware search method this problem, FairDARTS [5] and Tian et al [38] proposed which is able to search 3D networks with different zero-one loss and entropy loss respectively to push the con- GPU memory requirements. tinuous representation close to binary. Some works [27, 12] Scale d2x d2x u2x u2x 1 IN OUT 2x downsample 2x upsample feature nodes 3x3x3 conv 1x1x1 conv 3x3x3 conv, stride 2 direct pass d2x 0 1 2 3 4 5 6 7 8 9 10 11 12 u2x 1/2 d2x u2x d2x 1/4 skip cell 3D: 3x3x3 d2x u2x cell P3D: 3x3x1 1/8 + + P3D: 3x1x3 u2x cell P3D: 1x3x3 1/16 u2x Topology Search Space Cell Search Space Figure 2.