001/*-
002 *******************************************************************************
003 * Copyright (c) 2015, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015/**
016 * <p>Class to provide slice iteration through a dataset</p>
017 * <p>It allows a number of axes to be omitted and iterates over
018 * the axes left over.</p>
019 */
020public class SliceNDIterator extends IndexIterator {
021        final private int[] shape;
022        final private int[] start;
023        final private int[] stop;
024        final private int[] step;
025        final private int endrank;
026
027        final private boolean[] omit; // axes to miss out
028
029        /**
030         * position in source dataset
031         */
032        final private int[] pos;
033        final private int[] end;
034        private boolean once;
035
036        private SliceND cSlice; // current slice
037        
038        private int sRank; // number of dimensions used (i.e. not missing)
039        final private SliceND oSlice; // omitted source slice
040
041        final private SliceND sSlice; // shortened slice
042        final private int[] sStart; // shortened position
043        final private int[] sStop; // shortened end
044
045        private SliceND dSlice; // destination slice
046        private int[] dStart;
047        private int[] dStop;
048
049        /**
050         * Constructor for an iterator that misses out several axes
051         * @param slice
052         * @param axes missing axes
053         */
054        public SliceNDIterator(SliceND slice, int... axes) {
055                cSlice = slice.clone();
056                int[] sShape = cSlice.getSourceShape();
057                shape = cSlice.getShape().clone();
058                start = cSlice.getStart();
059                stop  = cSlice.getStop();
060                step  = cSlice.getStep();
061                for (int s : step) {
062                        if (s < 0) {
063                                throw new UnsupportedOperationException("Negative steps not implemented");
064                        }
065                }
066                int rank = shape.length;
067                endrank = rank - 1;
068
069                omit = new boolean[rank];
070                dSlice = new SliceND(shape);
071                dStart = dSlice.getStart();
072                dStop  = dSlice.getStop();
073                sRank = rank;
074                if (axes != null) {
075                        for (int a : axes) {
076                                if (a < 0) {
077                                        a += rank;
078                                }
079                                if (a >= 0 && a <= endrank) {
080                                        sRank--;
081                                        omit[a] = true;
082                                        shape[a] = 1;
083                                } else if (a > endrank) {
084                                        throw new IllegalArgumentException("Specified axis exceeds dataset rank");
085                                }
086                        }
087                }
088
089                cSlice = cSlice.clone();
090                pos = cSlice.getStart();
091                end = cSlice.getStop();
092                if (sRank == rank) {
093                        sStart = pos;
094                        sStop = null;
095                        oSlice = null;
096                        sSlice = cSlice;
097                } else {
098                        int[] dShape = dSlice.getShape();
099                        int[] lShape = new int[sRank];
100                        int[] oShape = new int[rank - sRank];
101                        for (int i = 0, j = 0, k = 0; i < rank; i++) {
102                                if (omit[i]) {
103                                        oShape[j++] = sShape[i];
104                                } else {
105                                        lShape[k++] = sShape[i];
106                                        dShape[i] = 1;
107                                }
108                        }
109                        sSlice = new SliceND(lShape);
110                        sStart = sSlice.getStart();
111                        sStop = sSlice.getStop();
112//                      lShape = sSlice.getShape();
113//                      for (int k = 0; k < sRank; k++) {
114//                              lShape[k] = 1;
115//                      }
116                        oSlice = new SliceND(oShape);
117                        for (int i = 0, j = 0; i < rank; i++) {
118                                if (omit[i]) {
119                                        oSlice.setSlice(j++, start[i], stop[i], step[i]);
120                                }
121                        }
122                }
123                
124                reset();
125        }
126
127        @Override
128        public boolean hasNext() {
129                // now move on one position
130                if (once) {
131                        once = false;
132                        return true;
133                }
134                int k = sRank - 1;
135                for (int j = endrank; j >= 0; j--) {
136                        if (omit[j]) {
137                                continue;
138                        }
139                        pos[j] += step[j];
140                        end[j] = pos[j] + step[j];
141                        dStart[j]++;
142                        dStop[j]++;
143                        if (pos[j] >= stop[j]) {
144                                pos[j] = start[j];
145                                end[j] = pos[j] + step[j];
146                                dStart[j] = 0;
147                                dStop[j] = 1;
148                                if (sStop != null) {
149                                        sStart[k] = pos[j];
150                                        sStop[k] = end[j];
151                                        k--;
152                                }
153                        } else {
154                                if (sStop != null) {
155                                        sStart[k] = pos[j];
156                                        sStop[k] = end[j];
157                                        k--;
158                                }
159                                return true;
160                        }
161                }
162                return false;
163        }
164
165        @Override
166        public int[] getPos() {
167                return pos;
168        }
169
170        /**
171         * Get omitted part of source slice which never changes
172         * @return slice (can be null)
173         */
174        public SliceND getOmittedSlice() {
175                return oSlice;
176        }
177
178        /**
179         * Get output or destination slice
180         * @return slice
181         */
182        public SliceND getOutputSlice() {
183                return dSlice;
184        }
185
186        /**
187         * Get current slice
188         * @return slice
189         */
190        public SliceND getCurrentSlice() {
191                return cSlice;
192        }
193
194        /**
195         * Shortened position where axes are omitted
196         * @return used position
197         */
198        public int[] getUsedPos() {
199                return sStart;
200        }
201
202        /**
203         * Shortened slice where axes are omitted
204         * @return used slice
205         */
206        public SliceND getUsedSlice() {
207                return sSlice;
208        }
209
210        /**
211         * @return omit array - array where true means miss out
212         */
213        public boolean[] getOmit() {
214                return omit;
215        }
216
217        @Override
218        public void reset() {
219                for (int i = 0, k = 0; i <= endrank; i++) {
220                        int b = start[i];
221                        int d = step[i];
222                        if (!omit[i]) {
223                                cSlice.setSlice(i, b, b + d, d);
224                                dStart[i] = 0;
225                                dStop[i] = 1;
226                                if (sStop != null) {
227                                        sSlice.setSlice(k++, b, b + d, d);
228                                }
229                        } else {
230                                cSlice.setSlice(i, b, end[i], d);
231                        }
232                }
233
234                int j = 0;
235                for (; j <= endrank; j++) {
236                        if (!omit[j])
237                                break;
238                }
239                if (j > endrank) {
240                        once = true;
241                        return;
242                }
243
244                if (omit[endrank]) {
245                        pos[endrank] = start[endrank];
246                        for (int i = endrank - 1; i >= 0; i--) {
247                                if (!omit[i]) {
248                                        end[i] = pos[i];
249                                        pos[i] -= step[i];
250                                        dStart[i]--;
251                                        dStop[i]--;
252                                        break;
253                                }
254                        }
255                } else {
256                        end[endrank] = pos[endrank];
257                        pos[endrank] -= step[endrank];
258                        dStart[endrank]--;
259                        dStop[endrank]--;
260                }
261
262                if (sStart != pos) {
263                        for (int i = 0, k = 0; i <= endrank; i++) {
264                                if (!omit[i]) {
265                                        sStart[k++] = pos[i];
266                                }
267                        }
268                }
269        }
270
271        @Override
272        public int[] getShape() {
273                return shape;
274        }
275}