This repository was archived by the owner on Jul 1, 2023. It is now read-only.
File tree Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Original file line number Diff line number Diff line change 1616import TensorFlow
1717#endif
1818
19+ /// Computes the mean squared error between logits and labels.
20+ ///
21+ /// - Parameters:
22+ /// - logits: One-hot encoded outputs from a neural network.
23+ /// - labels: One-hot encoded values that correspond to the correct output.
1924@differentiable
2025public func meanSquaredError< Scalar: TensorFlowFloatingPoint > (
2126 predicted: Tensor < Scalar > , expected: Tensor < Scalar > ) -> Tensor < Scalar > {
2227 return ( expected - predicted) . squared ( ) . mean ( )
2328}
2429
30+ /// Computes the softmax cross entropy (categorical cross entropy) between logits and labels.
31+ ///
32+ /// - Parameters:
33+ /// - logits: One-hot encoded outputs from a neural network.
34+ /// - labels: One-hot encoded values that correspond to the correct output.
2535@differentiable
2636public func softmaxCrossEntropy< Scalar: TensorFlowFloatingPoint > (
2737 logits: Tensor < Scalar > , labels: Tensor < Scalar > ) -> Tensor < Scalar > {
2838 return - ( labels * logSoftmax( logits) ) . mean ( alongAxes: 0 ) . sum ( )
2939}
40+
41+ /// Computes the sigmoid cross entropy (binary cross entropy) between logits and labels.
42+ ///
43+ /// - Parameters:
44+ /// - logits: Single continuous values from `0` to `1`.
45+ /// - labels: Integer values that correspond to the correct output.
46+ @differentiable
47+ public func sigmoidCrossEntropy< Scalar: TensorFlowFloatingPoint > (
48+ logits: Tensor < Scalar > , labels: Tensor < Scalar >
49+ ) -> Tensor < Scalar > {
50+ let loss = labels * log( logits) +
51+ ( Tensor < Scalar > ( 1 ) - labels) * log( Tensor < Scalar > ( 1 ) - logits)
52+ return - loss. mean ( alongAxes: 0 ) . sum ( )
53+ }
You can’t perform that action at this time.
0 commit comments