Skip to content

Commit dfaac8d

Browse files
committed
Nicely handle pre-existing gradients
Signed-off-by: Ryan Nett <[email protected]>
1 parent 89690ab commit dfaac8d

File tree

2 files changed

+71
-31
lines changed

2 files changed

+71
-31
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorFlow.java

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@
2323
import static org.tensorflow.internal.c_api.global.tensorflow.TF_Version;
2424

2525
import com.google.protobuf.InvalidProtocolBufferException;
26+
import java.lang.reflect.Field;
27+
import java.lang.reflect.Modifier;
2628
import java.util.Collections;
2729
import java.util.IdentityHashMap;
2830
import java.util.Set;
2931
import java.util.stream.Collectors;
32+
import org.bytedeco.javacpp.PointerPointer;
3033
import org.bytedeco.javacpp.PointerScope;
3134
import org.tensorflow.exceptions.TensorFlowException;
3235
import org.tensorflow.internal.c_api.GradFunc;
3336
import org.tensorflow.internal.c_api.GradOpRegistry;
37+
import org.tensorflow.internal.c_api.NativeStatus;
3438
import org.tensorflow.internal.c_api.TF_Buffer;
3539
import org.tensorflow.internal.c_api.TF_Library;
3640
import org.tensorflow.internal.c_api.TF_Status;
@@ -150,7 +154,16 @@ private TensorFlow() {}
150154
}
151155

152156
// to keep them from getting GC'd
153-
private static Set<GradFunc> gradientFuncs = Collections.newSetFromMap(new IdentityHashMap<>());
157+
private static final Set<GradFunc> gradientFuncs =
158+
Collections.newSetFromMap(new IdentityHashMap<>());
159+
160+
private static synchronized boolean hasGradient(String opType) {
161+
try (PointerScope scope = new PointerScope()) {
162+
NativeStatus status =
163+
GradOpRegistry.Global().Lookup(opType, new GradFunc(new PointerPointer<>(1)));
164+
return status.ok();
165+
}
166+
}
154167

155168
/**
156169
* Register a custom gradient function for ops of {@code opType} type.
@@ -161,12 +174,18 @@ private TensorFlow() {}
161174
* @param opType the type of op to register the gradient for. Should usually be an {@code OP_NAME}
162175
* field, i.e. {@link Add#OP_NAME}.
163176
* @param gradient the gradient function to use
177+
* @return {@code true} if the gradient was registered, {@code false} if there was already a
178+
* gradient registered for this op
164179
*/
165-
public static synchronized void registerCustomGradient(
180+
public static synchronized boolean registerCustomGradient(
166181
String opType, RawCustomGradient gradient) {
182+
if (hasGradient(opType)) {
183+
return false;
184+
}
167185
GradFunc g = new RawGradientAdapter(gradient);
168186
GradOpRegistry.Global().Register(opType, g);
169187
gradientFuncs.add(g);
188+
return true;
170189
}
171190

172191
/**
@@ -175,13 +194,29 @@ public static synchronized void registerCustomGradient(
175194
*
176195
* @param opClass the class of op to register the gradient for.
177196
* @param gradient the gradient function to use
197+
* @return {@code true} if the gradient was registered, {@code false} if there was already a
198+
* gradient registered for this op
199+
* @throws IllegalArgumentException if {@code opClass} does not have a static {@code OP_NAME}
200+
* field.
178201
*/
179-
public static synchronized <T extends RawOp> void registerCustomGradient(
202+
public static synchronized <T extends RawOp> boolean registerCustomGradient(
180203
Class<T> opClass, CustomGradient<T> gradient) {
181204
try {
182-
String opName = (String) opClass.getDeclaredField("OP_NAME").get(null);
205+
Field nameField = opClass.getDeclaredField("OP_NAME");
206+
207+
if (!Modifier.isStatic(nameField.getModifiers())) {
208+
throw new IllegalArgumentException(
209+
"Class " + opClass + " has an OP_NAME field, but it is not static.");
210+
}
211+
212+
String opType = (String) nameField.get(null);
213+
214+
if (hasGradient(opType)) {
215+
return false;
216+
}
217+
183218
GradFunc g = new TypedGradientAdapter<>(gradient, opClass);
184-
GradOpRegistry.Global().Register(opName, g);
219+
GradOpRegistry.Global().Register(opType, g);
185220
gradientFuncs.add(g);
186221
} catch (IllegalAccessException | NoSuchFieldException e) {
187222
throw new IllegalArgumentException(
@@ -190,5 +225,6 @@ public static synchronized <T extends RawOp> void registerCustomGradient(
190225
+ ", ensure it is a generated op class",
191226
e);
192227
}
228+
return true;
193229
}
194230
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,54 +17,58 @@
1717
package org.tensorflow;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertFalse;
2021
import static org.junit.jupiter.api.Assertions.assertNotNull;
22+
import static org.junit.jupiter.api.Assertions.assertTrue;
2123

2224
import java.util.Arrays;
2325
import org.junit.jupiter.api.Test;
2426
import org.tensorflow.ndarray.index.Indices;
2527
import org.tensorflow.op.Ops;
26-
import org.tensorflow.op.core.Concat;
28+
import org.tensorflow.op.dtypes.Cast;
29+
import org.tensorflow.op.nn.NthElement;
2730
import org.tensorflow.proto.framework.DataType;
2831
import org.tensorflow.types.TFloat32;
2932

3033
public class CustomGradientTest {
3134

3235
@Test
33-
public void testCustomConcat() {
36+
public void testAlreadyExisting() {
37+
assertFalse(
38+
TensorFlow.registerCustomGradient(
39+
Cast.class,
40+
(tf, op, gradInputs) -> {
41+
Operand<?> out = gradInputs.get(0);
42+
Operand<?> a = tf.stridedSlice(out, Indices.slice(0, 1));
43+
Operand<?> b = tf.stridedSlice(out, Indices.slice(1, 2));
44+
return Arrays.asList(a, b, tf.constant(0f));
45+
}));
46+
}
47+
48+
@Test
49+
public void testCustomGradient() {
3450
try (Graph g = new Graph();
3551
Session s = new Session(g)) {
36-
37-
TensorFlow.registerCustomGradient(
38-
Concat.class,
39-
(tf, op, gradInputs) -> {
40-
Operand<?> out = gradInputs.get(0);
41-
Operand<?> a = tf.stridedSlice(out, Indices.slice(0, 1));
42-
Operand<?> b = tf.stridedSlice(out, Indices.slice(1, 2));
43-
return Arrays.asList(a, b, tf.constant(0f));
44-
});
52+
assertTrue(
53+
TensorFlow.registerCustomGradient(
54+
NthElement.class,
55+
(tf, op, gradInputs) -> Arrays.asList(tf.constant(0f), tf.constant(0f))));
4556

4657
Ops tf = Ops.create(g);
58+
Output<TFloat32> x = tf.placeholder(TFloat32.class).output();
59+
Output<TFloat32> y = tf.nn.nthElement(x, tf.constant(2)).asOutput();
4760

48-
Output<TFloat32> x1 = tf.placeholder(TFloat32.class).output();
49-
Output<TFloat32> x2 = tf.placeholder(TFloat32.class).output();
50-
Operand<TFloat32> x = tf.concat(Arrays.asList(x1, x2), tf.constant(0));
51-
Output<TFloat32> y = tf.math.square(x).y();
52-
53-
Output<?>[] grads0 = g.addGradients(y, toArray(x1, x2));
61+
Output<?>[] grads0 = g.addGradients(y, toArray(x));
5462
assertNotNull(grads0);
55-
assertEquals(2, grads0.length);
63+
assertEquals(1, grads0.length);
5664
assertEquals(DataType.DT_FLOAT, grads0[0].dataType());
57-
assertEquals(DataType.DT_FLOAT, grads0[1].dataType());
5865

59-
try (TFloat32 c1 = TFloat32.scalarOf(3.0f);
60-
TFloat32 c2 = TFloat32.scalarOf(2.0f);
66+
try (TFloat32 c1 = TFloat32.vectorOf(3.0f, 2.0f, 1.0f, 0.0f);
6167
AutoCloseableList<Tensor> outputs =
62-
new AutoCloseableList<>(
63-
s.runner().feed(x1, c1).feed(x2, c2).fetch(grads0[0]).fetch(grads0[1]).run())) {
68+
new AutoCloseableList<>(s.runner().feed(x, c1).fetch(grads0[0]).run())) {
6469

65-
assertEquals(2, outputs.size());
66-
assertEquals(6.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
67-
assertEquals(4.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
70+
assertEquals(1, outputs.size());
71+
assertEquals(0.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
6872
}
6973
}
7074
}

0 commit comments

Comments
 (0)