001/******************************************************************************* 002 * Copyright (c) 2016 Diamond Light Source Ltd. and others. 003 * All rights reserved. This program and the accompanying materials 004 * are made available under the terms of the Eclipse Public License v1.0 005 * which accompanies this distribution, and is available at 006 * http://www.eclipse.org/legal/epl-v10.html 007 * 008 * Contributors: 009 * Diamond Light Source Ltd - initial API and implementation 010 *******************************************************************************/ 011package org.eclipse.january.dataset; 012 013import java.lang.reflect.Array; 014import java.util.ArrayList; 015import java.util.Arrays; 016import java.util.Collection; 017import java.util.List; 018import java.util.SortedSet; 019import java.util.TreeSet; 020 021public class ShapeUtils { 022 023 private ShapeUtils() { 024 } 025 026 /** 027 * Calculate total number of items in given shape 028 * @param shape 029 * @return size 030 */ 031 public static long calcLongSize(final int[] shape) { 032 if (shape == null) { // special case of null-shaped 033 return 0; 034 } 035 036 final int rank = shape.length; 037 if (rank == 0) { // special case of zero-rank shape 038 return 1; 039 } 040 041 double dsize = 1.0; 042 for (int i = 0; i < rank; i++) { 043 // make sure the indexes isn't zero or negative 044 if (shape[i] == 0) { 045 return 0; 046 } else if (shape[i] < 0) { 047 throw new IllegalArgumentException(String.format( 048 "The %d-th is %d which is not allowed as it is negative", i, shape[i])); 049 } 050 051 dsize *= shape[i]; 052 } 053 054 // check to see if the size is larger than an integer, i.e. we can't allocate it 055 if (dsize > Long.MAX_VALUE) { 056 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 057 } 058 return (long) dsize; 059 } 060 061 /** 062 * Calculate total number of items in given shape 063 * @param shape 064 * @return size 065 */ 066 public static int calcSize(final int[] shape) { 067 long lsize = calcLongSize(shape); 068 069 // check to see if the size is larger than an integer, i.e. we can't allocate it 070 if (lsize > Integer.MAX_VALUE) { 071 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 072 } 073 return (int) lsize; 074 } 075 076 /** 077 * Check if shapes are broadcast compatible 078 * 079 * @param ashape 080 * @param bshape 081 * @return true if they are compatible 082 */ 083 public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) { 084 if (ashape == null || bshape == null) { 085 return ashape == bshape; 086 } 087 088 if (ashape.length < bshape.length) { 089 return areShapesBroadcastCompatible(bshape, ashape); 090 } 091 092 for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) { 093 if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) { 094 return false; 095 } 096 } 097 098 return true; 099 } 100 101 /** 102 * Check if shapes are compatible, ignoring extra axes of length 1 103 * 104 * @param ashape 105 * @param bshape 106 * @return true if they are compatible 107 */ 108 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) { 109 if (ashape == null || bshape == null) { 110 return ashape == bshape; 111 } 112 113 List<Integer> alist = new ArrayList<Integer>(); 114 115 for (int a : ashape) { 116 if (a > 1) alist.add(a); 117 } 118 119 final int imax = alist.size(); 120 int i = 0; 121 for (int b : bshape) { 122 if (b == 1) 123 continue; 124 if (i >= imax || b != alist.get(i++)) 125 return false; 126 } 127 128 return i == imax; 129 } 130 131 /** 132 * Check if shapes are compatible but skip axis 133 * 134 * @param ashape 135 * @param bshape 136 * @param axis 137 * @return true if they are compatible 138 */ 139 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) { 140 if (ashape == null || bshape == null) { 141 return ashape == bshape; 142 } 143 144 if (ashape.length != bshape.length) { 145 return false; 146 } 147 148 final int rank = ashape.length; 149 for (int i = 0; i < rank; i++) { 150 if (i != axis && ashape[i] != bshape[i]) { 151 return false; 152 } 153 } 154 return true; 155 } 156 157 /** 158 * Remove dimensions of 1 in given shape - from both ends only, if true 159 * 160 * @param oshape 161 * @param onlyFromEnds 162 * @return newly squeezed shape (or original if unsqueezed) 163 */ 164 public static int[] squeezeShape(final int[] oshape, boolean onlyFromEnds) { 165 int unitDims = 0; 166 int rank = oshape.length; 167 int start = 0; 168 169 if (onlyFromEnds) { 170 int i = rank - 1; 171 for (; i >= 0; i--) { 172 if (oshape[i] == 1) { 173 unitDims++; 174 } else { 175 break; 176 } 177 } 178 for (int j = 0; j <= i; j++) { 179 if (oshape[j] == 1) { 180 unitDims++; 181 } else { 182 start = j; 183 break; 184 } 185 } 186 } else { 187 for (int i = 0; i < rank; i++) { 188 if (oshape[i] == 1) { 189 unitDims++; 190 } 191 } 192 } 193 194 if (unitDims == 0) { 195 return oshape; 196 } 197 198 int[] newDims = new int[rank - unitDims]; 199 if (unitDims == rank) 200 return newDims; // zero-rank dataset 201 202 if (onlyFromEnds) { 203 rank = newDims.length; 204 for (int i = 0; i < rank; i++) { 205 newDims[i] = oshape[i+start]; 206 } 207 } else { 208 int j = 0; 209 for (int i = 0; i < rank; i++) { 210 if (oshape[i] > 1) { 211 newDims[j++] = oshape[i]; 212 if (j >= newDims.length) 213 break; 214 } 215 } 216 } 217 218 return newDims; 219 } 220 221 /** 222 * Remove dimension of 1 in given shape 223 * 224 * @param oshape 225 * @param axis 226 * @return newly squeezed shape 227 */ 228 public static int[] squeezeShape(final int[] oshape, int axis) { 229 if (oshape == null) { 230 return null; 231 } 232 233 final int rank = oshape.length; 234 if (rank == 0) { 235 return new int[0]; 236 } 237 if (axis < 0) { 238 axis += rank; 239 } 240 if (axis < 0 || axis >= rank) { 241 throw new IllegalArgumentException("Axis argument is outside allowed range"); 242 } 243 int[] nshape = new int[rank-1]; 244 for (int i = 0; i < axis; i++) { 245 nshape[i] = oshape[i]; 246 } 247 for (int i = axis+1; i < rank; i++) { 248 nshape[i-1] = oshape[i]; 249 } 250 return nshape; 251 } 252 253 /** 254 * Get shape from object (array or list supported) 255 * @param obj 256 * @return shape can be null if obj is null 257 */ 258 public static int[] getShapeFromObject(final Object obj) { 259 if (obj == null) { 260 return null; 261 } 262 263 ArrayList<Integer> lshape = new ArrayList<Integer>(); 264 getShapeFromObj(lshape, obj, 0); 265 266 final int rank = lshape.size(); 267 final int[] shape = new int[rank]; 268 for (int i = 0; i < rank; i++) { 269 shape[i] = lshape.get(i); 270 } 271 272 return shape; 273 } 274 275 /** 276 * Get shape from object 277 * @param ldims 278 * @param obj 279 * @param depth 280 * @return true if there is a possibility of differing lengths 281 */ 282 private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) { 283 if (obj == null) 284 return true; 285 286 if (obj instanceof List<?>) { 287 List<?> jl = (List<?>) obj; 288 int l = jl.size(); 289 updateShape(ldims, depth, l); 290 for (int i = 0; i < l; i++) { 291 Object lo = jl.get(i); 292 if (!getShapeFromObj(ldims, lo, depth + 1)) { 293 break; 294 } 295 } 296 return true; 297 } 298 Class<? extends Object> ca = obj.getClass().getComponentType(); 299 if (ca != null) { 300 final int l = Array.getLength(obj); 301 updateShape(ldims, depth, l); 302 if (DTypeUtils.isClassSupportedAsElement(ca)) { 303 return true; 304 } 305 for (int i = 0; i < l; i++) { 306 Object lo = Array.get(obj, i); 307 if (!getShapeFromObj(ldims, lo, depth + 1)) { 308 break; 309 } 310 } 311 return true; 312 } else if (obj instanceof IDataset) { 313 int[] s = ((IDataset) obj).getShape(); 314 for (int i = 0; i < s.length; i++) { 315 updateShape(ldims, depth++, s[i]); 316 } 317 return true; 318 } else { 319 return false; // not an array of any type 320 } 321 } 322 323 private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) { 324 if (depth >= ldims.size()) { 325 ldims.add(l); 326 } else if (l > ldims.get(depth)) { 327 ldims.set(depth, l); 328 } 329 } 330 331 /** 332 * Get n-D position from given index 333 * @param n index 334 * @param shape 335 * @return n-D position 336 */ 337 public static int[] getNDPositionFromShape(int n, int[] shape) { 338 if (shape == null) { 339 return null; 340 } 341 342 int rank = shape.length; 343 if (rank == 0) { 344 return new int[0]; 345 } 346 347 if (rank == 1) { 348 return new int[] { n }; 349 } 350 351 int[] output = new int[rank]; 352 for (rank--; rank > 0; rank--) { 353 output[rank] = n % shape[rank]; 354 n /= shape[rank]; 355 } 356 output[0] = n; 357 358 return output; 359 } 360 361 /** 362 * Get flattened view index of given position 363 * @param shape 364 * @param pos 365 * the integer array specifying the n-D position 366 * @return the index on the flattened dataset 367 */ 368 public static int getFlat1DIndex(final int[] shape, final int[] pos) { 369 final int imax = pos.length; 370 if (imax == 0) { 371 return 0; 372 } 373 374 return AbstractDataset.get1DIndexFromShape(shape, pos); 375 } 376 377 /** 378 * This function takes a dataset and checks its shape against another dataset. If they are both of the same size, 379 * then this returns with no error, if there is a problem, then an error is thrown. 380 * 381 * @param g 382 * The first dataset to be compared 383 * @param h 384 * The second dataset to be compared 385 * @throws IllegalArgumentException 386 * This will be thrown if there is a problem with the compatibility 387 */ 388 public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException { 389 if (!areShapesCompatible(g.getShape(), h.getShape())) { 390 throw new IllegalArgumentException("Shapes do not match"); 391 } 392 } 393 394 /** 395 * Check that axis is in range [-rank,rank) 396 * 397 * @param rank 398 * @param axis 399 * @return sanitized axis in range [0, rank) 400 * @since 2.1 401 */ 402 public static int checkAxis(int rank, int axis) { 403 if (axis < 0) { 404 axis += rank; 405 } 406 407 if (axis < 0 || axis >= rank) { 408 throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")"); 409 } 410 return axis; 411 } 412 413 private static int[] convert(Collection<Integer> list) { 414 int[] array = new int[list.size()]; 415 int i = 0; 416 for (Integer l : list) { 417 array[i++] = l; 418 } 419 return array; 420 } 421 422 /** 423 * Check that all axes are in range [-rank,rank) 424 * @param rank 425 * @param axes 426 * @return sanitized axes in range [0, rank) and sorted in increasing order 427 * @since 2.2 428 */ 429 public static int[] checkAxes(int rank, int... axes) { 430 return convert(sanitizeAxes(rank, axes)); 431 } 432 433 /** 434 * Check that all axes are in range [-rank,rank) 435 * @param rank 436 * @param axes 437 * @return sanitized axes in range [0, rank) and sorted in increasing order 438 * @since 2.2 439 */ 440 private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) { 441 SortedSet<Integer> nAxes = new TreeSet<>(); 442 for (int i = 0; i < axes.length; i++) { 443 nAxes.add(checkAxis(rank, axes[i])); 444 } 445 446 return nAxes; 447 } 448 449 /** 450 * @param rank 451 * @param axes 452 * @return remaining axes not given by input 453 * @since 2.2 454 */ 455 public static int[] getRemainingAxes(int rank, int... axes) { 456 SortedSet<Integer> nAxes = sanitizeAxes(rank, axes); 457 458 int[] remains = new int[rank - axes.length]; 459 int j = 0; 460 for (int i = 0; i < rank; i++) { 461 if (!nAxes.contains(i)) { 462 remains[j++] = i; 463 } 464 } 465 return remains; 466 } 467 468 /** 469 * Remove axes from shape 470 * @param shape 471 * @param axes 472 * @return reduced shape 473 * @since 2.2 474 */ 475 public static int[] reduceShape(int[] shape, int... axes) { 476 int[] remain = getRemainingAxes(shape.length, axes); 477 for (int i = 0; i < remain.length; i++) { 478 int a = remain[i]; 479 remain[i] = shape[a]; 480 } 481 return remain; 482 } 483 484 /** 485 * Set reduced axes to 1 486 * @param shape 487 * @param axes 488 * @return shape with same rank 489 * @since 2.2 490 */ 491 public static int[] getReducedShapeKeepRank(int[] shape, int... axes) { 492 int[] keep = shape.clone(); 493 axes = checkAxes(shape.length, axes); 494 for (int i : axes) { 495 keep[i] = 1; 496 } 497 return keep; 498 } 499 500 /** 501 * @param a 502 * @param b 503 * @return true if arrays only differs by unit entries 504 * @since 2.2 505 */ 506 public static boolean differsByOnes(int[] a, int[] b) { 507 int aRank = a.length; 508 int bRank = b.length; 509 int ai = 0; 510 int bi = 0; 511 int al = 1; 512 int bl = 1; 513 do { 514 while (ai < aRank && (al = a[ai++]) == 1) { // next non-unit dimension 515 } 516 while (bi < bRank && (bl = b[bi++]) == 1) { 517 } 518 if (al != bl) { 519 return false; 520 } 521 } while (ai < aRank && bi < bRank); 522 523 if (ai == aRank) { 524 while (bi < bRank) { 525 if (b[bi++] != 1) { 526 return false; 527 } 528 } 529 } 530 if (bi == bRank) { 531 while (ai < aRank) { 532 if (a[ai++] != 1) { 533 return false; 534 } 535 } 536 } 537 return true; 538 } 539 540 /** 541 * Calculate the padding difference between two shapes. Padding can be positive (negative) 542 * for added (removed) dimensions. NB positive or negative padding is given after matched 543 * dimensions 544 * @param aShape 545 * @param bShape 546 * @return padding can be null if shapes are equal 547 * @throws IllegalArgumentException if one shape is null but not the other, or if shapes do 548 * not possess common non-unit lengths 549 * @since 2.2 550 */ 551 public static int[] calcShapePadding(int[] aShape, int[] bShape) { 552 if (Arrays.equals(aShape, bShape)) { 553 return null; 554 } 555 556 if (aShape == null || bShape == null) { 557 throw new IllegalArgumentException("If one shape is null then the other must be null too"); 558 } 559 560 if (!differsByOnes(aShape, bShape)) { 561 throw new IllegalArgumentException("Non-unit lengths in shapes must be equal"); 562 } 563 int aRank = aShape.length; 564 int bRank = bShape.length; 565 566 int[] padding; 567 if (aRank == 0 || bRank == 0) { 568 padding = new int[1]; 569 padding[0] = aRank == 0 ? bRank : -aRank; 570 return padding; 571 } 572 573 padding = new int[Math.max(aRank, bRank) + 2]; 574 int ai = 0; 575 int bi = 0; 576 int al = 0; 577 int bl = 0; 578 int pi = 0; 579 int p; 580 boolean aLeft = ai < aRank; 581 boolean bLeft = bi < bRank; 582 while (aLeft && bLeft) { 583 if (aLeft) { 584 al = aShape[ai++]; 585 aLeft = ai < aRank; 586 } 587 if (bLeft) { 588 bl = bShape[bi++]; 589 bLeft = bi < bRank; 590 } 591 if (al != bl) { 592 p = 0; 593 while (al == 1 && aLeft) { 594 al = aShape[ai++]; 595 aLeft = ai < aRank; 596 p--; 597 } 598 while (bl == 1 && bLeft) { 599 bl = bShape[bi++]; 600 bLeft = bi < bRank; 601 p++; 602 } 603 padding[pi++] = p; 604 } 605 if (al == bl) { 606 pi++; 607 } 608 } 609 if (aLeft || bLeft) { 610 p = 0; 611 while (ai < aRank && aShape[ai++] == 1) { 612 p--; 613 } 614 while (bi < bRank && bShape[bi++] == 1) { 615 p++; 616 } 617 padding[pi++] = p; 618 } 619 620 return Arrays.copyOf(padding, pi); 621 } 622 623 static int[] padShape(int[] padding, int nr, int[] oldShape) { 624 if (padding == null) { 625 return oldShape.clone(); 626 } 627 int or = oldShape.length; 628 int[] newShape = new int[nr]; 629 int di = 0; 630 for (int i = 0, si = 0; i < (or+1) && si <= or && di < nr; i++) { 631 int c = padding[i]; 632 if (c == 0) { 633 newShape[di++] = oldShape[si++]; 634 } else if (c > 0) { 635 int dim = di + c; 636 while (di < dim) { 637 newShape[di++] = 1; 638 } 639 } else if (c < 0) { 640 si -= c; // remove dimensions by skipping forward in source array (should check that they are unit entries) 641 } 642 } 643 while (di < nr) { 644 newShape[di++] = 1; 645 } 646 return newShape; 647 } 648}