Mojo function
get_accum_type
get_accum_type[dtype: DType, *, preferred_accum_type: DType = float32]() -> DType
Returns the recommended dtype for accumulation operations.
Half precision and float8 types can introduce numerical error if they are used in reduction/accumulation operations. This method returns a higher precision dtype to use for accumulation if a half precision types is provided, otherwise it returns the original dtype.
The rules are as follows: - If the dtype is a float8 type, return a float16 type. - If the dtype is a bfloat16 precision type, return a float32 type. - If the dtype is a float16 precision type, return a float32 dtype if the preferred_accum_type is float32, otherwise return a float16 type. - Otherwise, return the original type.
Parameters:
- dtype (
DType
): The dtype of some accumulation operation. - preferred_accum_type (
DType
): The preferred dtype for accumulation.
Returns:
The recommended dtype for accumulation operations based on the input dtype and the preferred accumulation type.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!