• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef MetalNeuronType_h
2 #define MetalNeuronType_h
3 
4 #import <ATen/native/metal/mpscnn/MPSCNNNeuronOp.h>
5 #import <MetalPerformanceShaders/MetalPerformanceShaders.h>
6 
7 #include <ATen/ATen.h>
8 
9 namespace at::native::metal {
10 
11 enum class NeuronType {
12   None,
13   Clamp,
14   Relu,
15   Sigmoid,
16   HardSigmoid,
17   Tanh,
18 };
19 
neuronType(std::optional<c10::Scalar> output_min,std::optional<c10::Scalar> output_max)20 static inline NeuronType neuronType(
21     std::optional<c10::Scalar> output_min,
22     std::optional<c10::Scalar> output_max) {
23   float inf_max = std::numeric_limits<float>::infinity();
24   float inf_min = -std::numeric_limits<float>::infinity();
25   float output_max_ =
26       output_max.has_value() ? output_max.value().toFloat() : inf_max;
27   float output_min_ =
28       output_min.has_value() ? output_min.value().toFloat() : inf_min;
29   if (output_max_ == inf_max && output_min_ == 0) {
30     return NeuronType::Relu;
31   } else if (output_max_ < inf_max && output_min_ > inf_min) {
32     return NeuronType::Clamp;
33   } else {
34     return NeuronType::None;
35   }
36 }
37 
neuron(NeuronType type)38 static inline MPSCNNNeuron* neuron(NeuronType type) {
39   if (type == NeuronType::Relu) {
40     return [MPSCNNNeuronOp relu];
41   } else if (type == NeuronType::Sigmoid) {
42     return [MPSCNNNeuronOp sigmoid];
43   } else if (type == NeuronType::Tanh) {
44     return [MPSCNNNeuronOp tanh];
45   } else if (type == NeuronType::HardSigmoid) {
46     return [MPSCNNNeuronOp hardSigmoid];
47   } else {
48     return nil;
49   }
50 }
51 
52 API_AVAILABLE(ios(11.3), macos(10.13), macCatalyst(13.0))
neuronDescriptor(NeuronType type)53 static inline MPSNNNeuronDescriptor* neuronDescriptor(NeuronType type) {
54   if (type == NeuronType::Relu) {
55     return [MPSCNNNeuronOpDescriptor reluDescriptor];
56   } else if (type == NeuronType::Sigmoid) {
57     return [MPSCNNNeuronOpDescriptor sigmoidDescriptor];
58   } else if (type == NeuronType::Tanh) {
59     return [MPSCNNNeuronOpDescriptor tanhDescriptor];
60   } else if (type == NeuronType::HardSigmoid) {
61     return [MPSCNNNeuronOpDescriptor hardSigmoidDescriptor];
62   } else {
63     return [MPSNNNeuronDescriptor cnnNeuronDescriptorWithType:MPSCNNNeuronTypeNone];
64   }
65 }
66 
67 } // namespace at::native::metal
68 
69 #endif /* MetalNeuronType_h */
70