Skip to main content
Log in

Mojo function

advanced_indexing_getitem

advanced_indexing_getitem[input_rank: Int, index_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], single_thread_blocking_override: Bool, trace_description: StringSlice[StaticConstantOrigin], input_tensor_fn: fn[Int](Index[input_rank]) capturing -> SIMD[input_type, $0], indices_fn: fn[Int](Index[index_rank]) capturing -> SIMD[index_type, 1]](out_tensor: NDBuffer[input_type, ((num_index_tensors * -1) + index_rank + input_rank), origin], in_tensor_strides: Index[input_rank], ctx: DeviceContextPtr)

Implement basic numpy-style advanced indexing.

This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics.

This assumes the dimensions in input_tensor not indexed by index tensors are ":", ie selecting all indices along the slice. For example in numpy:

# rank(indices1) == 3
# rank(indices2) == 3
out_tensor = input_tensor[:, :, :, indices1, indices2, :, :]
# rank(indices1) == 3
# rank(indices2) == 3
out_tensor = input_tensor[:, :, :, indices1, indices2, :, :]

We calculate the following for all valid valued indexing variables:

out_tensor[a, b, c, i, j, k, d, e] = input_tensor[
a, b, c,
indices1[i, j, k],
indices2[i, j, k],
d, e
]
out_tensor[a, b, c, i, j, k, d, e] = input_tensor[
a, b, c,
indices1[i, j, k],
indices2[i, j, k],
d, e
]

In this example start_axis = 3 and num_index_tensors = 2.

TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion)

Parameters:

  • input_rank (Int): The rank of the input tensor.
  • index_rank (Int): The rank of the indexing tensors.
  • input_type (DType): The dtype of the input tensor.
  • index_type (DType): The dtype of the indexing tensors.
  • start_axis (Int): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions.
  • num_index_tensors (Int): The number of indexing tensors.
  • target (StringSlice[StaticConstantOrigin]): The target architecture to operation on.
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.
  • trace_description (StringSlice[StaticConstantOrigin]): For profiling, the trace name the operation will appear under.
  • input_tensor_fn (fn[Int](Index[input_rank]) capturing -> SIMD[input_type, $0]): Fusion lambda for the input tensor.
  • indices_fn (fn[Int](Index[index_rank]) capturing -> SIMD[index_type, 1]): Fusion lambda for the indices tensors.

Args:

  • out_tensor (NDBuffer[input_type, ((num_index_tensors * -1) + index_rank + input_rank), origin]): The output tensor to write to.
  • in_tensor_strides (Index[input_rank]): The strides of the input tensor.
  • ctx (DeviceContextPtr): The DeviceContextPtr as prepared by the graph compiler.