1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
hardswish(float x)9 float hardswish(float x) {
10 if (x <= -3) {
11 return 0;
12 } else if (x >= 3) {
13 return x;
14 } else {
15 return x * (x + 3) / 6;
16 }
17 }
18
hardswish(vec4 tex)19 vec4 hardswish(vec4 tex) {
20 return vec4(
21 hardswish(tex.x), hardswish(tex.y), hardswish(tex.z), hardswish(tex.w));
22 }
23
hardshrink(float x,float lambda,float neg_lambda)24 float hardshrink(float x, float lambda, float neg_lambda) {
25 return x * (float(x > lambda) + float(x < neg_lambda));
26 }
27
hardshrink(vec4 tex,float lambda,float neg_lambda)28 vec4 hardshrink(vec4 tex, float lambda, float neg_lambda) {
29 return tex *
30 (vec4(greaterThan(tex, vec4(lambda))) +
31 vec4(lessThan(tex, vec4(neg_lambda))));
32 }
33
hardsigmoid(float x)34 float hardsigmoid(float x) {
35 return mix(float(x >= 0.0), x / 6 + 0.5, float(abs(x) <= 3.0));
36 }
37
hardsigmoid(vec4 tex)38 vec4 hardsigmoid(vec4 tex) {
39 return vec4(
40 hardsigmoid(tex.x),
41 hardsigmoid(tex.y),
42 hardsigmoid(tex.z),
43 hardsigmoid(tex.w));
44 }
45