Skip to content

Commit cd43400

Browse files
hcurOceania2018
authored andcommitted
fix rot90_3d
1 parent d33927a commit cd43400

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

src/TensorFlowNET.Core/Operations/control_flow_ops.cs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
/*****************************************************************************
1+
/*****************************************************************************
22
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
@@ -226,6 +226,47 @@ public static Tensor[] tuple(Tensor[] tensors, string name = null, Operation[] c
226226
});
227227
}
228228

229+
internal static Tensor _case_helper(Func<Tensor, Tensor> cond_fn, Tensor[] pred_fn_pairs, Func<Tensor[]> callable_default, bool exclusive, string name,
230+
bool allow_python_preds = false)
231+
{
232+
/*
233+
(Tensor[] predicates, Tensor[] actions) = _case_verify_and_canonicalize_args(
234+
pred_fn_pairs, exclusive, name, allow_python_preds);
235+
return tf_with(ops.name_scope(name, "case", new [] {predicates}), delegate
236+
{
237+
if (callable_default == null)
238+
{
239+
(callable_default, predicates, actions) = _case_create_default_action(
240+
predicates, actions);
241+
}
242+
var fn = callable_default;
243+
});
244+
*/
245+
246+
throw new NotImplementedException("_case_helper");
247+
}
248+
249+
internal static (Func<Tensor[]>, Tensor[], Tensor[]) _case_create_default_action(Tensor[] predicates, Tensor[] actions)
250+
{
251+
throw new NotImplementedException("_case_create_default_action");
252+
}
253+
254+
internal static (Tensor[], Tensor[]) _case_verify_and_canonicalize_args(Tensor[] pred_fn_pairs, bool exclusive, string name, bool allow_python_preds)
255+
{
256+
throw new NotImplementedException("_case_verify_and_canonicalize_args");
257+
}
258+
259+
public static Tensor case_v2(Tensor[] pred_fn_pairs, Func<Tensor[]> callable_default = null, bool exclusive = false, bool strict = false, string name = "case")
260+
=> _case_helper(
261+
cond_fn: (Tensor x) => cond(x),
262+
pred_fn_pairs,
263+
default,
264+
exclusive,
265+
name,
266+
allow_python_preds: false//,
267+
//strict: strict
268+
);
269+
229270
/// <summary>
230271
/// Produces the content of `output_tensor` only after `dependencies`.
231272
///

src/TensorFlowNET.Core/Operations/image_ops_impl.cs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,13 @@ Tensor _rot270() {
265265
return gen_array_ops.reverse(array_ops.transpose(image, new [] {1, 0, 2}), new [] {1});
266266
};
267267

268-
var cases = new [] {new [] {math_ops.equal(k, 1), _rot90()},
269-
new [] {math_ops.equal(k, 2), _rot180()},
270-
new [] {math_ops.equal(k, 3), _rot270()}};
268+
var cases = new [] {math_ops.equal(k, 1), _rot90(),
269+
math_ops.equal(k, 2), _rot180(),
270+
math_ops.equal(k, 3), _rot270()};
271271

272-
// ! control_flow_ops doesn't have an implementation for case yet !
273-
// var result = control_flow_ops.case(cases, default: () => image, exclusive: true, name: name_scope);
274-
// result.set_shape(new [] {null, null, image.shape.dims[2]})
275-
// return result
276-
throw new NotImplementedException();
272+
var result = control_flow_ops.case_v2(cases, callable_default: () => new Tensor[] {image}, exclusive: true, name: name_scope);
273+
result.set_shape(new [] {-1, -1, image.TensorShape.dims[2]});
274+
return result;
277275
}
278276

279277
public static Tensor transpose(Tensor image, string name = null)

0 commit comments

Comments
 (0)