Skip to main content
Log in

Mojo function

swishGLU

swishGLU[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, target: StringSlice[StaticConstantOrigin] = __init__[__mlir_type.!kgen.string]("cpu")](a: NDBuffer[a_type, 2, MutableAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutableAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutableAnyOrigin, b_shape], c: NDBuffer[c_type, 2, MutableAnyOrigin, c_shape], ctx: DeviceContextPtr)

Reference: GLU Variants Improve Transformer by Noam Shazeer https://arxiv.org/pdf/2002.05202v1 The implementation follows cutlass, using one kernel invocation and writing to the destination once.