Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add adjust_contrast, adjust_hue, combined_non_max_suppression, crop_and_resize image oprs #1157

Merged
merged 2 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 119 additions & 12 deletions src/TensorFlowNET.Core/APIs/tf.image.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ You may obtain a copy of the License at
limitations under the License.
******************************************************************************/

using OneOf.Types;
using System;
using System.Buffers.Text;
using Tensorflow.Contexts;
using static Tensorflow.Binding;

namespace Tensorflow
Expand Down Expand Up @@ -162,17 +166,108 @@ public Tensor ssim_multiscale(Tensor img1, Tensor img2, float max_val, float[] p
public Tensor sobel_edges(Tensor image)
=> image_ops_impl.sobel_edges(image);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);
/// <summary>
/// Adjust contrast of RGB or grayscale images.
/// </summary>
/// <param name="images">Images to adjust. At least 3-D.</param>
/// <param name="contrast_factor"></param>
/// <param name="name">A float multiplier for adjusting contrast.</param>
/// <returns>The contrast-adjusted image or images.</returns>
public Tensor adjust_contrast(Tensor images, float contrast_factor, string name = null)
=> gen_image_ops.adjust_contrastv2(images, contrast_factor, name);

/// <summary>
/// Adjust hue of RGB images.
/// </summary>
/// <param name="images">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="delta">float. How much to add to the hue channel.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
/// <exception cref="ValueError">if `delta` is not in the interval of `[-1, 1]`.</exception>
public Tensor adjust_hue(Tensor images, float delta, string name = null)
{
if (tf.Context.executing_eagerly())
{
if (delta < -1f || delta > 1f)
throw new ValueError("delta must be in the interval [-1, 1]");
}
return gen_image_ops.adjust_hue(images, delta, name: name);
}

/// <summary>
/// Adjust saturation of RGB images.
/// </summary>
/// <param name="image">RGB image or images. The size of the last dimension must be 3.</param>
/// <param name="saturation_factor">float. Factor to multiply the saturation by.</param>
/// <param name="name">A name for this operation (optional).</param>
/// <returns>Adjusted image(s), same shape and DType as `image`.</returns>
public Tensor adjust_saturation(Tensor image, float saturation_factor, string name = null)
=> gen_image_ops.adjust_saturation(image, saturation_factor, name);

/// <summary>
/// Greedily selects a subset of bounding boxes in descending order of score.
/// </summary>
/// <param name="boxes">
/// A 4-D float `Tensor` of shape `[batch_size, num_boxes, q, 4]`. If `q`
/// is 1 then same boxes are used for all classes otherwise, if `q` is equal
/// to number of classes, class-specific boxes are used.
/// </param>
/// <param name="scores">
/// A 3-D float `Tensor` of shape `[batch_size, num_boxes, num_classes]`
/// representing a single score corresponding to each box(each row of boxes).
/// </param>
/// <param name="max_output_size_per_class">
/// A scalar integer `Tensor` representing the
/// maximum number of boxes to be selected by non-max suppression per class
/// </param>
/// <param name="max_total_size">
/// A int32 scalar representing maximum number of boxes retained
/// over all classes.Note that setting this value to a large number may
/// result in OOM error depending on the system workload.
/// </param>
/// <param name="iou_threshold">
/// A float representing the threshold for deciding whether boxes
/// overlap too much with respect to IOU.
/// </param>
/// <param name="score_threshold">
/// A float representing the threshold for deciding when to
/// remove boxes based on score.
/// </param>
/// <param name="pad_per_class">
/// If false, the output nmsed boxes, scores and classes are
/// padded/clipped to `max_total_size`. If true, the output nmsed boxes, scores and classes are padded to be of length `max_size_per_class`*`num_classes`,
/// unless it exceeds `max_total_size` in which case it is clipped to `max_total_size`. Defaults to false.
/// </param>
/// <param name="clip_boxes">
/// If true, the coordinates of output nmsed boxes will be clipped
/// to[0, 1]. If false, output the box coordinates as it is. Defaults to true.
/// </param>
/// <returns>
/// 'nmsed_boxes': A [batch_size, max_detections, 4] float32 tensor containing the non-max suppressed boxes.
/// 'nmsed_scores': A [batch_size, max_detections] float32 tensor containing the scores for the boxes.
/// 'nmsed_classes': A [batch_size, max_detections] float32 tensor containing the class for boxes.
/// 'valid_detections': A [batch_size] int32 tensor indicating the number of
/// valid detections per batch item. Only the top valid_detections[i] entries
/// in nms_boxes[i], nms_scores[i] and nms_class[i] are valid. The rest of the
/// entries are zero paddings.
/// </returns>
public (Tensor, Tensor, Tensor, Tensor) combined_non_max_suppression(
Tensor boxes,
Tensor scores,
int max_output_size_per_class,
int max_total_size,
float iou_threshold,
float score_threshold,
bool pad_per_class = false,
bool clip_boxes = true)
{
var iou_threshold_t = ops.convert_to_tensor(iou_threshold, TF_DataType.TF_FLOAT, name: "iou_threshold");
var score_threshold_t = ops.convert_to_tensor(score_threshold, TF_DataType.TF_FLOAT, name: "score_threshold");
var max_total_size_t = ops.convert_to_tensor(max_total_size);
var max_output_size_per_class_t = ops.convert_to_tensor(max_output_size_per_class);
return gen_image_ops.combined_non_max_suppression(boxes, scores, max_output_size_per_class_t, max_total_size_t,
iou_threshold_t, score_threshold_t, pad_per_class, clip_boxes);
}

/// <summary>
/// Extracts crops from the input image tensor and resizes them using bilinear sampling or nearest neighbor sampling (possibly with aspect ratio change) to a common output size specified by crop_size. This is more general than the crop_to_bounding_box op which extracts a fixed size slice from the input image and does not allow resizing or aspect ratio change.
Expand All @@ -187,7 +282,19 @@ public Tensor decode_jpeg(Tensor contents,
/// <param name="name">A name for the operation (optional).</param>
/// <returns>A 4-D tensor of shape [num_boxes, crop_height, crop_width, depth].</returns>
public Tensor crop_and_resize(Tensor image, Tensor boxes, Tensor box_ind, Tensor crop_size, string method = "bilinear", float extrapolation_value = 0f, string name = null) =>
image_ops_impl.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);
gen_image_ops.crop_and_resize(image, boxes, box_ind, crop_size, method, extrapolation_value, name);

public Tensor decode_jpeg(Tensor contents,
int channels = 0,
int ratio = 1,
bool fancy_upscaling = true,
bool try_recover_truncated = false,
int acceptable_fraction = 1,
string dct_method = "",
string name = null)
=> gen_image_ops.decode_jpeg(contents, channels: channels, ratio: ratio,
fancy_upscaling: fancy_upscaling, try_recover_truncated: try_recover_truncated,
acceptable_fraction: acceptable_fraction, dct_method: dct_method);

public Tensor extract_glimpse(Tensor input, Tensor size, Tensor offsets, bool centered = true, bool normalized = true,
bool uniform_noise = true, string name = null)
Expand Down
Loading
Loading