Skip to main content
Log in

Mojo function

get_accum_type

get_accum_type[dtype: DType, *, preferred_accum_type: DType = float32]() -> DType

Returns the recommended dtype for accumulation operations.

Half precision and float8 types can introduce numerical error if they are used in reduction/accumulation operations. This method returns a higher precision dtype to use for accumulation if a half precision types is provided, otherwise it returns the original dtype.

The rules are as follows: - If the dtype is a float8 type, return a float16 type. - If the dtype is a bfloat16 precision type, return a float32 type. - If the dtype is a float16 precision type, return a float32 dtype if the preferred_accum_type is float32, otherwise return a float16 type. - Otherwise, return the original type.

Parameters:

  • dtype (DType): The dtype of some accumulation operation.
  • preferred_accum_type (DType): The preferred dtype for accumulation.

Returns:

DType.float32 if dtype is a half-precision float, dtype otherwise.