Mojo function
get_accum_type
get_accum_type[dtype: DType, *, preferred_accum_type: DType = float32]() -> DType
Returns the recommended dtype for accumulation operations.
Half precision and float8 types can introduce numerical error if they are used in reduction/accumulation operations. This method returns a higher precision dtype to use for accumulation if a half precision types is provided, otherwise it returns the original dtype.
The rules are as follows: - If the dtype is a float8 type, return a float16 type. - If the dtype is a bfloat16 precision type, return a float32 type. - If the dtype is a float16 precision type, return a float32 dtype if the preferred_accum_type is float32, otherwise return a float16 type. - Otherwise, return the original type.
Parameters:
- dtype (
DType
): The dtype of some accumulation operation. - preferred_accum_type (
DType
): The preferred dtype for accumulation.
Returns:
DType.float32 if dtype is a half-precision float, dtype otherwise.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!