diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c index 4c537cf255..fcc64118b6 100644 --- a/libavfilter/vf_dnn_detect.c +++ b/libavfilter/vf_dnn_detect.c @@ -35,7 +35,8 @@ typedef enum { DDMT_SSD, DDMT_YOLOV1V2, - DDMT_YOLOV3 + DDMT_YOLOV3, + DDMT_YOLOV4 } DNNDetectionModelType; typedef struct DnnDetectContext { @@ -75,6 +76,7 @@ static const AVOption dnn_detect_options[] = { { "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" }, { "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" }, { "yolov3", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV3 }, 0, 0, FLAGS, "model_type" }, + { "yolov4", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV4 }, 0, 0, FLAGS, "model_type" }, { "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, { "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, { "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS }, @@ -84,6 +86,14 @@ static const AVOption dnn_detect_options[] = { AVFILTER_DEFINE_CLASS(dnn_detect); +static inline float sigmoid(float x) { + return 1.f / (1.f + exp(-x)); +} + +static inline float linear(float x) { + return x; +} + static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data) { float max_prob = 0; @@ -147,6 +157,8 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out float *output_data = output[output_index].data; float *anchors = ctx->anchors; AVDetectionBBox *bbox; + float (*post_process_raw_data)(float x); + int is_NHWC = 0; if (ctx->model_type == DDMT_YOLOV1V2) { cell_w = ctx->cell_w; @@ -154,13 +166,30 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out scale_w = cell_w; scale_h = cell_h; } else { - cell_w = output[output_index].width; - cell_h = output[output_index].height; + if (output[output_index].height != output[output_index].width && + output[output_index].height == output[output_index].channels) { + is_NHWC = 1; + cell_w = output[output_index].height; + cell_h = output[output_index].channels; + } else { + cell_w = output[output_index].width; + cell_h = output[output_index].height; + } scale_w = ctx->scale_width; scale_h = ctx->scale_height; } box_size = nb_classes + 5; + switch (ctx->model_type) { + case DDMT_YOLOV1V2: + case DDMT_YOLOV3: + post_process_raw_data = linear; + break; + case DDMT_YOLOV4: + post_process_raw_data = sigmoid; + break; + } + if (!cell_h || !cell_w) { av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n"); return AVERROR(EINVAL); @@ -198,19 +227,36 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out float *detection_boxes_data; int label_id; - detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h; - conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]; + if (is_NHWC) { + detection_boxes_data = output_data + + ((cy * cell_w + cx) * detection_boxes + box_id) * box_size; + conf = post_process_raw_data(detection_boxes_data[4]); + } else { + detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h; + conf = post_process_raw_data( + detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]); + } if (conf < conf_threshold) { continue; } - x = detection_boxes_data[cy * cell_w + cx]; - y = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]; - w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h]; - h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h]; - label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h, - detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h); - conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]; + if (is_NHWC) { + x = post_process_raw_data(detection_boxes_data[0]); + y = post_process_raw_data(detection_boxes_data[1]); + w = detection_boxes_data[2]; + h = detection_boxes_data[3]; + label_id = dnn_detect_get_label_id(ctx->nb_classes, 1, detection_boxes_data + 5); + conf = conf * post_process_raw_data(detection_boxes_data[label_id + 5]); + } else { + x = post_process_raw_data(detection_boxes_data[cy * cell_w + cx]); + y = post_process_raw_data(detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]); + w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h]; + h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h]; + label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h, + detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h); + conf = conf * post_process_raw_data( + detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]); + } bbox = av_mallocz(sizeof(*bbox)); if (!bbox) @@ -410,6 +456,7 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, int nb_outpu if (ret < 0) return ret; case DDMT_YOLOV3: + case DDMT_YOLOV4: ret = dnn_detect_post_proc_yolov3(frame, output, filter_ctx, nb_outputs); if (ret < 0) return ret;