Skip to content

Commit c5717b0

Browse files
committed
Working gradients
Signed-off-by: Ryan Nett <[email protected]>
1 parent ee96979 commit c5717b0

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ public void close() {
122122
}
123123
delete(nativeHandle);
124124
nativeHandle = null;
125+
allGraphs.remove(this);
125126
}
126127
}
127128

@@ -1320,7 +1321,7 @@ private static SaverDef addVariableSaver(Graph graph) {
13201321
*/
13211322
public static Graph findGraphForPointer(NativeGraphPointer pointer) {
13221323
for (Graph g : allGraphs) {
1323-
if (g.nativeHandle.graph().equals(pointer)) {
1324+
if (g.nativeHandle != null && !g.nativeHandle.isNull() && g.nativeHandle.graph().equals(pointer)) {
13241325
return g;
13251326
}
13261327
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperationBuilder.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import static org.tensorflow.internal.c_api.global.tensorflow.TF_FinishOperationLocked;
2323
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewOperation;
2424
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewOperationLocked;
25+
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationName;
2526
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBool;
2627
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrBoolList;
2728
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetAttrFloat;
@@ -457,7 +458,7 @@ private static TF_Operation finishDangerousGradient(TF_Graph g, TF_OperationDesc
457458
TF_Status status = TF_Status.newStatus();
458459
TF_Operation op = TF_FinishOperationLocked(handle, status);
459460
status.throwExceptionIfNotOK();
460-
// g.name_map().put(TF_OperationName(op), null);
461+
g.name_map().erase(TF_OperationName(op));
461462
return op;
462463
}
463464
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/presets/tensorflow.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ public void init(ClassProperties properties) {
287287
.put(new Info("TF_Graph::refiner", "TF_Graph::mu",
288288
"TF_Graph::sessions", "TF_Graph::delete_requested").skip())
289289
.put(new Info("std::unordered_map<tensorflow::string,tensorflow::Node*>")
290-
.pointerTypes("NameMap").define())
290+
.pointerTypes("NameMap").define().javaText("public native long erase(@StdString BytePointer key);"))
291291
.put(new Info("TF_Function")
292292
.pointerTypes("TF_Function")
293293
.base("org.tensorflow.internal.c_api.AbstractTF_Function"))

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/CustomGradientTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public void addGradientsToGraph() {
3838
Operand<?> out = gradInputs.get(0);
3939
Operand<?> a = tf.stridedSlice(out, Indices.slice(0, 1));
4040
Operand<?> b = tf.stridedSlice(out, Indices.slice(1, 2));
41-
return Arrays.asList(a, b);
41+
return Arrays.asList(a, b, tf.constant(0f));
4242
});
4343
Ops tf = Ops.create(g);
4444

@@ -64,8 +64,8 @@ public void addGradientsToGraph() {
6464
.run())) {
6565

6666
assertEquals(2, outputs.size());
67-
assertEquals(3.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
68-
assertEquals(2.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
67+
assertEquals(6.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f);
68+
assertEquals(4.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f);
6969
}
7070
}
7171
}

0 commit comments

Comments
 (0)