File: softmax_focal_loss_op.cu

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites:
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (256 lines) | stat: -rw-r--r-- 8,349 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
/**
 * Copyright (c) 2016-present, Facebook, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <cfloat>

#include "caffe2/core/context_gpu.h"
#include "modules/detectron/softmax_focal_loss_op.h"

namespace caffe2 {

namespace {

__global__ void SpatialSoftmaxKernel(const int N, const int A,
    const int H, const int W, const float* Xdata, float* Pdata,
    const int num_classes) {
  CUDA_1D_KERNEL_LOOP(index, N * A * H * W) {
    int D = num_classes * A;
    int x = index % W;
    int y = (index / W) % H;
    int a = (index / (W * H)) % A;
    int i = index / W / H / A;

    // Subtract max on each cell for numerical reasons
    float max_val = -FLT_MAX;
    for(int c = a * num_classes; c < (a + 1) * num_classes; ++c) {
      int idx = i * (H * W * D) +  c * (H * W) + y * W + x;
      max_val = max(max_val, Xdata[idx]);
    }
    // Exponentiate
    float expsum = 0.0f;
    for(int c = a * num_classes; c < (a + 1) * num_classes; ++c) {
      int idx = i * (H * W * D) + c * (H * W) + y * W + x;
      float expx = exp(Xdata[idx] - max_val);
      Pdata[idx] = expx;
      expsum += expx;
    }
    // Normalize
    for(int c = a * num_classes; c < (a + 1) * num_classes; ++c) {
      int idx = i * (H * W * D) + c * (H * W) + y * W + x;
      Pdata[idx] /= expsum;
    }
  }
}


__global__ void SoftmaxFocalLossKernel(
    const int N, const int A, const int H, const int W,
    const float* Pdata, const int* targets, float* losses,
    const float* weight_pos, const float gamma, const float alpha,
    const int num_classes) {
  CUDA_1D_KERNEL_LOOP(i, N * A * H * W) {
    int D = A * num_classes;
    int x = i % W;
    int y = (i / W) % H;
    int a = (i / (W * H)) % A;
    int n = i / (W * H * A);
    const int label = static_cast<int>(targets[i]);

    float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
    float z = (label == 0) * (1 - alpha) / Np +
              (label >= 1) * alpha / Np;

    losses[i] = 0.0;
    if (label >= 0) {
      int offset = a * num_classes;
      int idx = n * (H * W * D) + (offset + label) * (H * W) + y * W + x;
      losses[i] =
          -(pow(1.0f - Pdata[idx], gamma) *
          log(c10::cuda::compat::max(Pdata[idx], FLT_MIN))) * z;
    }
  }
}


__global__ void SoftmaxFocalLossGradientWeightKernel(
    const int N, const int A, const int H, const int W,
    const float* Pdata, const int* targets, float* buff,
    const float* weight_pos, const float gamma, const float alpha,
    const int num_classes) {
  CUDA_1D_KERNEL_LOOP(i, N * A * H * W) {
    int D = A * num_classes;
    int x = i % W;
    int y = (i / W) % H;
    int a = (i / (W * H)) % A;
    int n = i / (W * H * A);
    const int label = static_cast<int>(targets[i]);
    float Np = c10::cuda::compat::max(weight_pos[0], static_cast<float>(1.0));
    float z =  (label == 0) * (1 - alpha) / Np +
               (label >= 1) * alpha / Np;

    buff[i] = 0.0;
    if (label >= 0) {
      int offset = a * num_classes;
      int idx = n * (H * W * D) + (offset + label) * (H * W) + y * W + x;
      float onemp = 1. - Pdata[idx];
      float p = Pdata[idx];
      buff[i] =
          (-pow(onemp, gamma) +
          gamma * pow(onemp, gamma - 1) * p * log(c10::cuda::compat::max(p, FLT_MIN))) * z;
    }
  }
}


__global__ void SoftmaxFocalLossGradientKernel(
    const int N, const int D, const int H, const int W,
    const float* Pdata, const int* targets, const float* buff,
    const float* d_loss_data, float* dX, const int num_classes) {
  CUDA_1D_KERNEL_LOOP(i, N * D * H * W) {
    int A = D / num_classes;
    int x = i % W;
    int y = (i / W) % H;
    int d = (i / (W * H)) % D;
    int a = d / num_classes;
    int c = d % num_classes;
    int n = i / (W * H * D);
    float d_loss = *d_loss_data;

    int ind = n * (H * W * A) + a * (H * W) + y * W + x;
    const int label = static_cast<int>(targets[ind]);

    float c1 = (label >= 0) * 1.0;
    float c2 = (label == c) * 1.0;
    dX[i] = 0.0;
    dX[i] = c1 * d_loss * buff[ind] * (c2 - Pdata[i]);
  }
}

} // namespace


template <>
bool SoftmaxFocalLossOp<float, CUDAContext>::RunOnDevice() {
  auto& X = Input(0);         // Logits
  auto& T = Input(1);         // Labels
  auto& wp = Input(2);        // num of foreground
   // average loss as output
          // softmax probability, going to be re-used in gradient

  int N = X.dim32(0);
  int D = X.dim32(1);
  int H = X.dim32(2);
  int W = X.dim32(3);
  int A = D / num_classes_;

  ReinitializeTensor(&losses_, {N * A * H * W}, at::dtype<float>().device(CUDA));
  auto* P = Output(1, {N * D * H * W}, at::dtype<float>());
  auto* avg_loss = Output(0, vector<int64_t>(), at::dtype<float>());
  math::Set<float, CUDAContext>(
      avg_loss->size(), 0.f, avg_loss->mutable_data<float>(), &context_);
  math::Set<float, CUDAContext>(
      P->size(), 0.f, P->mutable_data<float>(), &context_);
  math::Set<float, CUDAContext>(
      losses_.size(), 0.f, losses_.mutable_data<float>(), &context_);
  TORCH_DCHECK_EQ(X.ndim(), 4);

  const float* Xdata = X.data<float>();
  const float* Wdata = wp.data<float>();


  // Spatial Softmax Kernel
  SpatialSoftmaxKernel
      <<<CAFFE_GET_BLOCKS(N * A * H * W), CAFFE_CUDA_NUM_THREADS,
         0, context_.cuda_stream()>>>(
    N, A, H, W, Xdata, P->mutable_data<float>(), num_classes_);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  // Compute loss for each x,y location
  const int* Tdata = T.data<int>();
  SoftmaxFocalLossKernel
  <<<CAFFE_GET_BLOCKS(N * A * H * W), CAFFE_CUDA_NUM_THREADS,
      0, context_.cuda_stream()>>>(
    N, A, H, W, P->data<float>(), Tdata, losses_.mutable_data<float>(),
    Wdata, gamma_, alpha_, num_classes_);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  // sum the losses
  float* avg_loss_data = avg_loss->mutable_data<float>();
  math::Sum<float, CUDAContext>(
      losses_.size(), losses_.data<float>(), avg_loss_data, &context_);
  math::Scale<float, float, CUDAContext>(
      1, scale_, avg_loss_data, avg_loss_data, &context_);

  return true;
}


template<>
bool SoftmaxFocalLossGradientOp<float, CUDAContext>::RunOnDevice() {
  auto& X = Input(0);    // Logits
  auto& T = Input(1);    // Label
  auto& wp = Input(2);   // num of foreground example
  auto& P = Input(3);    // Softmax Probability
  auto& d_avg_loss = Input(4);


  int N = X.dim32(0);
  int D = X.dim32(1);
  int H = X.dim32(2);
  int W = X.dim32(3);
  int A = D / num_classes_;

  ReinitializeTensor(&buff_, {N * A * H * W}, at::dtype<float>().device(CUDA));

  auto* dX = Output(0, X.sizes(), at::dtype<float>()); // gradient wrt logits

  const float* Xdata = X.data<float>();
  const int* Tdata = T.data<int>();
  const float* Pdata = P.data<float>();
  const float* Wdata = wp.data<float>();


  // Compute the weight for gradients
  SoftmaxFocalLossGradientWeightKernel
      <<<CAFFE_GET_BLOCKS(N * A * H * W), CAFFE_CUDA_NUM_THREADS,
         0, context_.cuda_stream()>>>(
    N, A, H, W, Pdata, Tdata, buff_.mutable_data<float>(),
    Wdata, gamma_, alpha_, num_classes_);
  C10_CUDA_KERNEL_LAUNCH_CHECK();

  // Compute the gradient with the weights
  const float* Bdata = buff_.data<float>();
  SoftmaxFocalLossGradientKernel
      <<<CAFFE_GET_BLOCKS(N * D * H * W), CAFFE_CUDA_NUM_THREADS,
         0, context_.cuda_stream()>>>(
    N, D, H, W, Pdata, Tdata, Bdata, d_avg_loss.data<float>(),
    dX->mutable_data<float>(), num_classes_);
  C10_CUDA_KERNEL_LAUNCH_CHECK();
  math::Scale<float, float, CUDAContext>(
      dX->size(),
      scale_,
      dX->data<float>(),
      dX->mutable_data<float>(),
      &context_);
  return true;
}


REGISTER_CUDA_OPERATOR(SoftmaxFocalLoss,
                       SoftmaxFocalLossOp<float, CUDAContext>);
REGISTER_CUDA_OPERATOR(SoftmaxFocalLossGradient,
                       SoftmaxFocalLossGradientOp<float, CUDAContext>);
} // namespace caffe2