Skip to main content
Log in

Mojo function

batched_matmul_shape

batched_matmul_shape[rank: Int, a_type: DType, b_type: DType, single_thread_blocking_override: Bool](a_buff: NDBuffer[a_type, rank, origin], b_buff: NDBuffer[b_type, rank, origin]) -> Index[rank]

Compute the output shape of a batch_matmul operation, and assert the inputs are compatible.

Parameters:

  • rank (Int): Rank of the input and output tensors.
  • a_type (DType): Type of the lhs input tensor.
  • b_type (DType): Type of the rhs input tensor.
  • single_thread_blocking_override (Bool): If True, then the operation is run synchronously using a single thread.

Args:

  • a_buff (NDBuffer[a_type, rank, origin]): The lhs input tensor.
  • b_buff (NDBuffer[b_type, rank, origin]): The rhs input tensor.

Returns:

The output shape.