Mojo function
advanced_indexing_setitem_inplace
advanced_indexing_setitem_inplace[input_rank: Int, index_rank: Int, updates_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], single_thread_blocking_override: Bool, trace_description: StringSlice[StaticConstantOrigin], updates_tensor_fn: fn[Int](Index[updates_rank]) capturing -> SIMD[input_type, $0], indices_fn: fn[Int](Index[index_rank]) capturing -> SIMD[index_type, 1]](input_tensor: NDBuffer[input_type, input_rank, origin], index_tensor_shape: Index[index_rank, element_type=element_type], updates_tensor_strides: Index[updates_rank], ctx: DeviceContextPtr)
Implement basic numpy-style advanced indexing with assignment.
This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics.
This assumes the dimensions in input_tensor
not indexed by index tensors
are ":", ie selecting all indices along the slice. For example in numpy:
# rank(indices1) == 2
# rank(indices2) == 2
# rank(updates) == 2
input_tensor[:, :, :, indices1, indices2, :, :] = updates
# rank(indices1) == 2
# rank(indices2) == 2
# rank(updates) == 2
input_tensor[:, :, :, indices1, indices2, :, :] = updates
We calculate the following for all valid valued indexing variables:
input_tensor[
a, b, c,
indices1[i, j],
indices2[i, j],
d, e
] = updates[i, j]
input_tensor[
a, b, c,
indices1[i, j],
indices2[i, j],
d, e
] = updates[i, j]
In this example start_axis = 3
and num_index_tensors = 2
.
In terms of implementation details, our strategy is to iterate over
all indices over a common iteration range. The idea is we can map
indices in this range to the write location in input_tensor
as well
as the data location in updates
. An update can illustrate how this is
possible best:
Imagine the input_tensor
shape is [A, B, C, D] and we have indexing
tensors I1 and I2 with shape [M, N, K]. Assume I1 and I2 are applied
to dimensions 1 and 2.
I claim an appropriate common iteration range is then (A, M, N, K, D).
Note we expect updates
to be the shape [A, M, N, K, D]. We will show
this by providing the mappings into updates
and input_tensor
:
Consider an arbitrary set of indices in this range (a, m, n, k, d):
- The index into updates
is (a, m, n, k, d).
- The index into input_tensor
is (a, I1[m, n, k], I2[m, n, k], d).
TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion) TODO(GEX-1954): Unify getitem and setitem using generic views. (Requires non-strided view functions).
Parameters:
- input_rank (
Int
): The rank of the input tensor. - index_rank (
Int
): The rank of the indexing tensors. - updates_rank (
Int
): The rank of the updates tensor. - input_type (
DType
): The dtype of the input tensor. - index_type (
DType
): The dtype of the indexing tensors. - start_axis (
Int
): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions. - num_index_tensors (
Int
): The number of indexing tensors. - target (
StringSlice[StaticConstantOrigin]
): The target architecture to operation on. - single_thread_blocking_override (
Bool
): If True, then the operation is run synchronously using a single thread. - trace_description (
StringSlice[StaticConstantOrigin]
): For profiling, the trace name the operation will appear under. - updates_tensor_fn (
fn[Int](Index[updates_rank]) capturing -> SIMD[input_type, $0]
): Fusion lambda for the update tensor. - indices_fn (
fn[Int](Index[index_rank]) capturing -> SIMD[index_type, 1]
): Fusion lambda for the indices tensors.
Args:
- input_tensor (
NDBuffer[input_type, input_rank, origin]
): The input tensor being indexed into and modified in-place. - index_tensor_shape (
Index[index_rank, element_type=element_type]
): The shape of each index tensor. - updates_tensor_strides (
Index[updates_rank]
): The strides of the update tensor. - ctx (
DeviceContextPtr
): The DeviceContextPtr as prepared by the graph compiler.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!