microsoft/vscode-languagedetection

Public

mirrored fromhttps://github.com/microsoft/vscode-languagedetectionAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d9c01ba7f55ea6994ee3edc9e5d98c4b11bbd24a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

lib/index.ts

214lines · modecode

1import { Rank, tensor, Tensor, io, setBackend, env } from '@tensorflow/tfjs-core';
2import { GraphModel, loadGraphModel } from '@tensorflow/tfjs-converter';
3import '@tensorflow/tfjs-backend-cpu';
4
5export interface ModelResult {
6 languageId: string;
7 confidence: number;
8}
9
10class InMemoryIOHandler implements io.IOHandler {
11
12 constructor(private readonly modelJSON: io.ModelJSON,
13 private readonly weights: ArrayBuffer) {
14 }
15
16 async load(): Promise<io.ModelArtifacts> {
17 // We do not allow both modelTopology and weightsManifest to be missing.
18 const modelTopology = this.modelJSON.modelTopology;
19 const weightsManifest = this.modelJSON.weightsManifest;
20 if (modelTopology === null && weightsManifest === null) {
21 throw new Error(
22 `The model contains neither model topology or manifest for weights.`);
23 }
24
25 return this.getModelArtifactsForJSON(
26 this.modelJSON, (weightsManifest) => this.loadWeights(weightsManifest));
27 }
28
29 private async getModelArtifactsForJSON(
30 modelJSON: io.ModelJSON,
31 loadWeights: (weightsManifest: io.WeightsManifestConfig) => Promise<[
32 /* weightSpecs */ io.WeightsManifestEntry[], /* weightData */ ArrayBuffer
33 ]>): Promise<io.ModelArtifacts> {
34 const modelArtifacts: io.ModelArtifacts = {
35 modelTopology: modelJSON.modelTopology,
36 format: modelJSON.format,
37 generatedBy: modelJSON.generatedBy,
38 convertedBy: modelJSON.convertedBy
39 };
40
41 if (modelJSON.trainingConfig !== null) {
42 modelArtifacts.trainingConfig = modelJSON.trainingConfig;
43 }
44 if (modelJSON.weightsManifest !== null) {
45 const [weightSpecs, weightData] =
46 await loadWeights(modelJSON.weightsManifest);
47 modelArtifacts.weightSpecs = weightSpecs;
48 modelArtifacts.weightData = weightData;
49 }
50 if (modelJSON.signature !== null) {
51 modelArtifacts.signature = modelJSON.signature;
52 }
53 if (modelJSON.userDefinedMetadata !== null) {
54 modelArtifacts.userDefinedMetadata = modelJSON.userDefinedMetadata;
55 }
56 if (modelJSON.modelInitializer !== null) {
57 modelArtifacts.modelInitializer = modelJSON.modelInitializer;
58 }
59
60 return modelArtifacts;
61 }
62
63 private async loadWeights(weightsManifest: io.WeightsManifestConfig): Promise<[io.WeightsManifestEntry[], ArrayBuffer]> {
64 const weightSpecs = [];
65 for (const entry of weightsManifest) {
66 weightSpecs.push(...entry.weights);
67 }
68
69 return [weightSpecs, this.weights];
70 }
71}
72
73export interface ModelOperationsOptions {
74 modelJsonLoaderFunc?: () => Promise<{ [key:string]: any }>;
75 weightsLoaderFunc?: () => Promise<ArrayBuffer>;
76 minContentSize?: number;
77 maxContentSize?: number;
78 normalizeNewline?: boolean;
79}
80
81export class ModelOperations {
82 private static DEFAULT_MAX_CONTENT_SIZE = 100000;
83 private static DEFAULT_MIN_CONTENT_SIZE = 20;
84
85 private static NODE_MODEL_JSON_FUNC: () => Promise<{ [key:string]: any }> = async () => {
86 const fs = await import('fs');
87 const path = await import('path');
88
89 return new Promise<any>((resolve, reject) => {
90 fs.readFile(path.join(__dirname, '..', '..', 'model', 'model.json'), (err, data) => {
91 if(err) {
92 reject(err);
93 return;
94 }
95 resolve(JSON.parse(data.toString()));
96 });
97 });
98 }
99
100 private static NODE_WEIGHTS_FUNC: () => Promise<ArrayBuffer> = async () => {
101 const fs = await import('fs');
102 const path = await import('path');
103
104 return new Promise<ArrayBuffer>((resolve, reject) => {
105 fs.readFile(path.join(__dirname, '..', '..', 'model', 'group1-shard1of1.bin'), (err, data) => {
106 if(err) {
107 reject(err);
108 return;
109 }
110 resolve(data.buffer);
111 });
112 });
113 }
114
115 private _model: GraphModel | undefined;
116 private _modelJson: io.ModelJSON | undefined;
117 private _weights: ArrayBuffer | undefined;
118 private readonly _minContentSize: number;
119 private readonly _maxContentSize: number;
120 private readonly _modelJsonLoaderFunc: () => Promise<{ [key:string]: any }>;
121 private readonly _weightsLoaderFunc: () => Promise<ArrayBuffer>;
122 private readonly _normalizeNewline: boolean;
123
124 constructor(modelOptions?: ModelOperationsOptions) {
125 this._modelJsonLoaderFunc = modelOptions?.modelJsonLoaderFunc ?? ModelOperations.NODE_MODEL_JSON_FUNC;
126 this._weightsLoaderFunc = modelOptions?.weightsLoaderFunc ?? ModelOperations.NODE_WEIGHTS_FUNC;
127 this._minContentSize = modelOptions?.minContentSize ?? ModelOperations.DEFAULT_MIN_CONTENT_SIZE;
128 this._maxContentSize = modelOptions?.maxContentSize ?? ModelOperations.DEFAULT_MAX_CONTENT_SIZE;
129 this._normalizeNewline = modelOptions?.normalizeNewline ?? true;
130 }
131
132 private async getModelJSON(): Promise<io.ModelJSON> {
133 if (this._modelJson) {
134 return this._modelJson;
135 }
136
137 // TODO: validate model.json
138 this._modelJson = await this._modelJsonLoaderFunc() as io.ModelJSON;
139 return this._modelJson;
140 }
141
142 private async getWeights() {
143 if (this._weights) {
144 return this._weights;
145 }
146
147 // TODO: validate weights
148 this._weights = await this._weightsLoaderFunc();
149 return this._weights;
150 }
151
152 private async loadModel() {
153 if (this._model) {
154 return;
155 }
156
157 env().set('IS_NODE', false);
158 if(!(await setBackend('cpu'))) {
159 throw new Error('Unable to set backend to CPU.');
160 }
161
162 const resolvedModelJSON = await this.getModelJSON();
163 const resolvedWeights = await this.getWeights();
164 this._model = await loadGraphModel(new InMemoryIOHandler(resolvedModelJSON, resolvedWeights));
165 }
166
167 public async runModel(content: string): Promise<Array<ModelResult>> {
168 if (!content || content.length < this._minContentSize) {
169 return [];
170 }
171
172 await this.loadModel();
173
174 // larger files cause a "RangeError: Maximum call stack size exceeded" in tfjs.
175 // So grab the first X characters as that should be good enough for guessing.
176 if (content.length >= this._maxContentSize) {
177 content = content.substring(0, this._maxContentSize);
178 }
179
180 if (this._normalizeNewline) {
181 content = content.replace(/\r\n/g, '\n');
182 }
183
184 // call out to the model
185 const predicted = await this._model!.executeAsync(tensor([content]));
186 const probabilitiesTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[0]! : predicted;
187 const languageTensor: Tensor<Rank> = Array.isArray(predicted) ? predicted[1]! : predicted;
188 const probabilities = probabilitiesTensor.dataSync() as Float32Array;
189 const langs: Array<string> = languageTensor.dataSync() as any;
190
191 const objs: Array<ModelResult> = [];
192 for (let i = 0; i < langs.length; i++) {
193 objs.push({
194 languageId: langs[i],
195 confidence: probabilities[i],
196 });
197 }
198
199 let maxIndex = 0;
200 for (let i = 0; i < probabilities.length; i++) {
201 if (probabilities[i] > probabilities[maxIndex]) {
202 maxIndex = i;
203 }
204 }
205
206 return objs.sort((a, b) => {
207 return b.confidence - a.confidence;
208 });
209 }
210
211 public dispose() {
212 this._model?.dispose();
213 }
214}
215